From c21c3b0befeb46a51b6bf3758ffa30813bea0ff0 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 9 Mar 2024 14:19:22 +0100 Subject: Adding upstream version 1.44.3. Signed-off-by: Daniel Baumann --- ml/dlib/tools/archive/train_face_5point_model.cpp | 159 ++ .../convert_dlib_nets_to_caffe/CMakeLists.txt | 25 + ml/dlib/tools/convert_dlib_nets_to_caffe/main.cpp | 792 ++++++++++ .../running_a_dlib_model_with_caffe_example.py | 77 + ml/dlib/tools/htmlify/CMakeLists.txt | 31 + ml/dlib/tools/htmlify/htmlify.cpp | 632 ++++++++ ml/dlib/tools/htmlify/to_xml.cpp | 1599 ++++++++++++++++++++ ml/dlib/tools/htmlify/to_xml.h | 22 + ml/dlib/tools/htmlify/to_xml_example/bigminus.gif | Bin 0 -> 91 bytes ml/dlib/tools/htmlify/to_xml_example/bigplus.gif | Bin 0 -> 99 bytes ml/dlib/tools/htmlify/to_xml_example/example.xml | 8 + ml/dlib/tools/htmlify/to_xml_example/minus.gif | Bin 0 -> 56 bytes ml/dlib/tools/htmlify/to_xml_example/output.xml | 49 + ml/dlib/tools/htmlify/to_xml_example/plus.gif | Bin 0 -> 59 bytes .../tools/htmlify/to_xml_example/stylesheet.xsl | 354 +++++ ml/dlib/tools/htmlify/to_xml_example/test.cpp | 78 + ml/dlib/tools/imglab/CMakeLists.txt | 41 + ml/dlib/tools/imglab/README.txt | 40 + .../tools/imglab/convert_imglab_paths_to_relative | 24 + ml/dlib/tools/imglab/copy_imglab_dataset | 22 + ml/dlib/tools/imglab/src/cluster.cpp | 260 ++++ ml/dlib/tools/imglab/src/cluster.h | 11 + ml/dlib/tools/imglab/src/common.cpp | 60 + ml/dlib/tools/imglab/src/common.h | 45 + ml/dlib/tools/imglab/src/convert_idl.cpp | 184 +++ ml/dlib/tools/imglab/src/convert_idl.h | 14 + ml/dlib/tools/imglab/src/convert_pascal_v1.cpp | 177 +++ ml/dlib/tools/imglab/src/convert_pascal_v1.h | 13 + ml/dlib/tools/imglab/src/convert_pascal_xml.cpp | 239 +++ ml/dlib/tools/imglab/src/convert_pascal_xml.h | 12 + ml/dlib/tools/imglab/src/flip_dataset.cpp | 249 +++ ml/dlib/tools/imglab/src/flip_dataset.h | 12 + ml/dlib/tools/imglab/src/main.cpp | 1145 ++++++++++++++ ml/dlib/tools/imglab/src/metadata_editor.cpp | 671 ++++++++ ml/dlib/tools/imglab/src/metadata_editor.h | 116 ++ ml/dlib/tools/python/CMakeLists.txt | 106 ++ ml/dlib/tools/python/src/basic.cpp | 272 ++++ ml/dlib/tools/python/src/cca.cpp | 137 ++ ml/dlib/tools/python/src/cnn_face_detector.cpp | 183 +++ ml/dlib/tools/python/src/conversion.h | 52 + ml/dlib/tools/python/src/correlation_tracker.cpp | 167 ++ ml/dlib/tools/python/src/decision_functions.cpp | 263 ++++ ml/dlib/tools/python/src/dlib.cpp | 110 ++ ml/dlib/tools/python/src/face_recognition.cpp | 245 +++ ml/dlib/tools/python/src/global_optimization.cpp | 442 ++++++ ml/dlib/tools/python/src/gui.cpp | 128 ++ ml/dlib/tools/python/src/image.cpp | 40 + .../tools/python/src/image_dataset_metadata.cpp | 279 ++++ ml/dlib/tools/python/src/indexing.h | 11 + ml/dlib/tools/python/src/matrix.cpp | 209 +++ ml/dlib/tools/python/src/numpy_returns.cpp | 158 ++ ml/dlib/tools/python/src/numpy_returns_stub.cpp | 59 + ml/dlib/tools/python/src/object_detection.cpp | 376 +++++ ml/dlib/tools/python/src/opaque_types.h | 55 + ml/dlib/tools/python/src/other.cpp | 268 ++++ ml/dlib/tools/python/src/rectangles.cpp | 268 ++++ ml/dlib/tools/python/src/sequence_segmenter.cpp | 827 ++++++++++ .../tools/python/src/serialize_object_detector.h | 49 + ml/dlib/tools/python/src/shape_predictor.cpp | 319 ++++ ml/dlib/tools/python/src/shape_predictor.h | 259 ++++ ml/dlib/tools/python/src/simple_object_detector.h | 318 ++++ .../tools/python/src/simple_object_detector_py.h | 290 ++++ ml/dlib/tools/python/src/svm_c_trainer.cpp | 311 ++++ ml/dlib/tools/python/src/svm_rank_trainer.cpp | 161 ++ ml/dlib/tools/python/src/svm_struct.cpp | 151 ++ ml/dlib/tools/python/src/testing_results.h | 50 + ml/dlib/tools/python/src/vector.cpp | 182 +++ ml/dlib/tools/python/test/.gitignore | 1 + ml/dlib/tools/python/test/test_array.py | 107 ++ .../tools/python/test/test_global_optimization.py | 69 + ml/dlib/tools/python/test/test_matrix.py | 100 ++ ml/dlib/tools/python/test/test_point.py | 48 + ml/dlib/tools/python/test/test_range.py | 97 ++ ml/dlib/tools/python/test/test_rgb_pixel.py | 26 + ml/dlib/tools/python/test/test_sparse_vector.py | 101 ++ ml/dlib/tools/python/test/test_svm_c_trainer.py | 65 + ml/dlib/tools/python/test/test_vector.py | 170 +++ ml/dlib/tools/visual_studio_natvis/README.txt | 12 + ml/dlib/tools/visual_studio_natvis/dlib.natvis | 51 + 79 files changed, 14753 insertions(+) create mode 100644 ml/dlib/tools/archive/train_face_5point_model.cpp create mode 100644 ml/dlib/tools/convert_dlib_nets_to_caffe/CMakeLists.txt create mode 100644 ml/dlib/tools/convert_dlib_nets_to_caffe/main.cpp create mode 100755 ml/dlib/tools/convert_dlib_nets_to_caffe/running_a_dlib_model_with_caffe_example.py create mode 100644 ml/dlib/tools/htmlify/CMakeLists.txt create mode 100644 ml/dlib/tools/htmlify/htmlify.cpp create mode 100644 ml/dlib/tools/htmlify/to_xml.cpp create mode 100644 ml/dlib/tools/htmlify/to_xml.h create mode 100644 ml/dlib/tools/htmlify/to_xml_example/bigminus.gif create mode 100644 ml/dlib/tools/htmlify/to_xml_example/bigplus.gif create mode 100644 ml/dlib/tools/htmlify/to_xml_example/example.xml create mode 100644 ml/dlib/tools/htmlify/to_xml_example/minus.gif create mode 100644 ml/dlib/tools/htmlify/to_xml_example/output.xml create mode 100644 ml/dlib/tools/htmlify/to_xml_example/plus.gif create mode 100644 ml/dlib/tools/htmlify/to_xml_example/stylesheet.xsl create mode 100644 ml/dlib/tools/htmlify/to_xml_example/test.cpp create mode 100644 ml/dlib/tools/imglab/CMakeLists.txt create mode 100644 ml/dlib/tools/imglab/README.txt create mode 100755 ml/dlib/tools/imglab/convert_imglab_paths_to_relative create mode 100755 ml/dlib/tools/imglab/copy_imglab_dataset create mode 100644 ml/dlib/tools/imglab/src/cluster.cpp create mode 100644 ml/dlib/tools/imglab/src/cluster.h create mode 100644 ml/dlib/tools/imglab/src/common.cpp create mode 100644 ml/dlib/tools/imglab/src/common.h create mode 100644 ml/dlib/tools/imglab/src/convert_idl.cpp create mode 100644 ml/dlib/tools/imglab/src/convert_idl.h create mode 100644 ml/dlib/tools/imglab/src/convert_pascal_v1.cpp create mode 100644 ml/dlib/tools/imglab/src/convert_pascal_v1.h create mode 100644 ml/dlib/tools/imglab/src/convert_pascal_xml.cpp create mode 100644 ml/dlib/tools/imglab/src/convert_pascal_xml.h create mode 100644 ml/dlib/tools/imglab/src/flip_dataset.cpp create mode 100644 ml/dlib/tools/imglab/src/flip_dataset.h create mode 100644 ml/dlib/tools/imglab/src/main.cpp create mode 100644 ml/dlib/tools/imglab/src/metadata_editor.cpp create mode 100644 ml/dlib/tools/imglab/src/metadata_editor.h create mode 100644 ml/dlib/tools/python/CMakeLists.txt create mode 100644 ml/dlib/tools/python/src/basic.cpp create mode 100644 ml/dlib/tools/python/src/cca.cpp create mode 100644 ml/dlib/tools/python/src/cnn_face_detector.cpp create mode 100644 ml/dlib/tools/python/src/conversion.h create mode 100644 ml/dlib/tools/python/src/correlation_tracker.cpp create mode 100644 ml/dlib/tools/python/src/decision_functions.cpp create mode 100644 ml/dlib/tools/python/src/dlib.cpp create mode 100644 ml/dlib/tools/python/src/face_recognition.cpp create mode 100644 ml/dlib/tools/python/src/global_optimization.cpp create mode 100644 ml/dlib/tools/python/src/gui.cpp create mode 100644 ml/dlib/tools/python/src/image.cpp create mode 100644 ml/dlib/tools/python/src/image_dataset_metadata.cpp create mode 100644 ml/dlib/tools/python/src/indexing.h create mode 100644 ml/dlib/tools/python/src/matrix.cpp create mode 100644 ml/dlib/tools/python/src/numpy_returns.cpp create mode 100644 ml/dlib/tools/python/src/numpy_returns_stub.cpp create mode 100644 ml/dlib/tools/python/src/object_detection.cpp create mode 100644 ml/dlib/tools/python/src/opaque_types.h create mode 100644 ml/dlib/tools/python/src/other.cpp create mode 100644 ml/dlib/tools/python/src/rectangles.cpp create mode 100644 ml/dlib/tools/python/src/sequence_segmenter.cpp create mode 100644 ml/dlib/tools/python/src/serialize_object_detector.h create mode 100644 ml/dlib/tools/python/src/shape_predictor.cpp create mode 100644 ml/dlib/tools/python/src/shape_predictor.h create mode 100644 ml/dlib/tools/python/src/simple_object_detector.h create mode 100644 ml/dlib/tools/python/src/simple_object_detector_py.h create mode 100644 ml/dlib/tools/python/src/svm_c_trainer.cpp create mode 100644 ml/dlib/tools/python/src/svm_rank_trainer.cpp create mode 100644 ml/dlib/tools/python/src/svm_struct.cpp create mode 100644 ml/dlib/tools/python/src/testing_results.h create mode 100644 ml/dlib/tools/python/src/vector.cpp create mode 100644 ml/dlib/tools/python/test/.gitignore create mode 100644 ml/dlib/tools/python/test/test_array.py create mode 100644 ml/dlib/tools/python/test/test_global_optimization.py create mode 100644 ml/dlib/tools/python/test/test_matrix.py create mode 100644 ml/dlib/tools/python/test/test_point.py create mode 100644 ml/dlib/tools/python/test/test_range.py create mode 100644 ml/dlib/tools/python/test/test_rgb_pixel.py create mode 100644 ml/dlib/tools/python/test/test_sparse_vector.py create mode 100644 ml/dlib/tools/python/test/test_svm_c_trainer.py create mode 100644 ml/dlib/tools/python/test/test_vector.py create mode 100644 ml/dlib/tools/visual_studio_natvis/README.txt create mode 100644 ml/dlib/tools/visual_studio_natvis/dlib.natvis (limited to 'ml/dlib/tools') diff --git a/ml/dlib/tools/archive/train_face_5point_model.cpp b/ml/dlib/tools/archive/train_face_5point_model.cpp new file mode 100644 index 000000000..0cd35467f --- /dev/null +++ b/ml/dlib/tools/archive/train_face_5point_model.cpp @@ -0,0 +1,159 @@ + +/* + + This is the program that created the http://dlib.net/files/shape_predictor_5_face_landmarks.dat.bz2 model file. + +*/ + + +#include +#include +#include +#include +#include +#include + +using namespace dlib; +using namespace std; + +// ---------------------------------------------------------------------------------------- + +std::vector > get_interocular_distances ( + const std::vector >& objects +); +/*! + ensures + - returns an object D such that: + - D[i][j] == the distance, in pixels, between the eyes for the face represented + by objects[i][j]. +!*/ + +// ---------------------------------------------------------------------------------------- + +template < + typename image_array_type, + typename T + > +void add_image_left_right_flips_5points ( + image_array_type& images, + std::vector >& objects +) +{ + // make sure requires clause is not broken + DLIB_ASSERT( images.size() == objects.size(), + "\t void add_image_left_right_flips()" + << "\n\t Invalid inputs were given to this function." + << "\n\t images.size(): " << images.size() + << "\n\t objects.size(): " << objects.size() + ); + + typename image_array_type::value_type temp; + std::vector rects; + + const unsigned long num = images.size(); + for (unsigned long j = 0; j < num; ++j) + { + const point_transform_affine tran = flip_image_left_right(images[j], temp); + + rects.clear(); + for (unsigned long i = 0; i < objects[j].size(); ++i) + { + rects.push_back(impl::tform_object(tran, objects[j][i])); + + DLIB_CASSERT(rects.back().num_parts() == 5); + swap(rects.back().part(0), rects.back().part(2)); + swap(rects.back().part(1), rects.back().part(3)); + } + + images.push_back(temp); + objects.push_back(rects); + } +} + +// ---------------------------------------------------------------------------------------- + +int main(int argc, char** argv) +{ + try + { + if (argc != 2) + { + cout << "give the path to the training data folder" << endl; + return 0; + } + const std::string faces_directory = argv[1]; + dlib::array > images_train, images_test; + std::vector > faces_train, faces_test; + + std::vector parts_list; + load_image_dataset(images_train, faces_train, faces_directory+"/train_cleaned.xml", parts_list); + load_image_dataset(images_test, faces_test, faces_directory+"/test_cleaned.xml"); + + add_image_left_right_flips_5points(images_train, faces_train); + add_image_left_right_flips_5points(images_test, faces_test); + add_image_rotations(linspace(-20,20,3)*pi/180.0,images_train, faces_train); + + cout << "num training images: "<< images_train.size() << endl; + + for (auto& part : parts_list) + cout << part << endl; + + shape_predictor_trainer trainer; + trainer.set_oversampling_amount(40); + trainer.set_num_test_splits(150); + trainer.set_feature_pool_size(800); + trainer.set_num_threads(4); + trainer.set_cascade_depth(15); + trainer.be_verbose(); + + // Now finally generate the shape model + shape_predictor sp = trainer.train(images_train, faces_train); + + serialize("shape_predictor_5_face_landmarks.dat") << sp; + + cout << "mean training error: "<< + test_shape_predictor(sp, images_train, faces_train, get_interocular_distances(faces_train)) << endl; + + cout << "mean testing error: "<< + test_shape_predictor(sp, images_test, faces_test, get_interocular_distances(faces_test)) << endl; + + } + catch (exception& e) + { + cout << "\nexception thrown!" << endl; + cout << e.what() << endl; + } +} + +// ---------------------------------------------------------------------------------------- + +double interocular_distance ( + const full_object_detection& det +) +{ + dlib::vector l, r; + // left eye + l = (det.part(0) + det.part(1))/2; + // right eye + r = (det.part(2) + det.part(3))/2; + + return length(l-r); +} + +std::vector > get_interocular_distances ( + const std::vector >& objects +) +{ + std::vector > temp(objects.size()); + for (unsigned long i = 0; i < objects.size(); ++i) + { + for (unsigned long j = 0; j < objects[i].size(); ++j) + { + temp[i].push_back(interocular_distance(objects[i][j])); + } + } + return temp; +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/convert_dlib_nets_to_caffe/CMakeLists.txt b/ml/dlib/tools/convert_dlib_nets_to_caffe/CMakeLists.txt new file mode 100644 index 000000000..f9518df21 --- /dev/null +++ b/ml/dlib/tools/convert_dlib_nets_to_caffe/CMakeLists.txt @@ -0,0 +1,25 @@ +# +# This is a CMake makefile. You can find the cmake utility and +# information about it at http://www.cmake.org +# + +cmake_minimum_required(VERSION 2.8.12) + +set (target_name dtoc) + +PROJECT(${target_name}) + +add_subdirectory(../../dlib dlib_build) + +add_executable(${target_name} + main.cpp + ) + +target_link_libraries(${target_name} dlib::dlib ) + + +INSTALL(TARGETS ${target_name} + RUNTIME DESTINATION bin + ) + + diff --git a/ml/dlib/tools/convert_dlib_nets_to_caffe/main.cpp b/ml/dlib/tools/convert_dlib_nets_to_caffe/main.cpp new file mode 100644 index 000000000..f5cc19748 --- /dev/null +++ b/ml/dlib/tools/convert_dlib_nets_to_caffe/main.cpp @@ -0,0 +1,792 @@ + +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace dlib; + + +// ---------------------------------------------------------------------------------------- + +// Only these computational layers have parameters +const std::set comp_tags_with_params = {"fc", "fc_no_bias", "con", "affine_con", "affine_fc", "affine", "prelu"}; + +struct layer +{ + string type; // comp, loss, or input + int idx; + + matrix output_tensor_shape; // (N,K,NR,NC) + + string detail_name; // The name of the tag inside the layer tag. e.g. fc, con, max_pool, input_rgb_image. + std::map attributes; + matrix params; + long tag_id = -1; // If this isn't -1 then it means this layer was tagged, e.g. wrapped with tag2<> giving tag_id==2 + long skip_id = -1; // If this isn't -1 then it means this layer draws its inputs from + // the most recent layer with tag_id==skip_id rather than its immediate predecessor. + + double attribute (const string& key) const + { + auto i = attributes.find(key); + if (i != attributes.end()) + return i->second; + else + throw dlib::error("Layer doesn't have the requested attribute '" + key + "'."); + } + + string caffe_layer_name() const + { + if (type == "input") + return "data"; + else + return detail_name+to_string(idx); + } +}; + +// ---------------------------------------------------------------------------------------- + +std::vector parse_dlib_xml( + const matrix& input_tensor_shape, + const string& xml_filename +); + +// ---------------------------------------------------------------------------------------- + +template +const layer& find_layer ( + iterator i, + long tag_id +) +/*! + requires + - i is a reverse iterator pointing to a layer in the list of layers produced by parse_dlib_xml(). + - i is not an input layer. + ensures + - if (tag_id == -1) then + - returns the previous layer (i.e. closer to the input) to layer i. + - else + - returns the previous layer (i.e. closer to the input) to layer i with the + given tag_id. +!*/ +{ + if (tag_id == -1) + { + return *(i-1); + } + else + { + while(true) + { + i--; + // if we hit the end of the network before we found what we were looking for + if (i->tag_id == tag_id) + return *i; + if (i->type == "input") + throw dlib::error("Network definition is bad, a layer wanted to skip back to a non-existing layer."); + } + } +} + +template +const layer& find_input_layer (iterator i) { return find_layer(i, i->skip_id); } + +template +string find_layer_caffe_name ( + iterator i, + long tag_id +) +{ + return find_layer(i,tag_id).caffe_layer_name(); +} + +template +string find_input_layer_caffe_name (iterator i) { return find_input_layer(i).caffe_layer_name(); } + +// ---------------------------------------------------------------------------------------- + +template +void compute_caffe_padding_size_for_pooling_layer( + const iterator& i, + long& pad_x, + long& pad_y +) +/*! + requires + - i is a reverse iterator pointing to a layer in the list of layers produced by parse_dlib_xml(). + - i is not an input layer. + ensures + - Caffe is funny about how it computes the output sizes from pooling layers. + Rather than using the normal formula for output row/column sizes used by all the + other layers (and what dlib uses everywhere), + floor((bottom_size + 2*pad - kernel_size) / stride) + 1 + it instead uses: + ceil((bottom_size + 2*pad - kernel_size) / stride) + 1 + + These are the same except when the stride!=1. In that case we need to figure out + how to change the padding value so that the output size of the caffe padding + layer will match the output size of the dlib padding layer. That is what this + function does. +!*/ +{ + const long dlib_output_nr = i->output_tensor_shape(2); + const long dlib_output_nc = i->output_tensor_shape(3); + const long bottom_nr = find_input_layer(i).output_tensor_shape(2); + const long bottom_nc = find_input_layer(i).output_tensor_shape(3); + const long padding_x = (long)i->attribute("padding_x"); + const long padding_y = (long)i->attribute("padding_y"); + const long stride_x = (long)i->attribute("stride_x"); + const long stride_y = (long)i->attribute("stride_y"); + long kernel_w = i->attribute("nc"); + long kernel_h = i->attribute("nr"); + + if (kernel_w == 0) + kernel_w = bottom_nc; + if (kernel_h == 0) + kernel_h = bottom_nr; + + + // The correct padding for caffe could be anything in the range [0,padding_x]. So + // check what gives the correct output size and use that. + for (pad_x = 0; pad_x <= padding_x; ++pad_x) + { + long caffe_out_size = ceil((bottom_nc + 2.0*pad_x - kernel_w)/(double)stride_x) + 1; + if (caffe_out_size == dlib_output_nc) + break; + } + if (pad_x == padding_x+1) + { + std::ostringstream sout; + sout << "No conversion between dlib pooling layer parameters and caffe pooling layer parameters found for layer " << to_string(i->idx) << endl; + sout << "dlib_output_nc: " << dlib_output_nc << endl; + sout << "bottom_nc: " << bottom_nc << endl; + sout << "padding_x: " << padding_x << endl; + sout << "stride_x: " << stride_x << endl; + sout << "kernel_w: " << kernel_w << endl; + sout << "pad_x: " << pad_x << endl; + throw dlib::error(sout.str()); + } + + for (pad_y = 0; pad_y <= padding_y; ++pad_y) + { + long caffe_out_size = ceil((bottom_nr + 2.0*pad_y - kernel_h)/(double)stride_y) + 1; + if (caffe_out_size == dlib_output_nr) + break; + } + if (pad_y == padding_y+1) + { + std::ostringstream sout; + sout << "No conversion between dlib pooling layer parameters and caffe pooling layer parameters found for layer " << to_string(i->idx) << endl; + sout << "dlib_output_nr: " << dlib_output_nr << endl; + sout << "bottom_nr: " << bottom_nr << endl; + sout << "padding_y: " << padding_y << endl; + sout << "stride_y: " << stride_y << endl; + sout << "kernel_h: " << kernel_h << endl; + sout << "pad_y: " << pad_y << endl; + throw dlib::error(sout.str()); + } +} + +// ---------------------------------------------------------------------------------------- + +void convert_dlib_xml_to_caffe_python_code( + const string& xml_filename, + const long N, + const long K, + const long NR, + const long NC +) +{ + const string out_filename = left_substr(xml_filename,".") + "_dlib_to_caffe_model.py"; + const string out_weights_filename = left_substr(xml_filename,".") + "_dlib_to_caffe_model.weights"; + cout << "Writing python part of model to " << out_filename << endl; + cout << "Writing weights part of model to " << out_weights_filename << endl; + ofstream fout(out_filename); + fout.precision(9); + const auto layers = parse_dlib_xml({N,K,NR,NC}, xml_filename); + + + fout << "#\n"; + fout << "# !!! This file was automatically generated by dlib's tools/convert_dlib_nets_to_caffe utility. !!!\n"; + fout << "# !!! It contains all the information from a dlib DNN network and lets you save it as a cafe model. !!!\n"; + fout << "#\n"; + fout << "import caffe " << endl; + fout << "from caffe import layers as L, params as P" << endl; + fout << "import numpy as np" << endl; + + // dlib nets don't commit to a batch size, so just use 1 as the default + fout << "\n# Input tensor dimensions" << endl; + fout << "input_batch_size = " << N << ";" << endl; + if (layers.back().detail_name == "input_rgb_image") + { + fout << "input_num_channels = 3;" << endl; + fout << "input_num_rows = "<type == "loss" || i->type == "input") + continue; + + + if (i->detail_name == "con") + { + fout << " n." << i->caffe_layer_name() << " = L.Convolution(n." << find_input_layer_caffe_name(i); + fout << ", num_output=" << i->attribute("num_filters"); + fout << ", kernel_w=" << i->attribute("nc"); + fout << ", kernel_h=" << i->attribute("nr"); + fout << ", stride_w=" << i->attribute("stride_x"); + fout << ", stride_h=" << i->attribute("stride_y"); + fout << ", pad_w=" << i->attribute("padding_x"); + fout << ", pad_h=" << i->attribute("padding_y"); + fout << ");\n"; + } + else if (i->detail_name == "relu") + { + fout << " n." << i->caffe_layer_name() << " = L.ReLU(n." << find_input_layer_caffe_name(i); + fout << ");\n"; + } + else if (i->detail_name == "sig") + { + fout << " n." << i->caffe_layer_name() << " = L.Sigmoid(n." << find_input_layer_caffe_name(i); + fout << ");\n"; + } + else if (i->detail_name == "prelu") + { + fout << " n." << i->caffe_layer_name() << " = L.PReLU(n." << find_input_layer_caffe_name(i); + fout << ", channel_shared=True"; + fout << ");\n"; + } + else if (i->detail_name == "max_pool") + { + fout << " n." << i->caffe_layer_name() << " = L.Pooling(n." << find_input_layer_caffe_name(i); + fout << ", pool=P.Pooling.MAX"; + if (i->attribute("nc")==0) + { + fout << ", global_pooling=True"; + } + else + { + fout << ", kernel_w=" << i->attribute("nc"); + fout << ", kernel_h=" << i->attribute("nr"); + } + + fout << ", stride_w=" << i->attribute("stride_x"); + fout << ", stride_h=" << i->attribute("stride_y"); + long pad_x, pad_y; + compute_caffe_padding_size_for_pooling_layer(i, pad_x, pad_y); + fout << ", pad_w=" << pad_x; + fout << ", pad_h=" << pad_y; + fout << ");\n"; + } + else if (i->detail_name == "avg_pool") + { + fout << " n." << i->caffe_layer_name() << " = L.Pooling(n." << find_input_layer_caffe_name(i); + fout << ", pool=P.Pooling.AVE"; + if (i->attribute("nc")==0) + { + fout << ", global_pooling=True"; + } + else + { + fout << ", kernel_w=" << i->attribute("nc"); + fout << ", kernel_h=" << i->attribute("nr"); + } + if (i->attribute("padding_x") != 0 || i->attribute("padding_y") != 0) + { + throw dlib::error("dlib and caffe implement pooling with non-zero padding differently, so you can't convert a " + "network with such pooling layers."); + } + + fout << ", stride_w=" << i->attribute("stride_x"); + fout << ", stride_h=" << i->attribute("stride_y"); + long pad_x, pad_y; + compute_caffe_padding_size_for_pooling_layer(i, pad_x, pad_y); + fout << ", pad_w=" << pad_x; + fout << ", pad_h=" << pad_y; + fout << ");\n"; + } + else if (i->detail_name == "fc") + { + fout << " n." << i->caffe_layer_name() << " = L.InnerProduct(n." << find_input_layer_caffe_name(i); + fout << ", num_output=" << i->attribute("num_outputs"); + fout << ", bias_term=True"; + fout << ");\n"; + } + else if (i->detail_name == "fc_no_bias") + { + fout << " n." << i->caffe_layer_name() << " = L.InnerProduct(n." << find_input_layer_caffe_name(i); + fout << ", num_output=" << i->attribute("num_outputs"); + fout << ", bias_term=False"; + fout << ");\n"; + } + else if (i->detail_name == "bn_con" || i->detail_name == "bn_fc") + { + throw dlib::error("Conversion from dlib's batch norm layers to caffe's isn't supported. Instead, " + "you should put your dlib network into 'test mode' by switching batch norm layers to affine layers. " + "Then you can convert that 'test mode' network to caffe."); + } + else if (i->detail_name == "affine_con") + { + fout << " n." << i->caffe_layer_name() << " = L.Scale(n." << find_input_layer_caffe_name(i); + fout << ", bias_term=True"; + fout << ");\n"; + } + else if (i->detail_name == "affine_fc") + { + fout << " n." << i->caffe_layer_name() << " = L.Scale(n." << find_input_layer_caffe_name(i); + fout << ", bias_term=True"; + fout << ");\n"; + } + else if (i->detail_name == "add_prev") + { + auto in_shape1 = find_input_layer(i).output_tensor_shape; + auto in_shape2 = find_layer(i,i->attribute("tag")).output_tensor_shape; + if (in_shape1 != in_shape2) + { + // if only the number of channels differs then we will use a dummy layer to + // pad with zeros. But otherwise we will throw an error. + if (in_shape1(0) == in_shape2(0) && + in_shape1(2) == in_shape2(2) && + in_shape1(3) == in_shape2(3)) + { + fout << " n." << i->caffe_layer_name() << "_zeropad = L.DummyData(num=" << in_shape1(0); + fout << ", channels="<attribute("tag")); + if (in_shape1(1) > in_shape2(1)) + swap(smaller_layer, bigger_layer); + + fout << " n." << i->caffe_layer_name() << "_concat = L.Concat(n." << smaller_layer; + fout << ", n." << i->caffe_layer_name() << "_zeropad"; + fout << ");\n"; + + fout << " n." << i->caffe_layer_name() << " = L.Eltwise(n." << i->caffe_layer_name() << "_concat"; + fout << ", n." << bigger_layer; + fout << ", operation=P.Eltwise.SUM"; + fout << ");\n"; + } + else + { + std::ostringstream sout; + sout << "The dlib network contained an add_prev layer (layer idx " << i->idx << ") that adds two previous "; + sout << "layers with different output tensor dimensions. Caffe's equivalent layer, Eltwise, doesn't support "; + sout << "adding layers together with different dimensions. In the special case where the only difference is "; + sout << "in the number of channels, this converter program will add a dummy layer that outputs a tensor full of zeros "; + sout << "and concat it appropriately so this will work. However, this network you are converting has tensor dimensions "; + sout << "different in values other than the number of channels. In particular, here are the two tensor shapes (batch size, channels, rows, cols): "; + std::ostringstream sout2; + sout2 << wrap_string(sout.str()) << endl; + sout2 << trans(in_shape1); + sout2 << trans(in_shape2); + throw dlib::error(sout2.str()); + } + } + else + { + fout << " n." << i->caffe_layer_name() << " = L.Eltwise(n." << find_input_layer_caffe_name(i); + fout << ", n." << find_layer_caffe_name(i, i->attribute("tag")); + fout << ", operation=P.Eltwise.SUM"; + fout << ");\n"; + } + } + else + { + throw dlib::error("No known transformation from dlib's " + i->detail_name + " layer to caffe."); + } + } + fout << " return n.to_proto();\n\n" << endl; + + + // ----------------------------------------------------------------------------------- + // The next block of code outputs python code that populates all the filter weights. + // ----------------------------------------------------------------------------------- + + ofstream fweights(out_weights_filename, ios::binary); + fout << "def set_network_weights(net):\n"; + fout << " # populate network parameters\n"; + fout << " f = open('"<type == "loss" || i->type == "input") + continue; + + + if (i->detail_name == "con") + { + const long num_filters = i->attribute("num_filters"); + matrix weights = trans(rowm(i->params,range(0,i->params.size()-num_filters-1))); + matrix biases = trans(rowm(i->params,range(i->params.size()-num_filters, i->params.size()-1))); + fweights.write((char*)&weights(0,0), weights.size()*sizeof(float)); + fweights.write((char*)&biases(0,0), biases.size()*sizeof(float)); + + // main filter weights + fout << " p = np.fromfile(f, dtype='float32', count="<caffe_layer_name()<<"'][0].data.shape;\n"; + fout << " net.params['"<caffe_layer_name()<<"'][0].data[:] = p;\n"; + + // biases + fout << " p = np.fromfile(f, dtype='float32', count="<caffe_layer_name()<<"'][1].data.shape;\n"; + fout << " net.params['"<caffe_layer_name()<<"'][1].data[:] = p;\n"; + } + else if (i->detail_name == "fc") + { + matrix weights = trans(rowm(i->params, range(0,i->params.nr()-2))); + matrix biases = rowm(i->params, i->params.nr()-1); + fweights.write((char*)&weights(0,0), weights.size()*sizeof(float)); + fweights.write((char*)&biases(0,0), biases.size()*sizeof(float)); + + // main filter weights + fout << " p = np.fromfile(f, dtype='float32', count="<caffe_layer_name()<<"'][0].data.shape;\n"; + fout << " net.params['"<caffe_layer_name()<<"'][0].data[:] = p;\n"; + + // biases + fout << " p = np.fromfile(f, dtype='float32', count="<caffe_layer_name()<<"'][1].data.shape;\n"; + fout << " net.params['"<caffe_layer_name()<<"'][1].data[:] = p;\n"; + } + else if (i->detail_name == "fc_no_bias") + { + matrix weights = trans(i->params); + fweights.write((char*)&weights(0,0), weights.size()*sizeof(float)); + + // main filter weights + fout << " p = np.fromfile(f, dtype='float32', count="<caffe_layer_name()<<"'][0].data.shape;\n"; + fout << " net.params['"<caffe_layer_name()<<"'][0].data[:] = p;\n"; + } + else if (i->detail_name == "affine_con" || i->detail_name == "affine_fc") + { + const long dims = i->params.size()/2; + matrix gamma = trans(rowm(i->params,range(0,dims-1))); + matrix beta = trans(rowm(i->params,range(dims, 2*dims-1))); + fweights.write((char*)&gamma(0,0), gamma.size()*sizeof(float)); + fweights.write((char*)&beta(0,0), beta.size()*sizeof(float)); + + // set gamma weights + fout << " p = np.fromfile(f, dtype='float32', count="<caffe_layer_name()<<"'][0].data.shape;\n"; + fout << " net.params['"<caffe_layer_name()<<"'][0].data[:] = p;\n"; + + // set beta weights + fout << " p = np.fromfile(f, dtype='float32', count="<caffe_layer_name()<<"'][1].data.shape;\n"; + fout << " net.params['"<caffe_layer_name()<<"'][1].data[:] = p;\n"; + } + else if (i->detail_name == "prelu") + { + const double param = i->params(0); + + // main filter weights + fout << " tmp = net.params['"<caffe_layer_name()<<"'][0].data.view();\n"; + fout << " tmp.shape = 1;\n"; + fout << " tmp[0] = "< layers; + bool seen_first_tag = false; + + layer next_layer; + std::stack current_tag; + long tag_id = -1; + + + virtual void start_document ( + ) + { + layers.clear(); + seen_first_tag = false; + tag_id = -1; + } + + virtual void end_document ( + ) { } + + virtual void start_element ( + const unsigned long /*line_number*/, + const std::string& name, + const dlib::attribute_list& atts + ) + { + if (!seen_first_tag) + { + if (name != "net") + throw dlib::error("The top level XML tag must be a 'net' tag."); + seen_first_tag = true; + } + + if (name == "layer") + { + next_layer = layer(); + if (atts["type"] == "skip") + { + // Don't make a new layer, just apply the tag id to the previous layer + if (layers.size() == 0) + throw dlib::error("A skip layer was found as the first layer, but the first layer should be an input layer."); + layers.back().skip_id = sa = atts["id"]; + + // We intentionally leave next_layer empty so the end_element() callback + // don't add it as another layer when called. + } + else if (atts["type"] == "tag") + { + // Don't make a new layer, just remember the tag id so we can apply it on + // the next layer. + tag_id = sa = atts["id"]; + + // We intentionally leave next_layer empty so the end_element() callback + // don't add it as another layer when called. + } + else + { + next_layer.idx = sa = atts["idx"]; + next_layer.type = atts["type"]; + if (tag_id != -1) + { + next_layer.tag_id = tag_id; + tag_id = -1; + } + } + } + else if (current_tag.size() != 0 && current_tag.top() == "layer") + { + next_layer.detail_name = name; + // copy all the XML tag's attributes into the layer struct + atts.reset(); + while (atts.move_next()) + next_layer.attributes[atts.element().key()] = sa = atts.element().value(); + } + + current_tag.push(name); + } + + virtual void end_element ( + const unsigned long /*line_number*/, + const std::string& name + ) + { + current_tag.pop(); + if (name == "layer" && next_layer.type.size() != 0) + layers.push_back(next_layer); + } + + virtual void characters ( + const std::string& data + ) + { + if (current_tag.size() == 0) + return; + + if (comp_tags_with_params.count(current_tag.top()) != 0) + { + istringstream sin(data); + sin >> next_layer.params; + } + + } + + virtual void processing_instruction ( + const unsigned long /*line_number*/, + const std::string& /*target*/, + const std::string& /*data*/ + ) + { + } +}; + +// ---------------------------------------------------------------------------------------- + +void compute_output_tensor_shapes(const matrix& input_tensor_shape, std::vector& layers) +{ + DLIB_CASSERT(layers.back().type == "input"); + layers.back().output_tensor_shape = input_tensor_shape; + for (auto i = ++layers.rbegin(); i != layers.rend(); ++i) + { + const auto input_shape = find_input_layer(i).output_tensor_shape; + if (i->type == "comp") + { + if (i->detail_name == "fc" || i->detail_name == "fc_no_bias") + { + long num_outputs = i->attribute("num_outputs"); + i->output_tensor_shape = {input_shape(0), num_outputs, 1, 1}; + } + else if (i->detail_name == "con") + { + long num_filters = i->attribute("num_filters"); + long filter_nc = i->attribute("nc"); + long filter_nr = i->attribute("nr"); + long stride_x = i->attribute("stride_x"); + long stride_y = i->attribute("stride_y"); + long padding_x = i->attribute("padding_x"); + long padding_y = i->attribute("padding_y"); + long nr = 1+(input_shape(2) + 2*padding_y - filter_nr)/stride_y; + long nc = 1+(input_shape(3) + 2*padding_x - filter_nc)/stride_x; + i->output_tensor_shape = {input_shape(0), num_filters, nr, nc}; + } + else if (i->detail_name == "max_pool" || i->detail_name == "avg_pool") + { + long filter_nc = i->attribute("nc"); + long filter_nr = i->attribute("nr"); + long stride_x = i->attribute("stride_x"); + long stride_y = i->attribute("stride_y"); + long padding_x = i->attribute("padding_x"); + long padding_y = i->attribute("padding_y"); + if (filter_nc != 0) + { + long nr = 1+(input_shape(2) + 2*padding_y - filter_nr)/stride_y; + long nc = 1+(input_shape(3) + 2*padding_x - filter_nc)/stride_x; + i->output_tensor_shape = {input_shape(0), input_shape(1), nr, nc}; + } + else // if we are filtering the whole input down to one thing + { + i->output_tensor_shape = {input_shape(0), input_shape(1), 1, 1}; + } + } + else if (i->detail_name == "add_prev") + { + auto aux_shape = find_layer(i, i->attribute("tag")).output_tensor_shape; + for (long j = 0; j < input_shape.size(); ++j) + i->output_tensor_shape(j) = std::max(input_shape(j), aux_shape(j)); + } + else + { + i->output_tensor_shape = input_shape; + } + } + else + { + i->output_tensor_shape = input_shape; + } + + } +} + +// ---------------------------------------------------------------------------------------- + +std::vector parse_dlib_xml( + const matrix& input_tensor_shape, + const string& xml_filename +) +{ + doc_handler dh; + parse_xml(xml_filename, dh); + if (dh.layers.size() == 0) + throw dlib::error("No layers found in XML file!"); + + if (dh.layers.back().type != "input") + throw dlib::error("The network in the XML file is missing an input layer!"); + + compute_output_tensor_shapes(input_tensor_shape, dh.layers); + + return dh.layers; +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/convert_dlib_nets_to_caffe/running_a_dlib_model_with_caffe_example.py b/ml/dlib/tools/convert_dlib_nets_to_caffe/running_a_dlib_model_with_caffe_example.py new file mode 100755 index 000000000..c03a7bf5c --- /dev/null +++ b/ml/dlib/tools/convert_dlib_nets_to_caffe/running_a_dlib_model_with_caffe_example.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python + +# This script takes the dlib lenet model trained by the +# examples/dnn_introduction_ex.cpp example program and runs it using caffe. + +import caffe +import numpy as np + +# Before you run this program, you need to run dnn_introduction_ex.cpp to get a +# dlib lenet model. Then you need to convert that model into a "dlib to caffe +# model" python script. You can do this using the command line program +# included with dlib: tools/convert_dlib_nets_to_caffe. That program will +# output a lenet_dlib_to_caffe_model.py file. You run that program like this: +# ./dtoc lenet.xml 1 1 28 28 +# and it will create the lenet_dlib_to_caffe_model.py file, which we import +# with the next line: +import lenet_dlib_to_caffe_model as dlib_model + +# lenet_dlib_to_caffe_model defines a function, save_as_caffe_model() that does +# the work of converting dlib's DNN model to a caffe model and saves it to disk +# in two files. These files are all you need to run the model with caffe. +dlib_model.save_as_caffe_model('dlib_model_def.prototxt', 'dlib_model.proto') + +# Now that we created the caffe model files, we can load them into a caffe Net object. +net = caffe.Net('dlib_model_def.prototxt', 'dlib_model.proto', caffe.TEST); + + +# Now lets do a test, we will run one of the MNIST images through the network. + +# An MNIST image of a 7, it is the very first testing image in MNIST (i.e. wrt dnn_introduction_ex.cpp, it is testing_images[0]) +data = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0,84,185,159,151,60,36, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0,222,254,254,254,254,241,198,198,198,198,198,198,198,198,170,52, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0,67,114,72,114,163,227,254,225,254,254,254,250,229,254,254,140, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,17,66,14,67,67,67,59,21,236,254,106, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,83,253,209,18, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,22,233,255,83, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,129,254,238,44, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,59,249,254,62, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,133,254,187,5, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,9,205,248,58, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,126,254,182, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,75,251,240,57, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,19,221,254,166, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,3,203,254,219,35, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,38,254,254,77, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,31,224,254,115,1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,133,254,254,52, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,61,242,254,254,52, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,121,254,254,219,40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,121,254,207,18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype='float32'); +data.shape = (dlib_model.input_batch_size, dlib_model.input_num_channels, dlib_model.input_num_rows, dlib_model.input_num_cols); + +# labels isn't logically needed but there doesn't seem to be a way to use +# caffe's Net interface without providing a superfluous input array. So we do +# that here. +labels = np.ones((dlib_model.input_batch_size), dtype='float32') +# Give the image to caffe +net.set_input_arrays(data/256, labels) +# Run the data through the network and get the results. +out = net.forward() + +# Print outputs, looping over minibatch. You should see that the network +# correctly classifies the image (it's the number 7). +for i in xrange(dlib_model.input_batch_size): + print i, 'net final layer = ', out['fc1'][i] + print i, 'predicted number =', np.argmax(out['fc1'][i]) + + + diff --git a/ml/dlib/tools/htmlify/CMakeLists.txt b/ml/dlib/tools/htmlify/CMakeLists.txt new file mode 100644 index 000000000..02cae2172 --- /dev/null +++ b/ml/dlib/tools/htmlify/CMakeLists.txt @@ -0,0 +1,31 @@ +# +# This is a CMake makefile. You can find the cmake utility and +# information about it at http://www.cmake.org +# + +cmake_minimum_required(VERSION 2.8.12) + +# create a variable called target_name and set it to the string "htmlify" +set (target_name htmlify) + +project(${target_name}) + +add_subdirectory(../../dlib dlib_build) + +# add all the cpp files we want to compile to this list. This tells +# cmake that they are part of our target (which is the executable named htmlify) +add_executable(${target_name} + htmlify.cpp + to_xml.cpp + ) + +# Tell cmake to link our target executable to dlib. +target_link_libraries(${target_name} dlib::dlib ) + + + +install(TARGETS ${target_name} + RUNTIME DESTINATION bin + ) + + diff --git a/ml/dlib/tools/htmlify/htmlify.cpp b/ml/dlib/tools/htmlify/htmlify.cpp new file mode 100644 index 000000000..e822a5aaa --- /dev/null +++ b/ml/dlib/tools/htmlify/htmlify.cpp @@ -0,0 +1,632 @@ +#include +#include +#include + + +#include "dlib/cpp_pretty_printer.h" +#include "dlib/cmd_line_parser.h" +#include "dlib/queue.h" +#include "dlib/misc_api.h" +#include "dlib/dir_nav.h" +#include "to_xml.h" + + +const char* VERSION = "3.5"; + +using namespace std; +using namespace dlib; + +typedef cpp_pretty_printer::kernel_1a cprinter; +typedef cpp_pretty_printer::kernel_2a bprinter; +typedef dlib::map::kernel_1a map_string_to_string; +typedef dlib::set::kernel_1a set_of_string; +typedef queue::kernel_1a queue_of_files; +typedef queue::kernel_1a queue_of_dirs; + +void print_manual ( +); +/*! + ensures + - prints detailed information about this program. +!*/ + +void htmlify ( + const map_string_to_string& file_map, + bool colored, + bool number_lines, + const std::string& title +); +/*! + ensures + - for all valid out_file: + - the file out_file is the html transformed version of + file_map[out_file] + - if (number_lines) then + - the html version will have numbered lines + - if (colored) then + - the html version will have colors + - title will be the first part of the HTML title in the output file +!*/ + +void htmlify ( + istream& in, + ostream& out, + const std::string& title, + bool colored, + bool number_lines +); +/*! + ensures + - transforms in into html with the given title and writes it to out. + - if (number_lines) then + - the html version of in will have numbered lines + - if (colored) then + - the html version of in will have colors +!*/ + +void add_files ( + const directory& dir, + const std::string& out_dir, + map_string_to_string& file_map, + bool flatten, + bool cat, + const set_of_string& filter, + unsigned long search_depth, + unsigned long cur_depth = 0 +); +/*! + ensures + - searches the directory dir for files matching the filter and adds them + to the file_map. only looks search_depth deep. +!*/ + +int main(int argc, char** argv) +{ + if (argc == 1) + { + cout << "\nTry the -h option for more information.\n"; + return 0; + } + + string file; + try + { + command_line_parser parser; + parser.add_option("b","Pretty print in black and white. The default is to pretty print in color."); + parser.add_option("n","Number lines."); + parser.add_option("h","Displays this information."); + parser.add_option("index","Create an index."); + parser.add_option("v","Display version."); + parser.add_option("man","Display the manual."); + parser.add_option("f","Specifies a list of file extensions to process when using the -i option. The list elements should be separated by spaces. The default is \"cpp h c\".",1); + parser.add_option("i","Specifies an input directory.",1); + parser.add_option("cat","Puts all the output into a single html file with the given name.",1); + parser.add_option("depth","Specifies how many directories deep to search when using the i option. The default value is 30.",1); + parser.add_option("o","This option causes all the output files to be created inside the given directory. If this option is not given then all output goes to the current working directory.",1); + parser.add_option("flatten","When this option is given it prevents the input directory structure from being replicated."); + parser.add_option("title","This option specifies a string which is prepended onto the title of the generated HTML",1); + parser.add_option("to-xml","Instead of generating HTML output, create a single output file called output.xml that contains " + "a simple XML database which lists all documented classes and functions."); + parser.add_option("t", "When creating XML output, replace tabs in comments with spaces.", 1); + + + parser.parse(argc,argv); + + + parser.check_incompatible_options("cat","o"); + parser.check_incompatible_options("cat","flatten"); + parser.check_incompatible_options("cat","index"); + parser.check_option_arg_type("depth"); + parser.check_option_arg_range("t", 1, 100); + + parser.check_incompatible_options("to-xml", "b"); + parser.check_incompatible_options("to-xml", "n"); + parser.check_incompatible_options("to-xml", "index"); + parser.check_incompatible_options("to-xml", "cat"); + parser.check_incompatible_options("to-xml", "o"); + parser.check_incompatible_options("to-xml", "flatten"); + parser.check_incompatible_options("to-xml", "title"); + + const char* singles[] = {"b","n","h","index","v","man","f","cat","depth","o","flatten","title","to-xml", "t"}; + parser.check_one_time_options(singles); + + const char* i_sub_ops[] = {"f","depth","flatten"}; + parser.check_sub_options("i",i_sub_ops); + + const char* to_xml_sub_ops[] = {"t"}; + parser.check_sub_options("to-xml",to_xml_sub_ops); + + const command_line_parser::option_type& b_opt = parser.option("b"); + const command_line_parser::option_type& n_opt = parser.option("n"); + const command_line_parser::option_type& h_opt = parser.option("h"); + const command_line_parser::option_type& index_opt = parser.option("index"); + const command_line_parser::option_type& v_opt = parser.option("v"); + const command_line_parser::option_type& o_opt = parser.option("o"); + const command_line_parser::option_type& man_opt = parser.option("man"); + const command_line_parser::option_type& f_opt = parser.option("f"); + const command_line_parser::option_type& cat_opt = parser.option("cat"); + const command_line_parser::option_type& i_opt = parser.option("i"); + const command_line_parser::option_type& flatten_opt = parser.option("flatten"); + const command_line_parser::option_type& depth_opt = parser.option("depth"); + const command_line_parser::option_type& title_opt = parser.option("title"); + const command_line_parser::option_type& to_xml_opt = parser.option("to-xml"); + + + string filter = "cpp h c"; + + bool cat = false; + bool color = true; + bool number = false; + unsigned long search_depth = 30; + + string out_dir; // the name of the output directory if the o option is given. "" otherwise + string full_out_dir; // the full name of the output directory if the o option is given. "" otherwise + const char separator = directory::get_separator(); + + bool no_run = false; + if (v_opt) + { + cout << "Htmlify v" << VERSION + << "\nCompiled: " << __TIME__ << " " << __DATE__ + << "\nWritten by Davis King\n"; + cout << "Check for updates at http://dlib.net\n\n"; + no_run = true; + } + + if (h_opt) + { + cout << "This program pretty prints C or C++ source code to HTML.\n"; + cout << "Usage: htmlify [options] [file]...\n"; + parser.print_options(); + cout << "\n\n"; + no_run = true; + } + + if (man_opt) + { + print_manual(); + no_run = true; + } + + if (no_run) + return 0; + + if (f_opt) + { + filter = f_opt.argument(); + } + + if (cat_opt) + { + cat = true; + } + + if (depth_opt) + { + search_depth = string_cast(depth_opt.argument()); + } + + if (to_xml_opt) + { + unsigned long expand_tabs = 0; + if (parser.option("t")) + expand_tabs = string_cast(parser.option("t").argument()); + + generate_xml_markup(parser, filter, search_depth, expand_tabs); + return 0; + } + + if (o_opt) + { + // make sure this directory exists + out_dir = o_opt.argument(); + create_directory(out_dir); + directory dir(out_dir); + full_out_dir = dir.full_name(); + + // make sure the last character of out_dir is a separator + if (out_dir[out_dir.size()-1] != separator) + out_dir += separator; + if (full_out_dir[out_dir.size()-1] != separator) + full_out_dir += separator; + } + + if (b_opt) + color = false; + if (n_opt) + number = true; + + // this is a map of output file names to input file names. + map_string_to_string file_map; + + + // add all the files that are just given on the command line to the + // file_map. + for (unsigned long i = 0; i < parser.number_of_arguments(); ++i) + { + string in_file, out_file; + in_file = parser[i]; + string::size_type pos = in_file.find_last_of(separator); + if (pos != string::npos) + { + out_file = out_dir + in_file.substr(pos+1) + ".html"; + } + else + { + out_file = out_dir + in_file + ".html"; + } + + if (file_map.is_in_domain(out_file)) + { + if (file_map[out_file] != in_file) + { + // there is a file name colision in the output folder. definitly a bad thing + cout << "Error: Two of the input files have the same name and would overwrite each\n"; + cout << "other. They are " << in_file << " and " << file_map[out_file] << ".\n" << endl; + return 1; + } + else + { + continue; + } + } + + file_map.add(out_file,in_file); + } + + // pick out the filter strings + set_of_string sfilter; + istringstream sin(filter); + string temp; + sin >> temp; + while (sin) + { + if (sfilter.is_member(temp) == false) + sfilter.add(temp); + sin >> temp; + } + + // now get all the files given by the i options + for (unsigned long i = 0; i < i_opt.count(); ++i) + { + directory dir(i_opt.argument(0,i)); + add_files(dir, out_dir, file_map, flatten_opt, cat, sfilter, search_depth); + } + + if (cat) + { + file_map.reset(); + ofstream fout(cat_opt.argument().c_str()); + if (!fout) + { + throw error("Error: unable to open file " + cat_opt.argument()); + } + fout << "" << cat_opt.argument() << ""; + + const char separator = directory::get_separator(); + string file; + while (file_map.move_next()) + { + ifstream fin(file_map.element().value().c_str()); + if (!fin) + { + throw error("Error: unable to open file " + file_map.element().value()); + } + + string::size_type pos = file_map.element().value().find_last_of(separator); + if (pos != string::npos) + file = file_map.element().value().substr(pos+1); + else + file = file_map.element().value(); + + std::string title; + if (title_opt) + title = title_opt.argument(); + htmlify(fin, fout, title + file, color, number); + } + + } + else + { + std::string title; + if (title_opt) + title = title_opt.argument(); + htmlify(file_map,color,number,title); + } + + + + if (index_opt) + { + ofstream index((out_dir + "index.html").c_str()); + ofstream menu((out_dir + "menu.html").c_str()); + + if (!index) + { + cout << "Error: unable to create " << out_dir << "index.html\n\n"; + return 0; + } + + if (!menu) + { + cout << "Error: unable to create " << out_dir << "menu.html\n\n"; + return 0; + } + + + index << ""; + index << ""; + index << ""; + + menu << "
"; + + file_map.reset(); + while (file_map.move_next()) + { + if (o_opt) + { + file = file_map.element().key(); + if (file.find(full_out_dir) != string::npos) + file = file.substr(full_out_dir.size()); + else + file = file.substr(out_dir.size()); + } + else + { + file = file_map.element().key(); + } + // strip the .html from file + file = file.substr(0,file.size()-5); + menu << "" + << file << "
"; + } + + menu << ""; + + } + + } + catch (ios_base::failure&) + { + cout << "ERROR: unable to write to " << file << endl; + cout << endl; + } + catch (exception& e) + { + cout << e.what() << endl; + cout << "\nTry the -h option for more information.\n"; + cout << endl; + } +} + +// ------------------------------------------------------------------------------------------------- + +void htmlify ( + istream& in, + ostream& out, + const std::string& title, + bool colored, + bool number_lines +) +{ + if (colored) + { + static cprinter cp; + if (number_lines) + { + cp.print_and_number(in,out,title); + } + else + { + cp.print(in,out,title); + } + } + else + { + static bprinter bp; + if (number_lines) + { + bp.print_and_number(in,out,title); + } + else + { + bp.print(in,out,title); + } + } +} + +// ------------------------------------------------------------------------------------------------- + +void htmlify ( + const map_string_to_string& file_map, + bool colored, + bool number_lines, + const std::string& title +) +{ + file_map.reset(); + const char separator = directory::get_separator(); + string file; + while (file_map.move_next()) + { + ifstream fin(file_map.element().value().c_str()); + if (!fin) + { + throw error("Error: unable to open file " + file_map.element().value() ); + } + + ofstream fout(file_map.element().key().c_str()); + + if (!fout) + { + throw error("Error: unable to open file " + file_map.element().key()); + } + + string::size_type pos = file_map.element().value().find_last_of(separator); + if (pos != string::npos) + file = file_map.element().value().substr(pos+1); + else + file = file_map.element().value(); + + htmlify(fin, fout,title + file, colored, number_lines); + } +} + +// ------------------------------------------------------------------------------------------------- + +void add_files ( + const directory& dir, + const std::string& out_dir, + map_string_to_string& file_map, + bool flatten, + bool cat, + const set_of_string& filter, + unsigned long search_depth, + unsigned long cur_depth +) +{ + const char separator = directory::get_separator(); + + queue_of_files files; + queue_of_dirs dirs; + + dir.get_files(files); + + // look though all the files in the current directory and add the + // ones that match the filter to file_map + string name, ext, in_file, out_file; + files.reset(); + while (files.move_next()) + { + name = files.element().name(); + string::size_type pos = name.find_last_of('.'); + if (pos != string::npos && filter.is_member(name.substr(pos+1))) + { + in_file = files.element().full_name(); + + if (flatten) + { + pos = in_file.find_last_of(separator); + } + else + { + // figure out how much of the file's path we need to keep + // for the output file name + pos = in_file.size(); + for (unsigned long i = 0; i <= cur_depth && pos != string::npos; ++i) + { + pos = in_file.find_last_of(separator,pos-1); + } + } + + if (pos != string::npos) + { + out_file = out_dir + in_file.substr(pos+1) + ".html"; + } + else + { + out_file = out_dir + in_file + ".html"; + } + + if (file_map.is_in_domain(out_file)) + { + if (file_map[out_file] != in_file) + { + // there is a file name colision in the output folder. definitly a bad thing + ostringstream sout; + sout << "Error: Two of the input files have the same name and would overwrite each\n"; + sout << "other. They are " << in_file << " and " << file_map[out_file] << "."; + throw error(sout.str()); + } + else + { + continue; + } + } + + file_map.add(out_file,in_file); + + } + } // while (files.move_next()) + files.clear(); + + if (search_depth > cur_depth) + { + // search all the sub directories + dir.get_dirs(dirs); + dirs.reset(); + while (dirs.move_next()) + { + if (!flatten && !cat) + { + string d = dirs.element().full_name(); + + // figure out how much of the directorie's path we need to keep. + string::size_type pos = d.size(); + for (unsigned long i = 0; i <= cur_depth && pos != string::npos; ++i) + { + pos = d.find_last_of(separator,pos-1); + } + + // make sure this directory exists in the output directory tree + d = d.substr(pos+1); + create_directory(out_dir + separator + d); + } + + add_files(dirs.element(), out_dir, file_map, flatten, cat, filter, search_depth, cur_depth+1); + } + } + +} + +// ------------------------------------------------------------------------------------------------- + +void print_manual ( +) +{ + ostringstream sout; + + const unsigned long indent = 2; + + cout << "\n"; + sout << "Htmlify v" << VERSION; + cout << wrap_string(sout.str(),indent,indent); sout.str(""); + + + sout << "This is a fairly simple program that takes source files and pretty prints them " + << "in HTML. There are two pretty printing styles, black and white or color. The " + << "black and white style is meant to look nice when printed out on paper. It looks " + << "a little funny on the screen but on paper it is pretty nice. The color version " + << "on the other hand has nonprintable HTML elements such as links and anchors."; + cout << "\n\n" << wrap_string(sout.str(),indent,indent); sout.str(""); + + + sout << "The colored style puts HTML anchors on class and function names. This means " + << "you can link directly to the part of the code that contains these names. For example, " + << "if you had a source file bar.cpp with a function called foo in it you could link " + << "directly to the function with a link address of \"bar.cpp.html#foo\". It is also " + << "possible to instruct Htmlify to place HTML anchors at arbitrary spots by using a " + << "special comment of the form /*!A anchor_name */. You can put other things in the " + << "comment but the important bit is to have it begin with /*!A then some white space " + << "then the anchor name you want then more white space and then you can add whatever " + << "you like. You would then refer to this anchor with a link address of " + << "\"file.html#anchor_name\"."; + cout << "\n\n" << wrap_string(sout.str(),indent,indent); sout.str(""); + + sout << "Htmlify also has the ability to create a simple index of all the files it is given. " + << "The --index option creates a file named index.html with a frame on the left side " + << "that contains links to all the files."; + cout << "\n\n" << wrap_string(sout.str(),indent,indent); sout.str(""); + + + sout << "Finally, Htmlify can produce annotated XML output instead of HTML. The output will " + << "contain all functions which are immediately followed by comments of the form /*! comment body !*/. " + << "Similarly, all classes or structs that immediately contain one of these comments following their " + << "opening { will also be output as annotated XML. Note also that if you wish to document a " + << "piece of code using one of these comments but don't want it to appear in the output XML then " + << "use either a comment like /* */ or /*!P !*/ to mark the code as \"private\"."; + cout << "\n\n" << wrap_string(sout.str(),indent,indent) << "\n\n"; sout.str(""); +} + +// ------------------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/htmlify/to_xml.cpp b/ml/dlib/tools/htmlify/to_xml.cpp new file mode 100644 index 000000000..7fae43380 --- /dev/null +++ b/ml/dlib/tools/htmlify/to_xml.cpp @@ -0,0 +1,1599 @@ + +#include "to_xml.h" +#include "dlib/dir_nav.h" +#include +#include +#include +#include +#include +#include "dlib/cpp_tokenizer.h" +#include "dlib/string.h" + +using namespace dlib; +using namespace std; + +// ---------------------------------------------------------------------------------------- + +typedef cpp_tokenizer::kernel_1a_c tok_type; + +// ---------------------------------------------------------------------------------------- + +class file_filter +{ +public: + + file_filter( + const string& filter + ) + { + // pick out the filter strings + istringstream sin(filter); + string temp; + sin >> temp; + while (sin) + { + endings.push_back("." + temp); + sin >> temp; + } + } + + bool operator() ( const file& f) const + { + // check if any of the endings match + for (unsigned long i = 0; i < endings.size(); ++i) + { + // if the ending is bigger than f's name then it obviously doesn't match + if (endings[i].size() > f.name().size()) + continue; + + // now check if the actual characters that make up the end of the file name + // matches what is in endings[i]. + if ( std::equal(endings[i].begin(), endings[i].end(), f.name().end()-endings[i].size())) + return true; + } + + return false; + } + + std::vector endings; +}; + +// ---------------------------------------------------------------------------------------- + +void obtain_list_of_files ( + const cmd_line_parser::check_1a_c& parser, + const std::string& filter, + const unsigned long search_depth, + std::vector >& files +) +{ + for (unsigned long i = 0; i < parser.option("i").count(); ++i) + { + const directory dir(parser.option("i").argument(0,i)); + + const std::vector& temp = get_files_in_directory_tree(dir, file_filter(filter), search_depth); + + // figure out how many characters need to be removed from the path of each file + const string parent = dir.get_parent().full_name(); + unsigned long strip = parent.size(); + if (parent.size() > 0 && parent[parent.size()-1] != '\\' && parent[parent.size()-1] != '/') + strip += 1; + + for (unsigned long i = 0; i < temp.size(); ++i) + { + files.push_back(make_pair(temp[i].full_name().substr(strip), temp[i].full_name())); + } + } + + for (unsigned long i = 0; i < parser.number_of_arguments(); ++i) + { + files.push_back(make_pair(parser[i], parser[i])); + } + + std::sort(files.begin(), files.end()); +} + +// ---------------------------------------------------------------------------------------- + +struct tok_function_record +{ + std::vector > declaration; + string scope; + string file; + string comment; +}; + +struct tok_method_record +{ + std::vector > declaration; + string comment; +}; + +struct tok_variable_record +{ + std::vector > declaration; +}; + +struct tok_typedef_record +{ + std::vector > declaration; +}; + +struct tok_class_record +{ + std::vector > declaration; + string name; + string scope; + string file; + string comment; + + std::vector public_methods; + std::vector protected_methods; + std::vector public_variables; + std::vector public_typedefs; + std::vector protected_variables; + std::vector protected_typedefs; + std::vector public_inner_classes; + std::vector protected_inner_classes; +}; + +// ---------------------------------------------------------------------------------------- + +struct function_record +{ + string name; + string scope; + string declaration; + string file; + string comment; +}; + +struct method_record +{ + string name; + string declaration; + string comment; +}; + +struct variable_record +{ + string declaration; +}; + +struct typedef_record +{ + string declaration; +}; + +struct class_record +{ + string name; + string scope; + string declaration; + string file; + string comment; + + std::vector public_methods; + std::vector public_variables; + std::vector public_typedefs; + + std::vector protected_methods; + std::vector protected_variables; + std::vector protected_typedefs; + + std::vector public_inner_classes; + std::vector protected_inner_classes; +}; + +// ---------------------------------------------------------------------------------------- + +unsigned long count_newlines ( + const string& str +) +/*! + ensures + - returns the number of '\n' characters inside str +!*/ +{ + unsigned long count = 0; + for (unsigned long i = 0; i < str.size(); ++i) + { + if (str[i] == '\n') + ++count; + } + return count; +} + +// ---------------------------------------------------------------------------------------- + +bool contains_unescaped_newline ( + const string& str +) +/*! + ensures + - returns true if str contains a '\n' character that isn't preceded by a '\' + character. +!*/ +{ + if (str.size() == 0) + return false; + + if (str[0] == '\n') + return true; + + for (unsigned long i = 1; i < str.size(); ++i) + { + if (str[i] == '\n' && str[i-1] != '\\') + return true; + } + + return false; +} + +// ---------------------------------------------------------------------------------------- + +bool is_formal_comment ( + const string& str +) +{ + if (str.size() < 6) + return false; + + if (str[0] == '/' && + str[1] == '*' && + str[2] == '!' && + str[3] != 'P' && + str[3] != 'p' && + str[str.size()-3] == '!' && + str[str.size()-2] == '*' && + str[str.size()-1] == '/' ) + return true; + + return false; +} + +// ---------------------------------------------------------------------------------------- + +string make_scope_string ( + const std::vector& namespaces, + unsigned long exclude_last_num_scopes = 0 +) +{ + string temp; + for (unsigned long i = 0; i + exclude_last_num_scopes < namespaces.size(); ++i) + { + if (namespaces[i].size() == 0) + continue; + + if (temp.size() == 0) + temp = namespaces[i]; + else + temp += "::" + namespaces[i]; + } + return temp; +} + +// ---------------------------------------------------------------------------------------- + +bool looks_like_function_declaration ( + const std::vector >& declaration +) +{ + + // Check if declaration contains IDENTIFIER ( ) somewhere in it. + bool seen_first_part = false; + bool seen_operator = false; + int local_paren_count = 0; + for (unsigned long i = 1; i < declaration.size(); ++i) + { + if (declaration[i].first == tok_type::KEYWORD && + declaration[i].second == "operator") + { + seen_operator = true; + } + + if (declaration[i].first == tok_type::OTHER && + declaration[i].second == "(" && + (declaration[i-1].first == tok_type::IDENTIFIER || seen_operator)) + { + seen_first_part = true; + } + + if (declaration[i].first == tok_type::OTHER) + { + if ( declaration[i].second == "(") + ++local_paren_count; + else if ( declaration[i].second == ")") + --local_paren_count; + } + } + + if (seen_first_part && local_paren_count == 0) + return true; + else + return false; +} + +// ---------------------------------------------------------------------------------------- + +enum scope_type +{ + public_scope, + protected_scope, + private_scope +}; + + +void process_file ( + istream& fin, + const string& file, + std::vector& functions, + std::vector& classes +) +/*! + ensures + - scans the given file for global functions and appends any found into functions. + - scans the given file for global classes and appends any found into classes. +!*/ +{ + tok_type tok; + tok.set_stream(fin); + + bool recently_seen_struct_keyword = false; + // true if we have seen the struct keyword and + // we have not seen any identifiers or { characters + + string last_struct_name; + // the name of the last struct we have seen + + bool recently_seen_class_keyword = false; + // true if we have seen the class keyword and + // we have not seen any identifiers or { characters + + string last_class_name; + // the name of the last class we have seen + + bool recently_seen_namespace_keyword = false; + // true if we have seen the namespace keyword and + // we have not seen any identifiers or { characters + + string last_namespace_name; + // the name of the last namespace we have seen + + bool recently_seen_pound_define = false; + // true if we have seen a #define and haven't seen an unescaped newline + + bool recently_seen_preprocessor = false; + // true if we have seen a preprocessor statement and haven't seen an unescaped newline + + bool recently_seen_typedef = false; + // true if we have seen a typedef keyword and haven't seen a ; + + bool recently_seen_paren_0 = false; + // true if we have seen paren_count transition to zero but haven't yet seen a ; or { or + // a new line if recently_seen_pound_define is true. + + bool recently_seen_slots = false; + // true if we have seen the identifier "slots" at a zero scope but haven't seen any + // other identifiers or the ';' or ':' characters. + + bool recently_seen_closing_bracket = false; + // true if we have seen a } and haven't yet seen an IDENTIFIER or ; + + bool recently_seen_new_scope = false; + // true if we have seen the keywords class, namespace, struct, or extern and + // we have not seen the characters {, ), or ; since then + + bool at_top_of_new_scope = false; + // true if we have seen the { that started a new scope but haven't seen anything yet but WHITE_SPACE + + std::vector namespaces; + // a stack to hold the names of the scopes we have entered. This is the classes, structs, and namespaces we enter. + namespaces.push_back(""); // this is the global namespace + + std::stack scope_access; + // If the stack isn't empty then we are inside a class or struct and the top value + // in the stack tells if we are in a public, protected, or private region. + + std::stack scopes; // a stack to hold current and old scope counts + // the top of the stack counts the number of new scopes (i.e. unmatched { } we have entered + // since we were at a scope where functions can be defined. + // We also maintain the invariant that scopes.size() == namespaces.size() + scopes.push(0); + + std::stack class_stack; + // This is a stack where class_stack.top() == the incomplete class record for the class declaration we are + // currently in. + + unsigned long paren_count = 0; + // this is the number of ( we have seen minus the number of ) we have + // seen. + + std::vector > token_accum; + // Used to accumulate tokens for function and class declarations + + std::vector > last_full_declaration; + // Once we determine that token_accum has a full declaration in it we copy it into last_full_declaration. + + int type; + string token; + + tok.get_token(type, token); + + while (type != tok_type::END_OF_FILE) + { + switch(type) + { + case tok_type::KEYWORD: // ------------------------------------------ + { + token_accum.push_back(make_pair(type,token)); + + if (token[0] == '#') + recently_seen_preprocessor = true; + + if (token == "class") + { + recently_seen_class_keyword = true; + recently_seen_new_scope = true; + } + else if (token == "struct") + { + recently_seen_struct_keyword = true; + recently_seen_new_scope = true; + } + else if (token == "namespace") + { + recently_seen_namespace_keyword = true; + recently_seen_new_scope = true; + } + else if (token == "extern") + { + recently_seen_new_scope = true; + } + else if (token == "#define") + { + recently_seen_pound_define = true; + } + else if (token == "typedef") + { + recently_seen_typedef = true; + } + else if (recently_seen_pound_define == false) + { + // eat white space + int temp_type; + string temp_token; + if (tok.peek_type() == tok_type::WHITE_SPACE) + tok.get_token(temp_type, temp_token); + + const bool next_is_colon = (tok.peek_type() == tok_type::OTHER && tok.peek_token() == ":"); + if (next_is_colon) + { + // eat the colon + tok.get_token(temp_type, temp_token); + + if (scope_access.size() > 0 && token == "public") + { + scope_access.top() = public_scope; + token_accum.clear(); + last_full_declaration.clear(); + } + else if (scope_access.size() > 0 && token == "protected") + { + scope_access.top() = protected_scope; + token_accum.clear(); + last_full_declaration.clear(); + } + else if (scope_access.size() > 0 && token == "private") + { + scope_access.top() = private_scope; + token_accum.clear(); + last_full_declaration.clear(); + } + } + } + + at_top_of_new_scope = false; + + }break; + + case tok_type::COMMENT: // ------------------------------------------ + { + if (scopes.top() == 0 && last_full_declaration.size() > 0 && is_formal_comment(token) && + paren_count == 0) + { + + // if we are inside a class or struct + if (scope_access.size() > 0) + { + // if we are looking at a comment at the top of a class + if (at_top_of_new_scope) + { + // push an entry for this class into the class_stack + tok_class_record temp; + temp.declaration = last_full_declaration; + temp.file = file; + temp.name = namespaces.back(); + temp.scope = make_scope_string(namespaces,1); + temp.comment = token; + class_stack.push(temp); + } + else if (scope_access.top() == public_scope || scope_access.top() == protected_scope) + { + // This should be a member function. + // Only do anything if the class that contains this member function is + // in the class_stack. + if (class_stack.size() > 0 && class_stack.top().name == namespaces.back() && + looks_like_function_declaration(last_full_declaration)) + { + tok_method_record temp; + + // Check if there is an initialization list inside the declaration and if there is + // then find out where the starting : is located so we can avoid including it in + // the output. + unsigned long pos = last_full_declaration.size(); + long temp_paren_count = 0; + for (unsigned long i = 0; i < last_full_declaration.size(); ++i) + { + if (last_full_declaration[i].first == tok_type::OTHER) + { + if (last_full_declaration[i].second == "(") + ++temp_paren_count; + else if (last_full_declaration[i].second == ")") + --temp_paren_count; + else if (temp_paren_count == 0 && last_full_declaration[i].second == ":") + { + // if this is a :: then ignore it + if (i > 0 && last_full_declaration[i-1].second == ":") + continue; + else if (i+1 < last_full_declaration.size() && last_full_declaration[i+1].second == ":") + continue; + else + { + pos = i; + break; + } + } + } + } + + temp.declaration.assign(last_full_declaration.begin(), last_full_declaration.begin()+pos); + temp.comment = token; + if (scope_access.top() == public_scope) + class_stack.top().public_methods.push_back(temp); + else + class_stack.top().protected_methods.push_back(temp); + } + } + } + else + { + // we should be looking at a global declaration of some kind. + if (looks_like_function_declaration(last_full_declaration)) + { + tok_function_record temp; + + // make sure we never include anything beyond the first closing ) + // if we are looking at a #defined function + unsigned long pos = last_full_declaration.size(); + if (last_full_declaration[0].second == "#define") + { + long temp_paren_count = 0; + for (unsigned long i = 0; i < last_full_declaration.size(); ++i) + { + if (last_full_declaration[i].first == tok_type::OTHER) + { + if (last_full_declaration[i].second == "(") + { + ++temp_paren_count; + } + else if (last_full_declaration[i].second == ")") + { + --temp_paren_count; + if (temp_paren_count == 0) + { + pos = i+1; + break; + } + } + } + } + } + + temp.declaration.assign(last_full_declaration.begin(), last_full_declaration.begin()+pos); + temp.file = file; + temp.scope = make_scope_string(namespaces); + temp.comment = token; + functions.push_back(temp); + } + } + + token_accum.clear(); + last_full_declaration.clear(); + } + + at_top_of_new_scope = false; + }break; + + case tok_type::IDENTIFIER: // ------------------------------------------ + { + if (recently_seen_class_keyword) + { + last_class_name = token; + last_struct_name.clear(); + last_namespace_name.clear(); + } + else if (recently_seen_struct_keyword) + { + last_struct_name = token; + last_class_name.clear(); + last_namespace_name.clear(); + } + else if (recently_seen_namespace_keyword) + { + last_namespace_name = token; + last_class_name.clear(); + last_struct_name.clear(); + } + + if (scopes.top() == 0 && token == "slots") + recently_seen_slots = true; + else + recently_seen_slots = false; + + recently_seen_class_keyword = false; + recently_seen_struct_keyword = false; + recently_seen_namespace_keyword = false; + recently_seen_closing_bracket = false; + at_top_of_new_scope = false; + + token_accum.push_back(make_pair(type,token)); + }break; + + case tok_type::OTHER: // ------------------------------------------ + { + switch(token[0]) + { + case '{': + // if we are entering a new scope + if (recently_seen_new_scope) + { + scopes.push(0); + at_top_of_new_scope = true; + + // if we are entering a class + if (last_class_name.size() > 0) + { + scope_access.push(private_scope); + namespaces.push_back(last_class_name); + } + else if (last_struct_name.size() > 0) + { + scope_access.push(public_scope); + namespaces.push_back(last_struct_name); + } + else if (last_namespace_name.size() > 0) + { + namespaces.push_back(last_namespace_name); + } + else + { + namespaces.push_back(""); + } + } + else + { + scopes.top() += 1; + } + recently_seen_new_scope = false; + recently_seen_class_keyword = false; + recently_seen_struct_keyword = false; + recently_seen_namespace_keyword = false; + recently_seen_paren_0 = false; + + // a { at function scope is an end of a potential declaration + if (scopes.top() == 0) + { + // put token_accum into last_full_declaration + token_accum.swap(last_full_declaration); + } + token_accum.clear(); + break; + + case '}': + if (scopes.top() > 0) + { + scopes.top() -= 1; + } + else if (scopes.size() > 1) + { + scopes.pop(); + + if (scope_access.size() > 0) + scope_access.pop(); + + // If the scope we are leaving is the top class on the class_stack + // then we need to either pop it into its containing class or put it + // into the classes output vector. + if (class_stack.size() > 0 && namespaces.back() == class_stack.top().name) + { + // If this class is a inner_class of another then push it into the + // public_inner_classes or protected_inner_classes field of it's containing class. + if (class_stack.size() > 1) + { + tok_class_record temp = class_stack.top(); + class_stack.pop(); + if (scope_access.size() > 0) + { + if (scope_access.top() == public_scope) + class_stack.top().public_inner_classes.push_back(temp); + else if (scope_access.top() == protected_scope) + class_stack.top().protected_inner_classes.push_back(temp); + } + } + else if (class_stack.size() > 0) + { + classes.push_back(class_stack.top()); + class_stack.pop(); + } + } + + namespaces.pop_back(); + last_full_declaration.clear(); + } + + token_accum.clear(); + recently_seen_closing_bracket = true; + at_top_of_new_scope = false; + break; + + case ';': + // a ; at function scope is an end of a potential declaration + if (scopes.top() == 0) + { + // put token_accum into last_full_declaration + token_accum.swap(last_full_declaration); + } + token_accum.clear(); + + // if we are inside the public area of a class and this ; might be the end + // of a typedef or variable declaration + if (scopes.top() == 0 && scope_access.size() > 0 && + (scope_access.top() == public_scope || scope_access.top() == protected_scope) && + recently_seen_closing_bracket == false) + { + if (recently_seen_typedef) + { + // This should be a typedef inside the public area of a class or struct: + // Only do anything if the class that contains this typedef is in the class_stack. + if (class_stack.size() > 0 && class_stack.top().name == namespaces.back()) + { + tok_typedef_record temp; + temp.declaration = last_full_declaration; + if (scope_access.top() == public_scope) + class_stack.top().public_typedefs.push_back(temp); + else + class_stack.top().protected_typedefs.push_back(temp); + } + + } + else if (recently_seen_paren_0 == false && recently_seen_new_scope == false) + { + // This should be some kind of public variable declaration inside a class or struct: + // Only do anything if the class that contains this member variable is in the class_stack. + if (class_stack.size() > 0 && class_stack.top().name == namespaces.back()) + { + tok_variable_record temp; + temp.declaration = last_full_declaration; + if (scope_access.top() == public_scope) + class_stack.top().public_variables.push_back(temp); + else + class_stack.top().protected_variables.push_back(temp); + } + + } + } + + recently_seen_new_scope = false; + recently_seen_typedef = false; + recently_seen_paren_0 = false; + recently_seen_closing_bracket = false; + recently_seen_slots = false; + at_top_of_new_scope = false; + break; + + case ':': + token_accum.push_back(make_pair(type,token)); + if (recently_seen_slots) + { + token_accum.clear(); + last_full_declaration.clear(); + recently_seen_slots = false; + } + break; + + case '(': + ++paren_count; + token_accum.push_back(make_pair(type,token)); + at_top_of_new_scope = false; + break; + + case ')': + token_accum.push_back(make_pair(type,token)); + + --paren_count; + if (paren_count == 0) + { + recently_seen_paren_0 = true; + if (scopes.top() == 0) + { + last_full_declaration = token_accum; + } + } + + recently_seen_new_scope = false; + at_top_of_new_scope = false; + break; + + default: + token_accum.push_back(make_pair(type,token)); + at_top_of_new_scope = false; + break; + } + }break; + + + case tok_type::WHITE_SPACE: // ------------------------------------------ + { + if (recently_seen_pound_define) + { + if (contains_unescaped_newline(token)) + { + recently_seen_pound_define = false; + recently_seen_paren_0 = false; + recently_seen_preprocessor = false; + + // this is an end of a potential declaration + token_accum.swap(last_full_declaration); + token_accum.clear(); + } + } + + if (recently_seen_preprocessor) + { + if (contains_unescaped_newline(token)) + { + recently_seen_preprocessor = false; + + last_full_declaration.clear(); + token_accum.clear(); + } + } + }break; + + default: // ------------------------------------------ + { + token_accum.push_back(make_pair(type,token)); + at_top_of_new_scope = false; + }break; + } + + + tok.get_token(type, token); + } +} + +// ---------------------------------------------------------------------------------------- + +string get_function_name ( + const std::vector >& declaration +) +{ + string name; + + bool contains_operator = false; + unsigned long operator_pos = 0; + for (unsigned long i = 0; i < declaration.size(); ++i) + { + if (declaration[i].first == tok_type::KEYWORD && + declaration[i].second == "operator") + { + contains_operator = true; + operator_pos = i; + break; + } + } + + + // find the opening ( for the function + unsigned long paren_pos = 0; + long paren_count = 0; + for (long i = declaration.size()-1; i >= 0; --i) + { + if (declaration[i].first == tok_type::OTHER && + declaration[i].second == ")") + { + ++paren_count; + } + else if (declaration[i].first == tok_type::OTHER && + declaration[i].second == "(") + { + --paren_count; + if (paren_count == 0) + { + paren_pos = i; + break; + } + } + } + + + if (contains_operator) + { + name = declaration[operator_pos].second; + for (unsigned long i = operator_pos+1; i < paren_pos; ++i) + { + if (declaration[i].first == tok_type::IDENTIFIER || declaration[i].first == tok_type::KEYWORD) + { + name += " "; + } + + name += declaration[i].second; + } + } + else + { + // if this is a destructor then include the ~ + if (paren_pos > 1 && declaration[paren_pos-2].second == "~") + name = "~" + declaration[paren_pos-1].second; + else if (paren_pos > 0) + name = declaration[paren_pos-1].second; + + + } + + return name; +} + +// ---------------------------------------------------------------------------------------- + +string pretty_print_declaration ( + const std::vector >& decl +) +{ + string temp; + long angle_count = 0; + long paren_count = 0; + + if (decl.size() == 0) + return temp; + + temp = decl[0].second; + + + bool just_closed_template = false; + bool in_template = false; + bool last_was_scope_res = false; + bool seen_operator = false; + + if (temp == "operator") + seen_operator = true; + + for (unsigned long i = 1; i < decl.size(); ++i) + { + bool last_was_less_than = false; + if (decl[i-1].first == tok_type::OTHER && decl[i-1].second == "<") + last_was_less_than = true; + + + if (decl[i].first == tok_type::OTHER && decl[i].second == "<" && + (decl[i-1].second != "operator" && ((i>1 && decl[i-2].second != "operator") || decl[i-1].second != "<") )) + ++angle_count; + + if (decl[i-1].first == tok_type::KEYWORD && decl[i-1].second == "template" && + decl[i].first == tok_type::OTHER && decl[i].second == "<") + { + in_template = true; + temp += " <\n "; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == ">") + { + // don't count angle brackets when they are part of an operator + if (decl[i-1].second != "operator" && ((i>1 && decl[i-2].second != "operator") || decl[i-1].second != ">")) + --angle_count; + + if (angle_count == 0 && in_template) + { + temp += "\n >\n"; + just_closed_template = true; + in_template = false; + } + else + { + temp += ">"; + } + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == "<") + { + temp += "<"; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == ",") + { + if (in_template || (paren_count == 1 && angle_count == 0)) + temp += ",\n "; + else + temp += ","; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == "&") + { + temp += "&"; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == ".") + { + temp += "."; + } + else if (decl[i].first == tok_type::SINGLE_QUOTED_TEXT) + { + temp += decl[i].second; + } + else if (decl[i].first == tok_type::DOUBLE_QUOTED_TEXT) + { + temp += decl[i].second; + } + else if (decl[i-1].first == tok_type::SINGLE_QUOTED_TEXT && decl[i].second == "'") + { + temp += decl[i].second; + } + else if (decl[i-1].first == tok_type::DOUBLE_QUOTED_TEXT && decl[i].second == "\"") + { + temp += decl[i].second; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == "[") + { + temp += "["; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == "]") + { + temp += "]"; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == "-") + { + temp += "-"; + } + else if (decl[i].first == tok_type::NUMBER) + { + if (decl[i-1].second == "=") + temp += " " + decl[i].second; + else + temp += decl[i].second; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == "*") + { + temp += "*"; + } + else if (decl[i].first == tok_type::KEYWORD && decl[i].second == "operator") + { + temp += "\noperator"; + seen_operator = true; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == ":" && + (decl[i-1].second == ":" || (i+10 && decl[i-1].second == "(")) + temp += decl[i].second; + else + temp += " " + decl[i].second; + + just_closed_template = false; + last_was_scope_res = false; + } + + + + } + + return temp; +} + +// ---------------------------------------------------------------------------------------- + +string format_comment ( + const string& comment, + const unsigned long expand_tabs +) +{ + if (comment.size() <= 6) + return ""; + + string temp = trim(trim(comment.substr(3,comment.size()-6), " \t"), "\n\r"); + + + // if we should expand tabs to spaces + if (expand_tabs != 0) + { + unsigned long column = 0; + string str; + for (unsigned long i = 0; i < temp.size(); ++i) + { + if (temp[i] == '\t') + { + const unsigned long num_spaces = expand_tabs - column%expand_tabs; + column += num_spaces; + str.insert(str.end(), num_spaces, ' '); + } + else if (temp[i] == '\n' || temp[i] == '\r') + { + str += temp[i]; + column = 0; + } + else + { + str += temp[i]; + ++column; + } + } + + // put str into temp + str.swap(temp); + } + + // now figure out what the smallest amount of leading white space is and remove it from each line. + unsigned long num_whitespace = 100000; + + string::size_type pos1 = 0, pos2 = 0; + + while (pos1 != string::npos) + { + // find start of non-white-space + pos2 = temp.find_first_not_of(" \t",pos1); + + // if this is a line of just white space then ignore it + if (pos2 != string::npos && temp[pos2] != '\n' && temp[pos2] != '\r') + { + if (pos2-pos1 < num_whitespace) + num_whitespace = pos2-pos1; + } + + // find end-of-line + pos1 = temp.find_first_of("\n\r", pos2); + // find start of next line + pos2 = temp.find_first_not_of("\n\r", pos1); + pos1 = pos2; + } + + // now remove the leading white space + string temp2; + unsigned long counter = 0; + for (unsigned long i = 0; i < temp.size(); ++i) + { + // if we are looking at a new line + if (temp[i] == '\n' || temp[i] == '\r') + { + counter = 0; + } + else if (counter < num_whitespace) + { + ++counter; + continue; + } + + temp2 += temp[i]; + } + + return temp2; +} + +// ---------------------------------------------------------------------------------------- + +typedef_record convert_tok_typedef_record ( + const tok_typedef_record& rec +) +{ + typedef_record temp; + temp.declaration = pretty_print_declaration(rec.declaration); + return temp; +} + +// ---------------------------------------------------------------------------------------- + +variable_record convert_tok_variable_record ( + const tok_variable_record& rec +) +{ + variable_record temp; + temp.declaration = pretty_print_declaration(rec.declaration); + return temp; +} + +// ---------------------------------------------------------------------------------------- + +method_record convert_tok_method_record ( + const tok_method_record& rec, + const unsigned long expand_tabs +) +{ + method_record temp; + + temp.comment = format_comment(rec.comment, expand_tabs); + temp.name = get_function_name(rec.declaration); + temp.declaration = pretty_print_declaration(rec.declaration); + return temp; +} + +// ---------------------------------------------------------------------------------------- + +class_record convert_tok_class_record ( + const tok_class_record& rec, + const unsigned long expand_tabs +) +{ + class_record crec; + + + crec.scope = rec.scope; + crec.file = rec.file; + crec.comment = format_comment(rec.comment, expand_tabs); + + crec.name.clear(); + + // find the first class token + for (unsigned long i = 0; i+1 < rec.declaration.size(); ++i) + { + if (rec.declaration[i].first == tok_type::KEYWORD && + (rec.declaration[i].second == "class" || + rec.declaration[i].second == "struct" ) + ) + { + crec.name = rec.declaration[i+1].second; + break; + } + } + + crec.declaration = pretty_print_declaration(rec.declaration); + + for (unsigned long i = 0; i < rec.public_typedefs.size(); ++i) + crec.public_typedefs.push_back(convert_tok_typedef_record(rec.public_typedefs[i])); + + for (unsigned long i = 0; i < rec.public_variables.size(); ++i) + crec.public_variables.push_back(convert_tok_variable_record(rec.public_variables[i])); + + for (unsigned long i = 0; i < rec.protected_typedefs.size(); ++i) + crec.protected_typedefs.push_back(convert_tok_typedef_record(rec.protected_typedefs[i])); + + for (unsigned long i = 0; i < rec.protected_variables.size(); ++i) + crec.protected_variables.push_back(convert_tok_variable_record(rec.protected_variables[i])); + + for (unsigned long i = 0; i < rec.public_methods.size(); ++i) + crec.public_methods.push_back(convert_tok_method_record(rec.public_methods[i], expand_tabs)); + + for (unsigned long i = 0; i < rec.protected_methods.size(); ++i) + crec.protected_methods.push_back(convert_tok_method_record(rec.protected_methods[i], expand_tabs)); + + for (unsigned long i = 0; i < rec.public_inner_classes.size(); ++i) + crec.public_inner_classes.push_back(convert_tok_class_record(rec.public_inner_classes[i], expand_tabs)); + + for (unsigned long i = 0; i < rec.protected_inner_classes.size(); ++i) + crec.protected_inner_classes.push_back(convert_tok_class_record(rec.protected_inner_classes[i], expand_tabs)); + + + return crec; +} + +// ---------------------------------------------------------------------------------------- + +function_record convert_tok_function_record ( + const tok_function_record& rec, + const unsigned long expand_tabs +) +{ + function_record temp; + + temp.scope = rec.scope; + temp.file = rec.file; + temp.comment = format_comment(rec.comment, expand_tabs); + temp.name = get_function_name(rec.declaration); + temp.declaration = pretty_print_declaration(rec.declaration); + + return temp; +} + +// ---------------------------------------------------------------------------------------- + +void convert_to_normal_records ( + const std::vector& tok_functions, + const std::vector& tok_classes, + const unsigned long expand_tabs, + std::vector& functions, + std::vector& classes +) +{ + functions.clear(); + classes.clear(); + + + for (unsigned long i = 0; i < tok_functions.size(); ++i) + { + functions.push_back(convert_tok_function_record(tok_functions[i], expand_tabs)); + } + + + for (unsigned long i = 0; i < tok_classes.size(); ++i) + { + classes.push_back(convert_tok_class_record(tok_classes[i], expand_tabs)); + } + + +} + +// ---------------------------------------------------------------------------------------- + +string add_entity_ref (const string& str) +{ + string temp; + for (unsigned long i = 0; i < str.size(); ++i) + { + if (str[i] == '&') + temp += "&"; + else if (str[i] == '<') + temp += "<"; + else if (str[i] == '>') + temp += ">"; + else + temp += str[i]; + } + return temp; +} + +// ---------------------------------------------------------------------------------------- + +string flip_slashes (string str) +{ + for (unsigned long i = 0; i < str.size(); ++i) + { + if (str[i] == '\\') + str[i] = '/'; + } + return str; +} + +// ---------------------------------------------------------------------------------------- + +void write_as_xml ( + const function_record& rec, + ostream& fout +) +{ + fout << " \n"; + fout << " " << add_entity_ref(rec.name) << "\n"; + fout << " " << add_entity_ref(rec.scope) << "\n"; + fout << " " << add_entity_ref(rec.declaration) << "\n"; + fout << " " << flip_slashes(add_entity_ref(rec.file)) << "\n"; + fout << " " << add_entity_ref(rec.comment) << "\n"; + fout << " \n"; +} + +// ---------------------------------------------------------------------------------------- + +void write_as_xml ( + const class_record& rec, + ostream& fout, + unsigned long indent +) +{ + const string pad(indent, ' '); + + fout << pad << "\n"; + fout << pad << " " << add_entity_ref(rec.name) << "\n"; + fout << pad << " " << add_entity_ref(rec.scope) << "\n"; + fout << pad << " " << add_entity_ref(rec.declaration) << "\n"; + fout << pad << " " << flip_slashes(add_entity_ref(rec.file)) << "\n"; + fout << pad << " " << add_entity_ref(rec.comment) << "\n"; + + + if (rec.public_typedefs.size() > 0) + { + fout << pad << " \n"; + for (unsigned long i = 0; i < rec.public_typedefs.size(); ++i) + { + fout << pad << " " << add_entity_ref(rec.public_typedefs[i].declaration) << "\n"; + } + fout << pad << " \n"; + } + + + if (rec.public_variables.size() > 0) + { + fout << pad << " \n"; + for (unsigned long i = 0; i < rec.public_variables.size(); ++i) + { + fout << pad << " " << add_entity_ref(rec.public_variables[i].declaration) << "\n"; + } + fout << pad << " \n"; + } + + if (rec.protected_typedefs.size() > 0) + { + fout << pad << " \n"; + for (unsigned long i = 0; i < rec.protected_typedefs.size(); ++i) + { + fout << pad << " " << add_entity_ref(rec.protected_typedefs[i].declaration) << "\n"; + } + fout << pad << " \n"; + } + + + if (rec.protected_variables.size() > 0) + { + fout << pad << " \n"; + for (unsigned long i = 0; i < rec.protected_variables.size(); ++i) + { + fout << pad << " " << add_entity_ref(rec.protected_variables[i].declaration) << "\n"; + } + fout << pad << " \n"; + } + + + if (rec.public_methods.size() > 0) + { + fout << pad << " \n"; + for (unsigned long i = 0; i < rec.public_methods.size(); ++i) + { + fout << pad << " \n"; + fout << pad << " " << add_entity_ref(rec.public_methods[i].name) << "\n"; + fout << pad << " " << add_entity_ref(rec.public_methods[i].declaration) << "\n"; + fout << pad << " " << add_entity_ref(rec.public_methods[i].comment) << "\n"; + fout << pad << " \n"; + } + fout << pad << " \n"; + } + + + if (rec.protected_methods.size() > 0) + { + fout << pad << " \n"; + for (unsigned long i = 0; i < rec.protected_methods.size(); ++i) + { + fout << pad << " \n"; + fout << pad << " " << add_entity_ref(rec.protected_methods[i].name) << "\n"; + fout << pad << " " << add_entity_ref(rec.protected_methods[i].declaration) << "\n"; + fout << pad << " " << add_entity_ref(rec.protected_methods[i].comment) << "\n"; + fout << pad << " \n"; + } + fout << pad << " \n"; + } + + + if (rec.public_inner_classes.size() > 0) + { + fout << pad << " \n"; + for (unsigned long i = 0; i < rec.public_inner_classes.size(); ++i) + { + write_as_xml(rec.public_inner_classes[i], fout, indent+4); + } + fout << pad << " \n"; + } + + if (rec.protected_inner_classes.size() > 0) + { + fout << pad << " \n"; + for (unsigned long i = 0; i < rec.protected_inner_classes.size(); ++i) + { + write_as_xml(rec.protected_inner_classes[i], fout, indent+4); + } + fout << pad << " \n"; + } + + + fout << pad << "\n"; +} + +// ---------------------------------------------------------------------------------------- + +void save_to_xml_file ( + const std::vector& functions, + const std::vector& classes +) +{ + ofstream fout("output.xml"); + + fout << "" << endl; + fout << "" << endl; + + fout << " " << endl; + for (unsigned long i = 0; i < classes.size(); ++i) + { + write_as_xml(classes[i], fout, 4); + fout << "\n"; + } + fout << " \n\n" << endl; + + + fout << " " << endl; + for (unsigned long i = 0; i < functions.size(); ++i) + { + write_as_xml(functions[i], fout); + fout << "\n"; + } + fout << " " << endl; + + fout << "" << endl; +} + +// ---------------------------------------------------------------------------------------- + +void generate_xml_markup( + const cmd_line_parser::check_1a_c& parser, + const std::string& filter, + const unsigned long search_depth, + const unsigned long expand_tabs +) +{ + + // first figure out which files should be processed + std::vector > files; + obtain_list_of_files(parser, filter, search_depth, files); + + + std::vector tok_functions; + std::vector tok_classes; + + for (unsigned long i = 0; i < files.size(); ++i) + { + ifstream fin(files[i].second.c_str()); + if (!fin) + { + cerr << "Error opening file: " << files[i].second << endl; + return; + } + process_file(fin, files[i].first, tok_functions, tok_classes); + } + + std::vector functions; + std::vector classes; + + convert_to_normal_records(tok_functions, tok_classes, expand_tabs, functions, classes); + + save_to_xml_file(functions, classes); +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/htmlify/to_xml.h b/ml/dlib/tools/htmlify/to_xml.h new file mode 100644 index 000000000..4bdf3f00e --- /dev/null +++ b/ml/dlib/tools/htmlify/to_xml.h @@ -0,0 +1,22 @@ +#ifndef DLIB_HTMLIFY_TO_XmL_H__ +#define DLIB_HTMLIFY_TO_XmL_H__ + +#include "dlib/cmd_line_parser.h" +#include + +void generate_xml_markup( + const dlib::cmd_line_parser::check_1a_c& parser, + const std::string& filter, + const unsigned long search_depth, + const unsigned long expand_tabs +); +/*! + ensures + - reads all the files indicated by the parser arguments and converts them + to XML. The output will be stored in the output.xml file. + - if (expand_tabs != 0) then + - tabs will be replaced with expand_tabs spaces inside comment blocks +!*/ + +#endif // DLIB_HTMLIFY_TO_XmL_H__ + diff --git a/ml/dlib/tools/htmlify/to_xml_example/bigminus.gif b/ml/dlib/tools/htmlify/to_xml_example/bigminus.gif new file mode 100644 index 000000000..aea8e5c01 Binary files /dev/null and b/ml/dlib/tools/htmlify/to_xml_example/bigminus.gif differ diff --git a/ml/dlib/tools/htmlify/to_xml_example/bigplus.gif b/ml/dlib/tools/htmlify/to_xml_example/bigplus.gif new file mode 100644 index 000000000..6bee68e21 Binary files /dev/null and b/ml/dlib/tools/htmlify/to_xml_example/bigplus.gif differ diff --git a/ml/dlib/tools/htmlify/to_xml_example/example.xml b/ml/dlib/tools/htmlify/to_xml_example/example.xml new file mode 100644 index 000000000..472a4a5e1 --- /dev/null +++ b/ml/dlib/tools/htmlify/to_xml_example/example.xml @@ -0,0 +1,8 @@ + + + + + Documented Code + + + diff --git a/ml/dlib/tools/htmlify/to_xml_example/minus.gif b/ml/dlib/tools/htmlify/to_xml_example/minus.gif new file mode 100644 index 000000000..1deac2fe1 Binary files /dev/null and b/ml/dlib/tools/htmlify/to_xml_example/minus.gif differ diff --git a/ml/dlib/tools/htmlify/to_xml_example/output.xml b/ml/dlib/tools/htmlify/to_xml_example/output.xml new file mode 100644 index 000000000..95e4de6ae --- /dev/null +++ b/ml/dlib/tools/htmlify/to_xml_example/output.xml @@ -0,0 +1,49 @@ + + + + + test + + class test + test.cpp + WHAT THIS OBJECT REPRESENTS + This is a simple test class that doesn't do anything + + typedef int type + + + + test + test() + ensures + - constructs a test object + + + print + void +print() const + ensures + - prints a message to the screen + + + + + + + + + + add_numbers + + int +add_numbers ( + int a, + int b +) + test.cpp + ensures + - returns a + b + + + + diff --git a/ml/dlib/tools/htmlify/to_xml_example/plus.gif b/ml/dlib/tools/htmlify/to_xml_example/plus.gif new file mode 100644 index 000000000..2d15c1417 Binary files /dev/null and b/ml/dlib/tools/htmlify/to_xml_example/plus.gif differ diff --git a/ml/dlib/tools/htmlify/to_xml_example/stylesheet.xsl b/ml/dlib/tools/htmlify/to_xml_example/stylesheet.xsl new file mode 100644 index 000000000..7a44862a3 --- /dev/null +++ b/ml/dlib/tools/htmlify/to_xml_example/stylesheet.xsl @@ -0,0 +1,354 @@ + + + + + + + + + + + abcdefghijklmnopqrstuvwxyz + ABCDEFGHIJKLMNOPQRSTUVWXYZ + + + + + + + + <xsl:if test="title"> + <xsl:value-of select="title" /> + </xsl:if> + + + + + + + + + + +

+
+ + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + +

Classes and Structs:

+ + + + + +

Global Functions:

+ + +
+ + () +
+
+ + Scope:
+
+ File:

+
+
;
+
+
+
+
+
+
+ +
+ + + + +
+ + +
+
+ + Scope:
+
+ File:

+
+
;

+

+
+ + + + Protected Typedefs +
+
    + +
  • ;
  • +
    +
+
+
+
+ + + + Public Typedefs +
+
    + +
  • ;
  • +
    +
+
+
+
+ + + + Protected Variables +
+
    + +
  • ;
  • +
    +
+
+
+
+ + + + Public Variables +
+
    + +
  • ;
  • +
    +
+
+
+
+ + + + Protected Methods +
+ +
+ Method Name:

+
+
;
+

+
+
+
+
+
+
+ + + + Public Methods +
+ +
+ Method Name:

+
+
;
+

+
+
+
+
+
+
+ + + + Protected Inner Classes +
+ + + +
+
+
+ + + + Public Inner Classes +
+ + + +
+
+
+ +
+
+
+ + + + + + + + + + +
diff --git a/ml/dlib/tools/htmlify/to_xml_example/test.cpp b/ml/dlib/tools/htmlify/to_xml_example/test.cpp new file mode 100644 index 000000000..edbdfff54 --- /dev/null +++ b/ml/dlib/tools/htmlify/to_xml_example/test.cpp @@ -0,0 +1,78 @@ +#include + +// ---------------------------------------------------------------------------------------- + +using namespace std; + +// ---------------------------------------------------------------------------------------- + +class test +{ + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple test class that doesn't do anything + !*/ + +public: + + typedef int type; + + test (); + /*! + ensures + - constructs a test object + !*/ + + void print () const; + /*! + ensures + - prints a message to the screen + !*/ + +}; + +// ---------------------------------------------------------------------------------------- + +test::test() {} + +void test::print() const +{ + cout << "A message!" << endl; +} + +// ---------------------------------------------------------------------------------------- + +int add_numbers ( + int a, + int b +) +/*! + ensures + - returns a + b +!*/ +{ + return a + b; +} + +// ---------------------------------------------------------------------------------------- + +void a_function ( +) +/*!P + This is a function which won't show up in the output of htmlify --to-xml + because of the presence of the P in the above /*!P above. +!*/ +{ +} + +// ---------------------------------------------------------------------------------------- + +int main() +{ + test a; + a.print(); +} + +// ---------------------------------------------------------------------------------------- + + diff --git a/ml/dlib/tools/imglab/CMakeLists.txt b/ml/dlib/tools/imglab/CMakeLists.txt new file mode 100644 index 000000000..46c64fb3e --- /dev/null +++ b/ml/dlib/tools/imglab/CMakeLists.txt @@ -0,0 +1,41 @@ +# +# This is a CMake makefile. You can find the cmake utility and +# information about it at http://www.cmake.org +# + +cmake_minimum_required(VERSION 2.8.12) + +# create a variable called target_name and set it to the string "imglab" +set (target_name imglab) + +PROJECT(${target_name}) +add_subdirectory(../../dlib dlib_build) + +# add all the cpp files we want to compile to this list. This tells +# cmake that they are part of our target (which is the executable named imglab) +add_executable(${target_name} + src/main.cpp + src/metadata_editor.h + src/metadata_editor.cpp + src/convert_pascal_xml.h + src/convert_pascal_xml.cpp + src/convert_pascal_v1.h + src/convert_pascal_v1.cpp + src/convert_idl.h + src/convert_idl.cpp + src/common.h + src/common.cpp + src/cluster.cpp + src/flip_dataset.cpp +) + + +# Tell cmake to link our target executable to dlib. +target_link_libraries(${target_name} dlib::dlib ) + + +install(TARGETS ${target_name} + RUNTIME DESTINATION bin + ) +install(PROGRAMS convert_imglab_paths_to_relative copy_imglab_dataset DESTINATION bin ) + diff --git a/ml/dlib/tools/imglab/README.txt b/ml/dlib/tools/imglab/README.txt new file mode 100644 index 000000000..3f0ca92a1 --- /dev/null +++ b/ml/dlib/tools/imglab/README.txt @@ -0,0 +1,40 @@ +imglab is a simple graphical tool for annotating images with object bounding +boxes and optionally their part locations. Generally, you use it when you want +to train an object detector (e.g. a face detector) since it allows you to +easily create the needed training dataset. + +You can compile imglab with the following commands: + cd dlib/tools/imglab + mkdir build + cd build + cmake .. + cmake --build . --config Release +Note that you may need to install CMake (www.cmake.org) for this to work. On a +unix system you can also install imglab into /usr/local/bin by running + sudo make install +This will make running it more convenient. + +Next, to use it, lets assume you have a folder of images called /tmp/images. +These images should contain examples of the objects you want to learn to +detect. You will use the imglab tool to label these objects. Do this by +typing the following command: + ./imglab -c mydataset.xml /tmp/images +This will create a file called mydataset.xml which simply lists the images in +/tmp/images. To add bounding boxes to the objects you run: + ./imglab mydataset.xml +and a window will appear showing all the images. You can use the up and down +arrow keys to cycle though the images and the mouse to label objects. In +particular, holding the shift key, left clicking, and dragging the mouse will +allow you to draw boxes around the objects you wish to detect. + +Once you finish labeling objects go to the file menu, click save, and then +close the program. This will save the object boxes back to mydataset.xml. You +can verify this by opening the tool again with: + ./imglab mydataset.xml +and observing that the boxes are present. + + +imglab can do a few additional things. To see these run: + imglab -h +and also read the instructions in the About->Help menu. + diff --git a/ml/dlib/tools/imglab/convert_imglab_paths_to_relative b/ml/dlib/tools/imglab/convert_imglab_paths_to_relative new file mode 100755 index 000000000..09c5ef7a5 --- /dev/null +++ b/ml/dlib/tools/imglab/convert_imglab_paths_to_relative @@ -0,0 +1,24 @@ +#!/usr/bin/perl + +use File::Spec; + +die "This script converts all the file names in an imglab XML file to have paths relative to the current folder. Call it like this: ./convert_imglab_paths_to_relative some_file.xml" if @ARGV != 1; + +$file = @ARGV[0]; +open(INFO, $file) or die('Could not open file.'); + +foreach $line () +{ + if (index($line, 'file=\'') != -1) + { + $line =~ /file='(.*)'/; + $relpath = File::Spec->abs2rel($1); + $line =~ s/$1/$relpath/; + print $line + } + else + { + print $line + } +} + diff --git a/ml/dlib/tools/imglab/copy_imglab_dataset b/ml/dlib/tools/imglab/copy_imglab_dataset new file mode 100755 index 000000000..8b44ed166 --- /dev/null +++ b/ml/dlib/tools/imglab/copy_imglab_dataset @@ -0,0 +1,22 @@ +#!/bin/bash + +if [ "$#" -ne 2 ]; then + echo "This script copies an imglab XML file and its associated images to a new folder." + echo "Notably, it will avoid copying unnecessary images." + echo "Call this script like this:" + echo " ./copy_dataset some_file.xml dest_dir" + exit 1 +fi + +XML_FILE=$1 +DEST=$2 + + + +mkdir -p $DEST + +# Get the list of files we need to copy, then build the cp statements with 1000 files at most in each statement, then tell bash to run them all. +imglab --files $XML_FILE | xargs perl -e 'use File::Spec; foreach (@ARGV) {print File::Spec->abs2rel($_) . "\n"}' | sort | uniq | xargs -L1000 echo | xargs -I{} echo cp -a --parents {} $DEST | bash + +convert_imglab_paths_to_relative $XML_FILE > $DEST/$(basename $XML_FILE) + diff --git a/ml/dlib/tools/imglab/src/cluster.cpp b/ml/dlib/tools/imglab/src/cluster.cpp new file mode 100644 index 000000000..23b289a7f --- /dev/null +++ b/ml/dlib/tools/imglab/src/cluster.cpp @@ -0,0 +1,260 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "cluster.h" +#include +#include +#include +#include +#include +#include +#include +#include + +// ---------------------------------------------------------------------------------------- + +using namespace std; +using namespace dlib; + +// ---------------------------------------------------------------------------- + +struct assignment +{ + unsigned long c; + double dist; + unsigned long idx; + + bool operator<(const assignment& item) const + { return dist < item.dist; } +}; + +std::vector angular_cluster ( + std::vector > feats, + const unsigned long num_clusters +) +{ + DLIB_CASSERT(feats.size() != 0, "The dataset can't be empty"); + for (unsigned long i = 0; i < feats.size(); ++i) + { + DLIB_CASSERT(feats[i].size() == feats[0].size(), "All feature vectors must have the same length."); + } + + // find the centroid of feats + matrix m; + for (unsigned long i = 0; i < feats.size(); ++i) + m += feats[i]; + m /= feats.size(); + + // Now center feats and then project onto the unit sphere. The reason for projecting + // onto the unit sphere is so pick_initial_centers() works in a sensible way. + for (unsigned long i = 0; i < feats.size(); ++i) + { + feats[i] -= m; + double len = length(feats[i]); + if (len != 0) + feats[i] /= len; + } + + // now do angular clustering of the points + std::vector > centers; + pick_initial_centers(num_clusters, centers, feats, linear_kernel >(), 0.05); + find_clusters_using_angular_kmeans(feats, centers); + + // and then report the resulting assignments + std::vector assignments; + for (unsigned long i = 0; i < feats.size(); ++i) + { + assignment temp; + temp.c = nearest_center(centers, feats[i]); + temp.dist = length(feats[i] - centers[temp.c]); + temp.idx = i; + assignments.push_back(temp); + } + return assignments; +} + +// ---------------------------------------------------------------------------------------- + +bool compare_first ( + const std::pair& a, + const std::pair& b +) +{ + return a.first < b.first; +} + +// ---------------------------------------------------------------------------------------- + +double mean_aspect_ratio ( + const image_dataset_metadata::dataset& data +) +{ + double sum = 0; + double cnt = 0; + for (unsigned long i = 0; i < data.images.size(); ++i) + { + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + rectangle rect = data.images[i].boxes[j].rect; + if (rect.area() == 0 || data.images[i].boxes[j].ignore) + continue; + sum += rect.width()/(double)rect.height(); + ++cnt; + } + } + + if (cnt != 0) + return sum/cnt; + else + return 0; +} + +// ---------------------------------------------------------------------------------------- + +bool has_non_ignored_boxes (const image_dataset_metadata::image& img) +{ + for (auto&& b : img.boxes) + { + if (!b.ignore) + return true; + } + return false; +} + +// ---------------------------------------------------------------------------------------- + +int cluster_dataset( + const dlib::command_line_parser& parser +) +{ + // make sure the user entered an argument to this program + if (parser.number_of_arguments() != 1) + { + cerr << "The --cluster option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + const unsigned long num_clusters = get_option(parser, "cluster", 2); + const unsigned long chip_size = get_option(parser, "size", 8000); + + image_dataset_metadata::dataset data; + + image_dataset_metadata::load_image_dataset_metadata(data, parser[0]); + set_current_dir(get_parent_directory(file(parser[0]))); + + const double aspect_ratio = mean_aspect_ratio(data); + + dlib::array > images; + std::vector > feats; + console_progress_indicator pbar(data.images.size()); + // extract all the object chips and HOG features. + cout << "Loading image data..." << endl; + for (unsigned long i = 0; i < data.images.size(); ++i) + { + pbar.print_status(i); + if (!has_non_ignored_boxes(data.images[i])) + continue; + + array2d img, chip; + load_image(img, data.images[i].filename); + + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (data.images[i].boxes[j].ignore || data.images[i].boxes[j].rect.area() < 10) + continue; + drectangle rect = data.images[i].boxes[j].rect; + rect = set_aspect_ratio(rect, aspect_ratio); + extract_image_chip(img, chip_details(rect, chip_size), chip); + feats.push_back(extract_fhog_features(chip)); + images.push_back(chip); + } + } + + if (feats.size() == 0) + { + cerr << "No non-ignored object boxes found in the XML dataset. You can't cluster an empty dataset." << endl; + return EXIT_FAILURE; + } + + cout << "\nClustering objects..." << endl; + std::vector assignments = angular_cluster(feats, num_clusters); + + + // Now output each cluster to disk as an XML file. + for (unsigned long c = 0; c < num_clusters; ++c) + { + // We are going to accumulate all the image metadata for cluster c. We put it + // into idata so we can sort the images such that images with central chips + // come before less central chips. The idea being to get the good chips to + // show up first in the listing, making it easy to manually remove bad ones if + // that is desired. + std::vector > idata(data.images.size()); + unsigned long idx = 0; + for (unsigned long i = 0; i < data.images.size(); ++i) + { + idata[i].first = std::numeric_limits::infinity(); + idata[i].second.filename = data.images[i].filename; + if (!has_non_ignored_boxes(data.images[i])) + continue; + + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + idata[i].second.boxes.push_back(data.images[i].boxes[j]); + + if (data.images[i].boxes[j].ignore || data.images[i].boxes[j].rect.area() < 10) + continue; + + // If this box goes into cluster c then update the score for the whole + // image based on this boxes' score. Otherwise, mark the box as + // ignored. + if (assignments[idx].c == c) + idata[i].first = std::min(idata[i].first, assignments[idx].dist); + else + idata[i].second.boxes.back().ignore = true; + + ++idx; + } + } + + // now save idata to an xml file. + std::sort(idata.begin(), idata.end(), compare_first); + image_dataset_metadata::dataset cdata; + cdata.comment = data.comment + "\n\n This file contains objects which were clustered into group " + + cast_to_string(c+1) + " of " + cast_to_string(num_clusters) + " groups with a chip size of " + + cast_to_string(chip_size) + " by imglab."; + cdata.name = data.name; + for (unsigned long i = 0; i < idata.size(); ++i) + { + // if this image has non-ignored boxes in it then include it in the output. + if (idata[i].first != std::numeric_limits::infinity()) + cdata.images.push_back(idata[i].second); + } + + string outfile = "cluster_"+pad_int_with_zeros(c+1, 3) + ".xml"; + cout << "Saving " << outfile << endl; + save_image_dataset_metadata(cdata, outfile); + } + + // Now output each cluster to disk as a big tiled jpeg file. Sort everything so, just + // like in the xml file above, the best objects come first in the tiling. + std::sort(assignments.begin(), assignments.end()); + for (unsigned long c = 0; c < num_clusters; ++c) + { + dlib::array > temp; + for (unsigned long i = 0; i < assignments.size(); ++i) + { + if (assignments[i].c == c) + temp.push_back(images[assignments[i].idx]); + } + + string outfile = "cluster_"+pad_int_with_zeros(c+1, 3) + ".jpg"; + cout << "Saving " << outfile << endl; + save_jpeg(tile_images(temp), outfile); + } + + + return EXIT_SUCCESS; +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/imglab/src/cluster.h b/ml/dlib/tools/imglab/src/cluster.h new file mode 100644 index 000000000..6cb41a373 --- /dev/null +++ b/ml/dlib/tools/imglab/src/cluster.h @@ -0,0 +1,11 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMGLAB_ClUSTER_H_ +#define DLIB_IMGLAB_ClUSTER_H_ + +#include + +int cluster_dataset(const dlib::command_line_parser& parser); + +#endif //DLIB_IMGLAB_ClUSTER_H_ + diff --git a/ml/dlib/tools/imglab/src/common.cpp b/ml/dlib/tools/imglab/src/common.cpp new file mode 100644 index 000000000..d9cc1dca4 --- /dev/null +++ b/ml/dlib/tools/imglab/src/common.cpp @@ -0,0 +1,60 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "common.h" +#include +#include + +// ---------------------------------------------------------------------------------------- + +std::string strip_path ( + const std::string& str, + const std::string& prefix +) +{ + unsigned long i; + for (i = 0; i < str.size() && i < prefix.size(); ++i) + { + if (str[i] != prefix[i]) + return str; + } + + if (i < str.size() && (str[i] == '/' || str[i] == '\\')) + ++i; + + return str.substr(i); +} + +// ---------------------------------------------------------------------------------------- + +void make_empty_file ( + const std::string& filename +) +{ + std::ofstream fout(filename.c_str()); + if (!fout) + throw dlib::error("ERROR: Unable to open " + filename + " for writing."); +} + +// ---------------------------------------------------------------------------------------- + +std::string to_png_name (const std::string& filename) +{ + std::string::size_type pos = filename.find_last_of("."); + if (pos == std::string::npos) + throw dlib::error("invalid filename: " + filename); + return filename.substr(0,pos) + ".png"; +} + +// ---------------------------------------------------------------------------------------- + +std::string to_jpg_name (const std::string& filename) +{ + std::string::size_type pos = filename.find_last_of("."); + if (pos == std::string::npos) + throw dlib::error("invalid filename: " + filename); + return filename.substr(0,pos) + ".jpg"; +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/imglab/src/common.h b/ml/dlib/tools/imglab/src/common.h new file mode 100644 index 000000000..42e905bc3 --- /dev/null +++ b/ml/dlib/tools/imglab/src/common.h @@ -0,0 +1,45 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMGLAB_COmMON_H__ +#define DLIB_IMGLAB_COmMON_H__ + +#include + +// ---------------------------------------------------------------------------------------- + +std::string strip_path ( + const std::string& str, + const std::string& prefix +); +/*! + ensures + - if (prefix is a prefix of str) then + - returns the part of str after the prefix + (additionally, str will not begin with a / or \ character) + - else + - return str +!*/ + +// ---------------------------------------------------------------------------------------- + +void make_empty_file ( + const std::string& filename +); +/*! + ensures + - creates an empty file of the given name +!*/ + +// ---------------------------------------------------------------------------------------- + +std::string to_png_name (const std::string& filename); +std::string to_jpg_name (const std::string& filename); + +// ---------------------------------------------------------------------------------------- + +const int JPEG_QUALITY = 90; + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_IMGLAB_COmMON_H__ + diff --git a/ml/dlib/tools/imglab/src/convert_idl.cpp b/ml/dlib/tools/imglab/src/convert_idl.cpp new file mode 100644 index 000000000..7ff601d0c --- /dev/null +++ b/ml/dlib/tools/imglab/src/convert_idl.cpp @@ -0,0 +1,184 @@ + +#include "convert_idl.h" +#include "dlib/data_io.h" +#include +#include +#include +#include +#include + +using namespace std; +using namespace dlib; + +namespace +{ + using namespace dlib::image_dataset_metadata; + +// ---------------------------------------------------------------------------------------- + + inline bool next_is_number(std::istream& in) + { + return ('0' <= in.peek() && in.peek() <= '9') || in.peek() == '-' || in.peek() == '+'; + } + + int read_int(std::istream& in) + { + bool is_neg = false; + if (in.peek() == '-') + { + is_neg = true; + in.get(); + } + if (in.peek() == '+') + in.get(); + + int val = 0; + while ('0' <= in.peek() && in.peek() <= '9') + { + val = 10*val + in.get()-'0'; + } + + if (is_neg) + return -val; + else + return val; + } + +// ---------------------------------------------------------------------------------------- + + void parse_annotation_file( + const std::string& file, + dlib::image_dataset_metadata::dataset& data + ) + { + ifstream fin(file.c_str()); + if (!fin) + throw dlib::error("Unable to open file " + file); + + + bool in_quote = false; + int point_count = 0; + bool in_point_list = false; + bool saw_any_points = false; + + image img; + string label; + point p1,p2; + while (fin.peek() != EOF) + { + if (in_point_list && next_is_number(fin)) + { + const int val = read_int(fin); + switch (point_count) + { + case 0: p1.x() = val; break; + case 1: p1.y() = val; break; + case 2: p2.x() = val; break; + case 3: p2.y() = val; break; + default: + throw dlib::error("parse error in file " + file); + } + + ++point_count; + } + + char ch = fin.get(); + + if (ch == ':') + continue; + + if (ch == '"') + { + in_quote = !in_quote; + continue; + } + + if (in_quote) + { + img.filename += ch; + continue; + } + + + if (ch == '(') + { + in_point_list = true; + point_count = 0; + label.clear(); + saw_any_points = true; + } + if (ch == ')') + { + in_point_list = false; + + label.clear(); + while (fin.peek() != EOF && + fin.peek() != ';' && + fin.peek() != ',') + { + char ch = fin.get(); + if (ch == ':') + continue; + + label += ch; + } + } + + if (ch == ',' && !in_point_list) + { + + box b; + b.rect = rectangle(p1,p2); + b.label = label; + img.boxes.push_back(b); + } + + + if (ch == ';') + { + + if (saw_any_points) + { + box b; + b.rect = rectangle(p1,p2); + b.label = label; + img.boxes.push_back(b); + saw_any_points = false; + } + data.images.push_back(img); + + + img.filename.clear(); + img.boxes.clear(); + } + + + } + + + + } + +// ---------------------------------------------------------------------------------------- + +} + +void convert_idl( + const command_line_parser& parser +) +{ + cout << "Convert from IDL annotation format..." << endl; + + dlib::image_dataset_metadata::dataset dataset; + + for (unsigned long i = 0; i < parser.number_of_arguments(); ++i) + { + parse_annotation_file(parser[i], dataset); + } + + const std::string filename = parser.option("c").argument(); + save_image_dataset_metadata(dataset, filename); +} + + + diff --git a/ml/dlib/tools/imglab/src/convert_idl.h b/ml/dlib/tools/imglab/src/convert_idl.h new file mode 100644 index 000000000..d8c33d961 --- /dev/null +++ b/ml/dlib/tools/imglab/src/convert_idl.h @@ -0,0 +1,14 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMGLAB_CONVErT_IDL_H__ +#define DLIB_IMGLAB_CONVErT_IDL_H__ + +#include "common.h" +#include + +void convert_idl(const dlib::command_line_parser& parser); + +#endif // DLIB_IMGLAB_CONVErT_IDL_H__ + + + diff --git a/ml/dlib/tools/imglab/src/convert_pascal_v1.cpp b/ml/dlib/tools/imglab/src/convert_pascal_v1.cpp new file mode 100644 index 000000000..8eaf5e2bb --- /dev/null +++ b/ml/dlib/tools/imglab/src/convert_pascal_v1.cpp @@ -0,0 +1,177 @@ + +#include "convert_pascal_v1.h" +#include "dlib/data_io.h" +#include +#include +#include +#include + +using namespace std; +using namespace dlib; + +namespace +{ + using namespace dlib::image_dataset_metadata; + +// ---------------------------------------------------------------------------------------- + + std::string pick_out_quoted_string ( + const std::string& str + ) + { + std::string temp; + bool in_quotes = false; + for (unsigned long i = 0; i < str.size(); ++i) + { + if (str[i] == '"') + { + in_quotes = !in_quotes; + } + else if (in_quotes) + { + temp += str[i]; + } + } + + return temp; + } + +// ---------------------------------------------------------------------------------------- + + void parse_annotation_file( + const std::string& file, + dlib::image_dataset_metadata::image& img, + std::string& dataset_name + ) + { + ifstream fin(file.c_str()); + if (!fin) + throw dlib::error("Unable to open file " + file); + + img = dlib::image_dataset_metadata::image(); + + string str, line; + std::vector words; + while (fin.peek() != EOF) + { + getline(fin, line); + words = split(line, " \r\n\t:(,-)\""); + if (words.size() > 2) + { + if (words[0] == "#") + continue; + + if (words[0] == "Image" && words[1] == "filename") + { + img.filename = pick_out_quoted_string(line); + } + else if (words[0] == "Database") + { + dataset_name = pick_out_quoted_string(line); + } + else if (words[0] == "Objects" && words[1] == "with" && words.size() >= 5) + { + const int num = sa = words[4]; + img.boxes.resize(num); + } + else if (words.size() > 4 && (words[2] == "for" || words[2] == "on") && words[3] == "object") + { + long idx = sa = words[4]; + --idx; + if (idx >= (long)img.boxes.size()) + throw dlib::error("Invalid object id number of " + words[4]); + + if (words[0] == "Center" && words[1] == "point" && words.size() > 9) + { + const long x = sa = words[8]; + const long y = sa = words[9]; + img.boxes[idx].parts["head"] = point(x,y); + } + else if (words[0] == "Bounding" && words[1] == "box" && words.size() > 13) + { + rectangle rect; + img.boxes[idx].rect.left() = sa = words[10]; + img.boxes[idx].rect.top() = sa = words[11]; + img.boxes[idx].rect.right() = sa = words[12]; + img.boxes[idx].rect.bottom() = sa = words[13]; + } + else if (words[0] == "Original" && words[1] == "label" && words.size() > 6) + { + img.boxes[idx].label = words[6]; + } + } + } + + } + } + +// ---------------------------------------------------------------------------------------- + + std::string figure_out_full_path_to_image ( + const std::string& annotation_file, + const std::string& image_name + ) + { + directory parent = get_parent_directory(file(annotation_file)); + + + string temp; + while (true) + { + if (parent.is_root()) + temp = parent.full_name() + image_name; + else + temp = parent.full_name() + directory::get_separator() + image_name; + + if (file_exists(temp)) + return temp; + + if (parent.is_root()) + throw dlib::error("Can't figure out where the file " + image_name + " is located."); + parent = get_parent_directory(parent); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +void convert_pascal_v1( + const command_line_parser& parser +) +{ + cout << "Convert from PASCAL v1.00 annotation format..." << endl; + + dlib::image_dataset_metadata::dataset dataset; + + std::string name; + dlib::image_dataset_metadata::image img; + + const std::string filename = parser.option("c").argument(); + // make sure the file exists so we can use the get_parent_directory() command to + // figure out it's parent directory. + make_empty_file(filename); + const std::string parent_dir = get_parent_directory(file(filename)).full_name(); + + for (unsigned long i = 0; i < parser.number_of_arguments(); ++i) + { + try + { + parse_annotation_file(parser[i], img, name); + + dataset.name = name; + img.filename = strip_path(figure_out_full_path_to_image(parser[i], img.filename), parent_dir); + dataset.images.push_back(img); + + } + catch (exception& ) + { + cout << "Error while processing file " << parser[i] << endl << endl; + throw; + } + } + + save_image_dataset_metadata(dataset, filename); +} + + diff --git a/ml/dlib/tools/imglab/src/convert_pascal_v1.h b/ml/dlib/tools/imglab/src/convert_pascal_v1.h new file mode 100644 index 000000000..3553d03a7 --- /dev/null +++ b/ml/dlib/tools/imglab/src/convert_pascal_v1.h @@ -0,0 +1,13 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMGLAB_CONVERT_PASCAl_V1_H__ +#define DLIB_IMGLAB_CONVERT_PASCAl_V1_H__ + +#include "common.h" +#include + +void convert_pascal_v1(const dlib::command_line_parser& parser); + +#endif // DLIB_IMGLAB_CONVERT_PASCAl_V1_H__ + + diff --git a/ml/dlib/tools/imglab/src/convert_pascal_xml.cpp b/ml/dlib/tools/imglab/src/convert_pascal_xml.cpp new file mode 100644 index 000000000..c699d7777 --- /dev/null +++ b/ml/dlib/tools/imglab/src/convert_pascal_xml.cpp @@ -0,0 +1,239 @@ + +#include "convert_pascal_xml.h" +#include "dlib/data_io.h" +#include +#include +#include +#include +#include + +using namespace std; +using namespace dlib; + +namespace +{ + using namespace dlib::image_dataset_metadata; + +// ---------------------------------------------------------------------------------------- + + class doc_handler : public document_handler + { + image& temp_image; + std::string& dataset_name; + + std::vector ts; + box temp_box; + + public: + + doc_handler( + image& temp_image_, + std::string& dataset_name_ + ): + temp_image(temp_image_), + dataset_name(dataset_name_) + {} + + + virtual void start_document ( + ) + { + ts.clear(); + temp_image = image(); + temp_box = box(); + dataset_name.clear(); + } + + virtual void end_document ( + ) + { + } + + virtual void start_element ( + const unsigned long , + const std::string& name, + const dlib::attribute_list& + ) + { + if (ts.size() == 0 && name != "annotation") + { + std::ostringstream sout; + sout << "Invalid XML document. Root tag must be . Found <" << name << "> instead."; + throw dlib::error(sout.str()); + } + + + ts.push_back(name); + } + + virtual void end_element ( + const unsigned long , + const std::string& name + ) + { + ts.pop_back(); + if (ts.size() == 0) + return; + + if (name == "object" && ts.back() == "annotation") + { + temp_image.boxes.push_back(temp_box); + temp_box = box(); + } + } + + virtual void characters ( + const std::string& data + ) + { + if (ts.size() == 2 && ts[1] == "filename") + { + temp_image.filename = trim(data); + } + else if (ts.size() == 3 && ts[2] == "database" && ts[1] == "source") + { + dataset_name = trim(data); + } + else if (ts.size() >= 3) + { + if (ts[ts.size()-2] == "bndbox" && ts[ts.size()-3] == "object") + { + if (ts.back() == "xmin") temp_box.rect.left() = string_cast(data); + else if (ts.back() == "ymin") temp_box.rect.top() = string_cast(data); + else if (ts.back() == "xmax") temp_box.rect.right() = string_cast(data); + else if (ts.back() == "ymax") temp_box.rect.bottom() = string_cast(data); + } + else if (ts.back() == "name" && ts[ts.size()-2] == "object") + { + temp_box.label = trim(data); + } + else if (ts.back() == "difficult" && ts[ts.size()-2] == "object") + { + if (trim(data) == "0" || trim(data) == "false") + { + temp_box.difficult = false; + } + else + { + temp_box.difficult = true; + } + } + else if (ts.back() == "truncated" && ts[ts.size()-2] == "object") + { + if (trim(data) == "0" || trim(data) == "false") + { + temp_box.truncated = false; + } + else + { + temp_box.truncated = true; + } + } + else if (ts.back() == "occluded" && ts[ts.size()-2] == "object") + { + if (trim(data) == "0" || trim(data) == "false") + { + temp_box.occluded = false; + } + else + { + temp_box.occluded = true; + } + } + + } + } + + virtual void processing_instruction ( + const unsigned long , + const std::string& , + const std::string& + ) + { + } + }; + +// ---------------------------------------------------------------------------------------- + + class xml_error_handler : public error_handler + { + public: + virtual void error ( + const unsigned long + ) { } + + virtual void fatal_error ( + const unsigned long line_number + ) + { + std::ostringstream sout; + sout << "There is a fatal error on line " << line_number << " so parsing will now halt."; + throw dlib::error(sout.str()); + } + }; + +// ---------------------------------------------------------------------------------------- + + void parse_annotation_file( + const std::string& file, + dlib::image_dataset_metadata::image& img, + std::string& dataset_name + ) + { + doc_handler dh(img, dataset_name); + xml_error_handler eh; + + xml_parser::kernel_1a parser; + parser.add_document_handler(dh); + parser.add_error_handler(eh); + + ifstream fin(file.c_str()); + if (!fin) + throw dlib::error("Unable to open file " + file); + parser.parse(fin); + } + +// ---------------------------------------------------------------------------------------- + +} + +void convert_pascal_xml( + const command_line_parser& parser +) +{ + cout << "Convert from PASCAL XML annotation format..." << endl; + + dlib::image_dataset_metadata::dataset dataset; + + std::string name; + dlib::image_dataset_metadata::image img; + + const std::string filename = parser.option("c").argument(); + // make sure the file exists so we can use the get_parent_directory() command to + // figure out it's parent directory. + make_empty_file(filename); + const std::string parent_dir = get_parent_directory(file(filename)).full_name(); + + for (unsigned long i = 0; i < parser.number_of_arguments(); ++i) + { + try + { + parse_annotation_file(parser[i], img, name); + const string root = get_parent_directory(get_parent_directory(file(parser[i]))).full_name(); + const string img_path = root + directory::get_separator() + "JPEGImages" + directory::get_separator(); + + dataset.name = name; + img.filename = strip_path(img_path + img.filename, parent_dir); + dataset.images.push_back(img); + + } + catch (exception& ) + { + cout << "Error while processing file " << parser[i] << endl << endl; + throw; + } + } + + save_image_dataset_metadata(dataset, filename); +} + diff --git a/ml/dlib/tools/imglab/src/convert_pascal_xml.h b/ml/dlib/tools/imglab/src/convert_pascal_xml.h new file mode 100644 index 000000000..01ee1e82f --- /dev/null +++ b/ml/dlib/tools/imglab/src/convert_pascal_xml.h @@ -0,0 +1,12 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMGLAB_CONVERT_PASCAl_XML_H__ +#define DLIB_IMGLAB_CONVERT_PASCAl_XML_H__ + +#include "common.h" +#include + +void convert_pascal_xml(const dlib::command_line_parser& parser); + +#endif // DLIB_IMGLAB_CONVERT_PASCAl_XML_H__ + diff --git a/ml/dlib/tools/imglab/src/flip_dataset.cpp b/ml/dlib/tools/imglab/src/flip_dataset.cpp new file mode 100644 index 000000000..e072dc790 --- /dev/null +++ b/ml/dlib/tools/imglab/src/flip_dataset.cpp @@ -0,0 +1,249 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "flip_dataset.h" +#include +#include +#include +#include "common.h" +#include +#include +#include + +using namespace dlib; +using namespace std; + +// ---------------------------------------------------------------------------------------- + +std::vector align_points( + const std::vector& from, + const std::vector& to, + double min_angle = -90*pi/180.0, + double max_angle = 90*pi/180.0, + long num_angles = 181 +) +/*! + ensures + - Figures out how to align the points in from with the points in to. Returns an + assignment array A that indicates that from[i] matches with to[A[i]]. + + We use the Hungarian algorithm with a search over reasonable angles. This method + works because we just need to account for a translation and a mild rotation and + nothing else. If there is any other more complex mapping then you probably don't + have landmarks that make sense to flip. +!*/ +{ + DLIB_CASSERT(from.size() == to.size()); + + std::vector best_assignment; + double best_assignment_cost = std::numeric_limits::infinity(); + + matrix dists(from.size(), to.size()); + matrix idists; + + for (auto angle : linspace(min_angle, max_angle, num_angles)) + { + auto rot = rotation_matrix(angle); + for (long r = 0; r < dists.nr(); ++r) + { + for (long c = 0; c < dists.nc(); ++c) + { + dists(r,c) = length_squared(rot*from[r]-to[c]); + } + } + + idists = matrix_cast(-round(std::numeric_limits::max()*(dists/max(dists)))); + + auto assignment = max_cost_assignment(idists); + auto cost = assignment_cost(dists, assignment); + if (cost < best_assignment_cost) + { + best_assignment_cost = cost; + best_assignment = std::move(assignment); + } + } + + + // Now compute the alignment error in terms of average distance moved by each part. We + // do this so we can give the user a warning if it's impossible to make a good + // alignment. + running_stats rs; + std::vector tmp(to.size()); + for (size_t i = 0; i < to.size(); ++i) + tmp[best_assignment[i]] = to[i]; + auto tform = find_similarity_transform(from, tmp); + for (size_t i = 0; i < from.size(); ++i) + rs.add(length(tform(from[i])-tmp[i])); + if (rs.mean() > 0.05) + { + cout << "WARNING, your dataset has object part annotations and you asked imglab to " << endl; + cout << "flip the data. Imglab tried to adjust the part labels so that the average" << endl; + cout << "part layout in the flipped dataset is the same as the source dataset. " << endl; + cout << "However, the part annotation scheme doesn't seem to be left-right symmetric." << endl; + cout << "You should manually review the output to make sure the part annotations are " << endl; + cout << "labeled as you expect." << endl; + } + + + return best_assignment; +} + +// ---------------------------------------------------------------------------------------- + +std::map normalized_parts ( + const image_dataset_metadata::box& b +) +{ + auto tform = dlib::impl::normalizing_tform(b.rect); + std::map temp; + for (auto& p : b.parts) + temp[p.first] = tform(p.second); + return temp; +} + +// ---------------------------------------------------------------------------------------- + +std::map average_parts ( + const image_dataset_metadata::dataset& data +) +/*! + ensures + - returns the average part layout over all objects in data. This is done by + centering the parts inside their rects and then averaging all the objects. +!*/ +{ + std::map psum; + std::map pcnt; + for (auto& image : data.images) + { + for (auto& box : image.boxes) + { + for (auto& p : normalized_parts(box)) + { + psum[p.first] += p.second; + pcnt[p.first] += 1; + } + } + } + + // make into an average + for (auto& p : psum) + p.second /= pcnt[p.first]; + + return psum; +} + +// ---------------------------------------------------------------------------------------- + +void make_part_labeling_match_target_dataset ( + const image_dataset_metadata::dataset& target, + image_dataset_metadata::dataset& data +) +/*! + This function tries to adjust the part labels in data so that the average part layout + in data is the same as target, according to the string labels. Therefore, it doesn't + adjust part positions, instead it changes the string labels on the parts to achieve + this. This really only makes sense when you flipped a dataset that contains left-right + symmetric objects and you want to remap the part labels of the flipped data so that + they match the unflipped data's annotation scheme. +!*/ +{ + auto target_parts = average_parts(target); + auto data_parts = average_parts(data); + + // Convert to a form align_points() understands. We also need to keep track of the + // labels for later. + std::vector from, to; + std::vector from_labels, to_labels; + for (auto& p : target_parts) + { + from_labels.emplace_back(p.first); + from.emplace_back(p.second); + } + for (auto& p : data_parts) + { + to_labels.emplace_back(p.first); + to.emplace_back(p.second); + } + + auto assignment = align_points(from, to); + // so now we know that from_labels[i] should replace to_labels[assignment[i]] + std::map label_mapping; + for (size_t i = 0; i < assignment.size(); ++i) + label_mapping[to_labels[assignment[i]]] = from_labels[i]; + + // now apply the label mapping to the dataset + for (auto& image : data.images) + { + for (auto& box : image.boxes) + { + std::map temp; + for (auto& p : box.parts) + temp[label_mapping[p.first]] = p.second; + box.parts = std::move(temp); + } + } +} + +// ---------------------------------------------------------------------------------------- + +void flip_dataset(const command_line_parser& parser) +{ + image_dataset_metadata::dataset metadata, orig_metadata; + string datasource; + if (parser.option("flip")) + datasource = parser.option("flip").argument(); + else + datasource = parser.option("flip-basic").argument(); + load_image_dataset_metadata(metadata,datasource); + orig_metadata = metadata; + + // Set the current directory to be the one that contains the + // metadata file. We do this because the file might contain + // file paths which are relative to this folder. + set_current_dir(get_parent_directory(file(datasource))); + + const string metadata_filename = get_parent_directory(file(datasource)).full_name() + + directory::get_separator() + "flipped_" + file(datasource).name(); + + + array2d img, temp; + for (unsigned long i = 0; i < metadata.images.size(); ++i) + { + file f(metadata.images[i].filename); + string filename = get_parent_directory(f).full_name() + directory::get_separator() + "flipped_" + to_png_name(f.name()); + + load_image(img, metadata.images[i].filename); + flip_image_left_right(img, temp); + if (parser.option("jpg")) + { + filename = to_jpg_name(filename); + save_jpeg(temp, filename,JPEG_QUALITY); + } + else + { + save_png(temp, filename); + } + + for (unsigned long j = 0; j < metadata.images[i].boxes.size(); ++j) + { + metadata.images[i].boxes[j].rect = impl::flip_rect_left_right(metadata.images[i].boxes[j].rect, get_rect(img)); + + // flip all the object parts + for (auto& part : metadata.images[i].boxes[j].parts) + { + part.second = impl::flip_rect_left_right(rectangle(part.second,part.second), get_rect(img)).tl_corner(); + } + } + + metadata.images[i].filename = filename; + } + + if (!parser.option("flip-basic")) + make_part_labeling_match_target_dataset(orig_metadata, metadata); + + save_image_dataset_metadata(metadata, metadata_filename); +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/imglab/src/flip_dataset.h b/ml/dlib/tools/imglab/src/flip_dataset.h new file mode 100644 index 000000000..8ac5db3e8 --- /dev/null +++ b/ml/dlib/tools/imglab/src/flip_dataset.h @@ -0,0 +1,12 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMGLAB_FLIP_DaTASET_H__ +#define DLIB_IMGLAB_FLIP_DaTASET_H__ + + +#include + +void flip_dataset(const dlib::command_line_parser& parser); + +#endif // DLIB_IMGLAB_FLIP_DaTASET_H__ + diff --git a/ml/dlib/tools/imglab/src/main.cpp b/ml/dlib/tools/imglab/src/main.cpp new file mode 100644 index 000000000..060c2c870 --- /dev/null +++ b/ml/dlib/tools/imglab/src/main.cpp @@ -0,0 +1,1145 @@ + +#include "dlib/data_io.h" +#include "dlib/string.h" +#include "metadata_editor.h" +#include "convert_pascal_xml.h" +#include "convert_pascal_v1.h" +#include "convert_idl.h" +#include "cluster.h" +#include "flip_dataset.h" +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + + +const char* VERSION = "1.13"; + + + +using namespace std; +using namespace dlib; + +// ---------------------------------------------------------------------------------------- + +void create_new_dataset ( + const command_line_parser& parser +) +{ + using namespace dlib::image_dataset_metadata; + + const std::string filename = parser.option("c").argument(); + // make sure the file exists so we can use the get_parent_directory() command to + // figure out it's parent directory. + make_empty_file(filename); + const std::string parent_dir = get_parent_directory(file(filename)); + + unsigned long depth = 0; + if (parser.option("r")) + depth = 30; + + dataset meta; + meta.name = "imglab dataset"; + meta.comment = "Created by imglab tool."; + for (unsigned long i = 0; i < parser.number_of_arguments(); ++i) + { + try + { + const string temp = strip_path(file(parser[i]), parent_dir); + meta.images.push_back(image(temp)); + } + catch (dlib::file::file_not_found&) + { + // then parser[i] should be a directory + + std::vector files = get_files_in_directory_tree(parser[i], + match_endings(".png .PNG .jpeg .JPEG .jpg .JPG .bmp .BMP .dng .DNG .gif .GIF"), + depth); + sort(files.begin(), files.end()); + + for (unsigned long j = 0; j < files.size(); ++j) + { + meta.images.push_back(image(strip_path(files[j], parent_dir))); + } + } + } + + save_image_dataset_metadata(meta, filename); +} + +// ---------------------------------------------------------------------------------------- + +int split_dataset ( + const command_line_parser& parser +) +{ + if (parser.number_of_arguments() != 1) + { + cerr << "The --split option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + const std::string label = parser.option("split").argument(); + + dlib::image_dataset_metadata::dataset data, data_with, data_without; + load_image_dataset_metadata(data, parser[0]); + + data_with.name = data.name; + data_with.comment = data.comment; + data_without.name = data.name; + data_without.comment = data.comment; + + for (unsigned long i = 0; i < data.images.size(); ++i) + { + auto&& temp = data.images[i]; + + bool has_the_label = false; + // check for the label we are looking for + for (unsigned long j = 0; j < temp.boxes.size(); ++j) + { + if (temp.boxes[j].label == label) + { + has_the_label = true; + break; + } + } + + if (has_the_label) + data_with.images.push_back(temp); + else + data_without.images.push_back(temp); + } + + + save_image_dataset_metadata(data_with, left_substr(parser[0],".") + "_with_"+label + ".xml"); + save_image_dataset_metadata(data_without, left_substr(parser[0],".") + "_without_"+label + ".xml"); + + return EXIT_SUCCESS; +} + +// ---------------------------------------------------------------------------------------- + +void print_all_labels ( + const dlib::image_dataset_metadata::dataset& data +) +{ + std::set labels; + for (unsigned long i = 0; i < data.images.size(); ++i) + { + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + labels.insert(data.images[i].boxes[j].label); + } + } + + for (std::set::iterator i = labels.begin(); i != labels.end(); ++i) + { + if (i->size() != 0) + { + cout << *i << endl; + } + } +} + +// ---------------------------------------------------------------------------------------- + +void print_all_label_stats ( + const dlib::image_dataset_metadata::dataset& data +) +{ + std::map > area_stats, aspect_ratio; + std::map image_hits; + std::set labels; + unsigned long num_unignored_boxes = 0; + for (unsigned long i = 0; i < data.images.size(); ++i) + { + std::set temp; + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + labels.insert(data.images[i].boxes[j].label); + temp.insert(data.images[i].boxes[j].label); + + area_stats[data.images[i].boxes[j].label].add(data.images[i].boxes[j].rect.area()); + aspect_ratio[data.images[i].boxes[j].label].add(data.images[i].boxes[j].rect.width()/ + (double)data.images[i].boxes[j].rect.height()); + + if (!data.images[i].boxes[j].ignore) + ++num_unignored_boxes; + } + + // count the number of images for each label + for (std::set::iterator i = temp.begin(); i != temp.end(); ++i) + image_hits[*i] += 1; + } + + cout << "Number of images: "<< data.images.size() << endl; + cout << "Number of different labels: "<< labels.size() << endl; + cout << "Number of non-ignored boxes: " << num_unignored_boxes << endl << endl; + + for (std::set::iterator i = labels.begin(); i != labels.end(); ++i) + { + if (i->size() == 0) + cout << "Unlabeled Boxes:" << endl; + else + cout << "Label: "<< *i << endl; + cout << " number of images: " << image_hits[*i] << endl; + cout << " number of occurrences: " << area_stats[*i].current_n() << endl; + cout << " min box area: " << area_stats[*i].min() << endl; + cout << " max box area: " << area_stats[*i].max() << endl; + cout << " mean box area: " << area_stats[*i].mean() << endl; + cout << " stddev box area: " << area_stats[*i].stddev() << endl; + cout << " mean width/height ratio: " << aspect_ratio[*i].mean() << endl; + cout << " stddev width/height ratio: " << aspect_ratio[*i].stddev() << endl; + cout << endl; + } +} + +// ---------------------------------------------------------------------------------------- + +void rename_labels ( + dlib::image_dataset_metadata::dataset& data, + const std::string& from, + const std::string& to +) +{ + for (unsigned long i = 0; i < data.images.size(); ++i) + { + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (data.images[i].boxes[j].label == from) + data.images[i].boxes[j].label = to; + } + } + +} + +// ---------------------------------------------------------------------------------------- + +void ignore_labels ( + dlib::image_dataset_metadata::dataset& data, + const std::string& label +) +{ + for (unsigned long i = 0; i < data.images.size(); ++i) + { + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (data.images[i].boxes[j].label == label) + data.images[i].boxes[j].ignore = true; + } + } +} + +// ---------------------------------------------------------------------------------------- + +void merge_metadata_files ( + const command_line_parser& parser +) +{ + image_dataset_metadata::dataset src, dest; + load_image_dataset_metadata(src, parser.option("add").argument(0)); + load_image_dataset_metadata(dest, parser.option("add").argument(1)); + + std::map merged_data; + for (unsigned long i = 0; i < dest.images.size(); ++i) + merged_data[dest.images[i].filename] = dest.images[i]; + // now add in the src data and overwrite anything if there are duplicate entries. + for (unsigned long i = 0; i < src.images.size(); ++i) + merged_data[src.images[i].filename] = src.images[i]; + + // copy merged data into dest + dest.images.clear(); + for (std::map::const_iterator i = merged_data.begin(); + i != merged_data.end(); ++i) + { + dest.images.push_back(i->second); + } + + save_image_dataset_metadata(dest, "merged.xml"); +} + +// ---------------------------------------------------------------------------------------- + +void rotate_dataset(const command_line_parser& parser) +{ + image_dataset_metadata::dataset metadata; + const string datasource = parser[0]; + load_image_dataset_metadata(metadata,datasource); + + double angle = get_option(parser, "rotate", 0); + + // Set the current directory to be the one that contains the + // metadata file. We do this because the file might contain + // file paths which are relative to this folder. + set_current_dir(get_parent_directory(file(datasource))); + + const string file_prefix = "rotated_"+ cast_to_string(angle) + "_"; + const string metadata_filename = get_parent_directory(file(datasource)).full_name() + + directory::get_separator() + file_prefix + file(datasource).name(); + + + array2d img, temp; + for (unsigned long i = 0; i < metadata.images.size(); ++i) + { + file f(metadata.images[i].filename); + string filename = get_parent_directory(f).full_name() + directory::get_separator() + file_prefix + to_png_name(f.name()); + + load_image(img, metadata.images[i].filename); + const point_transform_affine tran = rotate_image(img, temp, angle*pi/180); + if (parser.option("jpg")) + { + filename = to_jpg_name(filename); + save_jpeg(temp, filename,JPEG_QUALITY); + } + else + { + save_png(temp, filename); + } + + for (unsigned long j = 0; j < metadata.images[i].boxes.size(); ++j) + { + const rectangle rect = metadata.images[i].boxes[j].rect; + rectangle newrect; + newrect += tran(rect.tl_corner()); + newrect += tran(rect.tr_corner()); + newrect += tran(rect.bl_corner()); + newrect += tran(rect.br_corner()); + // now make newrect have the same area as the starting rect. + double ratio = std::sqrt(rect.area()/(double)newrect.area()); + newrect = centered_rect(newrect, newrect.width()*ratio, newrect.height()*ratio); + metadata.images[i].boxes[j].rect = newrect; + + // rotate all the object parts + std::map::iterator k; + for (k = metadata.images[i].boxes[j].parts.begin(); k != metadata.images[i].boxes[j].parts.end(); ++k) + { + k->second = tran(k->second); + } + } + + metadata.images[i].filename = filename; + } + + save_image_dataset_metadata(metadata, metadata_filename); +} + +// ---------------------------------------------------------------------------------------- + +int resample_dataset(const command_line_parser& parser) +{ + if (parser.number_of_arguments() != 1) + { + cerr << "The --resample option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + const size_t obj_size = get_option(parser,"cropped-object-size",100*100); + const double margin_scale = get_option(parser,"crop-size",2.5); // cropped image will be this times wider than the object. + const unsigned long min_object_size = get_option(parser,"min-object-size",1); + const bool one_object_per_image = parser.option("one-object-per-image"); + + dlib::image_dataset_metadata::dataset data, resampled_data; + std::ostringstream sout; + sout << "\nThe --resample parameters which generated this dataset were:" << endl; + sout << " cropped-object-size: "<< obj_size << endl; + sout << " crop-size: "<< margin_scale << endl; + sout << " min-object-size: "<< min_object_size << endl; + if (one_object_per_image) + sout << " one_object_per_image: true" << endl; + resampled_data.comment = data.comment + sout.str(); + resampled_data.name = data.name + " RESAMPLED"; + + load_image_dataset_metadata(data, parser[0]); + locally_change_current_dir chdir(get_parent_directory(file(parser[0]))); + dlib::rand rnd; + + const size_t image_size = std::round(std::sqrt(obj_size*margin_scale*margin_scale)); + const chip_dims cdims(image_size, image_size); + + console_progress_indicator pbar(data.images.size()); + for (unsigned long i = 0; i < data.images.size(); ++i) + { + // don't even bother loading images that don't have objects. + if (data.images[i].boxes.size() == 0) + continue; + + pbar.print_status(i); + array2d img, chip; + load_image(img, data.images[i].filename); + + + // figure out what chips we want to take from this image + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + const rectangle rect = data.images[i].boxes[j].rect; + if (data.images[i].boxes[j].ignore || rect.area() < min_object_size) + continue; + + const auto max_dim = std::max(rect.width(), rect.height()); + + const double rand_scale_perturb = 1 - 0.3*(rnd.get_random_double()-0.5); + const rectangle crop_rect = centered_rect(rect, max_dim*margin_scale*rand_scale_perturb, max_dim*margin_scale*rand_scale_perturb); + + const rectangle_transform tform = get_mapping_to_chip(chip_details(crop_rect, cdims)); + extract_image_chip(img, chip_details(crop_rect, cdims), chip); + + image_dataset_metadata::image dimg; + // Now transform the boxes to the crop and also mark them as ignored if they + // have already been cropped out or are outside the crop. + for (size_t k = 0; k < data.images[i].boxes.size(); ++k) + { + image_dataset_metadata::box box = data.images[i].boxes[k]; + // ignore boxes outside the cropped image + if (crop_rect.intersect(box.rect).area() == 0) + continue; + + // mark boxes we include in the crop as ignored. Also mark boxes that + // aren't totally within the crop as ignored. + if (crop_rect.contains(grow_rect(box.rect,10)) && (!one_object_per_image || k==j)) + data.images[i].boxes[k].ignore = true; + else + box.ignore = true; + + if (box.rect.area() < min_object_size) + box.ignore = true; + + box.rect = tform(box.rect); + for (auto&& p : box.parts) + p.second = tform.get_tform()(p.second); + dimg.boxes.push_back(box); + } + // Put a 64bit hash of the image data into the name to make sure there are no + // file name conflicts. + std::ostringstream sout; + sout << hex << murmur_hash3_128bit(&chip[0][0], chip.size()*sizeof(chip[0][0])).second; + dimg.filename = data.images[i].filename + "_RESAMPLED_"+sout.str()+".png"; + + if (parser.option("jpg")) + { + dimg.filename = to_jpg_name(dimg.filename); + save_jpeg(chip,dimg.filename, JPEG_QUALITY); + } + else + { + save_png(chip,dimg.filename); + } + resampled_data.images.push_back(dimg); + } + } + + save_image_dataset_metadata(resampled_data, parser[0] + ".RESAMPLED.xml"); + + return EXIT_SUCCESS; +} + +// ---------------------------------------------------------------------------------------- + +int tile_dataset(const command_line_parser& parser) +{ + if (parser.number_of_arguments() != 1) + { + cerr << "The --tile option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + string out_image = parser.option("tile").argument(); + string ext = right_substr(out_image,"."); + if (ext != "png" && ext != "jpg") + { + cerr << "The output image file must have either .png or .jpg extension." << endl; + return EXIT_FAILURE; + } + + const unsigned long chip_size = get_option(parser, "size", 8000); + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + locally_change_current_dir chdir(get_parent_directory(file(parser[0]))); + dlib::array > images; + console_progress_indicator pbar(data.images.size()); + for (unsigned long i = 0; i < data.images.size(); ++i) + { + // don't even bother loading images that don't have objects. + if (data.images[i].boxes.size() == 0) + continue; + + pbar.print_status(i); + array2d img; + load_image(img, data.images[i].filename); + + // figure out what chips we want to take from this image + std::vector dets; + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (data.images[i].boxes[j].ignore) + continue; + + rectangle rect = data.images[i].boxes[j].rect; + dets.push_back(chip_details(rect, chip_size)); + } + // Now grab all those chips at once. + dlib::array > chips; + extract_image_chips(img, dets, chips); + // and put the chips into the output. + for (unsigned long j = 0; j < chips.size(); ++j) + images.push_back(chips[j]); + } + + chdir.revert(); + + if (ext == "png") + save_png(tile_images(images), out_image); + else + save_jpeg(tile_images(images), out_image); + + return EXIT_SUCCESS; +} + + +// ---------------------------------------------------------------------------------------- + +int main(int argc, char** argv) +{ + try + { + + command_line_parser parser; + + parser.add_option("h","Displays this information."); + parser.add_option("v","Display version."); + + parser.set_group_name("Creating XML files"); + parser.add_option("c","Create an XML file named listing a set of images.",1); + parser.add_option("r","Search directories recursively for images."); + parser.add_option("convert","Convert foreign image Annotations from format to the imglab format. " + "Supported formats: pascal-xml, pascal-v1, idl.",1); + + parser.set_group_name("Viewing XML files"); + parser.add_option("tile","Chip out all the objects and save them as one big image called .",1); + parser.add_option("size","When using --tile or --cluster, make each extracted object contain " + "about pixels (default 8000).",1); + parser.add_option("l","List all the labels in the given XML file."); + parser.add_option("stats","List detailed statistics on the object labels in the given XML file."); + parser.add_option("files","List all the files in the given XML file."); + + parser.set_group_name("Editing/Transforming XML datasets"); + parser.add_option("rename", "Rename all labels of to .",2); + parser.add_option("parts","The display will allow image parts to be labeled. The set of allowable parts " + "is defined by which should be a space separated list of parts.",1); + parser.add_option("rmempty","Remove all images that don't contain non-ignored annotations and save the results to a new XML file."); + parser.add_option("rmdupes","Remove duplicate images from the dataset. This is done by comparing " + "the md5 hash of each image file and removing duplicate images. " ); + parser.add_option("rmdiff","Set the ignored flag to true for boxes marked as difficult."); + parser.add_option("rmtrunc","Set the ignored flag to true for boxes that are partially outside the image."); + parser.add_option("sort-num-objects","Sort the images listed an XML file so images with many objects are listed first."); + parser.add_option("sort","Alphabetically sort the images in an XML file."); + parser.add_option("shuffle","Randomly shuffle the order of the images listed in an XML file."); + parser.add_option("seed", "When using --shuffle, set the random seed to the string .",1); + parser.add_option("split", "Split the contents of an XML file into two separate files. One containing the " + "images with objects labeled and another file with all the other images. ",1); + parser.add_option("add", "Add the image metadata from into . If any of the image " + "tags are in both files then the ones in are deleted and replaced with the " + "image tags from . The results are saved into merged.xml and neither or " + " files are modified.",2); + parser.add_option("flip", "Read an XML image dataset from the XML file and output a left-right flipped " + "version of the dataset and an accompanying flipped XML file named flipped_. " + "We also adjust object part labels after flipping so that the new flipped dataset " + "has the same average part layout as the source dataset." ,1); + parser.add_option("flip-basic", "This option is just like --flip, except we don't adjust any object part labels after flipping. " + "The parts are instead simply mirrored to the flipped dataset.", 1); + parser.add_option("rotate", "Read an XML image dataset and output a copy that is rotated counter clockwise by degrees. " + "The output is saved to an XML file prefixed with rotated_.",1); + parser.add_option("cluster", "Cluster all the objects in an XML file into different clusters and save " + "the results as cluster_###.xml and cluster_###.jpg files.",1); + parser.add_option("ignore", "Mark boxes labeled as as ignored. The resulting XML file is output as a separate file and the original is not modified.",1); + parser.add_option("rmlabel","Remove all boxes labeled and save the results to a new XML file.",1); + parser.add_option("rm-other-labels","Remove all boxes not labeled and save the results to a new XML file.",1); + parser.add_option("rmignore","Remove all boxes marked ignore and save the results to a new XML file."); + parser.add_option("rm-if-overlaps","Remove all boxes labeled if they overlap any box not labeled and save the results to a new XML file.",1); + parser.add_option("jpg", "When saving images to disk, write them as jpg files instead of png."); + + parser.set_group_name("Cropping sub images"); + parser.add_option("resample", "Crop out images that are centered on each object in the dataset. " + "The output is a new XML dataset."); + parser.add_option("cropped-object-size", "When doing --resample, make the cropped objects contain about pixels (default 10000).",1); + parser.add_option("min-object-size", "When doing --resample, skip objects that have fewer than pixels in them (default 1).",1); + parser.add_option("crop-size", "When doing --resample, the entire cropped image will be times wider than the object (default 2.5).",1); + parser.add_option("one-object-per-image", "When doing --resample, only include one non-ignored object per image (i.e. the central object)."); + + + + parser.parse(argc, argv); + + const char* singles[] = {"h","c","r","l","files","convert","parts","rmdiff", "rmtrunc", "rmdupes", "seed", "shuffle", "split", "add", + "flip-basic", "flip", "rotate", "tile", "size", "cluster", "resample", "min-object-size", "rmempty", + "crop-size", "cropped-object-size", "rmlabel", "rm-other-labels", "rm-if-overlaps", "sort-num-objects", + "one-object-per-image", "jpg", "rmignore", "sort"}; + parser.check_one_time_options(singles); + const char* c_sub_ops[] = {"r", "convert"}; + parser.check_sub_options("c", c_sub_ops); + parser.check_sub_option("shuffle", "seed"); + const char* resample_sub_ops[] = {"min-object-size", "crop-size", "cropped-object-size", "one-object-per-image"}; + parser.check_sub_options("resample", resample_sub_ops); + const char* size_parent_ops[] = {"tile", "cluster"}; + parser.check_sub_options(size_parent_ops, "size"); + parser.check_incompatible_options("c", "l"); + parser.check_incompatible_options("c", "files"); + parser.check_incompatible_options("c", "rmdiff"); + parser.check_incompatible_options("c", "rmempty"); + parser.check_incompatible_options("c", "rmlabel"); + parser.check_incompatible_options("c", "rm-other-labels"); + parser.check_incompatible_options("c", "rmignore"); + parser.check_incompatible_options("c", "rm-if-overlaps"); + parser.check_incompatible_options("c", "rmdupes"); + parser.check_incompatible_options("c", "rmtrunc"); + parser.check_incompatible_options("c", "add"); + parser.check_incompatible_options("c", "flip"); + parser.check_incompatible_options("c", "flip-basic"); + parser.check_incompatible_options("flip", "flip-basic"); + parser.check_incompatible_options("c", "rotate"); + parser.check_incompatible_options("c", "rename"); + parser.check_incompatible_options("c", "ignore"); + parser.check_incompatible_options("c", "parts"); + parser.check_incompatible_options("c", "tile"); + parser.check_incompatible_options("c", "cluster"); + parser.check_incompatible_options("c", "resample"); + parser.check_incompatible_options("l", "rename"); + parser.check_incompatible_options("l", "ignore"); + parser.check_incompatible_options("l", "add"); + parser.check_incompatible_options("l", "parts"); + parser.check_incompatible_options("l", "flip"); + parser.check_incompatible_options("l", "flip-basic"); + parser.check_incompatible_options("l", "rotate"); + parser.check_incompatible_options("files", "rename"); + parser.check_incompatible_options("files", "ignore"); + parser.check_incompatible_options("files", "add"); + parser.check_incompatible_options("files", "parts"); + parser.check_incompatible_options("files", "flip"); + parser.check_incompatible_options("files", "flip-basic"); + parser.check_incompatible_options("files", "rotate"); + parser.check_incompatible_options("add", "flip"); + parser.check_incompatible_options("add", "flip-basic"); + parser.check_incompatible_options("add", "rotate"); + parser.check_incompatible_options("add", "tile"); + parser.check_incompatible_options("flip", "tile"); + parser.check_incompatible_options("flip-basic", "tile"); + parser.check_incompatible_options("rotate", "tile"); + parser.check_incompatible_options("cluster", "tile"); + parser.check_incompatible_options("resample", "tile"); + parser.check_incompatible_options("flip", "cluster"); + parser.check_incompatible_options("flip-basic", "cluster"); + parser.check_incompatible_options("rotate", "cluster"); + parser.check_incompatible_options("add", "cluster"); + parser.check_incompatible_options("flip", "resample"); + parser.check_incompatible_options("flip-basic", "resample"); + parser.check_incompatible_options("rotate", "resample"); + parser.check_incompatible_options("add", "resample"); + parser.check_incompatible_options("shuffle", "tile"); + parser.check_incompatible_options("sort-num-objects", "tile"); + parser.check_incompatible_options("sort", "tile"); + parser.check_incompatible_options("convert", "l"); + parser.check_incompatible_options("convert", "files"); + parser.check_incompatible_options("convert", "rename"); + parser.check_incompatible_options("convert", "ignore"); + parser.check_incompatible_options("convert", "parts"); + parser.check_incompatible_options("convert", "cluster"); + parser.check_incompatible_options("convert", "resample"); + parser.check_incompatible_options("rmdiff", "rename"); + parser.check_incompatible_options("rmdiff", "ignore"); + parser.check_incompatible_options("rmempty", "ignore"); + parser.check_incompatible_options("rmempty", "rename"); + parser.check_incompatible_options("rmlabel", "ignore"); + parser.check_incompatible_options("rmlabel", "rename"); + parser.check_incompatible_options("rm-other-labels", "ignore"); + parser.check_incompatible_options("rm-other-labels", "rename"); + parser.check_incompatible_options("rmignore", "ignore"); + parser.check_incompatible_options("rmignore", "rename"); + parser.check_incompatible_options("rm-if-overlaps", "ignore"); + parser.check_incompatible_options("rm-if-overlaps", "rename"); + parser.check_incompatible_options("rmdupes", "rename"); + parser.check_incompatible_options("rmdupes", "ignore"); + parser.check_incompatible_options("rmtrunc", "rename"); + parser.check_incompatible_options("rmtrunc", "ignore"); + const char* convert_args[] = {"pascal-xml","pascal-v1","idl"}; + parser.check_option_arg_range("convert", convert_args); + parser.check_option_arg_range("cluster", 2, 999); + parser.check_option_arg_range("rotate", -360, 360); + parser.check_option_arg_range("size", 10*10, 1000*1000); + parser.check_option_arg_range("min-object-size", 1, 10000*10000); + parser.check_option_arg_range("cropped-object-size", 4, 10000*10000); + parser.check_option_arg_range("crop-size", 1.0, 100.0); + + if (parser.option("h")) + { + cout << "Usage: imglab [options] \n"; + parser.print_options(cout); + cout << endl << endl; + return EXIT_SUCCESS; + } + + if (parser.option("add")) + { + merge_metadata_files(parser); + return EXIT_SUCCESS; + } + + if (parser.option("flip") || parser.option("flip-basic")) + { + flip_dataset(parser); + return EXIT_SUCCESS; + } + + if (parser.option("rotate")) + { + rotate_dataset(parser); + return EXIT_SUCCESS; + } + + if (parser.option("v")) + { + cout << "imglab v" << VERSION + << "\nCompiled: " << __TIME__ << " " << __DATE__ + << "\nWritten by Davis King\n"; + cout << "Check for updates at http://dlib.net\n\n"; + return EXIT_SUCCESS; + } + + if (parser.option("tile")) + { + return tile_dataset(parser); + } + + if (parser.option("cluster")) + { + return cluster_dataset(parser); + } + + if (parser.option("resample")) + { + return resample_dataset(parser); + } + + if (parser.option("c")) + { + if (parser.option("convert")) + { + if (parser.option("convert").argument() == "pascal-xml") + convert_pascal_xml(parser); + else if (parser.option("convert").argument() == "pascal-v1") + convert_pascal_v1(parser); + else if (parser.option("convert").argument() == "idl") + convert_idl(parser); + } + else + { + create_new_dataset(parser); + } + return EXIT_SUCCESS; + } + + if (parser.option("rmdiff")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rmdiff option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + for (unsigned long i = 0; i < data.images.size(); ++i) + { + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (data.images[i].boxes[j].difficult) + data.images[i].boxes[j].ignore = true; + } + } + save_image_dataset_metadata(data, parser[0]); + return EXIT_SUCCESS; + } + + if (parser.option("rmempty")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rmempty option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data, data2; + load_image_dataset_metadata(data, parser[0]); + + data2 = data; + data2.images.clear(); + for (unsigned long i = 0; i < data.images.size(); ++i) + { + bool has_label = false; + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (!data.images[i].boxes[j].ignore) + has_label = true; + } + if (has_label) + data2.images.push_back(data.images[i]); + } + save_image_dataset_metadata(data2, parser[0] + ".rmempty.xml"); + return EXIT_SUCCESS; + } + + if (parser.option("rmlabel")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rmlabel option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + + const auto label = parser.option("rmlabel").argument(); + + for (auto&& img : data.images) + { + std::vector boxes; + for (auto&& b : img.boxes) + { + if (b.label != label) + boxes.push_back(b); + } + img.boxes = boxes; + } + + save_image_dataset_metadata(data, parser[0] + ".rmlabel-"+label+".xml"); + return EXIT_SUCCESS; + } + + if (parser.option("rm-other-labels")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rm-other-labels option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + + const auto labels = parser.option("rm-other-labels").argument(); + // replace comma by dash to form the file name + std::string strlabels = labels; + std::replace(strlabels.begin(), strlabels.end(), ',', '-'); + std::vector all_labels = split(labels, ","); + for (auto&& img : data.images) + { + std::vector boxes; + for (auto&& b : img.boxes) + { + if (std::find(all_labels.begin(), all_labels.end(), b.label) != all_labels.end()) + boxes.push_back(b); + } + img.boxes = boxes; + } + + save_image_dataset_metadata(data, parser[0] + ".rm-other-labels-"+ strlabels +".xml"); + return EXIT_SUCCESS; + } + + if (parser.option("rmignore")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rmignore option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + + for (auto&& img : data.images) + { + std::vector boxes; + for (auto&& b : img.boxes) + { + if (!b.ignore) + boxes.push_back(b); + } + img.boxes = boxes; + } + + save_image_dataset_metadata(data, parser[0] + ".rmignore.xml"); + return EXIT_SUCCESS; + } + + if (parser.option("rm-if-overlaps")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rm-if-overlaps option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + + const auto label = parser.option("rm-if-overlaps").argument(); + + test_box_overlap overlaps(0.5); + + for (auto&& img : data.images) + { + std::vector boxes; + for (auto&& b : img.boxes) + { + if (b.label != label) + { + boxes.push_back(b); + } + else + { + bool has_overlap = false; + for (auto&& b2 : img.boxes) + { + if (b2.label != label && overlaps(b2.rect, b.rect)) + { + has_overlap = true; + break; + } + } + if (!has_overlap) + boxes.push_back(b); + } + } + img.boxes = boxes; + } + + save_image_dataset_metadata(data, parser[0] + ".rm-if-overlaps-"+label+".xml"); + return EXIT_SUCCESS; + } + + if (parser.option("rmdupes")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rmdupes option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data, data_out; + std::set hashes; + load_image_dataset_metadata(data, parser[0]); + data_out = data; + data_out.images.clear(); + + for (unsigned long i = 0; i < data.images.size(); ++i) + { + ifstream fin(data.images[i].filename.c_str(), ios::binary); + string hash = md5(fin); + if (hashes.count(hash) == 0) + { + hashes.insert(hash); + data_out.images.push_back(data.images[i]); + } + } + save_image_dataset_metadata(data_out, parser[0]); + return EXIT_SUCCESS; + } + + if (parser.option("rmtrunc")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rmtrunc option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + { + locally_change_current_dir chdir(get_parent_directory(file(parser[0]))); + for (unsigned long i = 0; i < data.images.size(); ++i) + { + array2d img; + load_image(img, data.images[i].filename); + const rectangle area = get_rect(img); + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (!area.contains(data.images[i].boxes[j].rect)) + data.images[i].boxes[j].ignore = true; + } + } + } + save_image_dataset_metadata(data, parser[0]); + return EXIT_SUCCESS; + } + + if (parser.option("l")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The -l option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + print_all_labels(data); + return EXIT_SUCCESS; + } + + if (parser.option("files")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --files option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + for (size_t i = 0; i < data.images.size(); ++i) + cout << data.images[i].filename << "\n"; + return EXIT_SUCCESS; + } + + if (parser.option("split")) + { + return split_dataset(parser); + } + + if (parser.option("shuffle")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --shuffle option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + const string default_seed = cast_to_string(time(0)); + const string seed = get_option(parser, "seed", default_seed); + dlib::rand rnd(seed); + randomize_samples(data.images, rnd); + save_image_dataset_metadata(data, parser[0]); + return EXIT_SUCCESS; + } + + if (parser.option("sort-num-objects")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --sort-num-objects option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + std::sort(data.images.rbegin(), data.images.rend(), + [](const image_dataset_metadata::image& a, const image_dataset_metadata::image& b) { return a.boxes.size() < b.boxes.size(); }); + save_image_dataset_metadata(data, parser[0]); + return EXIT_SUCCESS; + } + + if (parser.option("sort")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --sort option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + std::sort(data.images.begin(), data.images.end(), + [](const image_dataset_metadata::image& a, const image_dataset_metadata::image& b) { return a.filename < b.filename; }); + save_image_dataset_metadata(data, parser[0]); + return EXIT_SUCCESS; + } + + if (parser.option("stats")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --stats option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + print_all_label_stats(data); + return EXIT_SUCCESS; + } + + if (parser.option("rename")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rename option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + for (unsigned long i = 0; i < parser.option("rename").count(); ++i) + { + rename_labels(data, parser.option("rename").argument(0,i), parser.option("rename").argument(1,i)); + } + save_image_dataset_metadata(data, parser[0]); + return EXIT_SUCCESS; + } + + if (parser.option("ignore")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --ignore option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + for (unsigned long i = 0; i < parser.option("ignore").count(); ++i) + { + ignore_labels(data, parser.option("ignore").argument()); + } + save_image_dataset_metadata(data, parser[0]+".ignored.xml"); + return EXIT_SUCCESS; + } + + if (parser.number_of_arguments() == 1) + { + metadata_editor editor(parser[0]); + if (parser.option("parts")) + { + std::vector parts = split(parser.option("parts").argument()); + for (unsigned long i = 0; i < parts.size(); ++i) + { + editor.add_labelable_part_name(parts[i]); + } + } + editor.wait_until_closed(); + return EXIT_SUCCESS; + } + + cout << "Invalid command, give -h to see options." << endl; + return EXIT_FAILURE; + } + catch (exception& e) + { + cerr << e.what() << endl; + return EXIT_FAILURE; + } +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/imglab/src/metadata_editor.cpp b/ml/dlib/tools/imglab/src/metadata_editor.cpp new file mode 100644 index 000000000..76177e893 --- /dev/null +++ b/ml/dlib/tools/imglab/src/metadata_editor.cpp @@ -0,0 +1,671 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "metadata_editor.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace dlib; + +extern const char* VERSION; + +// ---------------------------------------------------------------------------------------- + +metadata_editor:: +metadata_editor( + const std::string& filename_ +) : + mbar(*this), + lb_images(*this), + image_pos(0), + display(*this), + overlay_label_name(*this), + overlay_label(*this), + keyboard_jump_pos(0), + last_keyboard_jump_pos_update(0) +{ + file metadata_file(filename_); + filename = metadata_file.full_name(); + // Make our current directory be the one that contains the metadata file. We + // do this because that file might contain relative paths to the image files + // we are supposed to be loading. + set_current_dir(get_parent_directory(metadata_file).full_name()); + + load_image_dataset_metadata(metadata, filename); + + dlib::array::expand_1a files; + files.resize(metadata.images.size()); + for (unsigned long i = 0; i < metadata.images.size(); ++i) + { + files[i] = metadata.images[i].filename; + } + lb_images.load(files); + lb_images.enable_multiple_select(); + + lb_images.set_click_handler(*this, &metadata_editor::on_lb_images_clicked); + + overlay_label_name.set_text("Next Label: "); + overlay_label.set_width(200); + + display.set_image_clicked_handler(*this, &metadata_editor::on_image_clicked); + display.set_overlay_rects_changed_handler(*this, &metadata_editor::on_overlay_rects_changed); + display.set_overlay_rect_selected_handler(*this, &metadata_editor::on_overlay_rect_selected); + overlay_label.set_text_modified_handler(*this, &metadata_editor::on_overlay_label_changed); + + mbar.set_number_of_menus(2); + mbar.set_menu_name(0,"File",'F'); + mbar.set_menu_name(1,"Help",'H'); + + + mbar.menu(0).add_menu_item(menu_item_text("Save",*this,&metadata_editor::file_save,'S')); + mbar.menu(0).add_menu_item(menu_item_text("Save As",*this,&metadata_editor::file_save_as,'A')); + mbar.menu(0).add_menu_item(menu_item_separator()); + mbar.menu(0).add_menu_item(menu_item_text("Remove Selected Images",*this,&metadata_editor::remove_selected_images,'R')); + mbar.menu(0).add_menu_item(menu_item_separator()); + mbar.menu(0).add_menu_item(menu_item_text("Exit",static_cast(*this),&drawable_window::close_window,'x')); + + mbar.menu(1).add_menu_item(menu_item_text("About",*this,&metadata_editor::display_about,'A')); + + // set the size of this window. + on_window_resized(); + load_image_and_set_size(0); + on_window_resized(); + if (image_pos < lb_images.size() ) + lb_images.select(image_pos); + + // make sure the window is centered on the screen. + unsigned long width, height; + get_size(width, height); + unsigned long screen_width, screen_height; + get_display_size(screen_width, screen_height); + set_pos((screen_width-width)/2, (screen_height-height)/2); + + show(); +} + +// ---------------------------------------------------------------------------------------- + +metadata_editor:: +~metadata_editor( +) +{ + close_window(); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +add_labelable_part_name ( + const std::string& name +) +{ + display.add_labelable_part_name(name); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +file_save() +{ + save_metadata_to_file(filename); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +save_metadata_to_file ( + const std::string& file +) +{ + try + { + save_image_dataset_metadata(metadata, file); + } + catch (dlib::error& e) + { + message_box("Error saving file", e.what()); + } +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +file_save_as() +{ + save_file_box(*this, &metadata_editor::save_metadata_to_file); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +remove_selected_images() +{ + dlib::queue::kernel_1a list; + lb_images.get_selected(list); + list.reset(); + unsigned long min_idx = lb_images.size(); + while (list.move_next()) + { + lb_images.unselect(list.element()); + min_idx = std::min(min_idx, list.element()); + } + + + // remove all the selected items from metadata.images + dlib::static_set::kernel_1a to_remove; + to_remove.load(list); + std::vector images; + for (unsigned long i = 0; i < metadata.images.size(); ++i) + { + if (to_remove.is_member(i) == false) + { + images.push_back(metadata.images[i]); + } + } + images.swap(metadata.images); + + + // reload metadata into lb_images + dlib::array::expand_1a files; + files.resize(metadata.images.size()); + for (unsigned long i = 0; i < metadata.images.size(); ++i) + { + files[i] = metadata.images[i].filename; + } + lb_images.load(files); + + + if (min_idx != 0) + min_idx--; + select_image(min_idx); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +on_window_resized( +) +{ + drawable_window::on_window_resized(); + + unsigned long width, height; + get_size(width, height); + + lb_images.set_pos(0,mbar.bottom()+1); + lb_images.set_size(180, height - mbar.height()); + + overlay_label_name.set_pos(lb_images.right()+10, mbar.bottom() + (overlay_label.height()-overlay_label_name.height())/2+1); + overlay_label.set_pos(overlay_label_name.right(), mbar.bottom()+1); + display.set_pos(lb_images.right(), overlay_label.bottom()+3); + + display.set_size(width - display.left(), height - display.top()); +} + +// ---------------------------------------------------------------------------------------- + +void propagate_boxes( + dlib::image_dataset_metadata::dataset& data, + unsigned long prev, + unsigned long next +) +{ + if (prev == next || next >= data.images.size()) + return; + + array2d img1, img2; + dlib::load_image(img1, data.images[prev].filename); + dlib::load_image(img2, data.images[next].filename); + for (unsigned long i = 0; i < data.images[prev].boxes.size(); ++i) + { + correlation_tracker tracker; + tracker.start_track(img1, data.images[prev].boxes[i].rect); + tracker.update(img2); + dlib::image_dataset_metadata::box box = data.images[prev].boxes[i]; + box.rect = tracker.get_position(); + data.images[next].boxes.push_back(box); + } +} + +// ---------------------------------------------------------------------------------------- + +void propagate_labels( + const std::string& label, + dlib::image_dataset_metadata::dataset& data, + unsigned long prev, + unsigned long next +) +{ + if (prev == next || next >= data.images.size()) + return; + + + for (unsigned long i = 0; i < data.images[prev].boxes.size(); ++i) + { + if (data.images[prev].boxes[i].label != label) + continue; + + // figure out which box in the next image matches the current one the best + const rectangle cur = data.images[prev].boxes[i].rect; + double best_overlap = 0; + unsigned long best_idx = 0; + for (unsigned long j = 0; j < data.images[next].boxes.size(); ++j) + { + const rectangle next_box = data.images[next].boxes[j].rect; + const double overlap = cur.intersect(next_box).area()/(double)(cur+next_box).area(); + if (overlap > best_overlap) + { + best_overlap = overlap; + best_idx = j; + } + } + + // If we found a matching rectangle in the next image and the best match doesn't + // already have a label. + if (best_overlap > 0.5 && data.images[next].boxes[best_idx].label == "") + { + data.images[next].boxes[best_idx].label = label; + } + } + +} + +// ---------------------------------------------------------------------------------------- + +bool has_label_or_all_boxes_labeled ( + const std::string& label, + const dlib::image_dataset_metadata::image& img +) +{ + if (label.size() == 0) + return true; + + bool all_boxes_labeled = true; + for (unsigned long i = 0; i < img.boxes.size(); ++i) + { + if (img.boxes[i].label == label) + return true; + if (img.boxes[i].label.size() == 0) + all_boxes_labeled = false; + } + + return all_boxes_labeled; +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state +) +{ + drawable_window::on_keydown(key, is_printable, state); + + if (is_printable) + { + if (key == '\t') + { + overlay_label.give_input_focus(); + overlay_label.select_all_text(); + } + + // If the user types a number then jump to that image. + if ('0' <= key && key <= '9' && metadata.images.size() != 0 && !overlay_label.has_input_focus()) + { + time_t curtime = time(0); + // If it's been a while since the user typed numbers then forget the last jump + // position and start accumulating numbers over again. + if (curtime-last_keyboard_jump_pos_update >= 2) + keyboard_jump_pos = 0; + last_keyboard_jump_pos_update = curtime; + + keyboard_jump_pos *= 10; + keyboard_jump_pos += key-'0'; + if (keyboard_jump_pos >= metadata.images.size()) + keyboard_jump_pos = metadata.images.size()-1; + + image_pos = keyboard_jump_pos; + select_image(image_pos); + } + else + { + last_keyboard_jump_pos_update = 0; + } + + if (key == 'd' && (state&base_window::KBD_MOD_ALT)) + { + remove_selected_images(); + } + + if (key == 'e' && !overlay_label.has_input_focus()) + { + display_equialized_image = !display_equialized_image; + select_image(image_pos); + } + + // Make 'w' and 's' act like KEY_UP and KEY_DOWN + if ((key == 'w' || key == 'W') && !overlay_label.has_input_focus()) + { + key = base_window::KEY_UP; + } + else if ((key == 's' || key == 'S') && !overlay_label.has_input_focus()) + { + key = base_window::KEY_DOWN; + } + else + { + return; + } + } + + if (key == base_window::KEY_UP) + { + if ((state&KBD_MOD_CONTROL) && (state&KBD_MOD_SHIFT)) + { + // Don't do anything if there are no boxes in the current image. + if (metadata.images[image_pos].boxes.size() == 0) + return; + // Also don't do anything if there *are* boxes in the next image. + if (image_pos > 1 && metadata.images[image_pos-1].boxes.size() != 0) + return; + + propagate_boxes(metadata, image_pos, image_pos-1); + } + else if (state&base_window::KBD_MOD_CONTROL) + { + // If the label we are supposed to propagate doesn't exist in the current image + // then don't advance. + if (!has_label_or_all_boxes_labeled(display.get_default_overlay_rect_label(),metadata.images[image_pos])) + return; + + // if the next image is going to be empty then fast forward to the next one + while (image_pos > 1 && metadata.images[image_pos-1].boxes.size() == 0) + --image_pos; + + propagate_labels(display.get_default_overlay_rect_label(), metadata, image_pos, image_pos-1); + } + select_image(image_pos-1); + } + else if (key == base_window::KEY_DOWN) + { + if ((state&KBD_MOD_CONTROL) && (state&KBD_MOD_SHIFT)) + { + // Don't do anything if there are no boxes in the current image. + if (metadata.images[image_pos].boxes.size() == 0) + return; + // Also don't do anything if there *are* boxes in the next image. + if (image_pos+1 < metadata.images.size() && metadata.images[image_pos+1].boxes.size() != 0) + return; + + propagate_boxes(metadata, image_pos, image_pos+1); + } + else if (state&base_window::KBD_MOD_CONTROL) + { + // If the label we are supposed to propagate doesn't exist in the current image + // then don't advance. + if (!has_label_or_all_boxes_labeled(display.get_default_overlay_rect_label(),metadata.images[image_pos])) + return; + + // if the next image is going to be empty then fast forward to the next one + while (image_pos+1 < metadata.images.size() && metadata.images[image_pos+1].boxes.size() == 0) + ++image_pos; + + propagate_labels(display.get_default_overlay_rect_label(), metadata, image_pos, image_pos+1); + } + select_image(image_pos+1); + } +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +select_image( + unsigned long idx +) +{ + if (idx < lb_images.size()) + { + // unselect all currently selected images + dlib::queue::kernel_1a list; + lb_images.get_selected(list); + list.reset(); + while (list.move_next()) + { + lb_images.unselect(list.element()); + } + + + lb_images.select(idx); + load_image(idx); + } + else if (lb_images.size() == 0) + { + display.clear_overlay(); + array2d empty_img; + display.set_image(empty_img); + } +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +on_lb_images_clicked( + unsigned long idx +) +{ + load_image(idx); +} + +// ---------------------------------------------------------------------------------------- + +std::vector get_overlays ( + const dlib::image_dataset_metadata::image& data, + color_mapper& string_to_color +) +{ + std::vector temp(data.boxes.size()); + for (unsigned long i = 0; i < temp.size(); ++i) + { + temp[i].rect = data.boxes[i].rect; + temp[i].label = data.boxes[i].label; + temp[i].parts = data.boxes[i].parts; + temp[i].crossed_out = data.boxes[i].ignore; + temp[i].color = string_to_color(data.boxes[i].label); + } + return temp; +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +load_image( + unsigned long idx +) +{ + if (idx >= metadata.images.size()) + return; + + image_pos = idx; + + array2d img; + display.clear_overlay(); + try + { + dlib::load_image(img, metadata.images[idx].filename); + set_title(metadata.name + " #"+cast_to_string(idx)+": " +metadata.images[idx].filename); + } + catch (exception& e) + { + message_box("Error loading image", e.what()); + } + + if (display_equialized_image) + equalize_histogram(img); + display.set_image(img); + display.add_overlay(get_overlays(metadata.images[idx], string_to_color)); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +load_image_and_set_size( + unsigned long idx +) +{ + if (idx >= metadata.images.size()) + return; + + image_pos = idx; + + array2d img; + display.clear_overlay(); + try + { + dlib::load_image(img, metadata.images[idx].filename); + set_title(metadata.name + " #"+cast_to_string(idx)+": " +metadata.images[idx].filename); + } + catch (exception& e) + { + message_box("Error loading image", e.what()); + } + + + unsigned long screen_width, screen_height; + get_display_size(screen_width, screen_height); + + + unsigned long needed_width = display.left() + img.nc() + 4; + unsigned long needed_height = display.top() + img.nr() + 4; + if (needed_width < 300) needed_width = 300; + if (needed_height < 300) needed_height = 300; + + if (needed_width > 100 + screen_width) + needed_width = screen_width - 100; + if (needed_height > 100 + screen_height) + needed_height = screen_height - 100; + + set_size(needed_width, needed_height); + + + if (display_equialized_image) + equalize_histogram(img); + display.set_image(img); + display.add_overlay(get_overlays(metadata.images[idx], string_to_color)); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +on_overlay_rects_changed( +) +{ + using namespace dlib::image_dataset_metadata; + if (image_pos < metadata.images.size()) + { + const std::vector& rects = display.get_overlay_rects(); + + std::vector& boxes = metadata.images[image_pos].boxes; + + boxes.clear(); + for (unsigned long i = 0; i < rects.size(); ++i) + { + box temp; + temp.label = rects[i].label; + temp.rect = rects[i].rect; + temp.parts = rects[i].parts; + temp.ignore = rects[i].crossed_out; + boxes.push_back(temp); + } + } +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +on_image_clicked( + const point& /*p*/, bool /*is_double_click*/, unsigned long /*btn*/ +) +{ + display.set_default_overlay_rect_color(string_to_color(trim(overlay_label.text()))); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +on_overlay_label_changed( +) +{ + display.set_default_overlay_rect_label(trim(overlay_label.text())); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +on_overlay_rect_selected( + const image_display::overlay_rect& orect +) +{ + overlay_label.set_text(orect.label); + display.set_default_overlay_rect_label(orect.label); + display.set_default_overlay_rect_color(string_to_color(orect.label)); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +display_about( +) +{ + std::ostringstream sout; + sout << wrap_string("Image Labeler v" + string(VERSION) + "." ,0,0) << endl << endl; + sout << wrap_string("This program is a tool for labeling images with rectangles. " ,0,0) << endl << endl; + + sout << wrap_string("You can add a new rectangle by holding the shift key, left clicking " + "the mouse, and dragging it. New rectangles are given the label from the \"Next Label\" " + "field at the top of the application. You can quickly edit the contents of the Next Label field " + "by hitting the tab key. Double clicking " + "a rectangle selects it and the delete key removes it. You can also mark " + "a rectangle as ignored by hitting the i or END keys when it is selected. Ignored " + "rectangles are visually displayed with an X through them. You can remove an image " + "entirely by selecting it in the list on the left and pressing alt+d." + ,0,0) << endl << endl; + + sout << wrap_string("It is also possible to label object parts by selecting a rectangle and " + "then right clicking. A popup menu will appear and you can select a part label. " + "Note that you must define the allowable part labels by giving --parts on the " + "command line. An example would be '--parts \"leye reye nose mouth\"'." + ,0,0) << endl << endl; + + sout << wrap_string("Press the down or s key to select the next image in the list and the up or w " + "key to select the previous one.",0,0) << endl << endl; + + sout << wrap_string("Additionally, you can hold ctrl and then scroll the mouse wheel to zoom. A normal left click " + "and drag allows you to navigate around the image. Holding ctrl and " + "left clicking a rectangle will give it the label from the Next Label field. " + "Holding shift + right click and then dragging allows you to move things around. " + "Holding ctrl and pressing the up or down keyboard keys will propagate " + "rectangle labels from one image to the next and also skip empty images. " + "Similarly, holding ctrl+shift will propagate entire boxes via a visual tracking " + "algorithm from one image to the next. " + "Finally, typing a number on the keyboard will jump you to a specific image.",0,0) << endl << endl; + + sout << wrap_string("You can also toggle image histogram equalization by pressing the e key." + ,0,0) << endl; + + + message_box("About Image Labeler",sout.str()); +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/imglab/src/metadata_editor.h b/ml/dlib/tools/imglab/src/metadata_editor.h new file mode 100644 index 000000000..71aa14ace --- /dev/null +++ b/ml/dlib/tools/imglab/src/metadata_editor.h @@ -0,0 +1,116 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_METADATA_EdITOR_H__ +#define DLIB_METADATA_EdITOR_H__ + +#include +#include +#include +#include + +// ---------------------------------------------------------------------------------------- + +class color_mapper +{ +public: + + dlib::rgb_alpha_pixel operator() ( + const std::string& str + ) + { + auto i = colors.find(str); + if (i != colors.end()) + { + return i->second; + } + else + { + using namespace dlib; + hsi_pixel pix; + pix.h = reverse(colors.size()); + pix.s = 255; + pix.i = 150; + rgb_alpha_pixel result; + assign_pixel(result, pix); + colors[str] = result; + return result; + } + } + +private: + + // We use a bit reverse here because it causes us to evenly spread the colors as we + // allocated them. First the colors are maximally different, then become interleaved + // and progressively more similar as they are allocated. + unsigned char reverse(unsigned char b) + { + // reverse the order of the bits in b. + b = ((b * 0x0802LU & 0x22110LU) | (b * 0x8020LU & 0x88440LU)) * 0x10101LU >> 16; + return b; + } + + std::map colors; +}; + +// ---------------------------------------------------------------------------------------- + +class metadata_editor : public dlib::drawable_window +{ +public: + metadata_editor( + const std::string& filename_ + ); + + ~metadata_editor(); + + void add_labelable_part_name ( + const std::string& name + ); + +private: + + void file_save(); + void file_save_as(); + void remove_selected_images(); + + virtual void on_window_resized(); + virtual void on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ); + + void on_lb_images_clicked(unsigned long idx); + void select_image(unsigned long idx); + void save_metadata_to_file (const std::string& file); + void load_image(unsigned long idx); + void load_image_and_set_size(unsigned long idx); + void on_image_clicked(const dlib::point& p, bool is_double_click, unsigned long btn); + void on_overlay_rects_changed(); + void on_overlay_label_changed(); + void on_overlay_rect_selected(const dlib::image_display::overlay_rect& orect); + + void display_about(); + + std::string filename; + dlib::image_dataset_metadata::dataset metadata; + + dlib::menu_bar mbar; + dlib::list_box lb_images; + unsigned long image_pos; + + dlib::image_display display; + dlib::label overlay_label_name; + dlib::text_field overlay_label; + + unsigned long keyboard_jump_pos; + time_t last_keyboard_jump_pos_update; + bool display_equialized_image = false; + color_mapper string_to_color; +}; + +// ---------------------------------------------------------------------------------------- + + +#endif // DLIB_METADATA_EdITOR_H__ + diff --git a/ml/dlib/tools/python/CMakeLists.txt b/ml/dlib/tools/python/CMakeLists.txt new file mode 100644 index 000000000..d3b947485 --- /dev/null +++ b/ml/dlib/tools/python/CMakeLists.txt @@ -0,0 +1,106 @@ + +CMAKE_MINIMUM_REQUIRED(VERSION 2.8.12) + +set(USE_SSE4_INSTRUCTIONS ON CACHE BOOL "Use SSE4 instructions") +# Make DLIB_ASSERT statements not abort the python interpreter, but just return an error. +add_definitions(-DDLIB_NO_ABORT_ON_2ND_FATAL_ERROR) + +# Set this to disable link time optimization. The only reason for +# doing this to make the compile faster which is nice when developing +# new modules. +#set(PYBIND11_LTO_CXX_FLAGS "") + + +# Avoid cmake warnings about changes in behavior of some Mac OS X path +# variable we don't care about. +if (POLICY CMP0042) + cmake_policy(SET CMP0042 NEW) +endif() + + +if (CMAKE_COMPILER_IS_GNUCXX) + # Just setting CMAKE_POSITION_INDEPENDENT_CODE should be enough to set + # -fPIC for GCC but sometimes it still doesn't get set, so make sure it + # does. + add_definitions("-fPIC") + set(CMAKE_POSITION_INDEPENDENT_CODE True) +else() + set(CMAKE_POSITION_INDEPENDENT_CODE True) +endif() + +# To avoid dll hell, always link everything statically when compiling in +# visual studio. This way, the resulting library won't depend on a bunch +# of other dll files and can be safely copied to someone elese's computer +# and expected to run. +if (MSVC) + include(${CMAKE_CURRENT_LIST_DIR}/../../dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake) +endif() + +add_subdirectory(../../dlib/external/pybind11 ./pybind11_build) +add_subdirectory(../../dlib ./dlib_build) + +if (USING_OLD_VISUAL_STUDIO_COMPILER) + message(FATAL_ERROR "You have to use a version of Visual Studio that supports C++11. As of December 2017, the only versions that have good enough C++11 support to compile the dlib Pyhton API is a fully updated Visual Studio 2015 or a fully updated Visual Studio 2017. Older versions of either of these compilers have bad C++11 support and will fail to compile the Python extension. ***SO UPDATE YOUR VISUAL STUDIO TO MAKE THIS ERROR GO AWAY***") +endif() + + +# Test for numpy +find_package(PythonInterp) +if(PYTHONINTERP_FOUND) + execute_process( COMMAND ${PYTHON_EXECUTABLE} -c "import numpy" OUTPUT_QUIET ERROR_QUIET RESULT_VARIABLE NUMPYRC) + if(NUMPYRC EQUAL 1) + message(WARNING "Numpy not found. Functions that return numpy arrays will throw exceptions!") + else() + message(STATUS "Found Python with installed numpy package") + execute_process( COMMAND ${PYTHON_EXECUTABLE} -c "import sys; from numpy import get_include; sys.stdout.write(get_include())" OUTPUT_VARIABLE NUMPY_INCLUDE_PATH) + message(STATUS "Numpy include path '${NUMPY_INCLUDE_PATH}'") + include_directories(${NUMPY_INCLUDE_PATH}) + endif() +else() + message(WARNING "Numpy not found. Functions that return numpy arrays will throw exceptions!") + set(NUMPYRC 1) +endif() + +add_definitions(-DDLIB_VERSION=${DLIB_VERSION}) + +# Tell cmake to compile all these cpp files into a dlib python module. +set(python_srcs + src/dlib.cpp + src/matrix.cpp + src/vector.cpp + src/svm_c_trainer.cpp + src/svm_rank_trainer.cpp + src/decision_functions.cpp + src/other.cpp + src/basic.cpp + src/cca.cpp + src/sequence_segmenter.cpp + src/svm_struct.cpp + src/image.cpp + src/rectangles.cpp + src/object_detection.cpp + src/shape_predictor.cpp + src/correlation_tracker.cpp + src/face_recognition.cpp + src/cnn_face_detector.cpp + src/global_optimization.cpp + src/image_dataset_metadata.cpp +) + +# Only add the Numpy returning functions if Numpy is present +if(NUMPYRC EQUAL 1) + list(APPEND python_srcs src/numpy_returns_stub.cpp) +else() + list(APPEND python_srcs src/numpy_returns.cpp) +endif() + +# Only add the GUI module if requested +if(NOT ${DLIB_NO_GUI_SUPPORT}) + list(APPEND python_srcs src/gui.cpp) +endif() + +pybind11_add_module(dlib_python ${python_srcs}) +target_link_libraries(dlib_python PRIVATE dlib::dlib) +# Set the output library name to dlib because that's what setup.py and distutils expects. +set_target_properties(dlib_python PROPERTIES OUTPUT_NAME dlib) + diff --git a/ml/dlib/tools/python/src/basic.cpp b/ml/dlib/tools/python/src/basic.cpp new file mode 100644 index 000000000..d87a53cc3 --- /dev/null +++ b/ml/dlib/tools/python/src/basic.cpp @@ -0,0 +1,272 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include +#include +#include +#include +#include "opaque_types.h" + +#include +#include + +using namespace std; +using namespace dlib; +namespace py = pybind11; + + +std::shared_ptr > array_from_object(py::object obj) +{ + try { + long nr = obj.cast(); + return std::make_shared>(nr); + } catch (py::cast_error &e) { + py::list li = obj.cast(); + const long nr = len(li); + auto temp = std::make_shared>(nr); + for ( long r = 0; r < nr; ++r) + { + (*temp)[r] = li[r].cast(); + } + return temp; + } +} + +string array__str__ (const std::vector& v) +{ + std::ostringstream sout; + for (unsigned long i = 0; i < v.size(); ++i) + { + sout << v[i]; + if (i+1 < v.size()) + sout << "\n"; + } + return sout.str(); +} + +string array__repr__ (const std::vector& v) +{ + std::ostringstream sout; + sout << "dlib.array(["; + for (unsigned long i = 0; i < v.size(); ++i) + { + sout << v[i]; + if (i+1 < v.size()) + sout << ", "; + } + sout << "])"; + return sout.str(); +} + +string range__str__ (const std::pair& p) +{ + std::ostringstream sout; + sout << p.first << ", " << p.second; + return sout.str(); +} + +string range__repr__ (const std::pair& p) +{ + std::ostringstream sout; + sout << "dlib.range(" << p.first << ", " << p.second << ")"; + return sout.str(); +} + +struct range_iter +{ + std::pair range; + unsigned long cur; + + unsigned long next() + { + if (cur < range.second) + { + return cur++; + } + else + { + PyErr_SetString(PyExc_StopIteration, "No more data."); + throw py::error_already_set(); + } + } +}; + +range_iter make_range_iterator (const std::pair& p) +{ + range_iter temp; + temp.range = p; + temp.cur = p.first; + return temp; +} + +string pair__str__ (const std::pair& p) +{ + std::ostringstream sout; + sout << p.first << ": " << p.second; + return sout.str(); +} + +string pair__repr__ (const std::pair& p) +{ + std::ostringstream sout; + sout << "dlib.pair(" << p.first << ", " << p.second << ")"; + return sout.str(); +} + +string sparse_vector__str__ (const std::vector >& v) +{ + std::ostringstream sout; + for (unsigned long i = 0; i < v.size(); ++i) + { + sout << v[i].first << ": " << v[i].second; + if (i+1 < v.size()) + sout << "\n"; + } + return sout.str(); +} + +string sparse_vector__repr__ (const std::vector >& v) +{ + std::ostringstream sout; + sout << "< dlib.sparse_vector containing: \n" << sparse_vector__str__(v) << " >"; + return sout.str(); +} + +unsigned long range_len(const std::pair& r) +{ + if (r.second > r.first) + return r.second-r.first; + else + return 0; +} + +template +void resize(T& v, unsigned long n) { v.resize(n); } + +void bind_basic_types(py::module& m) +{ + { + typedef double item_type; + typedef std::vector type; + typedef std::shared_ptr type_ptr; + py::bind_vector(m, "array", "This object represents a 1D array of floating point numbers. " + "Moreover, it binds directly to the C++ type std::vector.") + .def(py::init(&array_from_object)) + .def("__str__", array__str__) + .def("__repr__", array__repr__) + .def("clear", &type::clear) + .def("resize", resize) + .def("extend", extend_vector_with_python_list) + .def(py::pickle(&getstate, &setstate)); + } + + { + typedef matrix item_type; + typedef std::vector type; + py::bind_vector(m, "vectors", "This object is an array of vector objects.") + .def("clear", &type::clear) + .def("resize", resize) + .def("extend", extend_vector_with_python_list) + .def(py::pickle(&getstate, &setstate)); + } + + { + typedef std::vector > item_type; + typedef std::vector type; + py::bind_vector(m, "vectorss", "This object is an array of arrays of vector objects.") + .def("clear", &type::clear) + .def("resize", resize) + .def("extend", extend_vector_with_python_list) + .def(py::pickle(&getstate, &setstate)); + } + + typedef pair range_type; + py::class_(m, "range", "This object is used to represent a range of elements in an array.") + .def(py::init()) + .def_readwrite("begin",&range_type::first, "The index of the first element in the range. This is represented using an unsigned integer.") + .def_readwrite("end",&range_type::second, "One past the index of the last element in the range. This is represented using an unsigned integer.") + .def("__str__", range__str__) + .def("__repr__", range__repr__) + .def("__iter__", &make_range_iterator) + .def("__len__", &range_len) + .def(py::pickle(&getstate, &setstate)); + + py::class_(m, "_range_iter") + .def("next", &range_iter::next) + .def("__next__", &range_iter::next); + + { + typedef std::pair item_type; + typedef std::vector type; + py::bind_vector(m, "ranges", "This object is an array of range objects.") + .def("clear", &type::clear) + .def("resize", resize) + .def("extend", extend_vector_with_python_list) + .def(py::pickle(&getstate, &setstate)); + } + + { + typedef std::vector > item_type; + typedef std::vector type; + py::bind_vector(m, "rangess", "This object is an array of arrays of range objects.") + .def("clear", &type::clear) + .def("resize", resize) + .def("extend", extend_vector_with_python_list) + .def(py::pickle(&getstate, &setstate)); + } + + + typedef pair pair_type; + py::class_(m, "pair", "This object is used to represent the elements of a sparse_vector.") + .def(py::init()) + .def_readwrite("first",&pair_type::first, "This field represents the index/dimension number.") + .def_readwrite("second",&pair_type::second, "This field contains the value in a vector at dimension specified by the first field.") + .def("__str__", pair__str__) + .def("__repr__", pair__repr__) + .def(py::pickle(&getstate, &setstate)); + + { + typedef std::vector type; + py::bind_vector(m, "sparse_vector", +"This object represents the mathematical idea of a sparse column vector. It is \n\ +simply an array of dlib.pair objects, each representing an index/value pair in \n\ +the vector. Any elements of the vector which are missing are implicitly set to \n\ +zero. \n\ + \n\ +Unless otherwise noted, any routines taking a sparse_vector assume the sparse \n\ +vector is sorted and has unique elements. That is, the index values of the \n\ +pairs in a sparse_vector should be listed in increasing order and there should \n\ +not be duplicates. However, some functions work with \"unsorted\" sparse \n\ +vectors. These are dlib.sparse_vector objects that have either duplicate \n\ +entries or non-sorted index values. Note further that you can convert an \n\ +\"unsorted\" sparse_vector into a properly sorted sparse vector by calling \n\ +dlib.make_sparse_vector() on it. " + ) + .def("__str__", sparse_vector__str__) + .def("__repr__", sparse_vector__repr__) + .def("clear", &type::clear) + .def("resize", resize) + .def("extend", extend_vector_with_python_list) + .def(py::pickle(&getstate, &setstate)); + } + + { + typedef std::vector item_type; + typedef std::vector type; + py::bind_vector(m, "sparse_vectors", "This object is an array of sparse_vector objects.") + .def("clear", &type::clear) + .def("resize", resize) + .def("extend", extend_vector_with_python_list) + .def(py::pickle(&getstate, &setstate)); + } + + { + typedef std::vector > item_type; + typedef std::vector type; + py::bind_vector(m, "sparse_vectorss", "This object is an array of arrays of sparse_vector objects.") + .def("clear", &type::clear) + .def("resize", resize) + .def("extend", extend_vector_with_python_list) + .def(py::pickle(&getstate, &setstate)); + } +} + diff --git a/ml/dlib/tools/python/src/cca.cpp b/ml/dlib/tools/python/src/cca.cpp new file mode 100644 index 000000000..dcf476522 --- /dev/null +++ b/ml/dlib/tools/python/src/cca.cpp @@ -0,0 +1,137 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include +#include + +using namespace dlib; +namespace py = pybind11; + +typedef std::vector > sparse_vect; + +struct cca_outputs +{ + matrix correlations; + matrix Ltrans; + matrix Rtrans; +}; + +cca_outputs _cca1 ( + const std::vector& L, + const std::vector& R, + unsigned long num_correlations, + unsigned long extra_rank, + unsigned long q, + double regularization +) +{ + pyassert(num_correlations > 0 && L.size() > 0 && R.size() > 0 && L.size() == R.size() && regularization >= 0, + "Invalid inputs"); + + cca_outputs temp; + temp.correlations = cca(L,R,temp.Ltrans,temp.Rtrans,num_correlations,extra_rank,q,regularization); + return temp; +} + +// ---------------------------------------------------------------------------------------- + +unsigned long sparse_vector_max_index_plus_one ( + const sparse_vect& v +) +{ + return max_index_plus_one(v); +} + +matrix apply_cca_transform ( + const matrix& m, + const sparse_vect& v +) +{ + pyassert((long)max_index_plus_one(v) <= m.nr(), "Invalid Inputs"); + return sparse_matrix_vector_multiply(trans(m), v); +} + +void bind_cca(py::module& m) +{ + py::class_(m, "cca_outputs") + .def_readwrite("correlations", &cca_outputs::correlations) + .def_readwrite("Ltrans", &cca_outputs::Ltrans) + .def_readwrite("Rtrans", &cca_outputs::Rtrans); + + m.def("max_index_plus_one", sparse_vector_max_index_plus_one, py::arg("v"), +"ensures \n\ + - returns the dimensionality of the given sparse vector. That is, returns a \n\ + number one larger than the maximum index value in the vector. If the vector \n\ + is empty then returns 0. " + ); + + + m.def("apply_cca_transform", apply_cca_transform, py::arg("m"), py::arg("v"), +"requires \n\ + - max_index_plus_one(v) <= m.nr() \n\ +ensures \n\ + - returns trans(m)*v \n\ + (i.e. multiply m by the vector v and return the result) " + ); + + + m.def("cca", _cca1, py::arg("L"), py::arg("R"), py::arg("num_correlations"), py::arg("extra_rank")=5, py::arg("q")=2, py::arg("regularization")=0, +"requires \n\ + - num_correlations > 0 \n\ + - len(L) > 0 \n\ + - len(R) > 0 \n\ + - len(L) == len(R) \n\ + - regularization >= 0 \n\ + - L and R must be properly sorted sparse vectors. This means they must list their \n\ + elements in ascending index order and not contain duplicate index values. You can use \n\ + make_sparse_vector() to ensure this is true. \n\ +ensures \n\ + - This function performs a canonical correlation analysis between the vectors \n\ + in L and R. That is, it finds two transformation matrices, Ltrans and \n\ + Rtrans, such that row vectors in the transformed matrices L*Ltrans and \n\ + R*Rtrans are as correlated as possible (note that in this notation we \n\ + interpret L as a matrix with the input vectors in its rows). Note also that \n\ + this function tries to find transformations which produce num_correlations \n\ + dimensional output vectors. \n\ + - Note that you can easily apply the transformation to a vector using \n\ + apply_cca_transform(). So for example, like this: \n\ + - apply_cca_transform(Ltrans, some_sparse_vector) \n\ + - returns a structure containing the Ltrans and Rtrans transformation matrices \n\ + as well as the estimated correlations between elements of the transformed \n\ + vectors. \n\ + - This function assumes the data vectors in L and R have already been centered \n\ + (i.e. we assume the vectors have zero means). However, in many cases it is \n\ + fine to use uncentered data with cca(). But if it is important for your \n\ + problem then you should center your data before passing it to cca(). \n\ + - This function works with reduced rank approximations of the L and R matrices. \n\ + This makes it fast when working with large matrices. In particular, we use \n\ + the dlib::svd_fast() routine to find reduced rank representations of the input \n\ + matrices by calling it as follows: svd_fast(L, U,D,V, num_correlations+extra_rank, q) \n\ + and similarly for R. This means that you can use the extra_rank and q \n\ + arguments to cca() to influence the accuracy of the reduced rank \n\ + approximation. However, the default values should work fine for most \n\ + problems. \n\ + - The dimensions of the output vectors produced by L*#Ltrans or R*#Rtrans are \n\ + ordered such that the dimensions with the highest correlations come first. \n\ + That is, after applying the transforms produced by cca() to a set of vectors \n\ + you will find that dimension 0 has the highest correlation, then dimension 1 \n\ + has the next highest, and so on. This also means that the list of estimated \n\ + correlations returned from cca() will always be listed in decreasing order. \n\ + - This function performs the ridge regression version of Canonical Correlation \n\ + Analysis when regularization is set to a value > 0. In particular, larger \n\ + values indicate the solution should be more heavily regularized. This can be \n\ + useful when the dimensionality of the data is larger than the number of \n\ + samples. \n\ + - A good discussion of CCA can be found in the paper \"Canonical Correlation \n\ + Analysis\" by David Weenink. In particular, this function is implemented \n\ + using equations 29 and 30 from his paper. We also use the idea of doing CCA \n\ + on a reduced rank approximation of L and R as suggested by Paramveer S. \n\ + Dhillon in his paper \"Two Step CCA: A new spectral method for estimating \n\ + vector models of words\". " + + ); +} + + + diff --git a/ml/dlib/tools/python/src/cnn_face_detector.cpp b/ml/dlib/tools/python/src/cnn_face_detector.cpp new file mode 100644 index 000000000..f18d99d95 --- /dev/null +++ b/ml/dlib/tools/python/src/cnn_face_detector.cpp @@ -0,0 +1,183 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include +#include +#include +#include +#include "indexing.h" +#include + +using namespace dlib; +using namespace std; + +namespace py = pybind11; + + +class cnn_face_detection_model_v1 +{ + +public: + + cnn_face_detection_model_v1(const std::string& model_filename) + { + deserialize(model_filename) >> net; + } + + std::vector detect ( + py::object pyimage, + const int upsample_num_times + ) + { + pyramid_down<2> pyr; + std::vector rects; + + // Copy the data into dlib based objects + matrix image; + if (is_gray_python_image(pyimage)) + assign_image(image, numpy_gray_image(pyimage)); + else if (is_rgb_python_image(pyimage)) + assign_image(image, numpy_rgb_image(pyimage)); + else + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); + + // Upsampling the image will allow us to detect smaller faces but will cause the + // program to use more RAM and run longer. + unsigned int levels = upsample_num_times; + while (levels > 0) + { + levels--; + pyramid_up(image, pyr); + } + + auto dets = net(image); + + // Scale the detection locations back to the original image size + // if the image was upscaled. + for (auto&& d : dets) { + d.rect = pyr.rect_down(d.rect, upsample_num_times); + rects.push_back(d); + } + + return rects; + } + + std::vector > detect_mult ( + py::list imgs, + const int upsample_num_times, + const int batch_size = 128 + ) + { + pyramid_down<2> pyr; + std::vector > dimgs; + dimgs.reserve(len(imgs)); + + for(int i = 0; i < len(imgs); i++) + { + // Copy the data into dlib based objects + matrix image; + py::object tmp = imgs[i].cast(); + if (is_gray_python_image(tmp)) + assign_image(image, numpy_gray_image(tmp)); + else if (is_rgb_python_image(tmp)) + assign_image(image, numpy_rgb_image(tmp)); + else + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); + + for(int i = 0; i < upsample_num_times; i++) + { + pyramid_up(image); + } + dimgs.push_back(image); + } + + for(int i = 1; i < dimgs.size(); i++) + { + if + ( + dimgs[i - 1].nc() != dimgs[i].nc() || + dimgs[i - 1].nr() != dimgs[i].nr() + ) + throw dlib::error("Images in list must all have the same dimensions."); + + } + + auto dets = net(dimgs, batch_size); + std::vector > all_rects; + + for(auto&& im_dets : dets) + { + std::vector rects; + rects.reserve(im_dets.size()); + for (auto&& d : im_dets) { + d.rect = pyr.rect_down(d.rect, upsample_num_times); + rects.push_back(d); + } + all_rects.push_back(rects); + } + + return all_rects; + } + +private: + + template using con5d = con; + template using con5 = con; + + template using downsampler = relu>>>>>>>>; + template using rcon5 = relu>>; + + using net_type = loss_mmod>>>>>>>; + + net_type net; +}; + +// ---------------------------------------------------------------------------------------- + +void bind_cnn_face_detection(py::module& m) +{ + { + py::class_(m, "cnn_face_detection_model_v1", "This object detects human faces in an image. The constructor loads the face detection model from a file. You can download a pre-trained model from http://dlib.net/files/mmod_human_face_detector.dat.bz2.") + .def(py::init()) + .def( + "__call__", + &cnn_face_detection_model_v1::detect_mult, + py::arg("imgs"), py::arg("upsample_num_times")=0, py::arg("batch_size")=128, + "takes a list of images as input returning a 2d list of mmod rectangles" + ) + .def( + "__call__", + &cnn_face_detection_model_v1::detect, + py::arg("img"), py::arg("upsample_num_times")=0, + "Find faces in an image using a deep learning model.\n\ + - Upsamples the image upsample_num_times before running the face \n\ + detector." + ); + } + + m.def("set_dnn_prefer_smallest_algorithms", &set_dnn_prefer_smallest_algorithms, "Tells cuDNN to use slower algorithms that use less RAM."); + + auto cuda = m.def_submodule("cuda", "Routines for setting CUDA specific properties."); + cuda.def("set_device", &dlib::cuda::set_device, py::arg("device_id"), + "Set the active CUDA device. It is required that 0 <= device_id < get_num_devices()."); + cuda.def("get_device", &dlib::cuda::get_device, "Get the active CUDA device."); + cuda.def("get_num_devices", &dlib::cuda::get_num_devices, "Find out how many CUDA devices are available."); + + { + typedef mmod_rect type; + py::class_(m, "mmod_rectangle", "Wrapper around a rectangle object and a detection confidence score.") + .def_readwrite("rect", &type::rect) + .def_readwrite("confidence", &type::detection_confidence); + } + { + typedef std::vector type; + py::bind_vector(m, "mmod_rectangles", "An array of mmod rectangle objects.") + .def("extend", extend_vector_with_python_list); + } + { + typedef std::vector > type; + py::bind_vector(m, "mmod_rectangless", "A 2D array of mmod rectangle objects.") + .def("extend", extend_vector_with_python_list>); + } +} diff --git a/ml/dlib/tools/python/src/conversion.h b/ml/dlib/tools/python/src/conversion.h new file mode 100644 index 000000000..9ab2360a0 --- /dev/null +++ b/ml/dlib/tools/python/src/conversion.h @@ -0,0 +1,52 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PYTHON_CONVERSION_H__ +#define DLIB_PYTHON_CONVERSION_H__ + +#include "opaque_types.h" +#include +#include + +using namespace dlib; +using namespace std; + +namespace py = pybind11; + +template +void pyimage_to_dlib_image(py::object img, dest_image_type& image) +{ + if (is_gray_python_image(img)) + assign_image(image, numpy_gray_image(img)); + else if (is_rgb_python_image(img)) + assign_image(image, numpy_rgb_image(img)); + else + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); +} + +template +void images_and_nested_params_to_dlib( + const py::object& pyimages, + const py::object& pyparams, + image_array& images, + std::vector >& params +) +{ + // Now copy the data into dlib based objects. + py::iterator image_it = pyimages.begin(); + py::iterator params_it = pyparams.begin(); + + for (unsigned long image_idx = 0; + image_it != pyimages.end() + && params_it != pyparams.end(); + ++image_it, ++params_it, ++image_idx) + { + for (py::iterator param_it = params_it->begin(); + param_it != params_it->end(); + ++param_it) + params[image_idx].push_back(param_it->cast()); + + pyimage_to_dlib_image(image_it->cast(), images[image_idx]); + } +} + +#endif // DLIB_PYTHON_CONVERSION_H__ diff --git a/ml/dlib/tools/python/src/correlation_tracker.cpp b/ml/dlib/tools/python/src/correlation_tracker.cpp new file mode 100644 index 000000000..1b17ba54c --- /dev/null +++ b/ml/dlib/tools/python/src/correlation_tracker.cpp @@ -0,0 +1,167 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include +#include +#include + +using namespace dlib; +using namespace std; + +namespace py = pybind11; + +// ---------------------------------------------------------------------------------------- + +void start_track ( + correlation_tracker& tracker, + py::object img, + const drectangle& bounding_box +) +{ + if (is_gray_python_image(img)) + { + tracker.start_track(numpy_gray_image(img), bounding_box); + } + else if (is_rgb_python_image(img)) + { + tracker.start_track(numpy_rgb_image(img), bounding_box); + } + else + { + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); + } +} + +void start_track_rec ( + correlation_tracker& tracker, + py::object img, + const rectangle& bounding_box +) +{ + drectangle dbounding_box(bounding_box); + start_track(tracker, img, dbounding_box); +} + +double update ( + correlation_tracker& tracker, + py::object img +) +{ + if (is_gray_python_image(img)) + { + return tracker.update(numpy_gray_image(img)); + } + else if (is_rgb_python_image(img)) + { + return tracker.update(numpy_rgb_image(img)); + } + else + { + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); + } +} + +double update_guess ( + correlation_tracker& tracker, + py::object img, + const drectangle& bounding_box +) +{ + if (is_gray_python_image(img)) + { + return tracker.update(numpy_gray_image(img), bounding_box); + } + else if (is_rgb_python_image(img)) + { + return tracker.update(numpy_rgb_image(img), bounding_box); + } + else + { + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); + } +} + +double update_guess_rec ( + correlation_tracker& tracker, + py::object img, + const rectangle& bounding_box +) +{ + drectangle dbounding_box(bounding_box); + return update_guess(tracker, img, dbounding_box); +} + +drectangle get_position (const correlation_tracker& tracker) { return tracker.get_position(); } + +// ---------------------------------------------------------------------------------------- + +void bind_correlation_tracker(py::module &m) +{ + { + typedef correlation_tracker type; + py::class_(m, "correlation_tracker", "This is a tool for tracking moving objects in a video stream. You give it \n\ + the bounding box of an object in the first frame and it attempts to track the \n\ + object in the box from frame to frame. \n\ + This tool is an implementation of the method described in the following paper: \n\ + Danelljan, Martin, et al. 'Accurate scale estimation for robust visual \n\ + tracking.' Proceedings of the British Machine Vision Conference BMVC. 2014.") + .def(py::init()) + .def("start_track", &::start_track, py::arg("image"), py::arg("bounding_box"), "\ + requires \n\ + - image is a numpy ndarray containing either an 8bit grayscale or RGB image. \n\ + - bounding_box.is_empty() == false \n\ + ensures \n\ + - This object will start tracking the thing inside the bounding box in the \n\ + given image. That is, if you call update() with subsequent video frames \n\ + then it will try to keep track of the position of the object inside bounding_box. \n\ + - #get_position() == bounding_box") + .def("start_track", &::start_track_rec, py::arg("image"), py::arg("bounding_box"), "\ + requires \n\ + - image is a numpy ndarray containing either an 8bit grayscale or RGB image. \n\ + - bounding_box.is_empty() == false \n\ + ensures \n\ + - This object will start tracking the thing inside the bounding box in the \n\ + given image. That is, if you call update() with subsequent video frames \n\ + then it will try to keep track of the position of the object inside bounding_box. \n\ + - #get_position() == bounding_box") + .def("update", &::update, py::arg("image"), "\ + requires \n\ + - image is a numpy ndarray containing either an 8bit grayscale or RGB image. \n\ + - get_position().is_empty() == false \n\ + (i.e. you must have started tracking by calling start_track()) \n\ + ensures \n\ + - performs: return update(img, get_position())") + .def("update", &::update_guess, py::arg("image"), py::arg("guess"), "\ + requires \n\ + - image is a numpy ndarray containing either an 8bit grayscale or RGB image. \n\ + - get_position().is_empty() == false \n\ + (i.e. you must have started tracking by calling start_track()) \n\ + ensures \n\ + - When searching for the object in img, we search in the area around the \n\ + provided guess. \n\ + - #get_position() == the new predicted location of the object in img. This \n\ + location will be a copy of guess that has been translated and scaled \n\ + appropriately based on the content of img so that it, hopefully, bounds \n\ + the object in img. \n\ + - Returns the peak to side-lobe ratio. This is a number that measures how \n\ + confident the tracker is that the object is inside #get_position(). \n\ + Larger values indicate higher confidence.") + .def("update", &::update_guess_rec, py::arg("image"), py::arg("guess"), "\ + requires \n\ + - image is a numpy ndarray containing either an 8bit grayscale or RGB image. \n\ + - get_position().is_empty() == false \n\ + (i.e. you must have started tracking by calling start_track()) \n\ + ensures \n\ + - When searching for the object in img, we search in the area around the \n\ + provided guess. \n\ + - #get_position() == the new predicted location of the object in img. This \n\ + location will be a copy of guess that has been translated and scaled \n\ + appropriately based on the content of img so that it, hopefully, bounds \n\ + the object in img. \n\ + - Returns the peak to side-lobe ratio. This is a number that measures how \n\ + confident the tracker is that the object is inside #get_position(). \n\ + Larger values indicate higher confidence.") + .def("get_position", &::get_position, "returns the predicted position of the object under track."); + } +} diff --git a/ml/dlib/tools/python/src/decision_functions.cpp b/ml/dlib/tools/python/src/decision_functions.cpp new file mode 100644 index 000000000..a93fe49b9 --- /dev/null +++ b/ml/dlib/tools/python/src/decision_functions.cpp @@ -0,0 +1,263 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include +#include "testing_results.h" +#include + +using namespace dlib; +using namespace std; + +namespace py = pybind11; + +typedef matrix sample_type; +typedef std::vector > sparse_vect; + +template +double predict ( + const decision_function& df, + const typename decision_function::kernel_type::sample_type& samp +) +{ + typedef typename decision_function::kernel_type::sample_type T; + if (df.basis_vectors.size() == 0) + { + return 0; + } + else if (is_matrix::value && df.basis_vectors(0).size() != samp.size()) + { + std::ostringstream sout; + sout << "Input vector should have " << df.basis_vectors(0).size() + << " dimensions, not " << samp.size() << "."; + PyErr_SetString( PyExc_ValueError, sout.str().c_str() ); + throw py::error_already_set(); + } + return df(samp); +} + +template +void add_df ( + py::module& m, + const std::string name +) +{ + typedef decision_function df_type; + py::class_(m, name.c_str()) + .def("__call__", &predict) + .def(py::pickle(&getstate, &setstate)); +} + +template +typename df_type::sample_type get_weights( + const df_type& df +) +{ + if (df.basis_vectors.size() == 0) + { + PyErr_SetString( PyExc_ValueError, "Decision function is empty." ); + throw py::error_already_set(); + } + df_type temp = simplify_linear_decision_function(df); + return temp.basis_vectors(0); +} + +template +typename df_type::scalar_type get_bias( + const df_type& df +) +{ + if (df.basis_vectors.size() == 0) + { + PyErr_SetString( PyExc_ValueError, "Decision function is empty." ); + throw py::error_already_set(); + } + return df.b; +} + +template +void set_bias( + df_type& df, + double b +) +{ + if (df.basis_vectors.size() == 0) + { + PyErr_SetString( PyExc_ValueError, "Decision function is empty." ); + throw py::error_already_set(); + } + df.b = b; +} + +template +void add_linear_df ( + py::module &m, + const std::string name +) +{ + typedef decision_function df_type; + py::class_(m, name.c_str()) + .def("__call__", predict) + .def_property_readonly("weights", &get_weights) + .def_property("bias", get_bias, set_bias) + .def(py::pickle(&getstate, &setstate)); +} + +// ---------------------------------------------------------------------------------------- + +std::string binary_test__str__(const binary_test& item) +{ + std::ostringstream sout; + sout << "class1_accuracy: "<< item.class1_accuracy << " class2_accuracy: "<< item.class2_accuracy; + return sout.str(); +} +std::string binary_test__repr__(const binary_test& item) { return "< " + binary_test__str__(item) + " >";} + +std::string regression_test__str__(const regression_test& item) +{ + std::ostringstream sout; + sout << "mean_squared_error: "<< item.mean_squared_error << " R_squared: "<< item.R_squared; + sout << " mean_average_error: "<< item.mean_average_error << " mean_error_stddev: "<< item.mean_error_stddev; + return sout.str(); +} +std::string regression_test__repr__(const regression_test& item) { return "< " + regression_test__str__(item) + " >";} + +std::string ranking_test__str__(const ranking_test& item) +{ + std::ostringstream sout; + sout << "ranking_accuracy: "<< item.ranking_accuracy << " mean_ap: "<< item.mean_ap; + return sout.str(); +} +std::string ranking_test__repr__(const ranking_test& item) { return "< " + ranking_test__str__(item) + " >";} + +// ---------------------------------------------------------------------------------------- + +template +binary_test _test_binary_decision_function ( + const decision_function& dec_funct, + const std::vector& x_test, + const std::vector& y_test +) { return binary_test(test_binary_decision_function(dec_funct, x_test, y_test)); } + +template +regression_test _test_regression_function ( + const decision_function& reg_funct, + const std::vector& x_test, + const std::vector& y_test +) { return regression_test(test_regression_function(reg_funct, x_test, y_test)); } + +template < typename K > +ranking_test _test_ranking_function1 ( + const decision_function& funct, + const std::vector >& samples +) { return ranking_test(test_ranking_function(funct, samples)); } + +template < typename K > +ranking_test _test_ranking_function2 ( + const decision_function& funct, + const ranking_pair& sample +) { return ranking_test(test_ranking_function(funct, sample)); } + + +void bind_decision_functions(py::module &m) +{ + add_linear_df >(m, "_decision_function_linear"); + add_linear_df >(m, "_decision_function_sparse_linear"); + + add_df >(m, "_decision_function_histogram_intersection"); + add_df >(m, "_decision_function_sparse_histogram_intersection"); + + add_df >(m, "_decision_function_polynomial"); + add_df >(m, "_decision_function_sparse_polynomial"); + + add_df >(m, "_decision_function_radial_basis"); + add_df >(m, "_decision_function_sparse_radial_basis"); + + add_df >(m, "_decision_function_sigmoid"); + add_df >(m, "_decision_function_sparse_sigmoid"); + + + m.def("test_binary_decision_function", _test_binary_decision_function >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + + m.def("test_regression_function", _test_regression_function >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + + m.def("test_ranking_function", _test_ranking_function1 >, + py::arg("function"), py::arg("samples")); + m.def("test_ranking_function", _test_ranking_function1 >, + py::arg("function"), py::arg("samples")); + m.def("test_ranking_function", _test_ranking_function2 >, + py::arg("function"), py::arg("sample")); + m.def("test_ranking_function", _test_ranking_function2 >, + py::arg("function"), py::arg("sample")); + + + py::class_(m, "_binary_test") + .def("__str__", binary_test__str__) + .def("__repr__", binary_test__repr__) + .def_readwrite("class1_accuracy", &binary_test::class1_accuracy, + "A value between 0 and 1, measures accuracy on the +1 class.") + .def_readwrite("class2_accuracy", &binary_test::class2_accuracy, + "A value between 0 and 1, measures accuracy on the -1 class."); + + py::class_(m, "_ranking_test") + .def("__str__", ranking_test__str__) + .def("__repr__", ranking_test__repr__) + .def_readwrite("ranking_accuracy", &ranking_test::ranking_accuracy, + "A value between 0 and 1, measures the fraction of times a relevant sample was ordered before a non-relevant sample.") + .def_readwrite("mean_ap", &ranking_test::mean_ap, + "A value between 0 and 1, measures the mean average precision of the ranking."); + + py::class_(m, "_regression_test") + .def("__str__", regression_test__str__) + .def("__repr__", regression_test__repr__) + .def_readwrite("mean_average_error", ®ression_test::mean_average_error, + "The mean average error of a regression function on a dataset.") + .def_readwrite("mean_error_stddev", ®ression_test::mean_error_stddev, + "The standard deviation of the absolute value of the error of a regression function on a dataset.") + .def_readwrite("mean_squared_error", ®ression_test::mean_squared_error, + "The mean squared error of a regression function on a dataset.") + .def_readwrite("R_squared", ®ression_test::R_squared, + "A value between 0 and 1, measures the squared correlation between the output of a \n" + "regression function and the target values."); +} + + + diff --git a/ml/dlib/tools/python/src/dlib.cpp b/ml/dlib/tools/python/src/dlib.cpp new file mode 100644 index 000000000..ac6fea0db --- /dev/null +++ b/ml/dlib/tools/python/src/dlib.cpp @@ -0,0 +1,110 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include +#include +#include + +namespace py = pybind11; + +void bind_matrix(py::module& m); +void bind_vector(py::module& m); +void bind_svm_c_trainer(py::module& m); +void bind_decision_functions(py::module& m); +void bind_basic_types(py::module& m); +void bind_other(py::module& m); +void bind_svm_rank_trainer(py::module& m); +void bind_cca(py::module& m); +void bind_sequence_segmenter(py::module& m); +void bind_svm_struct(py::module& m); +void bind_image_classes(py::module& m); +void bind_rectangles(py::module& m); +void bind_object_detection(py::module& m); +void bind_shape_predictors(py::module& m); +void bind_correlation_tracker(py::module& m); +void bind_face_recognition(py::module& m); +void bind_cnn_face_detection(py::module& m); +void bind_global_optimization(py::module& m); +void bind_numpy_returns(py::module& m); +void bind_image_dataset_metadata(py::module& m); + +#ifndef DLIB_NO_GUI_SUPPORT +void bind_gui(py::module& m); +#endif + +PYBIND11_MODULE(dlib, m) +{ + warn_about_unavailable_but_used_cpu_instructions(); + + +#define DLIB_QUOTE_STRING(x) DLIB_QUOTE_STRING2(x) +#define DLIB_QUOTE_STRING2(x) #x + m.attr("__version__") = DLIB_QUOTE_STRING(DLIB_VERSION); + m.attr("__time_compiled__") = std::string(__DATE__) + " " + std::string(__TIME__); + +#ifdef DLIB_USE_CUDA + m.attr("DLIB_USE_CUDA") = true; +#else + m.attr("DLIB_USE_CUDA") = false; +#endif +#ifdef DLIB_USE_BLAS + m.attr("DLIB_USE_BLAS") = true; +#else + m.attr("DLIB_USE_BLAS") = false; +#endif +#ifdef DLIB_USE_LAPACK + m.attr("DLIB_USE_LAPACK") = true; +#else + m.attr("DLIB_USE_LAPACK") = false; +#endif +#ifdef DLIB_HAVE_AVX + m.attr("USE_AVX_INSTRUCTIONS") = true; +#else + m.attr("USE_AVX_INSTRUCTIONS") = false; +#endif +#ifdef DLIB_HAVE_NEON + m.attr("USE_NEON_INSTRUCTIONS") = true; +#else + m.attr("USE_NEON_INSTRUCTIONS") = false; +#endif + + + + // Note that the order here matters. We need to do the basic types first. If we don't + // then what happens is the documentation created by sphinx will use horrible big + // template names to refer to C++ objects rather than the python names python users + // will expect. For instance, if bind_basic_types() isn't called early then when + // routines take a std::vector, rather than saying dlib.array in the python + // docs it will say "std::vector >" which is awful and + // confusing to python users. + // + // So when adding new things always add them to the end of the list. + bind_matrix(m); + bind_vector(m); + bind_basic_types(m); + bind_other(m); + + bind_svm_rank_trainer(m); + bind_decision_functions(m); + bind_cca(m); + bind_sequence_segmenter(m); + bind_svm_struct(m); + bind_image_classes(m); + bind_rectangles(m); + bind_object_detection(m); + bind_shape_predictors(m); + bind_correlation_tracker(m); + bind_face_recognition(m); + bind_cnn_face_detection(m); + bind_global_optimization(m); + bind_numpy_returns(m); + bind_svm_c_trainer(m); +#ifndef DLIB_NO_GUI_SUPPORT + bind_gui(m); +#endif + + bind_image_dataset_metadata(m); + + +} diff --git a/ml/dlib/tools/python/src/face_recognition.cpp b/ml/dlib/tools/python/src/face_recognition.cpp new file mode 100644 index 000000000..8d5dee678 --- /dev/null +++ b/ml/dlib/tools/python/src/face_recognition.cpp @@ -0,0 +1,245 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include +#include +#include +#include +#include +#include "indexing.h" +#include +#include +#include + + +using namespace dlib; +using namespace std; + +namespace py = pybind11; + + +typedef matrix cv; + +class face_recognition_model_v1 +{ + +public: + + face_recognition_model_v1(const std::string& model_filename) + { + deserialize(model_filename) >> net; + } + + matrix compute_face_descriptor ( + py::object img, + const full_object_detection& face, + const int num_jitters + ) + { + std::vector faces(1, face); + return compute_face_descriptors(img, faces, num_jitters)[0]; + } + + std::vector> compute_face_descriptors ( + py::object img, + const std::vector& faces, + const int num_jitters + ) + { + if (!is_rgb_python_image(img)) + throw dlib::error("Unsupported image type, must be RGB image."); + + for (auto& f : faces) + { + if (f.num_parts() != 68 && f.num_parts() != 5) + throw dlib::error("The full_object_detection must use the iBUG 300W 68 point face landmark style or dlib's 5 point style."); + } + + + std::vector dets; + for (auto& f : faces) + dets.push_back(get_face_chip_details(f, 150, 0.25)); + dlib::array> face_chips; + extract_image_chips(numpy_rgb_image(img), dets, face_chips); + + std::vector> face_descriptors; + face_descriptors.reserve(face_chips.size()); + + if (num_jitters <= 1) + { + // extract descriptors and convert from float vectors to double vectors + for (auto& d : net(face_chips,16)) + face_descriptors.push_back(matrix_cast(d)); + } + else + { + for (auto& fimg : face_chips) + face_descriptors.push_back(matrix_cast(mean(mat(net(jitter_image(fimg,num_jitters),16))))); + } + + return face_descriptors; + } + +private: + + dlib::rand rnd; + + std::vector> jitter_image( + const matrix& img, + const int num_jitters + ) + { + std::vector> crops; + for (int i = 0; i < num_jitters; ++i) + crops.push_back(dlib::jitter_image(img,rnd)); + return crops; + } + + + template