diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-03-09 13:19:48 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-03-09 13:20:02 +0000 |
commit | 58daab21cd043e1dc37024a7f99b396788372918 (patch) | |
tree | 96771e43bb69f7c1c2b0b4f7374cb74d7866d0cb /ml/dlib/tools | |
parent | Releasing debian version 1.43.2-1. (diff) | |
download | netdata-58daab21cd043e1dc37024a7f99b396788372918.tar.xz netdata-58daab21cd043e1dc37024a7f99b396788372918.zip |
Merging upstream version 1.44.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'ml/dlib/tools')
79 files changed, 14753 insertions, 0 deletions
diff --git a/ml/dlib/tools/archive/train_face_5point_model.cpp b/ml/dlib/tools/archive/train_face_5point_model.cpp new file mode 100644 index 000000000..0cd35467f --- /dev/null +++ b/ml/dlib/tools/archive/train_face_5point_model.cpp @@ -0,0 +1,159 @@ + +/* + + This is the program that created the http://dlib.net/files/shape_predictor_5_face_landmarks.dat.bz2 model file. + +*/ + + +#include <dlib/image_processing/frontal_face_detector.h> +#include <dlib/image_processing.h> +#include <dlib/console_progress_indicator.h> +#include <dlib/data_io.h> +#include <dlib/statistics.h> +#include <iostream> + +using namespace dlib; +using namespace std; + +// ---------------------------------------------------------------------------------------- + +std::vector<std::vector<double> > get_interocular_distances ( + const std::vector<std::vector<full_object_detection> >& objects +); +/*! + ensures + - returns an object D such that: + - D[i][j] == the distance, in pixels, between the eyes for the face represented + by objects[i][j]. +!*/ + +// ---------------------------------------------------------------------------------------- + +template < + typename image_array_type, + typename T + > +void add_image_left_right_flips_5points ( + image_array_type& images, + std::vector<std::vector<T> >& objects +) +{ + // make sure requires clause is not broken + DLIB_ASSERT( images.size() == objects.size(), + "\t void add_image_left_right_flips()" + << "\n\t Invalid inputs were given to this function." + << "\n\t images.size(): " << images.size() + << "\n\t objects.size(): " << objects.size() + ); + + typename image_array_type::value_type temp; + std::vector<T> rects; + + const unsigned long num = images.size(); + for (unsigned long j = 0; j < num; ++j) + { + const point_transform_affine tran = flip_image_left_right(images[j], temp); + + rects.clear(); + for (unsigned long i = 0; i < objects[j].size(); ++i) + { + rects.push_back(impl::tform_object(tran, objects[j][i])); + + DLIB_CASSERT(rects.back().num_parts() == 5); + swap(rects.back().part(0), rects.back().part(2)); + swap(rects.back().part(1), rects.back().part(3)); + } + + images.push_back(temp); + objects.push_back(rects); + } +} + +// ---------------------------------------------------------------------------------------- + +int main(int argc, char** argv) +{ + try + { + if (argc != 2) + { + cout << "give the path to the training data folder" << endl; + return 0; + } + const std::string faces_directory = argv[1]; + dlib::array<array2d<unsigned char> > images_train, images_test; + std::vector<std::vector<full_object_detection> > faces_train, faces_test; + + std::vector<std::string> parts_list; + load_image_dataset(images_train, faces_train, faces_directory+"/train_cleaned.xml", parts_list); + load_image_dataset(images_test, faces_test, faces_directory+"/test_cleaned.xml"); + + add_image_left_right_flips_5points(images_train, faces_train); + add_image_left_right_flips_5points(images_test, faces_test); + add_image_rotations(linspace(-20,20,3)*pi/180.0,images_train, faces_train); + + cout << "num training images: "<< images_train.size() << endl; + + for (auto& part : parts_list) + cout << part << endl; + + shape_predictor_trainer trainer; + trainer.set_oversampling_amount(40); + trainer.set_num_test_splits(150); + trainer.set_feature_pool_size(800); + trainer.set_num_threads(4); + trainer.set_cascade_depth(15); + trainer.be_verbose(); + + // Now finally generate the shape model + shape_predictor sp = trainer.train(images_train, faces_train); + + serialize("shape_predictor_5_face_landmarks.dat") << sp; + + cout << "mean training error: "<< + test_shape_predictor(sp, images_train, faces_train, get_interocular_distances(faces_train)) << endl; + + cout << "mean testing error: "<< + test_shape_predictor(sp, images_test, faces_test, get_interocular_distances(faces_test)) << endl; + + } + catch (exception& e) + { + cout << "\nexception thrown!" << endl; + cout << e.what() << endl; + } +} + +// ---------------------------------------------------------------------------------------- + +double interocular_distance ( + const full_object_detection& det +) +{ + dlib::vector<double,2> l, r; + // left eye + l = (det.part(0) + det.part(1))/2; + // right eye + r = (det.part(2) + det.part(3))/2; + + return length(l-r); +} + +std::vector<std::vector<double> > get_interocular_distances ( + const std::vector<std::vector<full_object_detection> >& objects +) +{ + std::vector<std::vector<double> > temp(objects.size()); + for (unsigned long i = 0; i < objects.size(); ++i) + { + for (unsigned long j = 0; j < objects[i].size(); ++j) + { + temp[i].push_back(interocular_distance(objects[i][j])); + } + } + return temp; +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/convert_dlib_nets_to_caffe/CMakeLists.txt b/ml/dlib/tools/convert_dlib_nets_to_caffe/CMakeLists.txt new file mode 100644 index 000000000..f9518df21 --- /dev/null +++ b/ml/dlib/tools/convert_dlib_nets_to_caffe/CMakeLists.txt @@ -0,0 +1,25 @@ +# +# This is a CMake makefile. You can find the cmake utility and +# information about it at http://www.cmake.org +# + +cmake_minimum_required(VERSION 2.8.12) + +set (target_name dtoc) + +PROJECT(${target_name}) + +add_subdirectory(../../dlib dlib_build) + +add_executable(${target_name} + main.cpp + ) + +target_link_libraries(${target_name} dlib::dlib ) + + +INSTALL(TARGETS ${target_name} + RUNTIME DESTINATION bin + ) + + diff --git a/ml/dlib/tools/convert_dlib_nets_to_caffe/main.cpp b/ml/dlib/tools/convert_dlib_nets_to_caffe/main.cpp new file mode 100644 index 000000000..f5cc19748 --- /dev/null +++ b/ml/dlib/tools/convert_dlib_nets_to_caffe/main.cpp @@ -0,0 +1,792 @@ + +#include <dlib/xml_parser.h> +#include <dlib/matrix.h> +#include <fstream> +#include <vector> +#include <stack> +#include <set> +#include <dlib/string.h> + +using namespace std; +using namespace dlib; + + +// ---------------------------------------------------------------------------------------- + +// Only these computational layers have parameters +const std::set<string> comp_tags_with_params = {"fc", "fc_no_bias", "con", "affine_con", "affine_fc", "affine", "prelu"}; + +struct layer +{ + string type; // comp, loss, or input + int idx; + + matrix<long,4,1> output_tensor_shape; // (N,K,NR,NC) + + string detail_name; // The name of the tag inside the layer tag. e.g. fc, con, max_pool, input_rgb_image. + std::map<string,double> attributes; + matrix<float> params; + long tag_id = -1; // If this isn't -1 then it means this layer was tagged, e.g. wrapped with tag2<> giving tag_id==2 + long skip_id = -1; // If this isn't -1 then it means this layer draws its inputs from + // the most recent layer with tag_id==skip_id rather than its immediate predecessor. + + double attribute (const string& key) const + { + auto i = attributes.find(key); + if (i != attributes.end()) + return i->second; + else + throw dlib::error("Layer doesn't have the requested attribute '" + key + "'."); + } + + string caffe_layer_name() const + { + if (type == "input") + return "data"; + else + return detail_name+to_string(idx); + } +}; + +// ---------------------------------------------------------------------------------------- + +std::vector<layer> parse_dlib_xml( + const matrix<long,4,1>& input_tensor_shape, + const string& xml_filename +); + +// ---------------------------------------------------------------------------------------- + +template <typename iterator> +const layer& find_layer ( + iterator i, + long tag_id +) +/*! + requires + - i is a reverse iterator pointing to a layer in the list of layers produced by parse_dlib_xml(). + - i is not an input layer. + ensures + - if (tag_id == -1) then + - returns the previous layer (i.e. closer to the input) to layer i. + - else + - returns the previous layer (i.e. closer to the input) to layer i with the + given tag_id. +!*/ +{ + if (tag_id == -1) + { + return *(i-1); + } + else + { + while(true) + { + i--; + // if we hit the end of the network before we found what we were looking for + if (i->tag_id == tag_id) + return *i; + if (i->type == "input") + throw dlib::error("Network definition is bad, a layer wanted to skip back to a non-existing layer."); + } + } +} + +template <typename iterator> +const layer& find_input_layer (iterator i) { return find_layer(i, i->skip_id); } + +template <typename iterator> +string find_layer_caffe_name ( + iterator i, + long tag_id +) +{ + return find_layer(i,tag_id).caffe_layer_name(); +} + +template <typename iterator> +string find_input_layer_caffe_name (iterator i) { return find_input_layer(i).caffe_layer_name(); } + +// ---------------------------------------------------------------------------------------- + +template <typename iterator> +void compute_caffe_padding_size_for_pooling_layer( + const iterator& i, + long& pad_x, + long& pad_y +) +/*! + requires + - i is a reverse iterator pointing to a layer in the list of layers produced by parse_dlib_xml(). + - i is not an input layer. + ensures + - Caffe is funny about how it computes the output sizes from pooling layers. + Rather than using the normal formula for output row/column sizes used by all the + other layers (and what dlib uses everywhere), + floor((bottom_size + 2*pad - kernel_size) / stride) + 1 + it instead uses: + ceil((bottom_size + 2*pad - kernel_size) / stride) + 1 + + These are the same except when the stride!=1. In that case we need to figure out + how to change the padding value so that the output size of the caffe padding + layer will match the output size of the dlib padding layer. That is what this + function does. +!*/ +{ + const long dlib_output_nr = i->output_tensor_shape(2); + const long dlib_output_nc = i->output_tensor_shape(3); + const long bottom_nr = find_input_layer(i).output_tensor_shape(2); + const long bottom_nc = find_input_layer(i).output_tensor_shape(3); + const long padding_x = (long)i->attribute("padding_x"); + const long padding_y = (long)i->attribute("padding_y"); + const long stride_x = (long)i->attribute("stride_x"); + const long stride_y = (long)i->attribute("stride_y"); + long kernel_w = i->attribute("nc"); + long kernel_h = i->attribute("nr"); + + if (kernel_w == 0) + kernel_w = bottom_nc; + if (kernel_h == 0) + kernel_h = bottom_nr; + + + // The correct padding for caffe could be anything in the range [0,padding_x]. So + // check what gives the correct output size and use that. + for (pad_x = 0; pad_x <= padding_x; ++pad_x) + { + long caffe_out_size = ceil((bottom_nc + 2.0*pad_x - kernel_w)/(double)stride_x) + 1; + if (caffe_out_size == dlib_output_nc) + break; + } + if (pad_x == padding_x+1) + { + std::ostringstream sout; + sout << "No conversion between dlib pooling layer parameters and caffe pooling layer parameters found for layer " << to_string(i->idx) << endl; + sout << "dlib_output_nc: " << dlib_output_nc << endl; + sout << "bottom_nc: " << bottom_nc << endl; + sout << "padding_x: " << padding_x << endl; + sout << "stride_x: " << stride_x << endl; + sout << "kernel_w: " << kernel_w << endl; + sout << "pad_x: " << pad_x << endl; + throw dlib::error(sout.str()); + } + + for (pad_y = 0; pad_y <= padding_y; ++pad_y) + { + long caffe_out_size = ceil((bottom_nr + 2.0*pad_y - kernel_h)/(double)stride_y) + 1; + if (caffe_out_size == dlib_output_nr) + break; + } + if (pad_y == padding_y+1) + { + std::ostringstream sout; + sout << "No conversion between dlib pooling layer parameters and caffe pooling layer parameters found for layer " << to_string(i->idx) << endl; + sout << "dlib_output_nr: " << dlib_output_nr << endl; + sout << "bottom_nr: " << bottom_nr << endl; + sout << "padding_y: " << padding_y << endl; + sout << "stride_y: " << stride_y << endl; + sout << "kernel_h: " << kernel_h << endl; + sout << "pad_y: " << pad_y << endl; + throw dlib::error(sout.str()); + } +} + +// ---------------------------------------------------------------------------------------- + +void convert_dlib_xml_to_caffe_python_code( + const string& xml_filename, + const long N, + const long K, + const long NR, + const long NC +) +{ + const string out_filename = left_substr(xml_filename,".") + "_dlib_to_caffe_model.py"; + const string out_weights_filename = left_substr(xml_filename,".") + "_dlib_to_caffe_model.weights"; + cout << "Writing python part of model to " << out_filename << endl; + cout << "Writing weights part of model to " << out_weights_filename << endl; + ofstream fout(out_filename); + fout.precision(9); + const auto layers = parse_dlib_xml({N,K,NR,NC}, xml_filename); + + + fout << "#\n"; + fout << "# !!! This file was automatically generated by dlib's tools/convert_dlib_nets_to_caffe utility. !!!\n"; + fout << "# !!! It contains all the information from a dlib DNN network and lets you save it as a cafe model. !!!\n"; + fout << "#\n"; + fout << "import caffe " << endl; + fout << "from caffe import layers as L, params as P" << endl; + fout << "import numpy as np" << endl; + + // dlib nets don't commit to a batch size, so just use 1 as the default + fout << "\n# Input tensor dimensions" << endl; + fout << "input_batch_size = " << N << ";" << endl; + if (layers.back().detail_name == "input_rgb_image") + { + fout << "input_num_channels = 3;" << endl; + fout << "input_num_rows = "<<NR<<";" << endl; + fout << "input_num_cols = "<<NC<<";" << endl; + if (K != 3) + throw dlib::error("The dlib model requires input tensors with NUM_CHANNELS==3, but the dtoc command line specified NUM_CHANNELS=="+to_string(K)); + } + else if (layers.back().detail_name == "input_rgb_image_sized") + { + fout << "input_num_channels = 3;" << endl; + fout << "input_num_rows = " << layers.back().attribute("nr") << ";" << endl; + fout << "input_num_cols = " << layers.back().attribute("nc") << ";" << endl; + if (NR != layers.back().attribute("nr")) + throw dlib::error("The dlib model requires input tensors with NUM_ROWS=="+to_string((long)layers.back().attribute("nr"))+", but the dtoc command line specified NUM_ROWS=="+to_string(NR)); + if (NC != layers.back().attribute("nc")) + throw dlib::error("The dlib model requires input tensors with NUM_COLUMNS=="+to_string((long)layers.back().attribute("nc"))+", but the dtoc command line specified NUM_COLUMNS=="+to_string(NC)); + if (K != 3) + throw dlib::error("The dlib model requires input tensors with NUM_CHANNELS==3, but the dtoc command line specified NUM_CHANNELS=="+to_string(K)); + } + else if (layers.back().detail_name == "input") + { + fout << "input_num_channels = 1;" << endl; + fout << "input_num_rows = "<<NR<<";" << endl; + fout << "input_num_cols = "<<NC<<";" << endl; + if (K != 1) + throw dlib::error("The dlib model requires input tensors with NUM_CHANNELS==1, but the dtoc command line specified NUM_CHANNELS=="+to_string(K)); + } + else + { + throw dlib::error("No known transformation from dlib's " + layers.back().detail_name + " layer to caffe."); + } + fout << endl; + fout << "# Call this function to write the dlib DNN model out to file as a pair of caffe\n"; + fout << "# definition and weight files. You can then use the network by loading it with\n"; + fout << "# this statement: \n"; + fout << "# net = caffe.Net(def_file, weights_file, caffe.TEST);\n"; + fout << "#\n"; + fout << "def save_as_caffe_model(def_file, weights_file):\n"; + fout << " with open(def_file, 'w') as f: f.write(str(make_netspec()));\n"; + fout << " net = caffe.Net(def_file, caffe.TEST);\n"; + fout << " set_network_weights(net);\n"; + fout << " net.save(weights_file);\n\n"; + fout << "###############################################################################\n"; + fout << "# EVERYTHING BELOW HERE DEFINES THE DLIB MODEL PARAMETERS #\n"; + fout << "###############################################################################\n\n\n"; + + + // ----------------------------------------------------------------------------------- + // The next block of code outputs python code that defines the network architecture. + // ----------------------------------------------------------------------------------- + + fout << "def make_netspec():" << endl; + fout << " # For reference, the only \"documentation\" about caffe layer parameters seems to be this page:\n"; + fout << " # https://github.com/BVLC/caffe/blob/master/src/caffe/proto/caffe.proto\n" << endl; + fout << " n = caffe.NetSpec(); " << endl; + fout << " n.data,n.label = L.MemoryData(batch_size=input_batch_size, channels=input_num_channels, height=input_num_rows, width=input_num_cols, ntop=2)" << endl; + // iterate the layers starting with the input layer + for (auto i = layers.rbegin(); i != layers.rend(); ++i) + { + // skip input and loss layers + if (i->type == "loss" || i->type == "input") + continue; + + + if (i->detail_name == "con") + { + fout << " n." << i->caffe_layer_name() << " = L.Convolution(n." << find_input_layer_caffe_name(i); + fout << ", num_output=" << i->attribute("num_filters"); + fout << ", kernel_w=" << i->attribute("nc"); + fout << ", kernel_h=" << i->attribute("nr"); + fout << ", stride_w=" << i->attribute("stride_x"); + fout << ", stride_h=" << i->attribute("stride_y"); + fout << ", pad_w=" << i->attribute("padding_x"); + fout << ", pad_h=" << i->attribute("padding_y"); + fout << ");\n"; + } + else if (i->detail_name == "relu") + { + fout << " n." << i->caffe_layer_name() << " = L.ReLU(n." << find_input_layer_caffe_name(i); + fout << ");\n"; + } + else if (i->detail_name == "sig") + { + fout << " n." << i->caffe_layer_name() << " = L.Sigmoid(n." << find_input_layer_caffe_name(i); + fout << ");\n"; + } + else if (i->detail_name == "prelu") + { + fout << " n." << i->caffe_layer_name() << " = L.PReLU(n." << find_input_layer_caffe_name(i); + fout << ", channel_shared=True"; + fout << ");\n"; + } + else if (i->detail_name == "max_pool") + { + fout << " n." << i->caffe_layer_name() << " = L.Pooling(n." << find_input_layer_caffe_name(i); + fout << ", pool=P.Pooling.MAX"; + if (i->attribute("nc")==0) + { + fout << ", global_pooling=True"; + } + else + { + fout << ", kernel_w=" << i->attribute("nc"); + fout << ", kernel_h=" << i->attribute("nr"); + } + + fout << ", stride_w=" << i->attribute("stride_x"); + fout << ", stride_h=" << i->attribute("stride_y"); + long pad_x, pad_y; + compute_caffe_padding_size_for_pooling_layer(i, pad_x, pad_y); + fout << ", pad_w=" << pad_x; + fout << ", pad_h=" << pad_y; + fout << ");\n"; + } + else if (i->detail_name == "avg_pool") + { + fout << " n." << i->caffe_layer_name() << " = L.Pooling(n." << find_input_layer_caffe_name(i); + fout << ", pool=P.Pooling.AVE"; + if (i->attribute("nc")==0) + { + fout << ", global_pooling=True"; + } + else + { + fout << ", kernel_w=" << i->attribute("nc"); + fout << ", kernel_h=" << i->attribute("nr"); + } + if (i->attribute("padding_x") != 0 || i->attribute("padding_y") != 0) + { + throw dlib::error("dlib and caffe implement pooling with non-zero padding differently, so you can't convert a " + "network with such pooling layers."); + } + + fout << ", stride_w=" << i->attribute("stride_x"); + fout << ", stride_h=" << i->attribute("stride_y"); + long pad_x, pad_y; + compute_caffe_padding_size_for_pooling_layer(i, pad_x, pad_y); + fout << ", pad_w=" << pad_x; + fout << ", pad_h=" << pad_y; + fout << ");\n"; + } + else if (i->detail_name == "fc") + { + fout << " n." << i->caffe_layer_name() << " = L.InnerProduct(n." << find_input_layer_caffe_name(i); + fout << ", num_output=" << i->attribute("num_outputs"); + fout << ", bias_term=True"; + fout << ");\n"; + } + else if (i->detail_name == "fc_no_bias") + { + fout << " n." << i->caffe_layer_name() << " = L.InnerProduct(n." << find_input_layer_caffe_name(i); + fout << ", num_output=" << i->attribute("num_outputs"); + fout << ", bias_term=False"; + fout << ");\n"; + } + else if (i->detail_name == "bn_con" || i->detail_name == "bn_fc") + { + throw dlib::error("Conversion from dlib's batch norm layers to caffe's isn't supported. Instead, " + "you should put your dlib network into 'test mode' by switching batch norm layers to affine layers. " + "Then you can convert that 'test mode' network to caffe."); + } + else if (i->detail_name == "affine_con") + { + fout << " n." << i->caffe_layer_name() << " = L.Scale(n." << find_input_layer_caffe_name(i); + fout << ", bias_term=True"; + fout << ");\n"; + } + else if (i->detail_name == "affine_fc") + { + fout << " n." << i->caffe_layer_name() << " = L.Scale(n." << find_input_layer_caffe_name(i); + fout << ", bias_term=True"; + fout << ");\n"; + } + else if (i->detail_name == "add_prev") + { + auto in_shape1 = find_input_layer(i).output_tensor_shape; + auto in_shape2 = find_layer(i,i->attribute("tag")).output_tensor_shape; + if (in_shape1 != in_shape2) + { + // if only the number of channels differs then we will use a dummy layer to + // pad with zeros. But otherwise we will throw an error. + if (in_shape1(0) == in_shape2(0) && + in_shape1(2) == in_shape2(2) && + in_shape1(3) == in_shape2(3)) + { + fout << " n." << i->caffe_layer_name() << "_zeropad = L.DummyData(num=" << in_shape1(0); + fout << ", channels="<<std::abs(in_shape1(1)-in_shape2(1)); + fout << ", height="<<in_shape1(2); + fout << ", width="<<in_shape1(3); + fout << ");\n"; + + string smaller_layer = find_input_layer_caffe_name(i); + string bigger_layer = find_layer_caffe_name(i, i->attribute("tag")); + if (in_shape1(1) > in_shape2(1)) + swap(smaller_layer, bigger_layer); + + fout << " n." << i->caffe_layer_name() << "_concat = L.Concat(n." << smaller_layer; + fout << ", n." << i->caffe_layer_name() << "_zeropad"; + fout << ");\n"; + + fout << " n." << i->caffe_layer_name() << " = L.Eltwise(n." << i->caffe_layer_name() << "_concat"; + fout << ", n." << bigger_layer; + fout << ", operation=P.Eltwise.SUM"; + fout << ");\n"; + } + else + { + std::ostringstream sout; + sout << "The dlib network contained an add_prev layer (layer idx " << i->idx << ") that adds two previous "; + sout << "layers with different output tensor dimensions. Caffe's equivalent layer, Eltwise, doesn't support "; + sout << "adding layers together with different dimensions. In the special case where the only difference is "; + sout << "in the number of channels, this converter program will add a dummy layer that outputs a tensor full of zeros "; + sout << "and concat it appropriately so this will work. However, this network you are converting has tensor dimensions "; + sout << "different in values other than the number of channels. In particular, here are the two tensor shapes (batch size, channels, rows, cols): "; + std::ostringstream sout2; + sout2 << wrap_string(sout.str()) << endl; + sout2 << trans(in_shape1); + sout2 << trans(in_shape2); + throw dlib::error(sout2.str()); + } + } + else + { + fout << " n." << i->caffe_layer_name() << " = L.Eltwise(n." << find_input_layer_caffe_name(i); + fout << ", n." << find_layer_caffe_name(i, i->attribute("tag")); + fout << ", operation=P.Eltwise.SUM"; + fout << ");\n"; + } + } + else + { + throw dlib::error("No known transformation from dlib's " + i->detail_name + " layer to caffe."); + } + } + fout << " return n.to_proto();\n\n" << endl; + + + // ----------------------------------------------------------------------------------- + // The next block of code outputs python code that populates all the filter weights. + // ----------------------------------------------------------------------------------- + + ofstream fweights(out_weights_filename, ios::binary); + fout << "def set_network_weights(net):\n"; + fout << " # populate network parameters\n"; + fout << " f = open('"<<out_weights_filename<<"', 'rb');\n"; + // iterate the layers starting with the input layer + for (auto i = layers.rbegin(); i != layers.rend(); ++i) + { + // skip input and loss layers + if (i->type == "loss" || i->type == "input") + continue; + + + if (i->detail_name == "con") + { + const long num_filters = i->attribute("num_filters"); + matrix<float> weights = trans(rowm(i->params,range(0,i->params.size()-num_filters-1))); + matrix<float> biases = trans(rowm(i->params,range(i->params.size()-num_filters, i->params.size()-1))); + fweights.write((char*)&weights(0,0), weights.size()*sizeof(float)); + fweights.write((char*)&biases(0,0), biases.size()*sizeof(float)); + + // main filter weights + fout << " p = np.fromfile(f, dtype='float32', count="<<weights.size()<<");\n"; + fout << " p.shape = net.params['"<<i->caffe_layer_name()<<"'][0].data.shape;\n"; + fout << " net.params['"<<i->caffe_layer_name()<<"'][0].data[:] = p;\n"; + + // biases + fout << " p = np.fromfile(f, dtype='float32', count="<<biases.size()<<");\n"; + fout << " p.shape = net.params['"<<i->caffe_layer_name()<<"'][1].data.shape;\n"; + fout << " net.params['"<<i->caffe_layer_name()<<"'][1].data[:] = p;\n"; + } + else if (i->detail_name == "fc") + { + matrix<float> weights = trans(rowm(i->params, range(0,i->params.nr()-2))); + matrix<float> biases = rowm(i->params, i->params.nr()-1); + fweights.write((char*)&weights(0,0), weights.size()*sizeof(float)); + fweights.write((char*)&biases(0,0), biases.size()*sizeof(float)); + + // main filter weights + fout << " p = np.fromfile(f, dtype='float32', count="<<weights.size()<<");\n"; + fout << " p.shape = net.params['"<<i->caffe_layer_name()<<"'][0].data.shape;\n"; + fout << " net.params['"<<i->caffe_layer_name()<<"'][0].data[:] = p;\n"; + + // biases + fout << " p = np.fromfile(f, dtype='float32', count="<<biases.size()<<");\n"; + fout << " p.shape = net.params['"<<i->caffe_layer_name()<<"'][1].data.shape;\n"; + fout << " net.params['"<<i->caffe_layer_name()<<"'][1].data[:] = p;\n"; + } + else if (i->detail_name == "fc_no_bias") + { + matrix<float> weights = trans(i->params); + fweights.write((char*)&weights(0,0), weights.size()*sizeof(float)); + + // main filter weights + fout << " p = np.fromfile(f, dtype='float32', count="<<weights.size()<<");\n"; + fout << " p.shape = net.params['"<<i->caffe_layer_name()<<"'][0].data.shape;\n"; + fout << " net.params['"<<i->caffe_layer_name()<<"'][0].data[:] = p;\n"; + } + else if (i->detail_name == "affine_con" || i->detail_name == "affine_fc") + { + const long dims = i->params.size()/2; + matrix<float> gamma = trans(rowm(i->params,range(0,dims-1))); + matrix<float> beta = trans(rowm(i->params,range(dims, 2*dims-1))); + fweights.write((char*)&gamma(0,0), gamma.size()*sizeof(float)); + fweights.write((char*)&beta(0,0), beta.size()*sizeof(float)); + + // set gamma weights + fout << " p = np.fromfile(f, dtype='float32', count="<<gamma.size()<<");\n"; + fout << " p.shape = net.params['"<<i->caffe_layer_name()<<"'][0].data.shape;\n"; + fout << " net.params['"<<i->caffe_layer_name()<<"'][0].data[:] = p;\n"; + + // set beta weights + fout << " p = np.fromfile(f, dtype='float32', count="<<beta.size()<<");\n"; + fout << " p.shape = net.params['"<<i->caffe_layer_name()<<"'][1].data.shape;\n"; + fout << " net.params['"<<i->caffe_layer_name()<<"'][1].data[:] = p;\n"; + } + else if (i->detail_name == "prelu") + { + const double param = i->params(0); + + // main filter weights + fout << " tmp = net.params['"<<i->caffe_layer_name()<<"'][0].data.view();\n"; + fout << " tmp.shape = 1;\n"; + fout << " tmp[0] = "<<param<<";\n"; + } + } + +} + +// ---------------------------------------------------------------------------------------- + +int main(int argc, char** argv) try +{ + if (argc != 6) + { + cout << "To use this program, give it an xml file generated by dlib::net_to_xml() " << endl; + cout << "and then 4 numbers that indicate the input tensor size. It will convert " << endl; + cout << "the xml file into a python file that outputs a caffe model containing the dlib model." << endl; + cout << "For example, you might run this program like this: " << endl; + cout << " ./dtoc lenet.xml 1 1 28 28" << endl; + cout << "would convert the lenet.xml model into a caffe model with an input tensor of shape(1,1,28,28)" << endl; + cout << "where the shape values are (num samples in batch, num channels, num rows, num columns)." << endl; + return 0; + } + + const long N = sa = argv[2]; + const long K = sa = argv[3]; + const long NR = sa = argv[4]; + const long NC = sa = argv[5]; + + convert_dlib_xml_to_caffe_python_code(argv[1], N, K, NR, NC); + + return 0; +} +catch(std::exception& e) +{ + cout << "\n\n*************** ERROR CONVERTING TO CAFFE ***************\n" << e.what() << endl; + return 1; +} + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +class doc_handler : public document_handler +{ +public: + std::vector<layer> layers; + bool seen_first_tag = false; + + layer next_layer; + std::stack<string> current_tag; + long tag_id = -1; + + + virtual void start_document ( + ) + { + layers.clear(); + seen_first_tag = false; + tag_id = -1; + } + + virtual void end_document ( + ) { } + + virtual void start_element ( + const unsigned long /*line_number*/, + const std::string& name, + const dlib::attribute_list& atts + ) + { + if (!seen_first_tag) + { + if (name != "net") + throw dlib::error("The top level XML tag must be a 'net' tag."); + seen_first_tag = true; + } + + if (name == "layer") + { + next_layer = layer(); + if (atts["type"] == "skip") + { + // Don't make a new layer, just apply the tag id to the previous layer + if (layers.size() == 0) + throw dlib::error("A skip layer was found as the first layer, but the first layer should be an input layer."); + layers.back().skip_id = sa = atts["id"]; + + // We intentionally leave next_layer empty so the end_element() callback + // don't add it as another layer when called. + } + else if (atts["type"] == "tag") + { + // Don't make a new layer, just remember the tag id so we can apply it on + // the next layer. + tag_id = sa = atts["id"]; + + // We intentionally leave next_layer empty so the end_element() callback + // don't add it as another layer when called. + } + else + { + next_layer.idx = sa = atts["idx"]; + next_layer.type = atts["type"]; + if (tag_id != -1) + { + next_layer.tag_id = tag_id; + tag_id = -1; + } + } + } + else if (current_tag.size() != 0 && current_tag.top() == "layer") + { + next_layer.detail_name = name; + // copy all the XML tag's attributes into the layer struct + atts.reset(); + while (atts.move_next()) + next_layer.attributes[atts.element().key()] = sa = atts.element().value(); + } + + current_tag.push(name); + } + + virtual void end_element ( + const unsigned long /*line_number*/, + const std::string& name + ) + { + current_tag.pop(); + if (name == "layer" && next_layer.type.size() != 0) + layers.push_back(next_layer); + } + + virtual void characters ( + const std::string& data + ) + { + if (current_tag.size() == 0) + return; + + if (comp_tags_with_params.count(current_tag.top()) != 0) + { + istringstream sin(data); + sin >> next_layer.params; + } + + } + + virtual void processing_instruction ( + const unsigned long /*line_number*/, + const std::string& /*target*/, + const std::string& /*data*/ + ) + { + } +}; + +// ---------------------------------------------------------------------------------------- + +void compute_output_tensor_shapes(const matrix<long,4,1>& input_tensor_shape, std::vector<layer>& layers) +{ + DLIB_CASSERT(layers.back().type == "input"); + layers.back().output_tensor_shape = input_tensor_shape; + for (auto i = ++layers.rbegin(); i != layers.rend(); ++i) + { + const auto input_shape = find_input_layer(i).output_tensor_shape; + if (i->type == "comp") + { + if (i->detail_name == "fc" || i->detail_name == "fc_no_bias") + { + long num_outputs = i->attribute("num_outputs"); + i->output_tensor_shape = {input_shape(0), num_outputs, 1, 1}; + } + else if (i->detail_name == "con") + { + long num_filters = i->attribute("num_filters"); + long filter_nc = i->attribute("nc"); + long filter_nr = i->attribute("nr"); + long stride_x = i->attribute("stride_x"); + long stride_y = i->attribute("stride_y"); + long padding_x = i->attribute("padding_x"); + long padding_y = i->attribute("padding_y"); + long nr = 1+(input_shape(2) + 2*padding_y - filter_nr)/stride_y; + long nc = 1+(input_shape(3) + 2*padding_x - filter_nc)/stride_x; + i->output_tensor_shape = {input_shape(0), num_filters, nr, nc}; + } + else if (i->detail_name == "max_pool" || i->detail_name == "avg_pool") + { + long filter_nc = i->attribute("nc"); + long filter_nr = i->attribute("nr"); + long stride_x = i->attribute("stride_x"); + long stride_y = i->attribute("stride_y"); + long padding_x = i->attribute("padding_x"); + long padding_y = i->attribute("padding_y"); + if (filter_nc != 0) + { + long nr = 1+(input_shape(2) + 2*padding_y - filter_nr)/stride_y; + long nc = 1+(input_shape(3) + 2*padding_x - filter_nc)/stride_x; + i->output_tensor_shape = {input_shape(0), input_shape(1), nr, nc}; + } + else // if we are filtering the whole input down to one thing + { + i->output_tensor_shape = {input_shape(0), input_shape(1), 1, 1}; + } + } + else if (i->detail_name == "add_prev") + { + auto aux_shape = find_layer(i, i->attribute("tag")).output_tensor_shape; + for (long j = 0; j < input_shape.size(); ++j) + i->output_tensor_shape(j) = std::max(input_shape(j), aux_shape(j)); + } + else + { + i->output_tensor_shape = input_shape; + } + } + else + { + i->output_tensor_shape = input_shape; + } + + } +} + +// ---------------------------------------------------------------------------------------- + +std::vector<layer> parse_dlib_xml( + const matrix<long,4,1>& input_tensor_shape, + const string& xml_filename +) +{ + doc_handler dh; + parse_xml(xml_filename, dh); + if (dh.layers.size() == 0) + throw dlib::error("No layers found in XML file!"); + + if (dh.layers.back().type != "input") + throw dlib::error("The network in the XML file is missing an input layer!"); + + compute_output_tensor_shapes(input_tensor_shape, dh.layers); + + return dh.layers; +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/convert_dlib_nets_to_caffe/running_a_dlib_model_with_caffe_example.py b/ml/dlib/tools/convert_dlib_nets_to_caffe/running_a_dlib_model_with_caffe_example.py new file mode 100755 index 000000000..c03a7bf5c --- /dev/null +++ b/ml/dlib/tools/convert_dlib_nets_to_caffe/running_a_dlib_model_with_caffe_example.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python + +# This script takes the dlib lenet model trained by the +# examples/dnn_introduction_ex.cpp example program and runs it using caffe. + +import caffe +import numpy as np + +# Before you run this program, you need to run dnn_introduction_ex.cpp to get a +# dlib lenet model. Then you need to convert that model into a "dlib to caffe +# model" python script. You can do this using the command line program +# included with dlib: tools/convert_dlib_nets_to_caffe. That program will +# output a lenet_dlib_to_caffe_model.py file. You run that program like this: +# ./dtoc lenet.xml 1 1 28 28 +# and it will create the lenet_dlib_to_caffe_model.py file, which we import +# with the next line: +import lenet_dlib_to_caffe_model as dlib_model + +# lenet_dlib_to_caffe_model defines a function, save_as_caffe_model() that does +# the work of converting dlib's DNN model to a caffe model and saves it to disk +# in two files. These files are all you need to run the model with caffe. +dlib_model.save_as_caffe_model('dlib_model_def.prototxt', 'dlib_model.proto') + +# Now that we created the caffe model files, we can load them into a caffe Net object. +net = caffe.Net('dlib_model_def.prototxt', 'dlib_model.proto', caffe.TEST); + + +# Now lets do a test, we will run one of the MNIST images through the network. + +# An MNIST image of a 7, it is the very first testing image in MNIST (i.e. wrt dnn_introduction_ex.cpp, it is testing_images[0]) +data = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0,84,185,159,151,60,36, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0,222,254,254,254,254,241,198,198,198,198,198,198,198,198,170,52, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0,67,114,72,114,163,227,254,225,254,254,254,250,229,254,254,140, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,17,66,14,67,67,67,59,21,236,254,106, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,83,253,209,18, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,22,233,255,83, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,129,254,238,44, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,59,249,254,62, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,133,254,187,5, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,9,205,248,58, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,126,254,182, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,75,251,240,57, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,19,221,254,166, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,3,203,254,219,35, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,38,254,254,77, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,31,224,254,115,1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,133,254,254,52, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,61,242,254,254,52, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,121,254,254,219,40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,121,254,207,18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype='float32'); +data.shape = (dlib_model.input_batch_size, dlib_model.input_num_channels, dlib_model.input_num_rows, dlib_model.input_num_cols); + +# labels isn't logically needed but there doesn't seem to be a way to use +# caffe's Net interface without providing a superfluous input array. So we do +# that here. +labels = np.ones((dlib_model.input_batch_size), dtype='float32') +# Give the image to caffe +net.set_input_arrays(data/256, labels) +# Run the data through the network and get the results. +out = net.forward() + +# Print outputs, looping over minibatch. You should see that the network +# correctly classifies the image (it's the number 7). +for i in xrange(dlib_model.input_batch_size): + print i, 'net final layer = ', out['fc1'][i] + print i, 'predicted number =', np.argmax(out['fc1'][i]) + + + diff --git a/ml/dlib/tools/htmlify/CMakeLists.txt b/ml/dlib/tools/htmlify/CMakeLists.txt new file mode 100644 index 000000000..02cae2172 --- /dev/null +++ b/ml/dlib/tools/htmlify/CMakeLists.txt @@ -0,0 +1,31 @@ +# +# This is a CMake makefile. You can find the cmake utility and +# information about it at http://www.cmake.org +# + +cmake_minimum_required(VERSION 2.8.12) + +# create a variable called target_name and set it to the string "htmlify" +set (target_name htmlify) + +project(${target_name}) + +add_subdirectory(../../dlib dlib_build) + +# add all the cpp files we want to compile to this list. This tells +# cmake that they are part of our target (which is the executable named htmlify) +add_executable(${target_name} + htmlify.cpp + to_xml.cpp + ) + +# Tell cmake to link our target executable to dlib. +target_link_libraries(${target_name} dlib::dlib ) + + + +install(TARGETS ${target_name} + RUNTIME DESTINATION bin + ) + + diff --git a/ml/dlib/tools/htmlify/htmlify.cpp b/ml/dlib/tools/htmlify/htmlify.cpp new file mode 100644 index 000000000..e822a5aaa --- /dev/null +++ b/ml/dlib/tools/htmlify/htmlify.cpp @@ -0,0 +1,632 @@ +#include <fstream> +#include <iostream> +#include <string> + + +#include "dlib/cpp_pretty_printer.h" +#include "dlib/cmd_line_parser.h" +#include "dlib/queue.h" +#include "dlib/misc_api.h" +#include "dlib/dir_nav.h" +#include "to_xml.h" + + +const char* VERSION = "3.5"; + +using namespace std; +using namespace dlib; + +typedef cpp_pretty_printer::kernel_1a cprinter; +typedef cpp_pretty_printer::kernel_2a bprinter; +typedef dlib::map<string,string>::kernel_1a map_string_to_string; +typedef dlib::set<string>::kernel_1a set_of_string; +typedef queue<file>::kernel_1a queue_of_files; +typedef queue<directory>::kernel_1a queue_of_dirs; + +void print_manual ( +); +/*! + ensures + - prints detailed information about this program. +!*/ + +void htmlify ( + const map_string_to_string& file_map, + bool colored, + bool number_lines, + const std::string& title +); +/*! + ensures + - for all valid out_file: + - the file out_file is the html transformed version of + file_map[out_file] + - if (number_lines) then + - the html version will have numbered lines + - if (colored) then + - the html version will have colors + - title will be the first part of the HTML title in the output file +!*/ + +void htmlify ( + istream& in, + ostream& out, + const std::string& title, + bool colored, + bool number_lines +); +/*! + ensures + - transforms in into html with the given title and writes it to out. + - if (number_lines) then + - the html version of in will have numbered lines + - if (colored) then + - the html version of in will have colors +!*/ + +void add_files ( + const directory& dir, + const std::string& out_dir, + map_string_to_string& file_map, + bool flatten, + bool cat, + const set_of_string& filter, + unsigned long search_depth, + unsigned long cur_depth = 0 +); +/*! + ensures + - searches the directory dir for files matching the filter and adds them + to the file_map. only looks search_depth deep. +!*/ + +int main(int argc, char** argv) +{ + if (argc == 1) + { + cout << "\nTry the -h option for more information.\n"; + return 0; + } + + string file; + try + { + command_line_parser parser; + parser.add_option("b","Pretty print in black and white. The default is to pretty print in color."); + parser.add_option("n","Number lines."); + parser.add_option("h","Displays this information."); + parser.add_option("index","Create an index."); + parser.add_option("v","Display version."); + parser.add_option("man","Display the manual."); + parser.add_option("f","Specifies a list of file extensions to process when using the -i option. The list elements should be separated by spaces. The default is \"cpp h c\".",1); + parser.add_option("i","Specifies an input directory.",1); + parser.add_option("cat","Puts all the output into a single html file with the given name.",1); + parser.add_option("depth","Specifies how many directories deep to search when using the i option. The default value is 30.",1); + parser.add_option("o","This option causes all the output files to be created inside the given directory. If this option is not given then all output goes to the current working directory.",1); + parser.add_option("flatten","When this option is given it prevents the input directory structure from being replicated."); + parser.add_option("title","This option specifies a string which is prepended onto the title of the generated HTML",1); + parser.add_option("to-xml","Instead of generating HTML output, create a single output file called output.xml that contains " + "a simple XML database which lists all documented classes and functions."); + parser.add_option("t", "When creating XML output, replace tabs in comments with <arg> spaces.", 1); + + + parser.parse(argc,argv); + + + parser.check_incompatible_options("cat","o"); + parser.check_incompatible_options("cat","flatten"); + parser.check_incompatible_options("cat","index"); + parser.check_option_arg_type<unsigned long>("depth"); + parser.check_option_arg_range("t", 1, 100); + + parser.check_incompatible_options("to-xml", "b"); + parser.check_incompatible_options("to-xml", "n"); + parser.check_incompatible_options("to-xml", "index"); + parser.check_incompatible_options("to-xml", "cat"); + parser.check_incompatible_options("to-xml", "o"); + parser.check_incompatible_options("to-xml", "flatten"); + parser.check_incompatible_options("to-xml", "title"); + + const char* singles[] = {"b","n","h","index","v","man","f","cat","depth","o","flatten","title","to-xml", "t"}; + parser.check_one_time_options(singles); + + const char* i_sub_ops[] = {"f","depth","flatten"}; + parser.check_sub_options("i",i_sub_ops); + + const char* to_xml_sub_ops[] = {"t"}; + parser.check_sub_options("to-xml",to_xml_sub_ops); + + const command_line_parser::option_type& b_opt = parser.option("b"); + const command_line_parser::option_type& n_opt = parser.option("n"); + const command_line_parser::option_type& h_opt = parser.option("h"); + const command_line_parser::option_type& index_opt = parser.option("index"); + const command_line_parser::option_type& v_opt = parser.option("v"); + const command_line_parser::option_type& o_opt = parser.option("o"); + const command_line_parser::option_type& man_opt = parser.option("man"); + const command_line_parser::option_type& f_opt = parser.option("f"); + const command_line_parser::option_type& cat_opt = parser.option("cat"); + const command_line_parser::option_type& i_opt = parser.option("i"); + const command_line_parser::option_type& flatten_opt = parser.option("flatten"); + const command_line_parser::option_type& depth_opt = parser.option("depth"); + const command_line_parser::option_type& title_opt = parser.option("title"); + const command_line_parser::option_type& to_xml_opt = parser.option("to-xml"); + + + string filter = "cpp h c"; + + bool cat = false; + bool color = true; + bool number = false; + unsigned long search_depth = 30; + + string out_dir; // the name of the output directory if the o option is given. "" otherwise + string full_out_dir; // the full name of the output directory if the o option is given. "" otherwise + const char separator = directory::get_separator(); + + bool no_run = false; + if (v_opt) + { + cout << "Htmlify v" << VERSION + << "\nCompiled: " << __TIME__ << " " << __DATE__ + << "\nWritten by Davis King\n"; + cout << "Check for updates at http://dlib.net\n\n"; + no_run = true; + } + + if (h_opt) + { + cout << "This program pretty prints C or C++ source code to HTML.\n"; + cout << "Usage: htmlify [options] [file]...\n"; + parser.print_options(); + cout << "\n\n"; + no_run = true; + } + + if (man_opt) + { + print_manual(); + no_run = true; + } + + if (no_run) + return 0; + + if (f_opt) + { + filter = f_opt.argument(); + } + + if (cat_opt) + { + cat = true; + } + + if (depth_opt) + { + search_depth = string_cast<unsigned long>(depth_opt.argument()); + } + + if (to_xml_opt) + { + unsigned long expand_tabs = 0; + if (parser.option("t")) + expand_tabs = string_cast<unsigned long>(parser.option("t").argument()); + + generate_xml_markup(parser, filter, search_depth, expand_tabs); + return 0; + } + + if (o_opt) + { + // make sure this directory exists + out_dir = o_opt.argument(); + create_directory(out_dir); + directory dir(out_dir); + full_out_dir = dir.full_name(); + + // make sure the last character of out_dir is a separator + if (out_dir[out_dir.size()-1] != separator) + out_dir += separator; + if (full_out_dir[out_dir.size()-1] != separator) + full_out_dir += separator; + } + + if (b_opt) + color = false; + if (n_opt) + number = true; + + // this is a map of output file names to input file names. + map_string_to_string file_map; + + + // add all the files that are just given on the command line to the + // file_map. + for (unsigned long i = 0; i < parser.number_of_arguments(); ++i) + { + string in_file, out_file; + in_file = parser[i]; + string::size_type pos = in_file.find_last_of(separator); + if (pos != string::npos) + { + out_file = out_dir + in_file.substr(pos+1) + ".html"; + } + else + { + out_file = out_dir + in_file + ".html"; + } + + if (file_map.is_in_domain(out_file)) + { + if (file_map[out_file] != in_file) + { + // there is a file name colision in the output folder. definitly a bad thing + cout << "Error: Two of the input files have the same name and would overwrite each\n"; + cout << "other. They are " << in_file << " and " << file_map[out_file] << ".\n" << endl; + return 1; + } + else + { + continue; + } + } + + file_map.add(out_file,in_file); + } + + // pick out the filter strings + set_of_string sfilter; + istringstream sin(filter); + string temp; + sin >> temp; + while (sin) + { + if (sfilter.is_member(temp) == false) + sfilter.add(temp); + sin >> temp; + } + + // now get all the files given by the i options + for (unsigned long i = 0; i < i_opt.count(); ++i) + { + directory dir(i_opt.argument(0,i)); + add_files(dir, out_dir, file_map, flatten_opt, cat, sfilter, search_depth); + } + + if (cat) + { + file_map.reset(); + ofstream fout(cat_opt.argument().c_str()); + if (!fout) + { + throw error("Error: unable to open file " + cat_opt.argument()); + } + fout << "<html><title>" << cat_opt.argument() << "</title></html>"; + + const char separator = directory::get_separator(); + string file; + while (file_map.move_next()) + { + ifstream fin(file_map.element().value().c_str()); + if (!fin) + { + throw error("Error: unable to open file " + file_map.element().value()); + } + + string::size_type pos = file_map.element().value().find_last_of(separator); + if (pos != string::npos) + file = file_map.element().value().substr(pos+1); + else + file = file_map.element().value(); + + std::string title; + if (title_opt) + title = title_opt.argument(); + htmlify(fin, fout, title + file, color, number); + } + + } + else + { + std::string title; + if (title_opt) + title = title_opt.argument(); + htmlify(file_map,color,number,title); + } + + + + if (index_opt) + { + ofstream index((out_dir + "index.html").c_str()); + ofstream menu((out_dir + "menu.html").c_str()); + + if (!index) + { + cout << "Error: unable to create " << out_dir << "index.html\n\n"; + return 0; + } + + if (!menu) + { + cout << "Error: unable to create " << out_dir << "menu.html\n\n"; + return 0; + } + + + index << "<html><frameset cols='200,*'>"; + index << "<frame src='menu.html' name='menu'>"; + index << "<frame name='main'></frameset></html>"; + + menu << "<html><body><br>"; + + file_map.reset(); + while (file_map.move_next()) + { + if (o_opt) + { + file = file_map.element().key(); + if (file.find(full_out_dir) != string::npos) + file = file.substr(full_out_dir.size()); + else + file = file.substr(out_dir.size()); + } + else + { + file = file_map.element().key(); + } + // strip the .html from file + file = file.substr(0,file.size()-5); + menu << "<a href='" << file << ".html' target='main'>" + << file << "</a><br>"; + } + + menu << "</body></html>"; + + } + + } + catch (ios_base::failure&) + { + cout << "ERROR: unable to write to " << file << endl; + cout << endl; + } + catch (exception& e) + { + cout << e.what() << endl; + cout << "\nTry the -h option for more information.\n"; + cout << endl; + } +} + +// ------------------------------------------------------------------------------------------------- + +void htmlify ( + istream& in, + ostream& out, + const std::string& title, + bool colored, + bool number_lines +) +{ + if (colored) + { + static cprinter cp; + if (number_lines) + { + cp.print_and_number(in,out,title); + } + else + { + cp.print(in,out,title); + } + } + else + { + static bprinter bp; + if (number_lines) + { + bp.print_and_number(in,out,title); + } + else + { + bp.print(in,out,title); + } + } +} + +// ------------------------------------------------------------------------------------------------- + +void htmlify ( + const map_string_to_string& file_map, + bool colored, + bool number_lines, + const std::string& title +) +{ + file_map.reset(); + const char separator = directory::get_separator(); + string file; + while (file_map.move_next()) + { + ifstream fin(file_map.element().value().c_str()); + if (!fin) + { + throw error("Error: unable to open file " + file_map.element().value() ); + } + + ofstream fout(file_map.element().key().c_str()); + + if (!fout) + { + throw error("Error: unable to open file " + file_map.element().key()); + } + + string::size_type pos = file_map.element().value().find_last_of(separator); + if (pos != string::npos) + file = file_map.element().value().substr(pos+1); + else + file = file_map.element().value(); + + htmlify(fin, fout,title + file, colored, number_lines); + } +} + +// ------------------------------------------------------------------------------------------------- + +void add_files ( + const directory& dir, + const std::string& out_dir, + map_string_to_string& file_map, + bool flatten, + bool cat, + const set_of_string& filter, + unsigned long search_depth, + unsigned long cur_depth +) +{ + const char separator = directory::get_separator(); + + queue_of_files files; + queue_of_dirs dirs; + + dir.get_files(files); + + // look though all the files in the current directory and add the + // ones that match the filter to file_map + string name, ext, in_file, out_file; + files.reset(); + while (files.move_next()) + { + name = files.element().name(); + string::size_type pos = name.find_last_of('.'); + if (pos != string::npos && filter.is_member(name.substr(pos+1))) + { + in_file = files.element().full_name(); + + if (flatten) + { + pos = in_file.find_last_of(separator); + } + else + { + // figure out how much of the file's path we need to keep + // for the output file name + pos = in_file.size(); + for (unsigned long i = 0; i <= cur_depth && pos != string::npos; ++i) + { + pos = in_file.find_last_of(separator,pos-1); + } + } + + if (pos != string::npos) + { + out_file = out_dir + in_file.substr(pos+1) + ".html"; + } + else + { + out_file = out_dir + in_file + ".html"; + } + + if (file_map.is_in_domain(out_file)) + { + if (file_map[out_file] != in_file) + { + // there is a file name colision in the output folder. definitly a bad thing + ostringstream sout; + sout << "Error: Two of the input files have the same name and would overwrite each\n"; + sout << "other. They are " << in_file << " and " << file_map[out_file] << "."; + throw error(sout.str()); + } + else + { + continue; + } + } + + file_map.add(out_file,in_file); + + } + } // while (files.move_next()) + files.clear(); + + if (search_depth > cur_depth) + { + // search all the sub directories + dir.get_dirs(dirs); + dirs.reset(); + while (dirs.move_next()) + { + if (!flatten && !cat) + { + string d = dirs.element().full_name(); + + // figure out how much of the directorie's path we need to keep. + string::size_type pos = d.size(); + for (unsigned long i = 0; i <= cur_depth && pos != string::npos; ++i) + { + pos = d.find_last_of(separator,pos-1); + } + + // make sure this directory exists in the output directory tree + d = d.substr(pos+1); + create_directory(out_dir + separator + d); + } + + add_files(dirs.element(), out_dir, file_map, flatten, cat, filter, search_depth, cur_depth+1); + } + } + +} + +// ------------------------------------------------------------------------------------------------- + +void print_manual ( +) +{ + ostringstream sout; + + const unsigned long indent = 2; + + cout << "\n"; + sout << "Htmlify v" << VERSION; + cout << wrap_string(sout.str(),indent,indent); sout.str(""); + + + sout << "This is a fairly simple program that takes source files and pretty prints them " + << "in HTML. There are two pretty printing styles, black and white or color. The " + << "black and white style is meant to look nice when printed out on paper. It looks " + << "a little funny on the screen but on paper it is pretty nice. The color version " + << "on the other hand has nonprintable HTML elements such as links and anchors."; + cout << "\n\n" << wrap_string(sout.str(),indent,indent); sout.str(""); + + + sout << "The colored style puts HTML anchors on class and function names. This means " + << "you can link directly to the part of the code that contains these names. For example, " + << "if you had a source file bar.cpp with a function called foo in it you could link " + << "directly to the function with a link address of \"bar.cpp.html#foo\". It is also " + << "possible to instruct Htmlify to place HTML anchors at arbitrary spots by using a " + << "special comment of the form /*!A anchor_name */. You can put other things in the " + << "comment but the important bit is to have it begin with /*!A then some white space " + << "then the anchor name you want then more white space and then you can add whatever " + << "you like. You would then refer to this anchor with a link address of " + << "\"file.html#anchor_name\"."; + cout << "\n\n" << wrap_string(sout.str(),indent,indent); sout.str(""); + + sout << "Htmlify also has the ability to create a simple index of all the files it is given. " + << "The --index option creates a file named index.html with a frame on the left side " + << "that contains links to all the files."; + cout << "\n\n" << wrap_string(sout.str(),indent,indent); sout.str(""); + + + sout << "Finally, Htmlify can produce annotated XML output instead of HTML. The output will " + << "contain all functions which are immediately followed by comments of the form /*! comment body !*/. " + << "Similarly, all classes or structs that immediately contain one of these comments following their " + << "opening { will also be output as annotated XML. Note also that if you wish to document a " + << "piece of code using one of these comments but don't want it to appear in the output XML then " + << "use either a comment like /* */ or /*!P !*/ to mark the code as \"private\"."; + cout << "\n\n" << wrap_string(sout.str(),indent,indent) << "\n\n"; sout.str(""); +} + +// ------------------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/htmlify/to_xml.cpp b/ml/dlib/tools/htmlify/to_xml.cpp new file mode 100644 index 000000000..7fae43380 --- /dev/null +++ b/ml/dlib/tools/htmlify/to_xml.cpp @@ -0,0 +1,1599 @@ + +#include "to_xml.h" +#include "dlib/dir_nav.h" +#include <vector> +#include <sstream> +#include <iostream> +#include <fstream> +#include <stack> +#include "dlib/cpp_tokenizer.h" +#include "dlib/string.h" + +using namespace dlib; +using namespace std; + +// ---------------------------------------------------------------------------------------- + +typedef cpp_tokenizer::kernel_1a_c tok_type; + +// ---------------------------------------------------------------------------------------- + +class file_filter +{ +public: + + file_filter( + const string& filter + ) + { + // pick out the filter strings + istringstream sin(filter); + string temp; + sin >> temp; + while (sin) + { + endings.push_back("." + temp); + sin >> temp; + } + } + + bool operator() ( const file& f) const + { + // check if any of the endings match + for (unsigned long i = 0; i < endings.size(); ++i) + { + // if the ending is bigger than f's name then it obviously doesn't match + if (endings[i].size() > f.name().size()) + continue; + + // now check if the actual characters that make up the end of the file name + // matches what is in endings[i]. + if ( std::equal(endings[i].begin(), endings[i].end(), f.name().end()-endings[i].size())) + return true; + } + + return false; + } + + std::vector<string> endings; +}; + +// ---------------------------------------------------------------------------------------- + +void obtain_list_of_files ( + const cmd_line_parser<char>::check_1a_c& parser, + const std::string& filter, + const unsigned long search_depth, + std::vector<std::pair<string,string> >& files +) +{ + for (unsigned long i = 0; i < parser.option("i").count(); ++i) + { + const directory dir(parser.option("i").argument(0,i)); + + const std::vector<file>& temp = get_files_in_directory_tree(dir, file_filter(filter), search_depth); + + // figure out how many characters need to be removed from the path of each file + const string parent = dir.get_parent().full_name(); + unsigned long strip = parent.size(); + if (parent.size() > 0 && parent[parent.size()-1] != '\\' && parent[parent.size()-1] != '/') + strip += 1; + + for (unsigned long i = 0; i < temp.size(); ++i) + { + files.push_back(make_pair(temp[i].full_name().substr(strip), temp[i].full_name())); + } + } + + for (unsigned long i = 0; i < parser.number_of_arguments(); ++i) + { + files.push_back(make_pair(parser[i], parser[i])); + } + + std::sort(files.begin(), files.end()); +} + +// ---------------------------------------------------------------------------------------- + +struct tok_function_record +{ + std::vector<std::pair<int,string> > declaration; + string scope; + string file; + string comment; +}; + +struct tok_method_record +{ + std::vector<std::pair<int,string> > declaration; + string comment; +}; + +struct tok_variable_record +{ + std::vector<std::pair<int,string> > declaration; +}; + +struct tok_typedef_record +{ + std::vector<std::pair<int,string> > declaration; +}; + +struct tok_class_record +{ + std::vector<std::pair<int,string> > declaration; + string name; + string scope; + string file; + string comment; + + std::vector<tok_method_record> public_methods; + std::vector<tok_method_record> protected_methods; + std::vector<tok_variable_record> public_variables; + std::vector<tok_typedef_record> public_typedefs; + std::vector<tok_variable_record> protected_variables; + std::vector<tok_typedef_record> protected_typedefs; + std::vector<tok_class_record> public_inner_classes; + std::vector<tok_class_record> protected_inner_classes; +}; + +// ---------------------------------------------------------------------------------------- + +struct function_record +{ + string name; + string scope; + string declaration; + string file; + string comment; +}; + +struct method_record +{ + string name; + string declaration; + string comment; +}; + +struct variable_record +{ + string declaration; +}; + +struct typedef_record +{ + string declaration; +}; + +struct class_record +{ + string name; + string scope; + string declaration; + string file; + string comment; + + std::vector<method_record> public_methods; + std::vector<variable_record> public_variables; + std::vector<typedef_record> public_typedefs; + + std::vector<method_record> protected_methods; + std::vector<variable_record> protected_variables; + std::vector<typedef_record> protected_typedefs; + + std::vector<class_record> public_inner_classes; + std::vector<class_record> protected_inner_classes; +}; + +// ---------------------------------------------------------------------------------------- + +unsigned long count_newlines ( + const string& str +) +/*! + ensures + - returns the number of '\n' characters inside str +!*/ +{ + unsigned long count = 0; + for (unsigned long i = 0; i < str.size(); ++i) + { + if (str[i] == '\n') + ++count; + } + return count; +} + +// ---------------------------------------------------------------------------------------- + +bool contains_unescaped_newline ( + const string& str +) +/*! + ensures + - returns true if str contains a '\n' character that isn't preceded by a '\' + character. +!*/ +{ + if (str.size() == 0) + return false; + + if (str[0] == '\n') + return true; + + for (unsigned long i = 1; i < str.size(); ++i) + { + if (str[i] == '\n' && str[i-1] != '\\') + return true; + } + + return false; +} + +// ---------------------------------------------------------------------------------------- + +bool is_formal_comment ( + const string& str +) +{ + if (str.size() < 6) + return false; + + if (str[0] == '/' && + str[1] == '*' && + str[2] == '!' && + str[3] != 'P' && + str[3] != 'p' && + str[str.size()-3] == '!' && + str[str.size()-2] == '*' && + str[str.size()-1] == '/' ) + return true; + + return false; +} + +// ---------------------------------------------------------------------------------------- + +string make_scope_string ( + const std::vector<string>& namespaces, + unsigned long exclude_last_num_scopes = 0 +) +{ + string temp; + for (unsigned long i = 0; i + exclude_last_num_scopes < namespaces.size(); ++i) + { + if (namespaces[i].size() == 0) + continue; + + if (temp.size() == 0) + temp = namespaces[i]; + else + temp += "::" + namespaces[i]; + } + return temp; +} + +// ---------------------------------------------------------------------------------------- + +bool looks_like_function_declaration ( + const std::vector<std::pair<int,string> >& declaration +) +{ + + // Check if declaration contains IDENTIFIER ( ) somewhere in it. + bool seen_first_part = false; + bool seen_operator = false; + int local_paren_count = 0; + for (unsigned long i = 1; i < declaration.size(); ++i) + { + if (declaration[i].first == tok_type::KEYWORD && + declaration[i].second == "operator") + { + seen_operator = true; + } + + if (declaration[i].first == tok_type::OTHER && + declaration[i].second == "(" && + (declaration[i-1].first == tok_type::IDENTIFIER || seen_operator)) + { + seen_first_part = true; + } + + if (declaration[i].first == tok_type::OTHER) + { + if ( declaration[i].second == "(") + ++local_paren_count; + else if ( declaration[i].second == ")") + --local_paren_count; + } + } + + if (seen_first_part && local_paren_count == 0) + return true; + else + return false; +} + +// ---------------------------------------------------------------------------------------- + +enum scope_type +{ + public_scope, + protected_scope, + private_scope +}; + + +void process_file ( + istream& fin, + const string& file, + std::vector<tok_function_record>& functions, + std::vector<tok_class_record>& classes +) +/*! + ensures + - scans the given file for global functions and appends any found into functions. + - scans the given file for global classes and appends any found into classes. +!*/ +{ + tok_type tok; + tok.set_stream(fin); + + bool recently_seen_struct_keyword = false; + // true if we have seen the struct keyword and + // we have not seen any identifiers or { characters + + string last_struct_name; + // the name of the last struct we have seen + + bool recently_seen_class_keyword = false; + // true if we have seen the class keyword and + // we have not seen any identifiers or { characters + + string last_class_name; + // the name of the last class we have seen + + bool recently_seen_namespace_keyword = false; + // true if we have seen the namespace keyword and + // we have not seen any identifiers or { characters + + string last_namespace_name; + // the name of the last namespace we have seen + + bool recently_seen_pound_define = false; + // true if we have seen a #define and haven't seen an unescaped newline + + bool recently_seen_preprocessor = false; + // true if we have seen a preprocessor statement and haven't seen an unescaped newline + + bool recently_seen_typedef = false; + // true if we have seen a typedef keyword and haven't seen a ; + + bool recently_seen_paren_0 = false; + // true if we have seen paren_count transition to zero but haven't yet seen a ; or { or + // a new line if recently_seen_pound_define is true. + + bool recently_seen_slots = false; + // true if we have seen the identifier "slots" at a zero scope but haven't seen any + // other identifiers or the ';' or ':' characters. + + bool recently_seen_closing_bracket = false; + // true if we have seen a } and haven't yet seen an IDENTIFIER or ; + + bool recently_seen_new_scope = false; + // true if we have seen the keywords class, namespace, struct, or extern and + // we have not seen the characters {, ), or ; since then + + bool at_top_of_new_scope = false; + // true if we have seen the { that started a new scope but haven't seen anything yet but WHITE_SPACE + + std::vector<string> namespaces; + // a stack to hold the names of the scopes we have entered. This is the classes, structs, and namespaces we enter. + namespaces.push_back(""); // this is the global namespace + + std::stack<scope_type> scope_access; + // If the stack isn't empty then we are inside a class or struct and the top value + // in the stack tells if we are in a public, protected, or private region. + + std::stack<unsigned long> scopes; // a stack to hold current and old scope counts + // the top of the stack counts the number of new scopes (i.e. unmatched { } we have entered + // since we were at a scope where functions can be defined. + // We also maintain the invariant that scopes.size() == namespaces.size() + scopes.push(0); + + std::stack<tok_class_record> class_stack; + // This is a stack where class_stack.top() == the incomplete class record for the class declaration we are + // currently in. + + unsigned long paren_count = 0; + // this is the number of ( we have seen minus the number of ) we have + // seen. + + std::vector<std::pair<int,string> > token_accum; + // Used to accumulate tokens for function and class declarations + + std::vector<std::pair<int,string> > last_full_declaration; + // Once we determine that token_accum has a full declaration in it we copy it into last_full_declaration. + + int type; + string token; + + tok.get_token(type, token); + + while (type != tok_type::END_OF_FILE) + { + switch(type) + { + case tok_type::KEYWORD: // ------------------------------------------ + { + token_accum.push_back(make_pair(type,token)); + + if (token[0] == '#') + recently_seen_preprocessor = true; + + if (token == "class") + { + recently_seen_class_keyword = true; + recently_seen_new_scope = true; + } + else if (token == "struct") + { + recently_seen_struct_keyword = true; + recently_seen_new_scope = true; + } + else if (token == "namespace") + { + recently_seen_namespace_keyword = true; + recently_seen_new_scope = true; + } + else if (token == "extern") + { + recently_seen_new_scope = true; + } + else if (token == "#define") + { + recently_seen_pound_define = true; + } + else if (token == "typedef") + { + recently_seen_typedef = true; + } + else if (recently_seen_pound_define == false) + { + // eat white space + int temp_type; + string temp_token; + if (tok.peek_type() == tok_type::WHITE_SPACE) + tok.get_token(temp_type, temp_token); + + const bool next_is_colon = (tok.peek_type() == tok_type::OTHER && tok.peek_token() == ":"); + if (next_is_colon) + { + // eat the colon + tok.get_token(temp_type, temp_token); + + if (scope_access.size() > 0 && token == "public") + { + scope_access.top() = public_scope; + token_accum.clear(); + last_full_declaration.clear(); + } + else if (scope_access.size() > 0 && token == "protected") + { + scope_access.top() = protected_scope; + token_accum.clear(); + last_full_declaration.clear(); + } + else if (scope_access.size() > 0 && token == "private") + { + scope_access.top() = private_scope; + token_accum.clear(); + last_full_declaration.clear(); + } + } + } + + at_top_of_new_scope = false; + + }break; + + case tok_type::COMMENT: // ------------------------------------------ + { + if (scopes.top() == 0 && last_full_declaration.size() > 0 && is_formal_comment(token) && + paren_count == 0) + { + + // if we are inside a class or struct + if (scope_access.size() > 0) + { + // if we are looking at a comment at the top of a class + if (at_top_of_new_scope) + { + // push an entry for this class into the class_stack + tok_class_record temp; + temp.declaration = last_full_declaration; + temp.file = file; + temp.name = namespaces.back(); + temp.scope = make_scope_string(namespaces,1); + temp.comment = token; + class_stack.push(temp); + } + else if (scope_access.top() == public_scope || scope_access.top() == protected_scope) + { + // This should be a member function. + // Only do anything if the class that contains this member function is + // in the class_stack. + if (class_stack.size() > 0 && class_stack.top().name == namespaces.back() && + looks_like_function_declaration(last_full_declaration)) + { + tok_method_record temp; + + // Check if there is an initialization list inside the declaration and if there is + // then find out where the starting : is located so we can avoid including it in + // the output. + unsigned long pos = last_full_declaration.size(); + long temp_paren_count = 0; + for (unsigned long i = 0; i < last_full_declaration.size(); ++i) + { + if (last_full_declaration[i].first == tok_type::OTHER) + { + if (last_full_declaration[i].second == "(") + ++temp_paren_count; + else if (last_full_declaration[i].second == ")") + --temp_paren_count; + else if (temp_paren_count == 0 && last_full_declaration[i].second == ":") + { + // if this is a :: then ignore it + if (i > 0 && last_full_declaration[i-1].second == ":") + continue; + else if (i+1 < last_full_declaration.size() && last_full_declaration[i+1].second == ":") + continue; + else + { + pos = i; + break; + } + } + } + } + + temp.declaration.assign(last_full_declaration.begin(), last_full_declaration.begin()+pos); + temp.comment = token; + if (scope_access.top() == public_scope) + class_stack.top().public_methods.push_back(temp); + else + class_stack.top().protected_methods.push_back(temp); + } + } + } + else + { + // we should be looking at a global declaration of some kind. + if (looks_like_function_declaration(last_full_declaration)) + { + tok_function_record temp; + + // make sure we never include anything beyond the first closing ) + // if we are looking at a #defined function + unsigned long pos = last_full_declaration.size(); + if (last_full_declaration[0].second == "#define") + { + long temp_paren_count = 0; + for (unsigned long i = 0; i < last_full_declaration.size(); ++i) + { + if (last_full_declaration[i].first == tok_type::OTHER) + { + if (last_full_declaration[i].second == "(") + { + ++temp_paren_count; + } + else if (last_full_declaration[i].second == ")") + { + --temp_paren_count; + if (temp_paren_count == 0) + { + pos = i+1; + break; + } + } + } + } + } + + temp.declaration.assign(last_full_declaration.begin(), last_full_declaration.begin()+pos); + temp.file = file; + temp.scope = make_scope_string(namespaces); + temp.comment = token; + functions.push_back(temp); + } + } + + token_accum.clear(); + last_full_declaration.clear(); + } + + at_top_of_new_scope = false; + }break; + + case tok_type::IDENTIFIER: // ------------------------------------------ + { + if (recently_seen_class_keyword) + { + last_class_name = token; + last_struct_name.clear(); + last_namespace_name.clear(); + } + else if (recently_seen_struct_keyword) + { + last_struct_name = token; + last_class_name.clear(); + last_namespace_name.clear(); + } + else if (recently_seen_namespace_keyword) + { + last_namespace_name = token; + last_class_name.clear(); + last_struct_name.clear(); + } + + if (scopes.top() == 0 && token == "slots") + recently_seen_slots = true; + else + recently_seen_slots = false; + + recently_seen_class_keyword = false; + recently_seen_struct_keyword = false; + recently_seen_namespace_keyword = false; + recently_seen_closing_bracket = false; + at_top_of_new_scope = false; + + token_accum.push_back(make_pair(type,token)); + }break; + + case tok_type::OTHER: // ------------------------------------------ + { + switch(token[0]) + { + case '{': + // if we are entering a new scope + if (recently_seen_new_scope) + { + scopes.push(0); + at_top_of_new_scope = true; + + // if we are entering a class + if (last_class_name.size() > 0) + { + scope_access.push(private_scope); + namespaces.push_back(last_class_name); + } + else if (last_struct_name.size() > 0) + { + scope_access.push(public_scope); + namespaces.push_back(last_struct_name); + } + else if (last_namespace_name.size() > 0) + { + namespaces.push_back(last_namespace_name); + } + else + { + namespaces.push_back(""); + } + } + else + { + scopes.top() += 1; + } + recently_seen_new_scope = false; + recently_seen_class_keyword = false; + recently_seen_struct_keyword = false; + recently_seen_namespace_keyword = false; + recently_seen_paren_0 = false; + + // a { at function scope is an end of a potential declaration + if (scopes.top() == 0) + { + // put token_accum into last_full_declaration + token_accum.swap(last_full_declaration); + } + token_accum.clear(); + break; + + case '}': + if (scopes.top() > 0) + { + scopes.top() -= 1; + } + else if (scopes.size() > 1) + { + scopes.pop(); + + if (scope_access.size() > 0) + scope_access.pop(); + + // If the scope we are leaving is the top class on the class_stack + // then we need to either pop it into its containing class or put it + // into the classes output vector. + if (class_stack.size() > 0 && namespaces.back() == class_stack.top().name) + { + // If this class is a inner_class of another then push it into the + // public_inner_classes or protected_inner_classes field of it's containing class. + if (class_stack.size() > 1) + { + tok_class_record temp = class_stack.top(); + class_stack.pop(); + if (scope_access.size() > 0) + { + if (scope_access.top() == public_scope) + class_stack.top().public_inner_classes.push_back(temp); + else if (scope_access.top() == protected_scope) + class_stack.top().protected_inner_classes.push_back(temp); + } + } + else if (class_stack.size() > 0) + { + classes.push_back(class_stack.top()); + class_stack.pop(); + } + } + + namespaces.pop_back(); + last_full_declaration.clear(); + } + + token_accum.clear(); + recently_seen_closing_bracket = true; + at_top_of_new_scope = false; + break; + + case ';': + // a ; at function scope is an end of a potential declaration + if (scopes.top() == 0) + { + // put token_accum into last_full_declaration + token_accum.swap(last_full_declaration); + } + token_accum.clear(); + + // if we are inside the public area of a class and this ; might be the end + // of a typedef or variable declaration + if (scopes.top() == 0 && scope_access.size() > 0 && + (scope_access.top() == public_scope || scope_access.top() == protected_scope) && + recently_seen_closing_bracket == false) + { + if (recently_seen_typedef) + { + // This should be a typedef inside the public area of a class or struct: + // Only do anything if the class that contains this typedef is in the class_stack. + if (class_stack.size() > 0 && class_stack.top().name == namespaces.back()) + { + tok_typedef_record temp; + temp.declaration = last_full_declaration; + if (scope_access.top() == public_scope) + class_stack.top().public_typedefs.push_back(temp); + else + class_stack.top().protected_typedefs.push_back(temp); + } + + } + else if (recently_seen_paren_0 == false && recently_seen_new_scope == false) + { + // This should be some kind of public variable declaration inside a class or struct: + // Only do anything if the class that contains this member variable is in the class_stack. + if (class_stack.size() > 0 && class_stack.top().name == namespaces.back()) + { + tok_variable_record temp; + temp.declaration = last_full_declaration; + if (scope_access.top() == public_scope) + class_stack.top().public_variables.push_back(temp); + else + class_stack.top().protected_variables.push_back(temp); + } + + } + } + + recently_seen_new_scope = false; + recently_seen_typedef = false; + recently_seen_paren_0 = false; + recently_seen_closing_bracket = false; + recently_seen_slots = false; + at_top_of_new_scope = false; + break; + + case ':': + token_accum.push_back(make_pair(type,token)); + if (recently_seen_slots) + { + token_accum.clear(); + last_full_declaration.clear(); + recently_seen_slots = false; + } + break; + + case '(': + ++paren_count; + token_accum.push_back(make_pair(type,token)); + at_top_of_new_scope = false; + break; + + case ')': + token_accum.push_back(make_pair(type,token)); + + --paren_count; + if (paren_count == 0) + { + recently_seen_paren_0 = true; + if (scopes.top() == 0) + { + last_full_declaration = token_accum; + } + } + + recently_seen_new_scope = false; + at_top_of_new_scope = false; + break; + + default: + token_accum.push_back(make_pair(type,token)); + at_top_of_new_scope = false; + break; + } + }break; + + + case tok_type::WHITE_SPACE: // ------------------------------------------ + { + if (recently_seen_pound_define) + { + if (contains_unescaped_newline(token)) + { + recently_seen_pound_define = false; + recently_seen_paren_0 = false; + recently_seen_preprocessor = false; + + // this is an end of a potential declaration + token_accum.swap(last_full_declaration); + token_accum.clear(); + } + } + + if (recently_seen_preprocessor) + { + if (contains_unescaped_newline(token)) + { + recently_seen_preprocessor = false; + + last_full_declaration.clear(); + token_accum.clear(); + } + } + }break; + + default: // ------------------------------------------ + { + token_accum.push_back(make_pair(type,token)); + at_top_of_new_scope = false; + }break; + } + + + tok.get_token(type, token); + } +} + +// ---------------------------------------------------------------------------------------- + +string get_function_name ( + const std::vector<std::pair<int,string> >& declaration +) +{ + string name; + + bool contains_operator = false; + unsigned long operator_pos = 0; + for (unsigned long i = 0; i < declaration.size(); ++i) + { + if (declaration[i].first == tok_type::KEYWORD && + declaration[i].second == "operator") + { + contains_operator = true; + operator_pos = i; + break; + } + } + + + // find the opening ( for the function + unsigned long paren_pos = 0; + long paren_count = 0; + for (long i = declaration.size()-1; i >= 0; --i) + { + if (declaration[i].first == tok_type::OTHER && + declaration[i].second == ")") + { + ++paren_count; + } + else if (declaration[i].first == tok_type::OTHER && + declaration[i].second == "(") + { + --paren_count; + if (paren_count == 0) + { + paren_pos = i; + break; + } + } + } + + + if (contains_operator) + { + name = declaration[operator_pos].second; + for (unsigned long i = operator_pos+1; i < paren_pos; ++i) + { + if (declaration[i].first == tok_type::IDENTIFIER || declaration[i].first == tok_type::KEYWORD) + { + name += " "; + } + + name += declaration[i].second; + } + } + else + { + // if this is a destructor then include the ~ + if (paren_pos > 1 && declaration[paren_pos-2].second == "~") + name = "~" + declaration[paren_pos-1].second; + else if (paren_pos > 0) + name = declaration[paren_pos-1].second; + + + } + + return name; +} + +// ---------------------------------------------------------------------------------------- + +string pretty_print_declaration ( + const std::vector<std::pair<int,string> >& decl +) +{ + string temp; + long angle_count = 0; + long paren_count = 0; + + if (decl.size() == 0) + return temp; + + temp = decl[0].second; + + + bool just_closed_template = false; + bool in_template = false; + bool last_was_scope_res = false; + bool seen_operator = false; + + if (temp == "operator") + seen_operator = true; + + for (unsigned long i = 1; i < decl.size(); ++i) + { + bool last_was_less_than = false; + if (decl[i-1].first == tok_type::OTHER && decl[i-1].second == "<") + last_was_less_than = true; + + + if (decl[i].first == tok_type::OTHER && decl[i].second == "<" && + (decl[i-1].second != "operator" && ((i>1 && decl[i-2].second != "operator") || decl[i-1].second != "<") )) + ++angle_count; + + if (decl[i-1].first == tok_type::KEYWORD && decl[i-1].second == "template" && + decl[i].first == tok_type::OTHER && decl[i].second == "<") + { + in_template = true; + temp += " <\n "; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == ">") + { + // don't count angle brackets when they are part of an operator + if (decl[i-1].second != "operator" && ((i>1 && decl[i-2].second != "operator") || decl[i-1].second != ">")) + --angle_count; + + if (angle_count == 0 && in_template) + { + temp += "\n >\n"; + just_closed_template = true; + in_template = false; + } + else + { + temp += ">"; + } + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == "<") + { + temp += "<"; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == ",") + { + if (in_template || (paren_count == 1 && angle_count == 0)) + temp += ",\n "; + else + temp += ","; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == "&") + { + temp += "&"; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == ".") + { + temp += "."; + } + else if (decl[i].first == tok_type::SINGLE_QUOTED_TEXT) + { + temp += decl[i].second; + } + else if (decl[i].first == tok_type::DOUBLE_QUOTED_TEXT) + { + temp += decl[i].second; + } + else if (decl[i-1].first == tok_type::SINGLE_QUOTED_TEXT && decl[i].second == "'") + { + temp += decl[i].second; + } + else if (decl[i-1].first == tok_type::DOUBLE_QUOTED_TEXT && decl[i].second == "\"") + { + temp += decl[i].second; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == "[") + { + temp += "["; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == "]") + { + temp += "]"; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == "-") + { + temp += "-"; + } + else if (decl[i].first == tok_type::NUMBER) + { + if (decl[i-1].second == "=") + temp += " " + decl[i].second; + else + temp += decl[i].second; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == "*") + { + temp += "*"; + } + else if (decl[i].first == tok_type::KEYWORD && decl[i].second == "operator") + { + temp += "\noperator"; + seen_operator = true; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == ":" && + (decl[i-1].second == ":" || (i+1<decl.size() && decl[i+1].second == ":") ) ) + { + temp += ":"; + last_was_scope_res = true; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == "(") + { + const bool next_is_paren = (i+1 < decl.size() && decl[i+1].first == tok_type::OTHER && decl[i+1].second == ")"); + + if (paren_count == 0 && next_is_paren == false && in_template == false) + temp += " (\n "; + else + temp += "("; + + ++paren_count; + } + else if (decl[i].first == tok_type::OTHER && decl[i].second == ")") + { + --paren_count; + if (paren_count == 0 && decl[i-1].second != "(" && in_template == false) + temp += "\n)"; + else + temp += ")"; + } + else if (decl[i].first == tok_type::IDENTIFIER && i+1 < decl.size() && + decl[i+1].first == tok_type::OTHER && decl[i+1].second == "(") + { + if (just_closed_template || paren_count != 0 || decl[i-1].second == "~") + temp += decl[i].second; + else if (seen_operator) + temp += " " + decl[i].second; + else + temp += "\n" + decl[i].second; + + just_closed_template = false; + last_was_scope_res = false; + } + else + { + if (just_closed_template || last_was_scope_res || last_was_less_than || + (seen_operator && paren_count == 0 && decl[i].first == tok_type::OTHER ) || + ((decl[i].first == tok_type::KEYWORD || decl[i].first == tok_type::IDENTIFIER) && i>0 && decl[i-1].second == "(")) + temp += decl[i].second; + else + temp += " " + decl[i].second; + + just_closed_template = false; + last_was_scope_res = false; + } + + + + } + + return temp; +} + +// ---------------------------------------------------------------------------------------- + +string format_comment ( + const string& comment, + const unsigned long expand_tabs +) +{ + if (comment.size() <= 6) + return ""; + + string temp = trim(trim(comment.substr(3,comment.size()-6), " \t"), "\n\r"); + + + // if we should expand tabs to spaces + if (expand_tabs != 0) + { + unsigned long column = 0; + string str; + for (unsigned long i = 0; i < temp.size(); ++i) + { + if (temp[i] == '\t') + { + const unsigned long num_spaces = expand_tabs - column%expand_tabs; + column += num_spaces; + str.insert(str.end(), num_spaces, ' '); + } + else if (temp[i] == '\n' || temp[i] == '\r') + { + str += temp[i]; + column = 0; + } + else + { + str += temp[i]; + ++column; + } + } + + // put str into temp + str.swap(temp); + } + + // now figure out what the smallest amount of leading white space is and remove it from each line. + unsigned long num_whitespace = 100000; + + string::size_type pos1 = 0, pos2 = 0; + + while (pos1 != string::npos) + { + // find start of non-white-space + pos2 = temp.find_first_not_of(" \t",pos1); + + // if this is a line of just white space then ignore it + if (pos2 != string::npos && temp[pos2] != '\n' && temp[pos2] != '\r') + { + if (pos2-pos1 < num_whitespace) + num_whitespace = pos2-pos1; + } + + // find end-of-line + pos1 = temp.find_first_of("\n\r", pos2); + // find start of next line + pos2 = temp.find_first_not_of("\n\r", pos1); + pos1 = pos2; + } + + // now remove the leading white space + string temp2; + unsigned long counter = 0; + for (unsigned long i = 0; i < temp.size(); ++i) + { + // if we are looking at a new line + if (temp[i] == '\n' || temp[i] == '\r') + { + counter = 0; + } + else if (counter < num_whitespace) + { + ++counter; + continue; + } + + temp2 += temp[i]; + } + + return temp2; +} + +// ---------------------------------------------------------------------------------------- + +typedef_record convert_tok_typedef_record ( + const tok_typedef_record& rec +) +{ + typedef_record temp; + temp.declaration = pretty_print_declaration(rec.declaration); + return temp; +} + +// ---------------------------------------------------------------------------------------- + +variable_record convert_tok_variable_record ( + const tok_variable_record& rec +) +{ + variable_record temp; + temp.declaration = pretty_print_declaration(rec.declaration); + return temp; +} + +// ---------------------------------------------------------------------------------------- + +method_record convert_tok_method_record ( + const tok_method_record& rec, + const unsigned long expand_tabs +) +{ + method_record temp; + + temp.comment = format_comment(rec.comment, expand_tabs); + temp.name = get_function_name(rec.declaration); + temp.declaration = pretty_print_declaration(rec.declaration); + return temp; +} + +// ---------------------------------------------------------------------------------------- + +class_record convert_tok_class_record ( + const tok_class_record& rec, + const unsigned long expand_tabs +) +{ + class_record crec; + + + crec.scope = rec.scope; + crec.file = rec.file; + crec.comment = format_comment(rec.comment, expand_tabs); + + crec.name.clear(); + + // find the first class token + for (unsigned long i = 0; i+1 < rec.declaration.size(); ++i) + { + if (rec.declaration[i].first == tok_type::KEYWORD && + (rec.declaration[i].second == "class" || + rec.declaration[i].second == "struct" ) + ) + { + crec.name = rec.declaration[i+1].second; + break; + } + } + + crec.declaration = pretty_print_declaration(rec.declaration); + + for (unsigned long i = 0; i < rec.public_typedefs.size(); ++i) + crec.public_typedefs.push_back(convert_tok_typedef_record(rec.public_typedefs[i])); + + for (unsigned long i = 0; i < rec.public_variables.size(); ++i) + crec.public_variables.push_back(convert_tok_variable_record(rec.public_variables[i])); + + for (unsigned long i = 0; i < rec.protected_typedefs.size(); ++i) + crec.protected_typedefs.push_back(convert_tok_typedef_record(rec.protected_typedefs[i])); + + for (unsigned long i = 0; i < rec.protected_variables.size(); ++i) + crec.protected_variables.push_back(convert_tok_variable_record(rec.protected_variables[i])); + + for (unsigned long i = 0; i < rec.public_methods.size(); ++i) + crec.public_methods.push_back(convert_tok_method_record(rec.public_methods[i], expand_tabs)); + + for (unsigned long i = 0; i < rec.protected_methods.size(); ++i) + crec.protected_methods.push_back(convert_tok_method_record(rec.protected_methods[i], expand_tabs)); + + for (unsigned long i = 0; i < rec.public_inner_classes.size(); ++i) + crec.public_inner_classes.push_back(convert_tok_class_record(rec.public_inner_classes[i], expand_tabs)); + + for (unsigned long i = 0; i < rec.protected_inner_classes.size(); ++i) + crec.protected_inner_classes.push_back(convert_tok_class_record(rec.protected_inner_classes[i], expand_tabs)); + + + return crec; +} + +// ---------------------------------------------------------------------------------------- + +function_record convert_tok_function_record ( + const tok_function_record& rec, + const unsigned long expand_tabs +) +{ + function_record temp; + + temp.scope = rec.scope; + temp.file = rec.file; + temp.comment = format_comment(rec.comment, expand_tabs); + temp.name = get_function_name(rec.declaration); + temp.declaration = pretty_print_declaration(rec.declaration); + + return temp; +} + +// ---------------------------------------------------------------------------------------- + +void convert_to_normal_records ( + const std::vector<tok_function_record>& tok_functions, + const std::vector<tok_class_record>& tok_classes, + const unsigned long expand_tabs, + std::vector<function_record>& functions, + std::vector<class_record>& classes +) +{ + functions.clear(); + classes.clear(); + + + for (unsigned long i = 0; i < tok_functions.size(); ++i) + { + functions.push_back(convert_tok_function_record(tok_functions[i], expand_tabs)); + } + + + for (unsigned long i = 0; i < tok_classes.size(); ++i) + { + classes.push_back(convert_tok_class_record(tok_classes[i], expand_tabs)); + } + + +} + +// ---------------------------------------------------------------------------------------- + +string add_entity_ref (const string& str) +{ + string temp; + for (unsigned long i = 0; i < str.size(); ++i) + { + if (str[i] == '&') + temp += "&"; + else if (str[i] == '<') + temp += "<"; + else if (str[i] == '>') + temp += ">"; + else + temp += str[i]; + } + return temp; +} + +// ---------------------------------------------------------------------------------------- + +string flip_slashes (string str) +{ + for (unsigned long i = 0; i < str.size(); ++i) + { + if (str[i] == '\\') + str[i] = '/'; + } + return str; +} + +// ---------------------------------------------------------------------------------------- + +void write_as_xml ( + const function_record& rec, + ostream& fout +) +{ + fout << " <function>\n"; + fout << " <name>" << add_entity_ref(rec.name) << "</name>\n"; + fout << " <scope>" << add_entity_ref(rec.scope) << "</scope>\n"; + fout << " <declaration>" << add_entity_ref(rec.declaration) << "</declaration>\n"; + fout << " <file>" << flip_slashes(add_entity_ref(rec.file)) << "</file>\n"; + fout << " <comment>" << add_entity_ref(rec.comment) << "</comment>\n"; + fout << " </function>\n"; +} + +// ---------------------------------------------------------------------------------------- + +void write_as_xml ( + const class_record& rec, + ostream& fout, + unsigned long indent +) +{ + const string pad(indent, ' '); + + fout << pad << "<class>\n"; + fout << pad << " <name>" << add_entity_ref(rec.name) << "</name>\n"; + fout << pad << " <scope>" << add_entity_ref(rec.scope) << "</scope>\n"; + fout << pad << " <declaration>" << add_entity_ref(rec.declaration) << "</declaration>\n"; + fout << pad << " <file>" << flip_slashes(add_entity_ref(rec.file)) << "</file>\n"; + fout << pad << " <comment>" << add_entity_ref(rec.comment) << "</comment>\n"; + + + if (rec.public_typedefs.size() > 0) + { + fout << pad << " <public_typedefs>\n"; + for (unsigned long i = 0; i < rec.public_typedefs.size(); ++i) + { + fout << pad << " <typedef>" << add_entity_ref(rec.public_typedefs[i].declaration) << "</typedef>\n"; + } + fout << pad << " </public_typedefs>\n"; + } + + + if (rec.public_variables.size() > 0) + { + fout << pad << " <public_variables>\n"; + for (unsigned long i = 0; i < rec.public_variables.size(); ++i) + { + fout << pad << " <variable>" << add_entity_ref(rec.public_variables[i].declaration) << "</variable>\n"; + } + fout << pad << " </public_variables>\n"; + } + + if (rec.protected_typedefs.size() > 0) + { + fout << pad << " <protected_typedefs>\n"; + for (unsigned long i = 0; i < rec.protected_typedefs.size(); ++i) + { + fout << pad << " <typedef>" << add_entity_ref(rec.protected_typedefs[i].declaration) << "</typedef>\n"; + } + fout << pad << " </protected_typedefs>\n"; + } + + + if (rec.protected_variables.size() > 0) + { + fout << pad << " <protected_variables>\n"; + for (unsigned long i = 0; i < rec.protected_variables.size(); ++i) + { + fout << pad << " <variable>" << add_entity_ref(rec.protected_variables[i].declaration) << "</variable>\n"; + } + fout << pad << " </protected_variables>\n"; + } + + + if (rec.public_methods.size() > 0) + { + fout << pad << " <public_methods>\n"; + for (unsigned long i = 0; i < rec.public_methods.size(); ++i) + { + fout << pad << " <method>\n"; + fout << pad << " <name>" << add_entity_ref(rec.public_methods[i].name) << "</name>\n"; + fout << pad << " <declaration>" << add_entity_ref(rec.public_methods[i].declaration) << "</declaration>\n"; + fout << pad << " <comment>" << add_entity_ref(rec.public_methods[i].comment) << "</comment>\n"; + fout << pad << " </method>\n"; + } + fout << pad << " </public_methods>\n"; + } + + + if (rec.protected_methods.size() > 0) + { + fout << pad << " <protected_methods>\n"; + for (unsigned long i = 0; i < rec.protected_methods.size(); ++i) + { + fout << pad << " <method>\n"; + fout << pad << " <name>" << add_entity_ref(rec.protected_methods[i].name) << "</name>\n"; + fout << pad << " <declaration>" << add_entity_ref(rec.protected_methods[i].declaration) << "</declaration>\n"; + fout << pad << " <comment>" << add_entity_ref(rec.protected_methods[i].comment) << "</comment>\n"; + fout << pad << " </method>\n"; + } + fout << pad << " </protected_methods>\n"; + } + + + if (rec.public_inner_classes.size() > 0) + { + fout << pad << " <public_inner_classes>\n"; + for (unsigned long i = 0; i < rec.public_inner_classes.size(); ++i) + { + write_as_xml(rec.public_inner_classes[i], fout, indent+4); + } + fout << pad << " </public_inner_classes>\n"; + } + + if (rec.protected_inner_classes.size() > 0) + { + fout << pad << " <protected_inner_classes>\n"; + for (unsigned long i = 0; i < rec.protected_inner_classes.size(); ++i) + { + write_as_xml(rec.protected_inner_classes[i], fout, indent+4); + } + fout << pad << " </protected_inner_classes>\n"; + } + + + fout << pad << "</class>\n"; +} + +// ---------------------------------------------------------------------------------------- + +void save_to_xml_file ( + const std::vector<function_record>& functions, + const std::vector<class_record>& classes +) +{ + ofstream fout("output.xml"); + + fout << "<!-- This XML file was generated using the htmlify tool available from http://dlib.net. -->" << endl; + fout << "<code>" << endl; + + fout << " <classes>" << endl; + for (unsigned long i = 0; i < classes.size(); ++i) + { + write_as_xml(classes[i], fout, 4); + fout << "\n"; + } + fout << " </classes>\n\n" << endl; + + + fout << " <global_functions>" << endl; + for (unsigned long i = 0; i < functions.size(); ++i) + { + write_as_xml(functions[i], fout); + fout << "\n"; + } + fout << " </global_functions>" << endl; + + fout << "</code>" << endl; +} + +// ---------------------------------------------------------------------------------------- + +void generate_xml_markup( + const cmd_line_parser<char>::check_1a_c& parser, + const std::string& filter, + const unsigned long search_depth, + const unsigned long expand_tabs +) +{ + + // first figure out which files should be processed + std::vector<std::pair<string,string> > files; + obtain_list_of_files(parser, filter, search_depth, files); + + + std::vector<tok_function_record> tok_functions; + std::vector<tok_class_record> tok_classes; + + for (unsigned long i = 0; i < files.size(); ++i) + { + ifstream fin(files[i].second.c_str()); + if (!fin) + { + cerr << "Error opening file: " << files[i].second << endl; + return; + } + process_file(fin, files[i].first, tok_functions, tok_classes); + } + + std::vector<function_record> functions; + std::vector<class_record> classes; + + convert_to_normal_records(tok_functions, tok_classes, expand_tabs, functions, classes); + + save_to_xml_file(functions, classes); +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/htmlify/to_xml.h b/ml/dlib/tools/htmlify/to_xml.h new file mode 100644 index 000000000..4bdf3f00e --- /dev/null +++ b/ml/dlib/tools/htmlify/to_xml.h @@ -0,0 +1,22 @@ +#ifndef DLIB_HTMLIFY_TO_XmL_H__ +#define DLIB_HTMLIFY_TO_XmL_H__ + +#include "dlib/cmd_line_parser.h" +#include <string> + +void generate_xml_markup( + const dlib::cmd_line_parser<char>::check_1a_c& parser, + const std::string& filter, + const unsigned long search_depth, + const unsigned long expand_tabs +); +/*! + ensures + - reads all the files indicated by the parser arguments and converts them + to XML. The output will be stored in the output.xml file. + - if (expand_tabs != 0) then + - tabs will be replaced with expand_tabs spaces inside comment blocks +!*/ + +#endif // DLIB_HTMLIFY_TO_XmL_H__ + diff --git a/ml/dlib/tools/htmlify/to_xml_example/bigminus.gif b/ml/dlib/tools/htmlify/to_xml_example/bigminus.gif Binary files differnew file mode 100644 index 000000000..aea8e5c01 --- /dev/null +++ b/ml/dlib/tools/htmlify/to_xml_example/bigminus.gif diff --git a/ml/dlib/tools/htmlify/to_xml_example/bigplus.gif b/ml/dlib/tools/htmlify/to_xml_example/bigplus.gif Binary files differnew file mode 100644 index 000000000..6bee68e21 --- /dev/null +++ b/ml/dlib/tools/htmlify/to_xml_example/bigplus.gif diff --git a/ml/dlib/tools/htmlify/to_xml_example/example.xml b/ml/dlib/tools/htmlify/to_xml_example/example.xml new file mode 100644 index 000000000..472a4a5e1 --- /dev/null +++ b/ml/dlib/tools/htmlify/to_xml_example/example.xml @@ -0,0 +1,8 @@ +<?xml version="1.0" encoding="ISO-8859-1"?> +<?xml-stylesheet type="text/xsl" href="stylesheet.xsl"?> + +<doc> + <title>Documented Code</title> + <body from_file="output.xml"/> +</doc> + diff --git a/ml/dlib/tools/htmlify/to_xml_example/minus.gif b/ml/dlib/tools/htmlify/to_xml_example/minus.gif Binary files differnew file mode 100644 index 000000000..1deac2fe1 --- /dev/null +++ b/ml/dlib/tools/htmlify/to_xml_example/minus.gif diff --git a/ml/dlib/tools/htmlify/to_xml_example/output.xml b/ml/dlib/tools/htmlify/to_xml_example/output.xml new file mode 100644 index 000000000..95e4de6ae --- /dev/null +++ b/ml/dlib/tools/htmlify/to_xml_example/output.xml @@ -0,0 +1,49 @@ +<!-- This XML file was generated using the htmlify tool available from http://dlib.net. --> +<code> + <classes> + <class> + <name>test</name> + <scope></scope> + <declaration>class test</declaration> + <file>test.cpp</file> + <comment>WHAT THIS OBJECT REPRESENTS + This is a simple test class that doesn't do anything</comment> + <public_typedefs> + <typedef>typedef int type</typedef> + </public_typedefs> + <public_methods> + <method> + <name>test</name> + <declaration>test()</declaration> + <comment>ensures + - constructs a test object</comment> + </method> + <method> + <name>print</name> + <declaration>void +print() const</declaration> + <comment>ensures + - prints a message to the screen</comment> + </method> + </public_methods> + </class> + + </classes> + + + <global_functions> + <function> + <name>add_numbers</name> + <scope></scope> + <declaration>int +add_numbers ( + int a, + int b +)</declaration> + <file>test.cpp</file> + <comment>ensures + - returns a + b</comment> + </function> + + </global_functions> +</code> diff --git a/ml/dlib/tools/htmlify/to_xml_example/plus.gif b/ml/dlib/tools/htmlify/to_xml_example/plus.gif Binary files differnew file mode 100644 index 000000000..2d15c1417 --- /dev/null +++ b/ml/dlib/tools/htmlify/to_xml_example/plus.gif diff --git a/ml/dlib/tools/htmlify/to_xml_example/stylesheet.xsl b/ml/dlib/tools/htmlify/to_xml_example/stylesheet.xsl new file mode 100644 index 000000000..7a44862a3 --- /dev/null +++ b/ml/dlib/tools/htmlify/to_xml_example/stylesheet.xsl @@ -0,0 +1,354 @@ +<?xml version="1.0" encoding="ISO-8859-1" ?> + +<!-- + To the extent possible under law, Davis E King has waived all copyright and + related or neighboring rights to dlib documentation (XML, HTML, and XSLT files). + This work is published from United States. +--> + +<xsl:stylesheet version="1.0" xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> + <xsl:output method='html' version='1.0' encoding='UTF-8' indent='yes' /> + + + <!-- ************************************************************************* --> + + <xsl:variable name="lcletters">abcdefghijklmnopqrstuvwxyz </xsl:variable> + <xsl:variable name="ucletters">ABCDEFGHIJKLMNOPQRSTUVWXYZ </xsl:variable> + + <!-- ************************************************************************* --> + + <xsl:template match="/doc"> + <html> + <head> + <title> + <xsl:if test="title"> + <xsl:value-of select="title" /> + </xsl:if> + </title> + + + <!-- [client side code for collapsing and unfolding branches] --> + <script language="JavaScript"> + + // --------------------------------------------- + // --- Name: Easy DHTML Treeview -- + // --- Author: D.D. de Kerf -- + // --- Version: 0.2 Date: 13-6-2001 -- + // --------------------------------------------- + function Toggle(node) + { + // Unfold the branch if it isn't visible + var next_node = node.nextSibling; + if (next_node.style.display == 'none') + { + // Change the image (if there is an image) + if (node.childNodes.length > 0) + { + if (node.childNodes.length > 0) + { + if (node.childNodes.item(0).nodeName == "IMG") + { + node.childNodes.item(0).src = "minus.gif"; + } + } + } + + next_node.style.display = 'block'; + } + // Collapse the branch if it IS visible + else + { + // Change the image (if there is an image) + if (node.childNodes.length > 0) + { + if (node.childNodes.length > 0) + { + if (node.childNodes.item(0).nodeName == "IMG") + { + node.childNodes.item(0).src = "plus.gif"; + } + } + } + + next_node.style.display = 'none'; + } + + } + function BigToggle(node) + { + // Unfold the branch if it isn't visible + var next_node = node.nextSibling; + if (next_node.style.display == 'none') + { + // Change the image (if there is an image) + if (node.childNodes.length > 0) + { + if (node.childNodes.length > 0) + { + if (node.childNodes.item(0).nodeName == "IMG") + { + node.childNodes.item(0).src = "bigminus.gif"; + } + } + } + + next_node.style.display = 'block'; + } + // Collapse the branch if it IS visible + else + { + // Change the image (if there is an image) + if (node.childNodes.length > 0) + { + if (node.childNodes.length > 0) + { + if (node.childNodes.item(0).nodeName == "IMG") + { + node.childNodes.item(0).src = "bigplus.gif"; + } + } + } + + next_node.style.display = 'none'; + } + + } + </script> + + <style type="text/css"> + pre {margin:0px;} + + ul.tree li { list-style: none; margin-left:10px;} + ul.tree { margin:0px; padding:0px; margin-left:5px; font-size:0.95em; } + ul.tree li ul { margin-left:10px; padding:0px; } + + div#component { + background-color:white; + border: 2px solid rgb(102,102,102); + text-align:left; + margin-top: 1.5em; + padding: 0.7em; + } + + div#function { + background-color:white; + border: 2px solid rgb(102,102,255); + text-align:left; + margin-top: 0.3em; + padding: 0.3em; + } + + div#class { + background-color:white; + border: 2px solid rgb(255,102,102); + text-align:left; + margin-top: 0.3em; + padding: 0.3em; + } + + </style> + </head> + <body> + <xsl:if test="title"> + <center><h1> <xsl:value-of select="title" /> </h1></center> + </xsl:if> + <xsl:apply-templates select="body"/> + </body> + </html> + </xsl:template> + + + + + + <!-- ************************************************************************* --> + + <xsl:template match="body"> + <xsl:choose> + <xsl:when test="@from_file"> + <xsl:apply-templates select="document(@from_file)"/> + <xsl:apply-templates/> + </xsl:when> + <xsl:otherwise> + <xsl:apply-templates/> + </xsl:otherwise> + </xsl:choose> + </xsl:template> + + + <!-- ************************************************************************* --> + <!-- ************************************************************************* --> + <!-- XSLT for dealing with <code> blocks generated by the htmlify to-xml option --> + <!-- ************************************************************************* --> + <!-- ************************************************************************* --> + + <xsl:template match="code"> + + <h1>Classes and Structs:</h1> + <xsl:for-each select="classes/class"> + <xsl:sort select="translate(concat(name,.),$lcletters, $ucletters)"/> + <xsl:apply-templates select="."/> + </xsl:for-each> + + <h1>Global Functions:</h1> + <xsl:for-each select="global_functions/function"> + <xsl:sort select="translate(concat(name,.),$lcletters, $ucletters)"/> + <div id="function"> + <a onclick="Toggle(this)" style="cursor: pointer"><img src="plus.gif" border="0"/><font color="blue"> + <u><b><xsl:value-of select="name"/>()</b></u></font></a> + <div style="display:none;"> + <br/> + <xsl:if test="scope != ''"> + <u>Scope</u>: <xsl:value-of select="scope"/> <br/> + </xsl:if> + <u>File</u>: <xsl:value-of select="file"/> <br/><br/> + <div style="margin-left:1.5em"> + <pre style="font-size:1.1em;"><xsl:value-of select="declaration"/>;</pre> + <font color="#009900"><pre><xsl:value-of select="comment"/></pre></font> + </div> + <br/> + </div> + </div> + </xsl:for-each> + + </xsl:template> + + <!-- ************************************************************************* --> + + <xsl:template match="class"> + <div id="class"> + <a onclick="Toggle(this)" style="cursor: pointer"><img src="plus.gif" border="0"/><font color="blue"> + <u><b><xsl:value-of select="name"/></b></u></font></a> + <div style="display:none;"> + <br/> + <xsl:if test="scope != ''"> + <u>Scope</u>: <xsl:value-of select="scope"/> <br/> + </xsl:if> + <u>File</u>: <xsl:value-of select="file"/> <br/><br/> + <div style="margin-left:1.5em"> + <pre style="font-size:1.1em;"><xsl:value-of select="declaration"/>;</pre> <br/> + <font color="#009900"><pre><xsl:value-of select="comment"/></pre></font> <br/> + </div> + + <xsl:if test="protected_typedefs"> + <a onclick="BigToggle(this)" style="cursor: pointer"><img src="bigplus.gif" border="0"/><font color="blue"> + <u style="font-size:2em">Protected Typedefs</u></font></a> + <div style="display:none;"> + <ul> + <xsl:for-each select="protected_typedefs/typedef"> + <li><xsl:value-of select="."/>;</li> + </xsl:for-each> + </ul> + </div> + <br/> + </xsl:if> + + <xsl:if test="public_typedefs"> + <a onclick="BigToggle(this)" style="cursor: pointer"><img src="bigplus.gif" border="0" style="size:2em"/><font color="blue"> + <u style="font-size:2em">Public Typedefs</u></font></a> + <div style="display:none;"> + <ul> + <xsl:for-each select="public_typedefs/typedef"> + <li><xsl:value-of select="."/>;</li> + </xsl:for-each> + </ul> + </div> + <br/> + </xsl:if> + + <xsl:if test="protected_variables"> + <a onclick="BigToggle(this)" style="cursor: pointer"><img src="bigplus.gif" border="0"/><font color="blue"> + <u style="font-size:2em">Protected Variables</u></font></a> + <div style="display:none;"> + <ul> + <xsl:for-each select="protected_variables/variable"> + <li><xsl:value-of select="."/>;</li> + </xsl:for-each> + </ul> + </div> + <br/> + </xsl:if> + + <xsl:if test="public_variables"> + <a onclick="BigToggle(this)" style="cursor: pointer"><img src="bigplus.gif" border="0"/><font color="blue"> + <u style="font-size:2em">Public Variables</u></font></a> + <div style="display:none;"> + <ul> + <xsl:for-each select="public_variables/variable"> + <li><xsl:value-of select="."/>;</li> + </xsl:for-each> + </ul> + </div> + <br/> + </xsl:if> + + <xsl:if test="protected_methods"> + <a onclick="BigToggle(this)" style="cursor: pointer"><img src="bigplus.gif" border="0"/><font color="blue"> + <u style="font-size:2em">Protected Methods</u></font></a> + <div style="display:none;"> + <xsl:for-each select="protected_methods/method"> + <div id="function"> + <u>Method Name</u>: <b><xsl:value-of select="name"/></b> <br/><br/> + <div style="margin-left:1.5em"> + <pre style="font-size:1.1em;"><xsl:value-of select="declaration"/>;</pre> + <font color="#009900"><pre><xsl:value-of select="comment"/></pre></font> <br/> + </div> + </div> + </xsl:for-each> + </div> + <br/> + </xsl:if> + + <xsl:if test="public_methods"> + <a onclick="BigToggle(this)" style="cursor: pointer"><img src="bigplus.gif" border="0"/><font color="blue"> + <u style="font-size:2em">Public Methods</u></font></a> + <div style="display:none;"> + <xsl:for-each select="public_methods/method"> + <div id="function"> + <u>Method Name</u>: <b><xsl:value-of select="name"/></b> <br/><br/> + <div style="margin-left:1.5em"> + <pre style="font-size:1.1em;"><xsl:value-of select="declaration"/>;</pre> + <font color="#009900"><pre><xsl:value-of select="comment"/></pre></font> <br/> + </div> + </div> + </xsl:for-each> + </div> + <br/> + </xsl:if> + + <xsl:if test="protected_inner_classes"> + <a onclick="BigToggle(this)" style="cursor: pointer"><img src="bigplus.gif" border="0"/><font color="blue"> + <u style="font-size:2em">Protected Inner Classes</u></font></a> + <div style="display:none;"> + <xsl:for-each select="protected_inner_classes/class"> + <xsl:apply-templates select="."/> + </xsl:for-each> + </div> + <br/> + </xsl:if> + + <xsl:if test="public_inner_classes"> + <a onclick="BigToggle(this)" style="cursor: pointer"><img src="bigplus.gif" border="0"/><font color="blue"> + <u style="font-size:2em">Public Inner Classes</u></font></a> + <div style="display:none;"> + <xsl:for-each select="public_inner_classes/class"> + <xsl:apply-templates select="."/> + </xsl:for-each> + </div> + <br/> + </xsl:if> + + </div> + </div> + </xsl:template> + + + <!-- ************************************************************************* --> + <!-- ************************************************************************* --> + <!-- ************************************************************************* --> + <!-- ************************************************************************* --> + + + + +</xsl:stylesheet> diff --git a/ml/dlib/tools/htmlify/to_xml_example/test.cpp b/ml/dlib/tools/htmlify/to_xml_example/test.cpp new file mode 100644 index 000000000..edbdfff54 --- /dev/null +++ b/ml/dlib/tools/htmlify/to_xml_example/test.cpp @@ -0,0 +1,78 @@ +#include <iostream> + +// ---------------------------------------------------------------------------------------- + +using namespace std; + +// ---------------------------------------------------------------------------------------- + +class test +{ + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple test class that doesn't do anything + !*/ + +public: + + typedef int type; + + test (); + /*! + ensures + - constructs a test object + !*/ + + void print () const; + /*! + ensures + - prints a message to the screen + !*/ + +}; + +// ---------------------------------------------------------------------------------------- + +test::test() {} + +void test::print() const +{ + cout << "A message!" << endl; +} + +// ---------------------------------------------------------------------------------------- + +int add_numbers ( + int a, + int b +) +/*! + ensures + - returns a + b +!*/ +{ + return a + b; +} + +// ---------------------------------------------------------------------------------------- + +void a_function ( +) +/*!P + This is a function which won't show up in the output of htmlify --to-xml + because of the presence of the P in the above /*!P above. +!*/ +{ +} + +// ---------------------------------------------------------------------------------------- + +int main() +{ + test a; + a.print(); +} + +// ---------------------------------------------------------------------------------------- + + diff --git a/ml/dlib/tools/imglab/CMakeLists.txt b/ml/dlib/tools/imglab/CMakeLists.txt new file mode 100644 index 000000000..46c64fb3e --- /dev/null +++ b/ml/dlib/tools/imglab/CMakeLists.txt @@ -0,0 +1,41 @@ +# +# This is a CMake makefile. You can find the cmake utility and +# information about it at http://www.cmake.org +# + +cmake_minimum_required(VERSION 2.8.12) + +# create a variable called target_name and set it to the string "imglab" +set (target_name imglab) + +PROJECT(${target_name}) +add_subdirectory(../../dlib dlib_build) + +# add all the cpp files we want to compile to this list. This tells +# cmake that they are part of our target (which is the executable named imglab) +add_executable(${target_name} + src/main.cpp + src/metadata_editor.h + src/metadata_editor.cpp + src/convert_pascal_xml.h + src/convert_pascal_xml.cpp + src/convert_pascal_v1.h + src/convert_pascal_v1.cpp + src/convert_idl.h + src/convert_idl.cpp + src/common.h + src/common.cpp + src/cluster.cpp + src/flip_dataset.cpp +) + + +# Tell cmake to link our target executable to dlib. +target_link_libraries(${target_name} dlib::dlib ) + + +install(TARGETS ${target_name} + RUNTIME DESTINATION bin + ) +install(PROGRAMS convert_imglab_paths_to_relative copy_imglab_dataset DESTINATION bin ) + diff --git a/ml/dlib/tools/imglab/README.txt b/ml/dlib/tools/imglab/README.txt new file mode 100644 index 000000000..3f0ca92a1 --- /dev/null +++ b/ml/dlib/tools/imglab/README.txt @@ -0,0 +1,40 @@ +imglab is a simple graphical tool for annotating images with object bounding +boxes and optionally their part locations. Generally, you use it when you want +to train an object detector (e.g. a face detector) since it allows you to +easily create the needed training dataset. + +You can compile imglab with the following commands: + cd dlib/tools/imglab + mkdir build + cd build + cmake .. + cmake --build . --config Release +Note that you may need to install CMake (www.cmake.org) for this to work. On a +unix system you can also install imglab into /usr/local/bin by running + sudo make install +This will make running it more convenient. + +Next, to use it, lets assume you have a folder of images called /tmp/images. +These images should contain examples of the objects you want to learn to +detect. You will use the imglab tool to label these objects. Do this by +typing the following command: + ./imglab -c mydataset.xml /tmp/images +This will create a file called mydataset.xml which simply lists the images in +/tmp/images. To add bounding boxes to the objects you run: + ./imglab mydataset.xml +and a window will appear showing all the images. You can use the up and down +arrow keys to cycle though the images and the mouse to label objects. In +particular, holding the shift key, left clicking, and dragging the mouse will +allow you to draw boxes around the objects you wish to detect. + +Once you finish labeling objects go to the file menu, click save, and then +close the program. This will save the object boxes back to mydataset.xml. You +can verify this by opening the tool again with: + ./imglab mydataset.xml +and observing that the boxes are present. + + +imglab can do a few additional things. To see these run: + imglab -h +and also read the instructions in the About->Help menu. + diff --git a/ml/dlib/tools/imglab/convert_imglab_paths_to_relative b/ml/dlib/tools/imglab/convert_imglab_paths_to_relative new file mode 100755 index 000000000..09c5ef7a5 --- /dev/null +++ b/ml/dlib/tools/imglab/convert_imglab_paths_to_relative @@ -0,0 +1,24 @@ +#!/usr/bin/perl + +use File::Spec; + +die "This script converts all the file names in an imglab XML file to have paths relative to the current folder. Call it like this: ./convert_imglab_paths_to_relative some_file.xml" if @ARGV != 1; + +$file = @ARGV[0]; +open(INFO, $file) or die('Could not open file.'); + +foreach $line (<INFO>) +{ + if (index($line, 'file=\'') != -1) + { + $line =~ /file='(.*)'/; + $relpath = File::Spec->abs2rel($1); + $line =~ s/$1/$relpath/; + print $line + } + else + { + print $line + } +} + diff --git a/ml/dlib/tools/imglab/copy_imglab_dataset b/ml/dlib/tools/imglab/copy_imglab_dataset new file mode 100755 index 000000000..8b44ed166 --- /dev/null +++ b/ml/dlib/tools/imglab/copy_imglab_dataset @@ -0,0 +1,22 @@ +#!/bin/bash + +if [ "$#" -ne 2 ]; then + echo "This script copies an imglab XML file and its associated images to a new folder." + echo "Notably, it will avoid copying unnecessary images." + echo "Call this script like this:" + echo " ./copy_dataset some_file.xml dest_dir" + exit 1 +fi + +XML_FILE=$1 +DEST=$2 + + + +mkdir -p $DEST + +# Get the list of files we need to copy, then build the cp statements with 1000 files at most in each statement, then tell bash to run them all. +imglab --files $XML_FILE | xargs perl -e 'use File::Spec; foreach (@ARGV) {print File::Spec->abs2rel($_) . "\n"}' | sort | uniq | xargs -L1000 echo | xargs -I{} echo cp -a --parents {} $DEST | bash + +convert_imglab_paths_to_relative $XML_FILE > $DEST/$(basename $XML_FILE) + diff --git a/ml/dlib/tools/imglab/src/cluster.cpp b/ml/dlib/tools/imglab/src/cluster.cpp new file mode 100644 index 000000000..23b289a7f --- /dev/null +++ b/ml/dlib/tools/imglab/src/cluster.cpp @@ -0,0 +1,260 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "cluster.h" +#include <dlib/console_progress_indicator.h> +#include <dlib/image_io.h> +#include <dlib/data_io.h> +#include <dlib/image_transforms.h> +#include <dlib/misc_api.h> +#include <dlib/dir_nav.h> +#include <dlib/clustering.h> +#include <dlib/svm.h> + +// ---------------------------------------------------------------------------------------- + +using namespace std; +using namespace dlib; + +// ---------------------------------------------------------------------------- + +struct assignment +{ + unsigned long c; + double dist; + unsigned long idx; + + bool operator<(const assignment& item) const + { return dist < item.dist; } +}; + +std::vector<assignment> angular_cluster ( + std::vector<matrix<double,0,1> > feats, + const unsigned long num_clusters +) +{ + DLIB_CASSERT(feats.size() != 0, "The dataset can't be empty"); + for (unsigned long i = 0; i < feats.size(); ++i) + { + DLIB_CASSERT(feats[i].size() == feats[0].size(), "All feature vectors must have the same length."); + } + + // find the centroid of feats + matrix<double,0,1> m; + for (unsigned long i = 0; i < feats.size(); ++i) + m += feats[i]; + m /= feats.size(); + + // Now center feats and then project onto the unit sphere. The reason for projecting + // onto the unit sphere is so pick_initial_centers() works in a sensible way. + for (unsigned long i = 0; i < feats.size(); ++i) + { + feats[i] -= m; + double len = length(feats[i]); + if (len != 0) + feats[i] /= len; + } + + // now do angular clustering of the points + std::vector<matrix<double,0,1> > centers; + pick_initial_centers(num_clusters, centers, feats, linear_kernel<matrix<double,0,1> >(), 0.05); + find_clusters_using_angular_kmeans(feats, centers); + + // and then report the resulting assignments + std::vector<assignment> assignments; + for (unsigned long i = 0; i < feats.size(); ++i) + { + assignment temp; + temp.c = nearest_center(centers, feats[i]); + temp.dist = length(feats[i] - centers[temp.c]); + temp.idx = i; + assignments.push_back(temp); + } + return assignments; +} + +// ---------------------------------------------------------------------------------------- + +bool compare_first ( + const std::pair<double,image_dataset_metadata::image>& a, + const std::pair<double,image_dataset_metadata::image>& b +) +{ + return a.first < b.first; +} + +// ---------------------------------------------------------------------------------------- + +double mean_aspect_ratio ( + const image_dataset_metadata::dataset& data +) +{ + double sum = 0; + double cnt = 0; + for (unsigned long i = 0; i < data.images.size(); ++i) + { + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + rectangle rect = data.images[i].boxes[j].rect; + if (rect.area() == 0 || data.images[i].boxes[j].ignore) + continue; + sum += rect.width()/(double)rect.height(); + ++cnt; + } + } + + if (cnt != 0) + return sum/cnt; + else + return 0; +} + +// ---------------------------------------------------------------------------------------- + +bool has_non_ignored_boxes (const image_dataset_metadata::image& img) +{ + for (auto&& b : img.boxes) + { + if (!b.ignore) + return true; + } + return false; +} + +// ---------------------------------------------------------------------------------------- + +int cluster_dataset( + const dlib::command_line_parser& parser +) +{ + // make sure the user entered an argument to this program + if (parser.number_of_arguments() != 1) + { + cerr << "The --cluster option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + const unsigned long num_clusters = get_option(parser, "cluster", 2); + const unsigned long chip_size = get_option(parser, "size", 8000); + + image_dataset_metadata::dataset data; + + image_dataset_metadata::load_image_dataset_metadata(data, parser[0]); + set_current_dir(get_parent_directory(file(parser[0]))); + + const double aspect_ratio = mean_aspect_ratio(data); + + dlib::array<array2d<rgb_pixel> > images; + std::vector<matrix<double,0,1> > feats; + console_progress_indicator pbar(data.images.size()); + // extract all the object chips and HOG features. + cout << "Loading image data..." << endl; + for (unsigned long i = 0; i < data.images.size(); ++i) + { + pbar.print_status(i); + if (!has_non_ignored_boxes(data.images[i])) + continue; + + array2d<rgb_pixel> img, chip; + load_image(img, data.images[i].filename); + + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (data.images[i].boxes[j].ignore || data.images[i].boxes[j].rect.area() < 10) + continue; + drectangle rect = data.images[i].boxes[j].rect; + rect = set_aspect_ratio(rect, aspect_ratio); + extract_image_chip(img, chip_details(rect, chip_size), chip); + feats.push_back(extract_fhog_features(chip)); + images.push_back(chip); + } + } + + if (feats.size() == 0) + { + cerr << "No non-ignored object boxes found in the XML dataset. You can't cluster an empty dataset." << endl; + return EXIT_FAILURE; + } + + cout << "\nClustering objects..." << endl; + std::vector<assignment> assignments = angular_cluster(feats, num_clusters); + + + // Now output each cluster to disk as an XML file. + for (unsigned long c = 0; c < num_clusters; ++c) + { + // We are going to accumulate all the image metadata for cluster c. We put it + // into idata so we can sort the images such that images with central chips + // come before less central chips. The idea being to get the good chips to + // show up first in the listing, making it easy to manually remove bad ones if + // that is desired. + std::vector<std::pair<double,image_dataset_metadata::image> > idata(data.images.size()); + unsigned long idx = 0; + for (unsigned long i = 0; i < data.images.size(); ++i) + { + idata[i].first = std::numeric_limits<double>::infinity(); + idata[i].second.filename = data.images[i].filename; + if (!has_non_ignored_boxes(data.images[i])) + continue; + + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + idata[i].second.boxes.push_back(data.images[i].boxes[j]); + + if (data.images[i].boxes[j].ignore || data.images[i].boxes[j].rect.area() < 10) + continue; + + // If this box goes into cluster c then update the score for the whole + // image based on this boxes' score. Otherwise, mark the box as + // ignored. + if (assignments[idx].c == c) + idata[i].first = std::min(idata[i].first, assignments[idx].dist); + else + idata[i].second.boxes.back().ignore = true; + + ++idx; + } + } + + // now save idata to an xml file. + std::sort(idata.begin(), idata.end(), compare_first); + image_dataset_metadata::dataset cdata; + cdata.comment = data.comment + "\n\n This file contains objects which were clustered into group " + + cast_to_string(c+1) + " of " + cast_to_string(num_clusters) + " groups with a chip size of " + + cast_to_string(chip_size) + " by imglab."; + cdata.name = data.name; + for (unsigned long i = 0; i < idata.size(); ++i) + { + // if this image has non-ignored boxes in it then include it in the output. + if (idata[i].first != std::numeric_limits<double>::infinity()) + cdata.images.push_back(idata[i].second); + } + + string outfile = "cluster_"+pad_int_with_zeros(c+1, 3) + ".xml"; + cout << "Saving " << outfile << endl; + save_image_dataset_metadata(cdata, outfile); + } + + // Now output each cluster to disk as a big tiled jpeg file. Sort everything so, just + // like in the xml file above, the best objects come first in the tiling. + std::sort(assignments.begin(), assignments.end()); + for (unsigned long c = 0; c < num_clusters; ++c) + { + dlib::array<array2d<rgb_pixel> > temp; + for (unsigned long i = 0; i < assignments.size(); ++i) + { + if (assignments[i].c == c) + temp.push_back(images[assignments[i].idx]); + } + + string outfile = "cluster_"+pad_int_with_zeros(c+1, 3) + ".jpg"; + cout << "Saving " << outfile << endl; + save_jpeg(tile_images(temp), outfile); + } + + + return EXIT_SUCCESS; +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/imglab/src/cluster.h b/ml/dlib/tools/imglab/src/cluster.h new file mode 100644 index 000000000..6cb41a373 --- /dev/null +++ b/ml/dlib/tools/imglab/src/cluster.h @@ -0,0 +1,11 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMGLAB_ClUSTER_H_ +#define DLIB_IMGLAB_ClUSTER_H_ + +#include <dlib/cmd_line_parser.h> + +int cluster_dataset(const dlib::command_line_parser& parser); + +#endif //DLIB_IMGLAB_ClUSTER_H_ + diff --git a/ml/dlib/tools/imglab/src/common.cpp b/ml/dlib/tools/imglab/src/common.cpp new file mode 100644 index 000000000..d9cc1dca4 --- /dev/null +++ b/ml/dlib/tools/imglab/src/common.cpp @@ -0,0 +1,60 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "common.h" +#include <fstream> +#include <dlib/error.h> + +// ---------------------------------------------------------------------------------------- + +std::string strip_path ( + const std::string& str, + const std::string& prefix +) +{ + unsigned long i; + for (i = 0; i < str.size() && i < prefix.size(); ++i) + { + if (str[i] != prefix[i]) + return str; + } + + if (i < str.size() && (str[i] == '/' || str[i] == '\\')) + ++i; + + return str.substr(i); +} + +// ---------------------------------------------------------------------------------------- + +void make_empty_file ( + const std::string& filename +) +{ + std::ofstream fout(filename.c_str()); + if (!fout) + throw dlib::error("ERROR: Unable to open " + filename + " for writing."); +} + +// ---------------------------------------------------------------------------------------- + +std::string to_png_name (const std::string& filename) +{ + std::string::size_type pos = filename.find_last_of("."); + if (pos == std::string::npos) + throw dlib::error("invalid filename: " + filename); + return filename.substr(0,pos) + ".png"; +} + +// ---------------------------------------------------------------------------------------- + +std::string to_jpg_name (const std::string& filename) +{ + std::string::size_type pos = filename.find_last_of("."); + if (pos == std::string::npos) + throw dlib::error("invalid filename: " + filename); + return filename.substr(0,pos) + ".jpg"; +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/imglab/src/common.h b/ml/dlib/tools/imglab/src/common.h new file mode 100644 index 000000000..42e905bc3 --- /dev/null +++ b/ml/dlib/tools/imglab/src/common.h @@ -0,0 +1,45 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMGLAB_COmMON_H__ +#define DLIB_IMGLAB_COmMON_H__ + +#include <string> + +// ---------------------------------------------------------------------------------------- + +std::string strip_path ( + const std::string& str, + const std::string& prefix +); +/*! + ensures + - if (prefix is a prefix of str) then + - returns the part of str after the prefix + (additionally, str will not begin with a / or \ character) + - else + - return str +!*/ + +// ---------------------------------------------------------------------------------------- + +void make_empty_file ( + const std::string& filename +); +/*! + ensures + - creates an empty file of the given name +!*/ + +// ---------------------------------------------------------------------------------------- + +std::string to_png_name (const std::string& filename); +std::string to_jpg_name (const std::string& filename); + +// ---------------------------------------------------------------------------------------- + +const int JPEG_QUALITY = 90; + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_IMGLAB_COmMON_H__ + diff --git a/ml/dlib/tools/imglab/src/convert_idl.cpp b/ml/dlib/tools/imglab/src/convert_idl.cpp new file mode 100644 index 000000000..7ff601d0c --- /dev/null +++ b/ml/dlib/tools/imglab/src/convert_idl.cpp @@ -0,0 +1,184 @@ + +#include "convert_idl.h" +#include "dlib/data_io.h" +#include <iostream> +#include <string> +#include <dlib/dir_nav.h> +#include <dlib/time_this.h> +#include <dlib/cmd_line_parser.h> + +using namespace std; +using namespace dlib; + +namespace +{ + using namespace dlib::image_dataset_metadata; + +// ---------------------------------------------------------------------------------------- + + inline bool next_is_number(std::istream& in) + { + return ('0' <= in.peek() && in.peek() <= '9') || in.peek() == '-' || in.peek() == '+'; + } + + int read_int(std::istream& in) + { + bool is_neg = false; + if (in.peek() == '-') + { + is_neg = true; + in.get(); + } + if (in.peek() == '+') + in.get(); + + int val = 0; + while ('0' <= in.peek() && in.peek() <= '9') + { + val = 10*val + in.get()-'0'; + } + + if (is_neg) + return -val; + else + return val; + } + +// ---------------------------------------------------------------------------------------- + + void parse_annotation_file( + const std::string& file, + dlib::image_dataset_metadata::dataset& data + ) + { + ifstream fin(file.c_str()); + if (!fin) + throw dlib::error("Unable to open file " + file); + + + bool in_quote = false; + int point_count = 0; + bool in_point_list = false; + bool saw_any_points = false; + + image img; + string label; + point p1,p2; + while (fin.peek() != EOF) + { + if (in_point_list && next_is_number(fin)) + { + const int val = read_int(fin); + switch (point_count) + { + case 0: p1.x() = val; break; + case 1: p1.y() = val; break; + case 2: p2.x() = val; break; + case 3: p2.y() = val; break; + default: + throw dlib::error("parse error in file " + file); + } + + ++point_count; + } + + char ch = fin.get(); + + if (ch == ':') + continue; + + if (ch == '"') + { + in_quote = !in_quote; + continue; + } + + if (in_quote) + { + img.filename += ch; + continue; + } + + + if (ch == '(') + { + in_point_list = true; + point_count = 0; + label.clear(); + saw_any_points = true; + } + if (ch == ')') + { + in_point_list = false; + + label.clear(); + while (fin.peek() != EOF && + fin.peek() != ';' && + fin.peek() != ',') + { + char ch = fin.get(); + if (ch == ':') + continue; + + label += ch; + } + } + + if (ch == ',' && !in_point_list) + { + + box b; + b.rect = rectangle(p1,p2); + b.label = label; + img.boxes.push_back(b); + } + + + if (ch == ';') + { + + if (saw_any_points) + { + box b; + b.rect = rectangle(p1,p2); + b.label = label; + img.boxes.push_back(b); + saw_any_points = false; + } + data.images.push_back(img); + + + img.filename.clear(); + img.boxes.clear(); + } + + + } + + + + } + +// ---------------------------------------------------------------------------------------- + +} + +void convert_idl( + const command_line_parser& parser +) +{ + cout << "Convert from IDL annotation format..." << endl; + + dlib::image_dataset_metadata::dataset dataset; + + for (unsigned long i = 0; i < parser.number_of_arguments(); ++i) + { + parse_annotation_file(parser[i], dataset); + } + + const std::string filename = parser.option("c").argument(); + save_image_dataset_metadata(dataset, filename); +} + + + diff --git a/ml/dlib/tools/imglab/src/convert_idl.h b/ml/dlib/tools/imglab/src/convert_idl.h new file mode 100644 index 000000000..d8c33d961 --- /dev/null +++ b/ml/dlib/tools/imglab/src/convert_idl.h @@ -0,0 +1,14 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMGLAB_CONVErT_IDL_H__ +#define DLIB_IMGLAB_CONVErT_IDL_H__ + +#include "common.h" +#include <dlib/cmd_line_parser.h> + +void convert_idl(const dlib::command_line_parser& parser); + +#endif // DLIB_IMGLAB_CONVErT_IDL_H__ + + + diff --git a/ml/dlib/tools/imglab/src/convert_pascal_v1.cpp b/ml/dlib/tools/imglab/src/convert_pascal_v1.cpp new file mode 100644 index 000000000..8eaf5e2bb --- /dev/null +++ b/ml/dlib/tools/imglab/src/convert_pascal_v1.cpp @@ -0,0 +1,177 @@ + +#include "convert_pascal_v1.h" +#include "dlib/data_io.h" +#include <iostream> +#include <string> +#include <dlib/dir_nav.h> +#include <dlib/time_this.h> + +using namespace std; +using namespace dlib; + +namespace +{ + using namespace dlib::image_dataset_metadata; + +// ---------------------------------------------------------------------------------------- + + std::string pick_out_quoted_string ( + const std::string& str + ) + { + std::string temp; + bool in_quotes = false; + for (unsigned long i = 0; i < str.size(); ++i) + { + if (str[i] == '"') + { + in_quotes = !in_quotes; + } + else if (in_quotes) + { + temp += str[i]; + } + } + + return temp; + } + +// ---------------------------------------------------------------------------------------- + + void parse_annotation_file( + const std::string& file, + dlib::image_dataset_metadata::image& img, + std::string& dataset_name + ) + { + ifstream fin(file.c_str()); + if (!fin) + throw dlib::error("Unable to open file " + file); + + img = dlib::image_dataset_metadata::image(); + + string str, line; + std::vector<string> words; + while (fin.peek() != EOF) + { + getline(fin, line); + words = split(line, " \r\n\t:(,-)\""); + if (words.size() > 2) + { + if (words[0] == "#") + continue; + + if (words[0] == "Image" && words[1] == "filename") + { + img.filename = pick_out_quoted_string(line); + } + else if (words[0] == "Database") + { + dataset_name = pick_out_quoted_string(line); + } + else if (words[0] == "Objects" && words[1] == "with" && words.size() >= 5) + { + const int num = sa = words[4]; + img.boxes.resize(num); + } + else if (words.size() > 4 && (words[2] == "for" || words[2] == "on") && words[3] == "object") + { + long idx = sa = words[4]; + --idx; + if (idx >= (long)img.boxes.size()) + throw dlib::error("Invalid object id number of " + words[4]); + + if (words[0] == "Center" && words[1] == "point" && words.size() > 9) + { + const long x = sa = words[8]; + const long y = sa = words[9]; + img.boxes[idx].parts["head"] = point(x,y); + } + else if (words[0] == "Bounding" && words[1] == "box" && words.size() > 13) + { + rectangle rect; + img.boxes[idx].rect.left() = sa = words[10]; + img.boxes[idx].rect.top() = sa = words[11]; + img.boxes[idx].rect.right() = sa = words[12]; + img.boxes[idx].rect.bottom() = sa = words[13]; + } + else if (words[0] == "Original" && words[1] == "label" && words.size() > 6) + { + img.boxes[idx].label = words[6]; + } + } + } + + } + } + +// ---------------------------------------------------------------------------------------- + + std::string figure_out_full_path_to_image ( + const std::string& annotation_file, + const std::string& image_name + ) + { + directory parent = get_parent_directory(file(annotation_file)); + + + string temp; + while (true) + { + if (parent.is_root()) + temp = parent.full_name() + image_name; + else + temp = parent.full_name() + directory::get_separator() + image_name; + + if (file_exists(temp)) + return temp; + + if (parent.is_root()) + throw dlib::error("Can't figure out where the file " + image_name + " is located."); + parent = get_parent_directory(parent); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +void convert_pascal_v1( + const command_line_parser& parser +) +{ + cout << "Convert from PASCAL v1.00 annotation format..." << endl; + + dlib::image_dataset_metadata::dataset dataset; + + std::string name; + dlib::image_dataset_metadata::image img; + + const std::string filename = parser.option("c").argument(); + // make sure the file exists so we can use the get_parent_directory() command to + // figure out it's parent directory. + make_empty_file(filename); + const std::string parent_dir = get_parent_directory(file(filename)).full_name(); + + for (unsigned long i = 0; i < parser.number_of_arguments(); ++i) + { + try + { + parse_annotation_file(parser[i], img, name); + + dataset.name = name; + img.filename = strip_path(figure_out_full_path_to_image(parser[i], img.filename), parent_dir); + dataset.images.push_back(img); + + } + catch (exception& ) + { + cout << "Error while processing file " << parser[i] << endl << endl; + throw; + } + } + + save_image_dataset_metadata(dataset, filename); +} + + diff --git a/ml/dlib/tools/imglab/src/convert_pascal_v1.h b/ml/dlib/tools/imglab/src/convert_pascal_v1.h new file mode 100644 index 000000000..3553d03a7 --- /dev/null +++ b/ml/dlib/tools/imglab/src/convert_pascal_v1.h @@ -0,0 +1,13 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMGLAB_CONVERT_PASCAl_V1_H__ +#define DLIB_IMGLAB_CONVERT_PASCAl_V1_H__ + +#include "common.h" +#include <dlib/cmd_line_parser.h> + +void convert_pascal_v1(const dlib::command_line_parser& parser); + +#endif // DLIB_IMGLAB_CONVERT_PASCAl_V1_H__ + + diff --git a/ml/dlib/tools/imglab/src/convert_pascal_xml.cpp b/ml/dlib/tools/imglab/src/convert_pascal_xml.cpp new file mode 100644 index 000000000..c699d7777 --- /dev/null +++ b/ml/dlib/tools/imglab/src/convert_pascal_xml.cpp @@ -0,0 +1,239 @@ + +#include "convert_pascal_xml.h" +#include "dlib/data_io.h" +#include <iostream> +#include <dlib/xml_parser.h> +#include <string> +#include <dlib/dir_nav.h> +#include <dlib/cmd_line_parser.h> + +using namespace std; +using namespace dlib; + +namespace +{ + using namespace dlib::image_dataset_metadata; + +// ---------------------------------------------------------------------------------------- + + class doc_handler : public document_handler + { + image& temp_image; + std::string& dataset_name; + + std::vector<std::string> ts; + box temp_box; + + public: + + doc_handler( + image& temp_image_, + std::string& dataset_name_ + ): + temp_image(temp_image_), + dataset_name(dataset_name_) + {} + + + virtual void start_document ( + ) + { + ts.clear(); + temp_image = image(); + temp_box = box(); + dataset_name.clear(); + } + + virtual void end_document ( + ) + { + } + + virtual void start_element ( + const unsigned long , + const std::string& name, + const dlib::attribute_list& + ) + { + if (ts.size() == 0 && name != "annotation") + { + std::ostringstream sout; + sout << "Invalid XML document. Root tag must be <annotation>. Found <" << name << "> instead."; + throw dlib::error(sout.str()); + } + + + ts.push_back(name); + } + + virtual void end_element ( + const unsigned long , + const std::string& name + ) + { + ts.pop_back(); + if (ts.size() == 0) + return; + + if (name == "object" && ts.back() == "annotation") + { + temp_image.boxes.push_back(temp_box); + temp_box = box(); + } + } + + virtual void characters ( + const std::string& data + ) + { + if (ts.size() == 2 && ts[1] == "filename") + { + temp_image.filename = trim(data); + } + else if (ts.size() == 3 && ts[2] == "database" && ts[1] == "source") + { + dataset_name = trim(data); + } + else if (ts.size() >= 3) + { + if (ts[ts.size()-2] == "bndbox" && ts[ts.size()-3] == "object") + { + if (ts.back() == "xmin") temp_box.rect.left() = string_cast<double>(data); + else if (ts.back() == "ymin") temp_box.rect.top() = string_cast<double>(data); + else if (ts.back() == "xmax") temp_box.rect.right() = string_cast<double>(data); + else if (ts.back() == "ymax") temp_box.rect.bottom() = string_cast<double>(data); + } + else if (ts.back() == "name" && ts[ts.size()-2] == "object") + { + temp_box.label = trim(data); + } + else if (ts.back() == "difficult" && ts[ts.size()-2] == "object") + { + if (trim(data) == "0" || trim(data) == "false") + { + temp_box.difficult = false; + } + else + { + temp_box.difficult = true; + } + } + else if (ts.back() == "truncated" && ts[ts.size()-2] == "object") + { + if (trim(data) == "0" || trim(data) == "false") + { + temp_box.truncated = false; + } + else + { + temp_box.truncated = true; + } + } + else if (ts.back() == "occluded" && ts[ts.size()-2] == "object") + { + if (trim(data) == "0" || trim(data) == "false") + { + temp_box.occluded = false; + } + else + { + temp_box.occluded = true; + } + } + + } + } + + virtual void processing_instruction ( + const unsigned long , + const std::string& , + const std::string& + ) + { + } + }; + +// ---------------------------------------------------------------------------------------- + + class xml_error_handler : public error_handler + { + public: + virtual void error ( + const unsigned long + ) { } + + virtual void fatal_error ( + const unsigned long line_number + ) + { + std::ostringstream sout; + sout << "There is a fatal error on line " << line_number << " so parsing will now halt."; + throw dlib::error(sout.str()); + } + }; + +// ---------------------------------------------------------------------------------------- + + void parse_annotation_file( + const std::string& file, + dlib::image_dataset_metadata::image& img, + std::string& dataset_name + ) + { + doc_handler dh(img, dataset_name); + xml_error_handler eh; + + xml_parser::kernel_1a parser; + parser.add_document_handler(dh); + parser.add_error_handler(eh); + + ifstream fin(file.c_str()); + if (!fin) + throw dlib::error("Unable to open file " + file); + parser.parse(fin); + } + +// ---------------------------------------------------------------------------------------- + +} + +void convert_pascal_xml( + const command_line_parser& parser +) +{ + cout << "Convert from PASCAL XML annotation format..." << endl; + + dlib::image_dataset_metadata::dataset dataset; + + std::string name; + dlib::image_dataset_metadata::image img; + + const std::string filename = parser.option("c").argument(); + // make sure the file exists so we can use the get_parent_directory() command to + // figure out it's parent directory. + make_empty_file(filename); + const std::string parent_dir = get_parent_directory(file(filename)).full_name(); + + for (unsigned long i = 0; i < parser.number_of_arguments(); ++i) + { + try + { + parse_annotation_file(parser[i], img, name); + const string root = get_parent_directory(get_parent_directory(file(parser[i]))).full_name(); + const string img_path = root + directory::get_separator() + "JPEGImages" + directory::get_separator(); + + dataset.name = name; + img.filename = strip_path(img_path + img.filename, parent_dir); + dataset.images.push_back(img); + + } + catch (exception& ) + { + cout << "Error while processing file " << parser[i] << endl << endl; + throw; + } + } + + save_image_dataset_metadata(dataset, filename); +} + diff --git a/ml/dlib/tools/imglab/src/convert_pascal_xml.h b/ml/dlib/tools/imglab/src/convert_pascal_xml.h new file mode 100644 index 000000000..01ee1e82f --- /dev/null +++ b/ml/dlib/tools/imglab/src/convert_pascal_xml.h @@ -0,0 +1,12 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMGLAB_CONVERT_PASCAl_XML_H__ +#define DLIB_IMGLAB_CONVERT_PASCAl_XML_H__ + +#include "common.h" +#include <dlib/cmd_line_parser.h> + +void convert_pascal_xml(const dlib::command_line_parser& parser); + +#endif // DLIB_IMGLAB_CONVERT_PASCAl_XML_H__ + diff --git a/ml/dlib/tools/imglab/src/flip_dataset.cpp b/ml/dlib/tools/imglab/src/flip_dataset.cpp new file mode 100644 index 000000000..e072dc790 --- /dev/null +++ b/ml/dlib/tools/imglab/src/flip_dataset.cpp @@ -0,0 +1,249 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "flip_dataset.h" +#include <dlib/data_io.h> +#include <dlib/dir_nav.h> +#include <string> +#include "common.h" +#include <dlib/image_transforms.h> +#include <dlib/optimization.h> +#include <dlib/image_processing.h> + +using namespace dlib; +using namespace std; + +// ---------------------------------------------------------------------------------------- + +std::vector<long> align_points( + const std::vector<dpoint>& from, + const std::vector<dpoint>& to, + double min_angle = -90*pi/180.0, + double max_angle = 90*pi/180.0, + long num_angles = 181 +) +/*! + ensures + - Figures out how to align the points in from with the points in to. Returns an + assignment array A that indicates that from[i] matches with to[A[i]]. + + We use the Hungarian algorithm with a search over reasonable angles. This method + works because we just need to account for a translation and a mild rotation and + nothing else. If there is any other more complex mapping then you probably don't + have landmarks that make sense to flip. +!*/ +{ + DLIB_CASSERT(from.size() == to.size()); + + std::vector<long> best_assignment; + double best_assignment_cost = std::numeric_limits<double>::infinity(); + + matrix<double> dists(from.size(), to.size()); + matrix<long long> idists; + + for (auto angle : linspace(min_angle, max_angle, num_angles)) + { + auto rot = rotation_matrix(angle); + for (long r = 0; r < dists.nr(); ++r) + { + for (long c = 0; c < dists.nc(); ++c) + { + dists(r,c) = length_squared(rot*from[r]-to[c]); + } + } + + idists = matrix_cast<long long>(-round(std::numeric_limits<long long>::max()*(dists/max(dists)))); + + auto assignment = max_cost_assignment(idists); + auto cost = assignment_cost(dists, assignment); + if (cost < best_assignment_cost) + { + best_assignment_cost = cost; + best_assignment = std::move(assignment); + } + } + + + // Now compute the alignment error in terms of average distance moved by each part. We + // do this so we can give the user a warning if it's impossible to make a good + // alignment. + running_stats<double> rs; + std::vector<dpoint> tmp(to.size()); + for (size_t i = 0; i < to.size(); ++i) + tmp[best_assignment[i]] = to[i]; + auto tform = find_similarity_transform(from, tmp); + for (size_t i = 0; i < from.size(); ++i) + rs.add(length(tform(from[i])-tmp[i])); + if (rs.mean() > 0.05) + { + cout << "WARNING, your dataset has object part annotations and you asked imglab to " << endl; + cout << "flip the data. Imglab tried to adjust the part labels so that the average" << endl; + cout << "part layout in the flipped dataset is the same as the source dataset. " << endl; + cout << "However, the part annotation scheme doesn't seem to be left-right symmetric." << endl; + cout << "You should manually review the output to make sure the part annotations are " << endl; + cout << "labeled as you expect." << endl; + } + + + return best_assignment; +} + +// ---------------------------------------------------------------------------------------- + +std::map<string,dpoint> normalized_parts ( + const image_dataset_metadata::box& b +) +{ + auto tform = dlib::impl::normalizing_tform(b.rect); + std::map<string,dpoint> temp; + for (auto& p : b.parts) + temp[p.first] = tform(p.second); + return temp; +} + +// ---------------------------------------------------------------------------------------- + +std::map<string,dpoint> average_parts ( + const image_dataset_metadata::dataset& data +) +/*! + ensures + - returns the average part layout over all objects in data. This is done by + centering the parts inside their rects and then averaging all the objects. +!*/ +{ + std::map<string,dpoint> psum; + std::map<string,double> pcnt; + for (auto& image : data.images) + { + for (auto& box : image.boxes) + { + for (auto& p : normalized_parts(box)) + { + psum[p.first] += p.second; + pcnt[p.first] += 1; + } + } + } + + // make into an average + for (auto& p : psum) + p.second /= pcnt[p.first]; + + return psum; +} + +// ---------------------------------------------------------------------------------------- + +void make_part_labeling_match_target_dataset ( + const image_dataset_metadata::dataset& target, + image_dataset_metadata::dataset& data +) +/*! + This function tries to adjust the part labels in data so that the average part layout + in data is the same as target, according to the string labels. Therefore, it doesn't + adjust part positions, instead it changes the string labels on the parts to achieve + this. This really only makes sense when you flipped a dataset that contains left-right + symmetric objects and you want to remap the part labels of the flipped data so that + they match the unflipped data's annotation scheme. +!*/ +{ + auto target_parts = average_parts(target); + auto data_parts = average_parts(data); + + // Convert to a form align_points() understands. We also need to keep track of the + // labels for later. + std::vector<dpoint> from, to; + std::vector<string> from_labels, to_labels; + for (auto& p : target_parts) + { + from_labels.emplace_back(p.first); + from.emplace_back(p.second); + } + for (auto& p : data_parts) + { + to_labels.emplace_back(p.first); + to.emplace_back(p.second); + } + + auto assignment = align_points(from, to); + // so now we know that from_labels[i] should replace to_labels[assignment[i]] + std::map<string,string> label_mapping; + for (size_t i = 0; i < assignment.size(); ++i) + label_mapping[to_labels[assignment[i]]] = from_labels[i]; + + // now apply the label mapping to the dataset + for (auto& image : data.images) + { + for (auto& box : image.boxes) + { + std::map<string,point> temp; + for (auto& p : box.parts) + temp[label_mapping[p.first]] = p.second; + box.parts = std::move(temp); + } + } +} + +// ---------------------------------------------------------------------------------------- + +void flip_dataset(const command_line_parser& parser) +{ + image_dataset_metadata::dataset metadata, orig_metadata; + string datasource; + if (parser.option("flip")) + datasource = parser.option("flip").argument(); + else + datasource = parser.option("flip-basic").argument(); + load_image_dataset_metadata(metadata,datasource); + orig_metadata = metadata; + + // Set the current directory to be the one that contains the + // metadata file. We do this because the file might contain + // file paths which are relative to this folder. + set_current_dir(get_parent_directory(file(datasource))); + + const string metadata_filename = get_parent_directory(file(datasource)).full_name() + + directory::get_separator() + "flipped_" + file(datasource).name(); + + + array2d<rgb_pixel> img, temp; + for (unsigned long i = 0; i < metadata.images.size(); ++i) + { + file f(metadata.images[i].filename); + string filename = get_parent_directory(f).full_name() + directory::get_separator() + "flipped_" + to_png_name(f.name()); + + load_image(img, metadata.images[i].filename); + flip_image_left_right(img, temp); + if (parser.option("jpg")) + { + filename = to_jpg_name(filename); + save_jpeg(temp, filename,JPEG_QUALITY); + } + else + { + save_png(temp, filename); + } + + for (unsigned long j = 0; j < metadata.images[i].boxes.size(); ++j) + { + metadata.images[i].boxes[j].rect = impl::flip_rect_left_right(metadata.images[i].boxes[j].rect, get_rect(img)); + + // flip all the object parts + for (auto& part : metadata.images[i].boxes[j].parts) + { + part.second = impl::flip_rect_left_right(rectangle(part.second,part.second), get_rect(img)).tl_corner(); + } + } + + metadata.images[i].filename = filename; + } + + if (!parser.option("flip-basic")) + make_part_labeling_match_target_dataset(orig_metadata, metadata); + + save_image_dataset_metadata(metadata, metadata_filename); +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/imglab/src/flip_dataset.h b/ml/dlib/tools/imglab/src/flip_dataset.h new file mode 100644 index 000000000..8ac5db3e8 --- /dev/null +++ b/ml/dlib/tools/imglab/src/flip_dataset.h @@ -0,0 +1,12 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMGLAB_FLIP_DaTASET_H__ +#define DLIB_IMGLAB_FLIP_DaTASET_H__ + + +#include <dlib/cmd_line_parser.h> + +void flip_dataset(const dlib::command_line_parser& parser); + +#endif // DLIB_IMGLAB_FLIP_DaTASET_H__ + diff --git a/ml/dlib/tools/imglab/src/main.cpp b/ml/dlib/tools/imglab/src/main.cpp new file mode 100644 index 000000000..060c2c870 --- /dev/null +++ b/ml/dlib/tools/imglab/src/main.cpp @@ -0,0 +1,1145 @@ + +#include "dlib/data_io.h" +#include "dlib/string.h" +#include "metadata_editor.h" +#include "convert_pascal_xml.h" +#include "convert_pascal_v1.h" +#include "convert_idl.h" +#include "cluster.h" +#include "flip_dataset.h" +#include <dlib/cmd_line_parser.h> +#include <dlib/image_transforms.h> +#include <dlib/svm.h> +#include <dlib/console_progress_indicator.h> +#include <dlib/md5.h> + +#include <iostream> +#include <fstream> +#include <string> +#include <set> + +#include <dlib/dir_nav.h> + + +const char* VERSION = "1.13"; + + + +using namespace std; +using namespace dlib; + +// ---------------------------------------------------------------------------------------- + +void create_new_dataset ( + const command_line_parser& parser +) +{ + using namespace dlib::image_dataset_metadata; + + const std::string filename = parser.option("c").argument(); + // make sure the file exists so we can use the get_parent_directory() command to + // figure out it's parent directory. + make_empty_file(filename); + const std::string parent_dir = get_parent_directory(file(filename)); + + unsigned long depth = 0; + if (parser.option("r")) + depth = 30; + + dataset meta; + meta.name = "imglab dataset"; + meta.comment = "Created by imglab tool."; + for (unsigned long i = 0; i < parser.number_of_arguments(); ++i) + { + try + { + const string temp = strip_path(file(parser[i]), parent_dir); + meta.images.push_back(image(temp)); + } + catch (dlib::file::file_not_found&) + { + // then parser[i] should be a directory + + std::vector<file> files = get_files_in_directory_tree(parser[i], + match_endings(".png .PNG .jpeg .JPEG .jpg .JPG .bmp .BMP .dng .DNG .gif .GIF"), + depth); + sort(files.begin(), files.end()); + + for (unsigned long j = 0; j < files.size(); ++j) + { + meta.images.push_back(image(strip_path(files[j], parent_dir))); + } + } + } + + save_image_dataset_metadata(meta, filename); +} + +// ---------------------------------------------------------------------------------------- + +int split_dataset ( + const command_line_parser& parser +) +{ + if (parser.number_of_arguments() != 1) + { + cerr << "The --split option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + const std::string label = parser.option("split").argument(); + + dlib::image_dataset_metadata::dataset data, data_with, data_without; + load_image_dataset_metadata(data, parser[0]); + + data_with.name = data.name; + data_with.comment = data.comment; + data_without.name = data.name; + data_without.comment = data.comment; + + for (unsigned long i = 0; i < data.images.size(); ++i) + { + auto&& temp = data.images[i]; + + bool has_the_label = false; + // check for the label we are looking for + for (unsigned long j = 0; j < temp.boxes.size(); ++j) + { + if (temp.boxes[j].label == label) + { + has_the_label = true; + break; + } + } + + if (has_the_label) + data_with.images.push_back(temp); + else + data_without.images.push_back(temp); + } + + + save_image_dataset_metadata(data_with, left_substr(parser[0],".") + "_with_"+label + ".xml"); + save_image_dataset_metadata(data_without, left_substr(parser[0],".") + "_without_"+label + ".xml"); + + return EXIT_SUCCESS; +} + +// ---------------------------------------------------------------------------------------- + +void print_all_labels ( + const dlib::image_dataset_metadata::dataset& data +) +{ + std::set<std::string> labels; + for (unsigned long i = 0; i < data.images.size(); ++i) + { + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + labels.insert(data.images[i].boxes[j].label); + } + } + + for (std::set<std::string>::iterator i = labels.begin(); i != labels.end(); ++i) + { + if (i->size() != 0) + { + cout << *i << endl; + } + } +} + +// ---------------------------------------------------------------------------------------- + +void print_all_label_stats ( + const dlib::image_dataset_metadata::dataset& data +) +{ + std::map<std::string, running_stats<double> > area_stats, aspect_ratio; + std::map<std::string, int> image_hits; + std::set<std::string> labels; + unsigned long num_unignored_boxes = 0; + for (unsigned long i = 0; i < data.images.size(); ++i) + { + std::set<std::string> temp; + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + labels.insert(data.images[i].boxes[j].label); + temp.insert(data.images[i].boxes[j].label); + + area_stats[data.images[i].boxes[j].label].add(data.images[i].boxes[j].rect.area()); + aspect_ratio[data.images[i].boxes[j].label].add(data.images[i].boxes[j].rect.width()/ + (double)data.images[i].boxes[j].rect.height()); + + if (!data.images[i].boxes[j].ignore) + ++num_unignored_boxes; + } + + // count the number of images for each label + for (std::set<std::string>::iterator i = temp.begin(); i != temp.end(); ++i) + image_hits[*i] += 1; + } + + cout << "Number of images: "<< data.images.size() << endl; + cout << "Number of different labels: "<< labels.size() << endl; + cout << "Number of non-ignored boxes: " << num_unignored_boxes << endl << endl; + + for (std::set<std::string>::iterator i = labels.begin(); i != labels.end(); ++i) + { + if (i->size() == 0) + cout << "Unlabeled Boxes:" << endl; + else + cout << "Label: "<< *i << endl; + cout << " number of images: " << image_hits[*i] << endl; + cout << " number of occurrences: " << area_stats[*i].current_n() << endl; + cout << " min box area: " << area_stats[*i].min() << endl; + cout << " max box area: " << area_stats[*i].max() << endl; + cout << " mean box area: " << area_stats[*i].mean() << endl; + cout << " stddev box area: " << area_stats[*i].stddev() << endl; + cout << " mean width/height ratio: " << aspect_ratio[*i].mean() << endl; + cout << " stddev width/height ratio: " << aspect_ratio[*i].stddev() << endl; + cout << endl; + } +} + +// ---------------------------------------------------------------------------------------- + +void rename_labels ( + dlib::image_dataset_metadata::dataset& data, + const std::string& from, + const std::string& to +) +{ + for (unsigned long i = 0; i < data.images.size(); ++i) + { + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (data.images[i].boxes[j].label == from) + data.images[i].boxes[j].label = to; + } + } + +} + +// ---------------------------------------------------------------------------------------- + +void ignore_labels ( + dlib::image_dataset_metadata::dataset& data, + const std::string& label +) +{ + for (unsigned long i = 0; i < data.images.size(); ++i) + { + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (data.images[i].boxes[j].label == label) + data.images[i].boxes[j].ignore = true; + } + } +} + +// ---------------------------------------------------------------------------------------- + +void merge_metadata_files ( + const command_line_parser& parser +) +{ + image_dataset_metadata::dataset src, dest; + load_image_dataset_metadata(src, parser.option("add").argument(0)); + load_image_dataset_metadata(dest, parser.option("add").argument(1)); + + std::map<string,image_dataset_metadata::image> merged_data; + for (unsigned long i = 0; i < dest.images.size(); ++i) + merged_data[dest.images[i].filename] = dest.images[i]; + // now add in the src data and overwrite anything if there are duplicate entries. + for (unsigned long i = 0; i < src.images.size(); ++i) + merged_data[src.images[i].filename] = src.images[i]; + + // copy merged data into dest + dest.images.clear(); + for (std::map<string,image_dataset_metadata::image>::const_iterator i = merged_data.begin(); + i != merged_data.end(); ++i) + { + dest.images.push_back(i->second); + } + + save_image_dataset_metadata(dest, "merged.xml"); +} + +// ---------------------------------------------------------------------------------------- + +void rotate_dataset(const command_line_parser& parser) +{ + image_dataset_metadata::dataset metadata; + const string datasource = parser[0]; + load_image_dataset_metadata(metadata,datasource); + + double angle = get_option(parser, "rotate", 0); + + // Set the current directory to be the one that contains the + // metadata file. We do this because the file might contain + // file paths which are relative to this folder. + set_current_dir(get_parent_directory(file(datasource))); + + const string file_prefix = "rotated_"+ cast_to_string(angle) + "_"; + const string metadata_filename = get_parent_directory(file(datasource)).full_name() + + directory::get_separator() + file_prefix + file(datasource).name(); + + + array2d<rgb_pixel> img, temp; + for (unsigned long i = 0; i < metadata.images.size(); ++i) + { + file f(metadata.images[i].filename); + string filename = get_parent_directory(f).full_name() + directory::get_separator() + file_prefix + to_png_name(f.name()); + + load_image(img, metadata.images[i].filename); + const point_transform_affine tran = rotate_image(img, temp, angle*pi/180); + if (parser.option("jpg")) + { + filename = to_jpg_name(filename); + save_jpeg(temp, filename,JPEG_QUALITY); + } + else + { + save_png(temp, filename); + } + + for (unsigned long j = 0; j < metadata.images[i].boxes.size(); ++j) + { + const rectangle rect = metadata.images[i].boxes[j].rect; + rectangle newrect; + newrect += tran(rect.tl_corner()); + newrect += tran(rect.tr_corner()); + newrect += tran(rect.bl_corner()); + newrect += tran(rect.br_corner()); + // now make newrect have the same area as the starting rect. + double ratio = std::sqrt(rect.area()/(double)newrect.area()); + newrect = centered_rect(newrect, newrect.width()*ratio, newrect.height()*ratio); + metadata.images[i].boxes[j].rect = newrect; + + // rotate all the object parts + std::map<std::string,point>::iterator k; + for (k = metadata.images[i].boxes[j].parts.begin(); k != metadata.images[i].boxes[j].parts.end(); ++k) + { + k->second = tran(k->second); + } + } + + metadata.images[i].filename = filename; + } + + save_image_dataset_metadata(metadata, metadata_filename); +} + +// ---------------------------------------------------------------------------------------- + +int resample_dataset(const command_line_parser& parser) +{ + if (parser.number_of_arguments() != 1) + { + cerr << "The --resample option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + const size_t obj_size = get_option(parser,"cropped-object-size",100*100); + const double margin_scale = get_option(parser,"crop-size",2.5); // cropped image will be this times wider than the object. + const unsigned long min_object_size = get_option(parser,"min-object-size",1); + const bool one_object_per_image = parser.option("one-object-per-image"); + + dlib::image_dataset_metadata::dataset data, resampled_data; + std::ostringstream sout; + sout << "\nThe --resample parameters which generated this dataset were:" << endl; + sout << " cropped-object-size: "<< obj_size << endl; + sout << " crop-size: "<< margin_scale << endl; + sout << " min-object-size: "<< min_object_size << endl; + if (one_object_per_image) + sout << " one_object_per_image: true" << endl; + resampled_data.comment = data.comment + sout.str(); + resampled_data.name = data.name + " RESAMPLED"; + + load_image_dataset_metadata(data, parser[0]); + locally_change_current_dir chdir(get_parent_directory(file(parser[0]))); + dlib::rand rnd; + + const size_t image_size = std::round(std::sqrt(obj_size*margin_scale*margin_scale)); + const chip_dims cdims(image_size, image_size); + + console_progress_indicator pbar(data.images.size()); + for (unsigned long i = 0; i < data.images.size(); ++i) + { + // don't even bother loading images that don't have objects. + if (data.images[i].boxes.size() == 0) + continue; + + pbar.print_status(i); + array2d<rgb_pixel> img, chip; + load_image(img, data.images[i].filename); + + + // figure out what chips we want to take from this image + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + const rectangle rect = data.images[i].boxes[j].rect; + if (data.images[i].boxes[j].ignore || rect.area() < min_object_size) + continue; + + const auto max_dim = std::max(rect.width(), rect.height()); + + const double rand_scale_perturb = 1 - 0.3*(rnd.get_random_double()-0.5); + const rectangle crop_rect = centered_rect(rect, max_dim*margin_scale*rand_scale_perturb, max_dim*margin_scale*rand_scale_perturb); + + const rectangle_transform tform = get_mapping_to_chip(chip_details(crop_rect, cdims)); + extract_image_chip(img, chip_details(crop_rect, cdims), chip); + + image_dataset_metadata::image dimg; + // Now transform the boxes to the crop and also mark them as ignored if they + // have already been cropped out or are outside the crop. + for (size_t k = 0; k < data.images[i].boxes.size(); ++k) + { + image_dataset_metadata::box box = data.images[i].boxes[k]; + // ignore boxes outside the cropped image + if (crop_rect.intersect(box.rect).area() == 0) + continue; + + // mark boxes we include in the crop as ignored. Also mark boxes that + // aren't totally within the crop as ignored. + if (crop_rect.contains(grow_rect(box.rect,10)) && (!one_object_per_image || k==j)) + data.images[i].boxes[k].ignore = true; + else + box.ignore = true; + + if (box.rect.area() < min_object_size) + box.ignore = true; + + box.rect = tform(box.rect); + for (auto&& p : box.parts) + p.second = tform.get_tform()(p.second); + dimg.boxes.push_back(box); + } + // Put a 64bit hash of the image data into the name to make sure there are no + // file name conflicts. + std::ostringstream sout; + sout << hex << murmur_hash3_128bit(&chip[0][0], chip.size()*sizeof(chip[0][0])).second; + dimg.filename = data.images[i].filename + "_RESAMPLED_"+sout.str()+".png"; + + if (parser.option("jpg")) + { + dimg.filename = to_jpg_name(dimg.filename); + save_jpeg(chip,dimg.filename, JPEG_QUALITY); + } + else + { + save_png(chip,dimg.filename); + } + resampled_data.images.push_back(dimg); + } + } + + save_image_dataset_metadata(resampled_data, parser[0] + ".RESAMPLED.xml"); + + return EXIT_SUCCESS; +} + +// ---------------------------------------------------------------------------------------- + +int tile_dataset(const command_line_parser& parser) +{ + if (parser.number_of_arguments() != 1) + { + cerr << "The --tile option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + string out_image = parser.option("tile").argument(); + string ext = right_substr(out_image,"."); + if (ext != "png" && ext != "jpg") + { + cerr << "The output image file must have either .png or .jpg extension." << endl; + return EXIT_FAILURE; + } + + const unsigned long chip_size = get_option(parser, "size", 8000); + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + locally_change_current_dir chdir(get_parent_directory(file(parser[0]))); + dlib::array<array2d<rgb_pixel> > images; + console_progress_indicator pbar(data.images.size()); + for (unsigned long i = 0; i < data.images.size(); ++i) + { + // don't even bother loading images that don't have objects. + if (data.images[i].boxes.size() == 0) + continue; + + pbar.print_status(i); + array2d<rgb_pixel> img; + load_image(img, data.images[i].filename); + + // figure out what chips we want to take from this image + std::vector<chip_details> dets; + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (data.images[i].boxes[j].ignore) + continue; + + rectangle rect = data.images[i].boxes[j].rect; + dets.push_back(chip_details(rect, chip_size)); + } + // Now grab all those chips at once. + dlib::array<array2d<rgb_pixel> > chips; + extract_image_chips(img, dets, chips); + // and put the chips into the output. + for (unsigned long j = 0; j < chips.size(); ++j) + images.push_back(chips[j]); + } + + chdir.revert(); + + if (ext == "png") + save_png(tile_images(images), out_image); + else + save_jpeg(tile_images(images), out_image); + + return EXIT_SUCCESS; +} + + +// ---------------------------------------------------------------------------------------- + +int main(int argc, char** argv) +{ + try + { + + command_line_parser parser; + + parser.add_option("h","Displays this information."); + parser.add_option("v","Display version."); + + parser.set_group_name("Creating XML files"); + parser.add_option("c","Create an XML file named <arg> listing a set of images.",1); + parser.add_option("r","Search directories recursively for images."); + parser.add_option("convert","Convert foreign image Annotations from <arg> format to the imglab format. " + "Supported formats: pascal-xml, pascal-v1, idl.",1); + + parser.set_group_name("Viewing XML files"); + parser.add_option("tile","Chip out all the objects and save them as one big image called <arg>.",1); + parser.add_option("size","When using --tile or --cluster, make each extracted object contain " + "about <arg> pixels (default 8000).",1); + parser.add_option("l","List all the labels in the given XML file."); + parser.add_option("stats","List detailed statistics on the object labels in the given XML file."); + parser.add_option("files","List all the files in the given XML file."); + + parser.set_group_name("Editing/Transforming XML datasets"); + parser.add_option("rename", "Rename all labels of <arg1> to <arg2>.",2); + parser.add_option("parts","The display will allow image parts to be labeled. The set of allowable parts " + "is defined by <arg> which should be a space separated list of parts.",1); + parser.add_option("rmempty","Remove all images that don't contain non-ignored annotations and save the results to a new XML file."); + parser.add_option("rmdupes","Remove duplicate images from the dataset. This is done by comparing " + "the md5 hash of each image file and removing duplicate images. " ); + parser.add_option("rmdiff","Set the ignored flag to true for boxes marked as difficult."); + parser.add_option("rmtrunc","Set the ignored flag to true for boxes that are partially outside the image."); + parser.add_option("sort-num-objects","Sort the images listed an XML file so images with many objects are listed first."); + parser.add_option("sort","Alphabetically sort the images in an XML file."); + parser.add_option("shuffle","Randomly shuffle the order of the images listed in an XML file."); + parser.add_option("seed", "When using --shuffle, set the random seed to the string <arg>.",1); + parser.add_option("split", "Split the contents of an XML file into two separate files. One containing the " + "images with objects labeled <arg> and another file with all the other images. ",1); + parser.add_option("add", "Add the image metadata from <arg1> into <arg2>. If any of the image " + "tags are in both files then the ones in <arg2> are deleted and replaced with the " + "image tags from <arg1>. The results are saved into merged.xml and neither <arg1> or " + "<arg2> files are modified.",2); + parser.add_option("flip", "Read an XML image dataset from the <arg> XML file and output a left-right flipped " + "version of the dataset and an accompanying flipped XML file named flipped_<arg>. " + "We also adjust object part labels after flipping so that the new flipped dataset " + "has the same average part layout as the source dataset." ,1); + parser.add_option("flip-basic", "This option is just like --flip, except we don't adjust any object part labels after flipping. " + "The parts are instead simply mirrored to the flipped dataset.", 1); + parser.add_option("rotate", "Read an XML image dataset and output a copy that is rotated counter clockwise by <arg> degrees. " + "The output is saved to an XML file prefixed with rotated_<arg>.",1); + parser.add_option("cluster", "Cluster all the objects in an XML file into <arg> different clusters and save " + "the results as cluster_###.xml and cluster_###.jpg files.",1); + parser.add_option("ignore", "Mark boxes labeled as <arg> as ignored. The resulting XML file is output as a separate file and the original is not modified.",1); + parser.add_option("rmlabel","Remove all boxes labeled <arg> and save the results to a new XML file.",1); + parser.add_option("rm-other-labels","Remove all boxes not labeled <arg> and save the results to a new XML file.",1); + parser.add_option("rmignore","Remove all boxes marked ignore and save the results to a new XML file."); + parser.add_option("rm-if-overlaps","Remove all boxes labeled <arg> if they overlap any box not labeled <arg> and save the results to a new XML file.",1); + parser.add_option("jpg", "When saving images to disk, write them as jpg files instead of png."); + + parser.set_group_name("Cropping sub images"); + parser.add_option("resample", "Crop out images that are centered on each object in the dataset. " + "The output is a new XML dataset."); + parser.add_option("cropped-object-size", "When doing --resample, make the cropped objects contain about <arg> pixels (default 10000).",1); + parser.add_option("min-object-size", "When doing --resample, skip objects that have fewer than <arg> pixels in them (default 1).",1); + parser.add_option("crop-size", "When doing --resample, the entire cropped image will be <arg> times wider than the object (default 2.5).",1); + parser.add_option("one-object-per-image", "When doing --resample, only include one non-ignored object per image (i.e. the central object)."); + + + + parser.parse(argc, argv); + + const char* singles[] = {"h","c","r","l","files","convert","parts","rmdiff", "rmtrunc", "rmdupes", "seed", "shuffle", "split", "add", + "flip-basic", "flip", "rotate", "tile", "size", "cluster", "resample", "min-object-size", "rmempty", + "crop-size", "cropped-object-size", "rmlabel", "rm-other-labels", "rm-if-overlaps", "sort-num-objects", + "one-object-per-image", "jpg", "rmignore", "sort"}; + parser.check_one_time_options(singles); + const char* c_sub_ops[] = {"r", "convert"}; + parser.check_sub_options("c", c_sub_ops); + parser.check_sub_option("shuffle", "seed"); + const char* resample_sub_ops[] = {"min-object-size", "crop-size", "cropped-object-size", "one-object-per-image"}; + parser.check_sub_options("resample", resample_sub_ops); + const char* size_parent_ops[] = {"tile", "cluster"}; + parser.check_sub_options(size_parent_ops, "size"); + parser.check_incompatible_options("c", "l"); + parser.check_incompatible_options("c", "files"); + parser.check_incompatible_options("c", "rmdiff"); + parser.check_incompatible_options("c", "rmempty"); + parser.check_incompatible_options("c", "rmlabel"); + parser.check_incompatible_options("c", "rm-other-labels"); + parser.check_incompatible_options("c", "rmignore"); + parser.check_incompatible_options("c", "rm-if-overlaps"); + parser.check_incompatible_options("c", "rmdupes"); + parser.check_incompatible_options("c", "rmtrunc"); + parser.check_incompatible_options("c", "add"); + parser.check_incompatible_options("c", "flip"); + parser.check_incompatible_options("c", "flip-basic"); + parser.check_incompatible_options("flip", "flip-basic"); + parser.check_incompatible_options("c", "rotate"); + parser.check_incompatible_options("c", "rename"); + parser.check_incompatible_options("c", "ignore"); + parser.check_incompatible_options("c", "parts"); + parser.check_incompatible_options("c", "tile"); + parser.check_incompatible_options("c", "cluster"); + parser.check_incompatible_options("c", "resample"); + parser.check_incompatible_options("l", "rename"); + parser.check_incompatible_options("l", "ignore"); + parser.check_incompatible_options("l", "add"); + parser.check_incompatible_options("l", "parts"); + parser.check_incompatible_options("l", "flip"); + parser.check_incompatible_options("l", "flip-basic"); + parser.check_incompatible_options("l", "rotate"); + parser.check_incompatible_options("files", "rename"); + parser.check_incompatible_options("files", "ignore"); + parser.check_incompatible_options("files", "add"); + parser.check_incompatible_options("files", "parts"); + parser.check_incompatible_options("files", "flip"); + parser.check_incompatible_options("files", "flip-basic"); + parser.check_incompatible_options("files", "rotate"); + parser.check_incompatible_options("add", "flip"); + parser.check_incompatible_options("add", "flip-basic"); + parser.check_incompatible_options("add", "rotate"); + parser.check_incompatible_options("add", "tile"); + parser.check_incompatible_options("flip", "tile"); + parser.check_incompatible_options("flip-basic", "tile"); + parser.check_incompatible_options("rotate", "tile"); + parser.check_incompatible_options("cluster", "tile"); + parser.check_incompatible_options("resample", "tile"); + parser.check_incompatible_options("flip", "cluster"); + parser.check_incompatible_options("flip-basic", "cluster"); + parser.check_incompatible_options("rotate", "cluster"); + parser.check_incompatible_options("add", "cluster"); + parser.check_incompatible_options("flip", "resample"); + parser.check_incompatible_options("flip-basic", "resample"); + parser.check_incompatible_options("rotate", "resample"); + parser.check_incompatible_options("add", "resample"); + parser.check_incompatible_options("shuffle", "tile"); + parser.check_incompatible_options("sort-num-objects", "tile"); + parser.check_incompatible_options("sort", "tile"); + parser.check_incompatible_options("convert", "l"); + parser.check_incompatible_options("convert", "files"); + parser.check_incompatible_options("convert", "rename"); + parser.check_incompatible_options("convert", "ignore"); + parser.check_incompatible_options("convert", "parts"); + parser.check_incompatible_options("convert", "cluster"); + parser.check_incompatible_options("convert", "resample"); + parser.check_incompatible_options("rmdiff", "rename"); + parser.check_incompatible_options("rmdiff", "ignore"); + parser.check_incompatible_options("rmempty", "ignore"); + parser.check_incompatible_options("rmempty", "rename"); + parser.check_incompatible_options("rmlabel", "ignore"); + parser.check_incompatible_options("rmlabel", "rename"); + parser.check_incompatible_options("rm-other-labels", "ignore"); + parser.check_incompatible_options("rm-other-labels", "rename"); + parser.check_incompatible_options("rmignore", "ignore"); + parser.check_incompatible_options("rmignore", "rename"); + parser.check_incompatible_options("rm-if-overlaps", "ignore"); + parser.check_incompatible_options("rm-if-overlaps", "rename"); + parser.check_incompatible_options("rmdupes", "rename"); + parser.check_incompatible_options("rmdupes", "ignore"); + parser.check_incompatible_options("rmtrunc", "rename"); + parser.check_incompatible_options("rmtrunc", "ignore"); + const char* convert_args[] = {"pascal-xml","pascal-v1","idl"}; + parser.check_option_arg_range("convert", convert_args); + parser.check_option_arg_range("cluster", 2, 999); + parser.check_option_arg_range("rotate", -360, 360); + parser.check_option_arg_range("size", 10*10, 1000*1000); + parser.check_option_arg_range("min-object-size", 1, 10000*10000); + parser.check_option_arg_range("cropped-object-size", 4, 10000*10000); + parser.check_option_arg_range("crop-size", 1.0, 100.0); + + if (parser.option("h")) + { + cout << "Usage: imglab [options] <image files/directories or XML file>\n"; + parser.print_options(cout); + cout << endl << endl; + return EXIT_SUCCESS; + } + + if (parser.option("add")) + { + merge_metadata_files(parser); + return EXIT_SUCCESS; + } + + if (parser.option("flip") || parser.option("flip-basic")) + { + flip_dataset(parser); + return EXIT_SUCCESS; + } + + if (parser.option("rotate")) + { + rotate_dataset(parser); + return EXIT_SUCCESS; + } + + if (parser.option("v")) + { + cout << "imglab v" << VERSION + << "\nCompiled: " << __TIME__ << " " << __DATE__ + << "\nWritten by Davis King\n"; + cout << "Check for updates at http://dlib.net\n\n"; + return EXIT_SUCCESS; + } + + if (parser.option("tile")) + { + return tile_dataset(parser); + } + + if (parser.option("cluster")) + { + return cluster_dataset(parser); + } + + if (parser.option("resample")) + { + return resample_dataset(parser); + } + + if (parser.option("c")) + { + if (parser.option("convert")) + { + if (parser.option("convert").argument() == "pascal-xml") + convert_pascal_xml(parser); + else if (parser.option("convert").argument() == "pascal-v1") + convert_pascal_v1(parser); + else if (parser.option("convert").argument() == "idl") + convert_idl(parser); + } + else + { + create_new_dataset(parser); + } + return EXIT_SUCCESS; + } + + if (parser.option("rmdiff")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rmdiff option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + for (unsigned long i = 0; i < data.images.size(); ++i) + { + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (data.images[i].boxes[j].difficult) + data.images[i].boxes[j].ignore = true; + } + } + save_image_dataset_metadata(data, parser[0]); + return EXIT_SUCCESS; + } + + if (parser.option("rmempty")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rmempty option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data, data2; + load_image_dataset_metadata(data, parser[0]); + + data2 = data; + data2.images.clear(); + for (unsigned long i = 0; i < data.images.size(); ++i) + { + bool has_label = false; + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (!data.images[i].boxes[j].ignore) + has_label = true; + } + if (has_label) + data2.images.push_back(data.images[i]); + } + save_image_dataset_metadata(data2, parser[0] + ".rmempty.xml"); + return EXIT_SUCCESS; + } + + if (parser.option("rmlabel")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rmlabel option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + + const auto label = parser.option("rmlabel").argument(); + + for (auto&& img : data.images) + { + std::vector<dlib::image_dataset_metadata::box> boxes; + for (auto&& b : img.boxes) + { + if (b.label != label) + boxes.push_back(b); + } + img.boxes = boxes; + } + + save_image_dataset_metadata(data, parser[0] + ".rmlabel-"+label+".xml"); + return EXIT_SUCCESS; + } + + if (parser.option("rm-other-labels")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rm-other-labels option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + + const auto labels = parser.option("rm-other-labels").argument(); + // replace comma by dash to form the file name + std::string strlabels = labels; + std::replace(strlabels.begin(), strlabels.end(), ',', '-'); + std::vector<string> all_labels = split(labels, ","); + for (auto&& img : data.images) + { + std::vector<dlib::image_dataset_metadata::box> boxes; + for (auto&& b : img.boxes) + { + if (std::find(all_labels.begin(), all_labels.end(), b.label) != all_labels.end()) + boxes.push_back(b); + } + img.boxes = boxes; + } + + save_image_dataset_metadata(data, parser[0] + ".rm-other-labels-"+ strlabels +".xml"); + return EXIT_SUCCESS; + } + + if (parser.option("rmignore")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rmignore option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + + for (auto&& img : data.images) + { + std::vector<dlib::image_dataset_metadata::box> boxes; + for (auto&& b : img.boxes) + { + if (!b.ignore) + boxes.push_back(b); + } + img.boxes = boxes; + } + + save_image_dataset_metadata(data, parser[0] + ".rmignore.xml"); + return EXIT_SUCCESS; + } + + if (parser.option("rm-if-overlaps")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rm-if-overlaps option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + + const auto label = parser.option("rm-if-overlaps").argument(); + + test_box_overlap overlaps(0.5); + + for (auto&& img : data.images) + { + std::vector<dlib::image_dataset_metadata::box> boxes; + for (auto&& b : img.boxes) + { + if (b.label != label) + { + boxes.push_back(b); + } + else + { + bool has_overlap = false; + for (auto&& b2 : img.boxes) + { + if (b2.label != label && overlaps(b2.rect, b.rect)) + { + has_overlap = true; + break; + } + } + if (!has_overlap) + boxes.push_back(b); + } + } + img.boxes = boxes; + } + + save_image_dataset_metadata(data, parser[0] + ".rm-if-overlaps-"+label+".xml"); + return EXIT_SUCCESS; + } + + if (parser.option("rmdupes")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rmdupes option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data, data_out; + std::set<std::string> hashes; + load_image_dataset_metadata(data, parser[0]); + data_out = data; + data_out.images.clear(); + + for (unsigned long i = 0; i < data.images.size(); ++i) + { + ifstream fin(data.images[i].filename.c_str(), ios::binary); + string hash = md5(fin); + if (hashes.count(hash) == 0) + { + hashes.insert(hash); + data_out.images.push_back(data.images[i]); + } + } + save_image_dataset_metadata(data_out, parser[0]); + return EXIT_SUCCESS; + } + + if (parser.option("rmtrunc")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rmtrunc option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + { + locally_change_current_dir chdir(get_parent_directory(file(parser[0]))); + for (unsigned long i = 0; i < data.images.size(); ++i) + { + array2d<unsigned char> img; + load_image(img, data.images[i].filename); + const rectangle area = get_rect(img); + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (!area.contains(data.images[i].boxes[j].rect)) + data.images[i].boxes[j].ignore = true; + } + } + } + save_image_dataset_metadata(data, parser[0]); + return EXIT_SUCCESS; + } + + if (parser.option("l")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The -l option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + print_all_labels(data); + return EXIT_SUCCESS; + } + + if (parser.option("files")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --files option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + for (size_t i = 0; i < data.images.size(); ++i) + cout << data.images[i].filename << "\n"; + return EXIT_SUCCESS; + } + + if (parser.option("split")) + { + return split_dataset(parser); + } + + if (parser.option("shuffle")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --shuffle option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + const string default_seed = cast_to_string(time(0)); + const string seed = get_option(parser, "seed", default_seed); + dlib::rand rnd(seed); + randomize_samples(data.images, rnd); + save_image_dataset_metadata(data, parser[0]); + return EXIT_SUCCESS; + } + + if (parser.option("sort-num-objects")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --sort-num-objects option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + std::sort(data.images.rbegin(), data.images.rend(), + [](const image_dataset_metadata::image& a, const image_dataset_metadata::image& b) { return a.boxes.size() < b.boxes.size(); }); + save_image_dataset_metadata(data, parser[0]); + return EXIT_SUCCESS; + } + + if (parser.option("sort")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --sort option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + std::sort(data.images.begin(), data.images.end(), + [](const image_dataset_metadata::image& a, const image_dataset_metadata::image& b) { return a.filename < b.filename; }); + save_image_dataset_metadata(data, parser[0]); + return EXIT_SUCCESS; + } + + if (parser.option("stats")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --stats option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + print_all_label_stats(data); + return EXIT_SUCCESS; + } + + if (parser.option("rename")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --rename option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + for (unsigned long i = 0; i < parser.option("rename").count(); ++i) + { + rename_labels(data, parser.option("rename").argument(0,i), parser.option("rename").argument(1,i)); + } + save_image_dataset_metadata(data, parser[0]); + return EXIT_SUCCESS; + } + + if (parser.option("ignore")) + { + if (parser.number_of_arguments() != 1) + { + cerr << "The --ignore option requires you to give one XML file on the command line." << endl; + return EXIT_FAILURE; + } + + dlib::image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, parser[0]); + for (unsigned long i = 0; i < parser.option("ignore").count(); ++i) + { + ignore_labels(data, parser.option("ignore").argument()); + } + save_image_dataset_metadata(data, parser[0]+".ignored.xml"); + return EXIT_SUCCESS; + } + + if (parser.number_of_arguments() == 1) + { + metadata_editor editor(parser[0]); + if (parser.option("parts")) + { + std::vector<string> parts = split(parser.option("parts").argument()); + for (unsigned long i = 0; i < parts.size(); ++i) + { + editor.add_labelable_part_name(parts[i]); + } + } + editor.wait_until_closed(); + return EXIT_SUCCESS; + } + + cout << "Invalid command, give -h to see options." << endl; + return EXIT_FAILURE; + } + catch (exception& e) + { + cerr << e.what() << endl; + return EXIT_FAILURE; + } +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/imglab/src/metadata_editor.cpp b/ml/dlib/tools/imglab/src/metadata_editor.cpp new file mode 100644 index 000000000..76177e893 --- /dev/null +++ b/ml/dlib/tools/imglab/src/metadata_editor.cpp @@ -0,0 +1,671 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "metadata_editor.h" +#include <dlib/array.h> +#include <dlib/queue.h> +#include <dlib/static_set.h> +#include <dlib/misc_api.h> +#include <dlib/image_io.h> +#include <dlib/array2d.h> +#include <dlib/pixel.h> +#include <dlib/image_transforms.h> +#include <dlib/image_processing.h> +#include <sstream> +#include <ctime> + +using namespace std; +using namespace dlib; + +extern const char* VERSION; + +// ---------------------------------------------------------------------------------------- + +metadata_editor:: +metadata_editor( + const std::string& filename_ +) : + mbar(*this), + lb_images(*this), + image_pos(0), + display(*this), + overlay_label_name(*this), + overlay_label(*this), + keyboard_jump_pos(0), + last_keyboard_jump_pos_update(0) +{ + file metadata_file(filename_); + filename = metadata_file.full_name(); + // Make our current directory be the one that contains the metadata file. We + // do this because that file might contain relative paths to the image files + // we are supposed to be loading. + set_current_dir(get_parent_directory(metadata_file).full_name()); + + load_image_dataset_metadata(metadata, filename); + + dlib::array<std::string>::expand_1a files; + files.resize(metadata.images.size()); + for (unsigned long i = 0; i < metadata.images.size(); ++i) + { + files[i] = metadata.images[i].filename; + } + lb_images.load(files); + lb_images.enable_multiple_select(); + + lb_images.set_click_handler(*this, &metadata_editor::on_lb_images_clicked); + + overlay_label_name.set_text("Next Label: "); + overlay_label.set_width(200); + + display.set_image_clicked_handler(*this, &metadata_editor::on_image_clicked); + display.set_overlay_rects_changed_handler(*this, &metadata_editor::on_overlay_rects_changed); + display.set_overlay_rect_selected_handler(*this, &metadata_editor::on_overlay_rect_selected); + overlay_label.set_text_modified_handler(*this, &metadata_editor::on_overlay_label_changed); + + mbar.set_number_of_menus(2); + mbar.set_menu_name(0,"File",'F'); + mbar.set_menu_name(1,"Help",'H'); + + + mbar.menu(0).add_menu_item(menu_item_text("Save",*this,&metadata_editor::file_save,'S')); + mbar.menu(0).add_menu_item(menu_item_text("Save As",*this,&metadata_editor::file_save_as,'A')); + mbar.menu(0).add_menu_item(menu_item_separator()); + mbar.menu(0).add_menu_item(menu_item_text("Remove Selected Images",*this,&metadata_editor::remove_selected_images,'R')); + mbar.menu(0).add_menu_item(menu_item_separator()); + mbar.menu(0).add_menu_item(menu_item_text("Exit",static_cast<base_window&>(*this),&drawable_window::close_window,'x')); + + mbar.menu(1).add_menu_item(menu_item_text("About",*this,&metadata_editor::display_about,'A')); + + // set the size of this window. + on_window_resized(); + load_image_and_set_size(0); + on_window_resized(); + if (image_pos < lb_images.size() ) + lb_images.select(image_pos); + + // make sure the window is centered on the screen. + unsigned long width, height; + get_size(width, height); + unsigned long screen_width, screen_height; + get_display_size(screen_width, screen_height); + set_pos((screen_width-width)/2, (screen_height-height)/2); + + show(); +} + +// ---------------------------------------------------------------------------------------- + +metadata_editor:: +~metadata_editor( +) +{ + close_window(); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +add_labelable_part_name ( + const std::string& name +) +{ + display.add_labelable_part_name(name); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +file_save() +{ + save_metadata_to_file(filename); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +save_metadata_to_file ( + const std::string& file +) +{ + try + { + save_image_dataset_metadata(metadata, file); + } + catch (dlib::error& e) + { + message_box("Error saving file", e.what()); + } +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +file_save_as() +{ + save_file_box(*this, &metadata_editor::save_metadata_to_file); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +remove_selected_images() +{ + dlib::queue<unsigned long>::kernel_1a list; + lb_images.get_selected(list); + list.reset(); + unsigned long min_idx = lb_images.size(); + while (list.move_next()) + { + lb_images.unselect(list.element()); + min_idx = std::min(min_idx, list.element()); + } + + + // remove all the selected items from metadata.images + dlib::static_set<unsigned long>::kernel_1a to_remove; + to_remove.load(list); + std::vector<dlib::image_dataset_metadata::image> images; + for (unsigned long i = 0; i < metadata.images.size(); ++i) + { + if (to_remove.is_member(i) == false) + { + images.push_back(metadata.images[i]); + } + } + images.swap(metadata.images); + + + // reload metadata into lb_images + dlib::array<std::string>::expand_1a files; + files.resize(metadata.images.size()); + for (unsigned long i = 0; i < metadata.images.size(); ++i) + { + files[i] = metadata.images[i].filename; + } + lb_images.load(files); + + + if (min_idx != 0) + min_idx--; + select_image(min_idx); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +on_window_resized( +) +{ + drawable_window::on_window_resized(); + + unsigned long width, height; + get_size(width, height); + + lb_images.set_pos(0,mbar.bottom()+1); + lb_images.set_size(180, height - mbar.height()); + + overlay_label_name.set_pos(lb_images.right()+10, mbar.bottom() + (overlay_label.height()-overlay_label_name.height())/2+1); + overlay_label.set_pos(overlay_label_name.right(), mbar.bottom()+1); + display.set_pos(lb_images.right(), overlay_label.bottom()+3); + + display.set_size(width - display.left(), height - display.top()); +} + +// ---------------------------------------------------------------------------------------- + +void propagate_boxes( + dlib::image_dataset_metadata::dataset& data, + unsigned long prev, + unsigned long next +) +{ + if (prev == next || next >= data.images.size()) + return; + + array2d<rgb_pixel> img1, img2; + dlib::load_image(img1, data.images[prev].filename); + dlib::load_image(img2, data.images[next].filename); + for (unsigned long i = 0; i < data.images[prev].boxes.size(); ++i) + { + correlation_tracker tracker; + tracker.start_track(img1, data.images[prev].boxes[i].rect); + tracker.update(img2); + dlib::image_dataset_metadata::box box = data.images[prev].boxes[i]; + box.rect = tracker.get_position(); + data.images[next].boxes.push_back(box); + } +} + +// ---------------------------------------------------------------------------------------- + +void propagate_labels( + const std::string& label, + dlib::image_dataset_metadata::dataset& data, + unsigned long prev, + unsigned long next +) +{ + if (prev == next || next >= data.images.size()) + return; + + + for (unsigned long i = 0; i < data.images[prev].boxes.size(); ++i) + { + if (data.images[prev].boxes[i].label != label) + continue; + + // figure out which box in the next image matches the current one the best + const rectangle cur = data.images[prev].boxes[i].rect; + double best_overlap = 0; + unsigned long best_idx = 0; + for (unsigned long j = 0; j < data.images[next].boxes.size(); ++j) + { + const rectangle next_box = data.images[next].boxes[j].rect; + const double overlap = cur.intersect(next_box).area()/(double)(cur+next_box).area(); + if (overlap > best_overlap) + { + best_overlap = overlap; + best_idx = j; + } + } + + // If we found a matching rectangle in the next image and the best match doesn't + // already have a label. + if (best_overlap > 0.5 && data.images[next].boxes[best_idx].label == "") + { + data.images[next].boxes[best_idx].label = label; + } + } + +} + +// ---------------------------------------------------------------------------------------- + +bool has_label_or_all_boxes_labeled ( + const std::string& label, + const dlib::image_dataset_metadata::image& img +) +{ + if (label.size() == 0) + return true; + + bool all_boxes_labeled = true; + for (unsigned long i = 0; i < img.boxes.size(); ++i) + { + if (img.boxes[i].label == label) + return true; + if (img.boxes[i].label.size() == 0) + all_boxes_labeled = false; + } + + return all_boxes_labeled; +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state +) +{ + drawable_window::on_keydown(key, is_printable, state); + + if (is_printable) + { + if (key == '\t') + { + overlay_label.give_input_focus(); + overlay_label.select_all_text(); + } + + // If the user types a number then jump to that image. + if ('0' <= key && key <= '9' && metadata.images.size() != 0 && !overlay_label.has_input_focus()) + { + time_t curtime = time(0); + // If it's been a while since the user typed numbers then forget the last jump + // position and start accumulating numbers over again. + if (curtime-last_keyboard_jump_pos_update >= 2) + keyboard_jump_pos = 0; + last_keyboard_jump_pos_update = curtime; + + keyboard_jump_pos *= 10; + keyboard_jump_pos += key-'0'; + if (keyboard_jump_pos >= metadata.images.size()) + keyboard_jump_pos = metadata.images.size()-1; + + image_pos = keyboard_jump_pos; + select_image(image_pos); + } + else + { + last_keyboard_jump_pos_update = 0; + } + + if (key == 'd' && (state&base_window::KBD_MOD_ALT)) + { + remove_selected_images(); + } + + if (key == 'e' && !overlay_label.has_input_focus()) + { + display_equialized_image = !display_equialized_image; + select_image(image_pos); + } + + // Make 'w' and 's' act like KEY_UP and KEY_DOWN + if ((key == 'w' || key == 'W') && !overlay_label.has_input_focus()) + { + key = base_window::KEY_UP; + } + else if ((key == 's' || key == 'S') && !overlay_label.has_input_focus()) + { + key = base_window::KEY_DOWN; + } + else + { + return; + } + } + + if (key == base_window::KEY_UP) + { + if ((state&KBD_MOD_CONTROL) && (state&KBD_MOD_SHIFT)) + { + // Don't do anything if there are no boxes in the current image. + if (metadata.images[image_pos].boxes.size() == 0) + return; + // Also don't do anything if there *are* boxes in the next image. + if (image_pos > 1 && metadata.images[image_pos-1].boxes.size() != 0) + return; + + propagate_boxes(metadata, image_pos, image_pos-1); + } + else if (state&base_window::KBD_MOD_CONTROL) + { + // If the label we are supposed to propagate doesn't exist in the current image + // then don't advance. + if (!has_label_or_all_boxes_labeled(display.get_default_overlay_rect_label(),metadata.images[image_pos])) + return; + + // if the next image is going to be empty then fast forward to the next one + while (image_pos > 1 && metadata.images[image_pos-1].boxes.size() == 0) + --image_pos; + + propagate_labels(display.get_default_overlay_rect_label(), metadata, image_pos, image_pos-1); + } + select_image(image_pos-1); + } + else if (key == base_window::KEY_DOWN) + { + if ((state&KBD_MOD_CONTROL) && (state&KBD_MOD_SHIFT)) + { + // Don't do anything if there are no boxes in the current image. + if (metadata.images[image_pos].boxes.size() == 0) + return; + // Also don't do anything if there *are* boxes in the next image. + if (image_pos+1 < metadata.images.size() && metadata.images[image_pos+1].boxes.size() != 0) + return; + + propagate_boxes(metadata, image_pos, image_pos+1); + } + else if (state&base_window::KBD_MOD_CONTROL) + { + // If the label we are supposed to propagate doesn't exist in the current image + // then don't advance. + if (!has_label_or_all_boxes_labeled(display.get_default_overlay_rect_label(),metadata.images[image_pos])) + return; + + // if the next image is going to be empty then fast forward to the next one + while (image_pos+1 < metadata.images.size() && metadata.images[image_pos+1].boxes.size() == 0) + ++image_pos; + + propagate_labels(display.get_default_overlay_rect_label(), metadata, image_pos, image_pos+1); + } + select_image(image_pos+1); + } +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +select_image( + unsigned long idx +) +{ + if (idx < lb_images.size()) + { + // unselect all currently selected images + dlib::queue<unsigned long>::kernel_1a list; + lb_images.get_selected(list); + list.reset(); + while (list.move_next()) + { + lb_images.unselect(list.element()); + } + + + lb_images.select(idx); + load_image(idx); + } + else if (lb_images.size() == 0) + { + display.clear_overlay(); + array2d<unsigned char> empty_img; + display.set_image(empty_img); + } +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +on_lb_images_clicked( + unsigned long idx +) +{ + load_image(idx); +} + +// ---------------------------------------------------------------------------------------- + +std::vector<dlib::image_display::overlay_rect> get_overlays ( + const dlib::image_dataset_metadata::image& data, + color_mapper& string_to_color +) +{ + std::vector<dlib::image_display::overlay_rect> temp(data.boxes.size()); + for (unsigned long i = 0; i < temp.size(); ++i) + { + temp[i].rect = data.boxes[i].rect; + temp[i].label = data.boxes[i].label; + temp[i].parts = data.boxes[i].parts; + temp[i].crossed_out = data.boxes[i].ignore; + temp[i].color = string_to_color(data.boxes[i].label); + } + return temp; +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +load_image( + unsigned long idx +) +{ + if (idx >= metadata.images.size()) + return; + + image_pos = idx; + + array2d<rgb_pixel> img; + display.clear_overlay(); + try + { + dlib::load_image(img, metadata.images[idx].filename); + set_title(metadata.name + " #"+cast_to_string(idx)+": " +metadata.images[idx].filename); + } + catch (exception& e) + { + message_box("Error loading image", e.what()); + } + + if (display_equialized_image) + equalize_histogram(img); + display.set_image(img); + display.add_overlay(get_overlays(metadata.images[idx], string_to_color)); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +load_image_and_set_size( + unsigned long idx +) +{ + if (idx >= metadata.images.size()) + return; + + image_pos = idx; + + array2d<rgb_pixel> img; + display.clear_overlay(); + try + { + dlib::load_image(img, metadata.images[idx].filename); + set_title(metadata.name + " #"+cast_to_string(idx)+": " +metadata.images[idx].filename); + } + catch (exception& e) + { + message_box("Error loading image", e.what()); + } + + + unsigned long screen_width, screen_height; + get_display_size(screen_width, screen_height); + + + unsigned long needed_width = display.left() + img.nc() + 4; + unsigned long needed_height = display.top() + img.nr() + 4; + if (needed_width < 300) needed_width = 300; + if (needed_height < 300) needed_height = 300; + + if (needed_width > 100 + screen_width) + needed_width = screen_width - 100; + if (needed_height > 100 + screen_height) + needed_height = screen_height - 100; + + set_size(needed_width, needed_height); + + + if (display_equialized_image) + equalize_histogram(img); + display.set_image(img); + display.add_overlay(get_overlays(metadata.images[idx], string_to_color)); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +on_overlay_rects_changed( +) +{ + using namespace dlib::image_dataset_metadata; + if (image_pos < metadata.images.size()) + { + const std::vector<image_display::overlay_rect>& rects = display.get_overlay_rects(); + + std::vector<box>& boxes = metadata.images[image_pos].boxes; + + boxes.clear(); + for (unsigned long i = 0; i < rects.size(); ++i) + { + box temp; + temp.label = rects[i].label; + temp.rect = rects[i].rect; + temp.parts = rects[i].parts; + temp.ignore = rects[i].crossed_out; + boxes.push_back(temp); + } + } +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +on_image_clicked( + const point& /*p*/, bool /*is_double_click*/, unsigned long /*btn*/ +) +{ + display.set_default_overlay_rect_color(string_to_color(trim(overlay_label.text()))); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +on_overlay_label_changed( +) +{ + display.set_default_overlay_rect_label(trim(overlay_label.text())); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +on_overlay_rect_selected( + const image_display::overlay_rect& orect +) +{ + overlay_label.set_text(orect.label); + display.set_default_overlay_rect_label(orect.label); + display.set_default_overlay_rect_color(string_to_color(orect.label)); +} + +// ---------------------------------------------------------------------------------------- + +void metadata_editor:: +display_about( +) +{ + std::ostringstream sout; + sout << wrap_string("Image Labeler v" + string(VERSION) + "." ,0,0) << endl << endl; + sout << wrap_string("This program is a tool for labeling images with rectangles. " ,0,0) << endl << endl; + + sout << wrap_string("You can add a new rectangle by holding the shift key, left clicking " + "the mouse, and dragging it. New rectangles are given the label from the \"Next Label\" " + "field at the top of the application. You can quickly edit the contents of the Next Label field " + "by hitting the tab key. Double clicking " + "a rectangle selects it and the delete key removes it. You can also mark " + "a rectangle as ignored by hitting the i or END keys when it is selected. Ignored " + "rectangles are visually displayed with an X through them. You can remove an image " + "entirely by selecting it in the list on the left and pressing alt+d." + ,0,0) << endl << endl; + + sout << wrap_string("It is also possible to label object parts by selecting a rectangle and " + "then right clicking. A popup menu will appear and you can select a part label. " + "Note that you must define the allowable part labels by giving --parts on the " + "command line. An example would be '--parts \"leye reye nose mouth\"'." + ,0,0) << endl << endl; + + sout << wrap_string("Press the down or s key to select the next image in the list and the up or w " + "key to select the previous one.",0,0) << endl << endl; + + sout << wrap_string("Additionally, you can hold ctrl and then scroll the mouse wheel to zoom. A normal left click " + "and drag allows you to navigate around the image. Holding ctrl and " + "left clicking a rectangle will give it the label from the Next Label field. " + "Holding shift + right click and then dragging allows you to move things around. " + "Holding ctrl and pressing the up or down keyboard keys will propagate " + "rectangle labels from one image to the next and also skip empty images. " + "Similarly, holding ctrl+shift will propagate entire boxes via a visual tracking " + "algorithm from one image to the next. " + "Finally, typing a number on the keyboard will jump you to a specific image.",0,0) << endl << endl; + + sout << wrap_string("You can also toggle image histogram equalization by pressing the e key." + ,0,0) << endl; + + + message_box("About Image Labeler",sout.str()); +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/imglab/src/metadata_editor.h b/ml/dlib/tools/imglab/src/metadata_editor.h new file mode 100644 index 000000000..71aa14ace --- /dev/null +++ b/ml/dlib/tools/imglab/src/metadata_editor.h @@ -0,0 +1,116 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_METADATA_EdITOR_H__ +#define DLIB_METADATA_EdITOR_H__ + +#include <dlib/gui_widgets.h> +#include <dlib/data_io.h> +#include <dlib/pixel.h> +#include <map> + +// ---------------------------------------------------------------------------------------- + +class color_mapper +{ +public: + + dlib::rgb_alpha_pixel operator() ( + const std::string& str + ) + { + auto i = colors.find(str); + if (i != colors.end()) + { + return i->second; + } + else + { + using namespace dlib; + hsi_pixel pix; + pix.h = reverse(colors.size()); + pix.s = 255; + pix.i = 150; + rgb_alpha_pixel result; + assign_pixel(result, pix); + colors[str] = result; + return result; + } + } + +private: + + // We use a bit reverse here because it causes us to evenly spread the colors as we + // allocated them. First the colors are maximally different, then become interleaved + // and progressively more similar as they are allocated. + unsigned char reverse(unsigned char b) + { + // reverse the order of the bits in b. + b = ((b * 0x0802LU & 0x22110LU) | (b * 0x8020LU & 0x88440LU)) * 0x10101LU >> 16; + return b; + } + + std::map<std::string, dlib::rgb_alpha_pixel> colors; +}; + +// ---------------------------------------------------------------------------------------- + +class metadata_editor : public dlib::drawable_window +{ +public: + metadata_editor( + const std::string& filename_ + ); + + ~metadata_editor(); + + void add_labelable_part_name ( + const std::string& name + ); + +private: + + void file_save(); + void file_save_as(); + void remove_selected_images(); + + virtual void on_window_resized(); + virtual void on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ); + + void on_lb_images_clicked(unsigned long idx); + void select_image(unsigned long idx); + void save_metadata_to_file (const std::string& file); + void load_image(unsigned long idx); + void load_image_and_set_size(unsigned long idx); + void on_image_clicked(const dlib::point& p, bool is_double_click, unsigned long btn); + void on_overlay_rects_changed(); + void on_overlay_label_changed(); + void on_overlay_rect_selected(const dlib::image_display::overlay_rect& orect); + + void display_about(); + + std::string filename; + dlib::image_dataset_metadata::dataset metadata; + + dlib::menu_bar mbar; + dlib::list_box lb_images; + unsigned long image_pos; + + dlib::image_display display; + dlib::label overlay_label_name; + dlib::text_field overlay_label; + + unsigned long keyboard_jump_pos; + time_t last_keyboard_jump_pos_update; + bool display_equialized_image = false; + color_mapper string_to_color; +}; + +// ---------------------------------------------------------------------------------------- + + +#endif // DLIB_METADATA_EdITOR_H__ + diff --git a/ml/dlib/tools/python/CMakeLists.txt b/ml/dlib/tools/python/CMakeLists.txt new file mode 100644 index 000000000..d3b947485 --- /dev/null +++ b/ml/dlib/tools/python/CMakeLists.txt @@ -0,0 +1,106 @@ + +CMAKE_MINIMUM_REQUIRED(VERSION 2.8.12) + +set(USE_SSE4_INSTRUCTIONS ON CACHE BOOL "Use SSE4 instructions") +# Make DLIB_ASSERT statements not abort the python interpreter, but just return an error. +add_definitions(-DDLIB_NO_ABORT_ON_2ND_FATAL_ERROR) + +# Set this to disable link time optimization. The only reason for +# doing this to make the compile faster which is nice when developing +# new modules. +#set(PYBIND11_LTO_CXX_FLAGS "") + + +# Avoid cmake warnings about changes in behavior of some Mac OS X path +# variable we don't care about. +if (POLICY CMP0042) + cmake_policy(SET CMP0042 NEW) +endif() + + +if (CMAKE_COMPILER_IS_GNUCXX) + # Just setting CMAKE_POSITION_INDEPENDENT_CODE should be enough to set + # -fPIC for GCC but sometimes it still doesn't get set, so make sure it + # does. + add_definitions("-fPIC") + set(CMAKE_POSITION_INDEPENDENT_CODE True) +else() + set(CMAKE_POSITION_INDEPENDENT_CODE True) +endif() + +# To avoid dll hell, always link everything statically when compiling in +# visual studio. This way, the resulting library won't depend on a bunch +# of other dll files and can be safely copied to someone elese's computer +# and expected to run. +if (MSVC) + include(${CMAKE_CURRENT_LIST_DIR}/../../dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake) +endif() + +add_subdirectory(../../dlib/external/pybind11 ./pybind11_build) +add_subdirectory(../../dlib ./dlib_build) + +if (USING_OLD_VISUAL_STUDIO_COMPILER) + message(FATAL_ERROR "You have to use a version of Visual Studio that supports C++11. As of December 2017, the only versions that have good enough C++11 support to compile the dlib Pyhton API is a fully updated Visual Studio 2015 or a fully updated Visual Studio 2017. Older versions of either of these compilers have bad C++11 support and will fail to compile the Python extension. ***SO UPDATE YOUR VISUAL STUDIO TO MAKE THIS ERROR GO AWAY***") +endif() + + +# Test for numpy +find_package(PythonInterp) +if(PYTHONINTERP_FOUND) + execute_process( COMMAND ${PYTHON_EXECUTABLE} -c "import numpy" OUTPUT_QUIET ERROR_QUIET RESULT_VARIABLE NUMPYRC) + if(NUMPYRC EQUAL 1) + message(WARNING "Numpy not found. Functions that return numpy arrays will throw exceptions!") + else() + message(STATUS "Found Python with installed numpy package") + execute_process( COMMAND ${PYTHON_EXECUTABLE} -c "import sys; from numpy import get_include; sys.stdout.write(get_include())" OUTPUT_VARIABLE NUMPY_INCLUDE_PATH) + message(STATUS "Numpy include path '${NUMPY_INCLUDE_PATH}'") + include_directories(${NUMPY_INCLUDE_PATH}) + endif() +else() + message(WARNING "Numpy not found. Functions that return numpy arrays will throw exceptions!") + set(NUMPYRC 1) +endif() + +add_definitions(-DDLIB_VERSION=${DLIB_VERSION}) + +# Tell cmake to compile all these cpp files into a dlib python module. +set(python_srcs + src/dlib.cpp + src/matrix.cpp + src/vector.cpp + src/svm_c_trainer.cpp + src/svm_rank_trainer.cpp + src/decision_functions.cpp + src/other.cpp + src/basic.cpp + src/cca.cpp + src/sequence_segmenter.cpp + src/svm_struct.cpp + src/image.cpp + src/rectangles.cpp + src/object_detection.cpp + src/shape_predictor.cpp + src/correlation_tracker.cpp + src/face_recognition.cpp + src/cnn_face_detector.cpp + src/global_optimization.cpp + src/image_dataset_metadata.cpp +) + +# Only add the Numpy returning functions if Numpy is present +if(NUMPYRC EQUAL 1) + list(APPEND python_srcs src/numpy_returns_stub.cpp) +else() + list(APPEND python_srcs src/numpy_returns.cpp) +endif() + +# Only add the GUI module if requested +if(NOT ${DLIB_NO_GUI_SUPPORT}) + list(APPEND python_srcs src/gui.cpp) +endif() + +pybind11_add_module(dlib_python ${python_srcs}) +target_link_libraries(dlib_python PRIVATE dlib::dlib) +# Set the output library name to dlib because that's what setup.py and distutils expects. +set_target_properties(dlib_python PROPERTIES OUTPUT_NAME dlib) + diff --git a/ml/dlib/tools/python/src/basic.cpp b/ml/dlib/tools/python/src/basic.cpp new file mode 100644 index 000000000..d87a53cc3 --- /dev/null +++ b/ml/dlib/tools/python/src/basic.cpp @@ -0,0 +1,272 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include <dlib/python.h> +#include <dlib/matrix.h> +#include <sstream> +#include <string> +#include "opaque_types.h" + +#include <dlib/string.h> +#include <pybind11/stl_bind.h> + +using namespace std; +using namespace dlib; +namespace py = pybind11; + + +std::shared_ptr<std::vector<double> > array_from_object(py::object obj) +{ + try { + long nr = obj.cast<long>(); + return std::make_shared<std::vector<double>>(nr); + } catch (py::cast_error &e) { + py::list li = obj.cast<py::list>(); + const long nr = len(li); + auto temp = std::make_shared<std::vector<double>>(nr); + for ( long r = 0; r < nr; ++r) + { + (*temp)[r] = li[r].cast<double>(); + } + return temp; + } +} + +string array__str__ (const std::vector<double>& v) +{ + std::ostringstream sout; + for (unsigned long i = 0; i < v.size(); ++i) + { + sout << v[i]; + if (i+1 < v.size()) + sout << "\n"; + } + return sout.str(); +} + +string array__repr__ (const std::vector<double>& v) +{ + std::ostringstream sout; + sout << "dlib.array(["; + for (unsigned long i = 0; i < v.size(); ++i) + { + sout << v[i]; + if (i+1 < v.size()) + sout << ", "; + } + sout << "])"; + return sout.str(); +} + +string range__str__ (const std::pair<unsigned long,unsigned long>& p) +{ + std::ostringstream sout; + sout << p.first << ", " << p.second; + return sout.str(); +} + +string range__repr__ (const std::pair<unsigned long,unsigned long>& p) +{ + std::ostringstream sout; + sout << "dlib.range(" << p.first << ", " << p.second << ")"; + return sout.str(); +} + +struct range_iter +{ + std::pair<unsigned long,unsigned long> range; + unsigned long cur; + + unsigned long next() + { + if (cur < range.second) + { + return cur++; + } + else + { + PyErr_SetString(PyExc_StopIteration, "No more data."); + throw py::error_already_set(); + } + } +}; + +range_iter make_range_iterator (const std::pair<unsigned long,unsigned long>& p) +{ + range_iter temp; + temp.range = p; + temp.cur = p.first; + return temp; +} + +string pair__str__ (const std::pair<unsigned long,double>& p) +{ + std::ostringstream sout; + sout << p.first << ": " << p.second; + return sout.str(); +} + +string pair__repr__ (const std::pair<unsigned long,double>& p) +{ + std::ostringstream sout; + sout << "dlib.pair(" << p.first << ", " << p.second << ")"; + return sout.str(); +} + +string sparse_vector__str__ (const std::vector<std::pair<unsigned long,double> >& v) +{ + std::ostringstream sout; + for (unsigned long i = 0; i < v.size(); ++i) + { + sout << v[i].first << ": " << v[i].second; + if (i+1 < v.size()) + sout << "\n"; + } + return sout.str(); +} + +string sparse_vector__repr__ (const std::vector<std::pair<unsigned long,double> >& v) +{ + std::ostringstream sout; + sout << "< dlib.sparse_vector containing: \n" << sparse_vector__str__(v) << " >"; + return sout.str(); +} + +unsigned long range_len(const std::pair<unsigned long, unsigned long>& r) +{ + if (r.second > r.first) + return r.second-r.first; + else + return 0; +} + +template <typename T> +void resize(T& v, unsigned long n) { v.resize(n); } + +void bind_basic_types(py::module& m) +{ + { + typedef double item_type; + typedef std::vector<item_type> type; + typedef std::shared_ptr<type> type_ptr; + py::bind_vector<type, type_ptr >(m, "array", "This object represents a 1D array of floating point numbers. " + "Moreover, it binds directly to the C++ type std::vector<double>.") + .def(py::init(&array_from_object)) + .def("__str__", array__str__) + .def("__repr__", array__repr__) + .def("clear", &type::clear) + .def("resize", resize<type>) + .def("extend", extend_vector_with_python_list<item_type>) + .def(py::pickle(&getstate<type>, &setstate<type>)); + } + + { + typedef matrix<double,0,1> item_type; + typedef std::vector<item_type > type; + py::bind_vector<type>(m, "vectors", "This object is an array of vector objects.") + .def("clear", &type::clear) + .def("resize", resize<type>) + .def("extend", extend_vector_with_python_list<item_type>) + .def(py::pickle(&getstate<type>, &setstate<type>)); + } + + { + typedef std::vector<matrix<double,0,1> > item_type; + typedef std::vector<item_type > type; + py::bind_vector<type>(m, "vectorss", "This object is an array of arrays of vector objects.") + .def("clear", &type::clear) + .def("resize", resize<type>) + .def("extend", extend_vector_with_python_list<item_type>) + .def(py::pickle(&getstate<type>, &setstate<type>)); + } + + typedef pair<unsigned long,unsigned long> range_type; + py::class_<range_type>(m, "range", "This object is used to represent a range of elements in an array.") + .def(py::init<unsigned long,unsigned long>()) + .def_readwrite("begin",&range_type::first, "The index of the first element in the range. This is represented using an unsigned integer.") + .def_readwrite("end",&range_type::second, "One past the index of the last element in the range. This is represented using an unsigned integer.") + .def("__str__", range__str__) + .def("__repr__", range__repr__) + .def("__iter__", &make_range_iterator) + .def("__len__", &range_len) + .def(py::pickle(&getstate<range_type>, &setstate<range_type>)); + + py::class_<range_iter>(m, "_range_iter") + .def("next", &range_iter::next) + .def("__next__", &range_iter::next); + + { + typedef std::pair<unsigned long, unsigned long> item_type; + typedef std::vector<item_type > type; + py::bind_vector<type>(m, "ranges", "This object is an array of range objects.") + .def("clear", &type::clear) + .def("resize", resize<type>) + .def("extend", extend_vector_with_python_list<item_type>) + .def(py::pickle(&getstate<type>, &setstate<type>)); + } + + { + typedef std::vector<std::pair<unsigned long, unsigned long> > item_type; + typedef std::vector<item_type > type; + py::bind_vector<type>(m, "rangess", "This object is an array of arrays of range objects.") + .def("clear", &type::clear) + .def("resize", resize<type>) + .def("extend", extend_vector_with_python_list<item_type>) + .def(py::pickle(&getstate<type>, &setstate<type>)); + } + + + typedef pair<unsigned long,double> pair_type; + py::class_<pair_type>(m, "pair", "This object is used to represent the elements of a sparse_vector.") + .def(py::init<unsigned long,double>()) + .def_readwrite("first",&pair_type::first, "This field represents the index/dimension number.") + .def_readwrite("second",&pair_type::second, "This field contains the value in a vector at dimension specified by the first field.") + .def("__str__", pair__str__) + .def("__repr__", pair__repr__) + .def(py::pickle(&getstate<pair_type>, &setstate<pair_type>)); + + { + typedef std::vector<pair_type> type; + py::bind_vector<type>(m, "sparse_vector", +"This object represents the mathematical idea of a sparse column vector. It is \n\ +simply an array of dlib.pair objects, each representing an index/value pair in \n\ +the vector. Any elements of the vector which are missing are implicitly set to \n\ +zero. \n\ + \n\ +Unless otherwise noted, any routines taking a sparse_vector assume the sparse \n\ +vector is sorted and has unique elements. That is, the index values of the \n\ +pairs in a sparse_vector should be listed in increasing order and there should \n\ +not be duplicates. However, some functions work with \"unsorted\" sparse \n\ +vectors. These are dlib.sparse_vector objects that have either duplicate \n\ +entries or non-sorted index values. Note further that you can convert an \n\ +\"unsorted\" sparse_vector into a properly sorted sparse vector by calling \n\ +dlib.make_sparse_vector() on it. " + ) + .def("__str__", sparse_vector__str__) + .def("__repr__", sparse_vector__repr__) + .def("clear", &type::clear) + .def("resize", resize<type>) + .def("extend", extend_vector_with_python_list<pair_type>) + .def(py::pickle(&getstate<type>, &setstate<type>)); + } + + { + typedef std::vector<pair_type> item_type; + typedef std::vector<item_type > type; + py::bind_vector<type>(m, "sparse_vectors", "This object is an array of sparse_vector objects.") + .def("clear", &type::clear) + .def("resize", resize<type>) + .def("extend", extend_vector_with_python_list<item_type>) + .def(py::pickle(&getstate<type>, &setstate<type>)); + } + + { + typedef std::vector<std::vector<pair_type> > item_type; + typedef std::vector<item_type > type; + py::bind_vector<type>(m, "sparse_vectorss", "This object is an array of arrays of sparse_vector objects.") + .def("clear", &type::clear) + .def("resize", resize<type>) + .def("extend", extend_vector_with_python_list<item_type>) + .def(py::pickle(&getstate<type>, &setstate<type>)); + } +} + diff --git a/ml/dlib/tools/python/src/cca.cpp b/ml/dlib/tools/python/src/cca.cpp new file mode 100644 index 000000000..dcf476522 --- /dev/null +++ b/ml/dlib/tools/python/src/cca.cpp @@ -0,0 +1,137 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/statistics.h> + +using namespace dlib; +namespace py = pybind11; + +typedef std::vector<std::pair<unsigned long,double> > sparse_vect; + +struct cca_outputs +{ + matrix<double,0,1> correlations; + matrix<double> Ltrans; + matrix<double> Rtrans; +}; + +cca_outputs _cca1 ( + const std::vector<sparse_vect>& L, + const std::vector<sparse_vect>& R, + unsigned long num_correlations, + unsigned long extra_rank, + unsigned long q, + double regularization +) +{ + pyassert(num_correlations > 0 && L.size() > 0 && R.size() > 0 && L.size() == R.size() && regularization >= 0, + "Invalid inputs"); + + cca_outputs temp; + temp.correlations = cca(L,R,temp.Ltrans,temp.Rtrans,num_correlations,extra_rank,q,regularization); + return temp; +} + +// ---------------------------------------------------------------------------------------- + +unsigned long sparse_vector_max_index_plus_one ( + const sparse_vect& v +) +{ + return max_index_plus_one(v); +} + +matrix<double,0,1> apply_cca_transform ( + const matrix<double>& m, + const sparse_vect& v +) +{ + pyassert((long)max_index_plus_one(v) <= m.nr(), "Invalid Inputs"); + return sparse_matrix_vector_multiply(trans(m), v); +} + +void bind_cca(py::module& m) +{ + py::class_<cca_outputs>(m, "cca_outputs") + .def_readwrite("correlations", &cca_outputs::correlations) + .def_readwrite("Ltrans", &cca_outputs::Ltrans) + .def_readwrite("Rtrans", &cca_outputs::Rtrans); + + m.def("max_index_plus_one", sparse_vector_max_index_plus_one, py::arg("v"), +"ensures \n\ + - returns the dimensionality of the given sparse vector. That is, returns a \n\ + number one larger than the maximum index value in the vector. If the vector \n\ + is empty then returns 0. " + ); + + + m.def("apply_cca_transform", apply_cca_transform, py::arg("m"), py::arg("v"), +"requires \n\ + - max_index_plus_one(v) <= m.nr() \n\ +ensures \n\ + - returns trans(m)*v \n\ + (i.e. multiply m by the vector v and return the result) " + ); + + + m.def("cca", _cca1, py::arg("L"), py::arg("R"), py::arg("num_correlations"), py::arg("extra_rank")=5, py::arg("q")=2, py::arg("regularization")=0, +"requires \n\ + - num_correlations > 0 \n\ + - len(L) > 0 \n\ + - len(R) > 0 \n\ + - len(L) == len(R) \n\ + - regularization >= 0 \n\ + - L and R must be properly sorted sparse vectors. This means they must list their \n\ + elements in ascending index order and not contain duplicate index values. You can use \n\ + make_sparse_vector() to ensure this is true. \n\ +ensures \n\ + - This function performs a canonical correlation analysis between the vectors \n\ + in L and R. That is, it finds two transformation matrices, Ltrans and \n\ + Rtrans, such that row vectors in the transformed matrices L*Ltrans and \n\ + R*Rtrans are as correlated as possible (note that in this notation we \n\ + interpret L as a matrix with the input vectors in its rows). Note also that \n\ + this function tries to find transformations which produce num_correlations \n\ + dimensional output vectors. \n\ + - Note that you can easily apply the transformation to a vector using \n\ + apply_cca_transform(). So for example, like this: \n\ + - apply_cca_transform(Ltrans, some_sparse_vector) \n\ + - returns a structure containing the Ltrans and Rtrans transformation matrices \n\ + as well as the estimated correlations between elements of the transformed \n\ + vectors. \n\ + - This function assumes the data vectors in L and R have already been centered \n\ + (i.e. we assume the vectors have zero means). However, in many cases it is \n\ + fine to use uncentered data with cca(). But if it is important for your \n\ + problem then you should center your data before passing it to cca(). \n\ + - This function works with reduced rank approximations of the L and R matrices. \n\ + This makes it fast when working with large matrices. In particular, we use \n\ + the dlib::svd_fast() routine to find reduced rank representations of the input \n\ + matrices by calling it as follows: svd_fast(L, U,D,V, num_correlations+extra_rank, q) \n\ + and similarly for R. This means that you can use the extra_rank and q \n\ + arguments to cca() to influence the accuracy of the reduced rank \n\ + approximation. However, the default values should work fine for most \n\ + problems. \n\ + - The dimensions of the output vectors produced by L*#Ltrans or R*#Rtrans are \n\ + ordered such that the dimensions with the highest correlations come first. \n\ + That is, after applying the transforms produced by cca() to a set of vectors \n\ + you will find that dimension 0 has the highest correlation, then dimension 1 \n\ + has the next highest, and so on. This also means that the list of estimated \n\ + correlations returned from cca() will always be listed in decreasing order. \n\ + - This function performs the ridge regression version of Canonical Correlation \n\ + Analysis when regularization is set to a value > 0. In particular, larger \n\ + values indicate the solution should be more heavily regularized. This can be \n\ + useful when the dimensionality of the data is larger than the number of \n\ + samples. \n\ + - A good discussion of CCA can be found in the paper \"Canonical Correlation \n\ + Analysis\" by David Weenink. In particular, this function is implemented \n\ + using equations 29 and 30 from his paper. We also use the idea of doing CCA \n\ + on a reduced rank approximation of L and R as suggested by Paramveer S. \n\ + Dhillon in his paper \"Two Step CCA: A new spectral method for estimating \n\ + vector models of words\". " + + ); +} + + + diff --git a/ml/dlib/tools/python/src/cnn_face_detector.cpp b/ml/dlib/tools/python/src/cnn_face_detector.cpp new file mode 100644 index 000000000..f18d99d95 --- /dev/null +++ b/ml/dlib/tools/python/src/cnn_face_detector.cpp @@ -0,0 +1,183 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/matrix.h> +#include <dlib/dnn.h> +#include <dlib/image_transforms.h> +#include "indexing.h" +#include <pybind11/stl_bind.h> + +using namespace dlib; +using namespace std; + +namespace py = pybind11; + + +class cnn_face_detection_model_v1 +{ + +public: + + cnn_face_detection_model_v1(const std::string& model_filename) + { + deserialize(model_filename) >> net; + } + + std::vector<mmod_rect> detect ( + py::object pyimage, + const int upsample_num_times + ) + { + pyramid_down<2> pyr; + std::vector<mmod_rect> rects; + + // Copy the data into dlib based objects + matrix<rgb_pixel> image; + if (is_gray_python_image(pyimage)) + assign_image(image, numpy_gray_image(pyimage)); + else if (is_rgb_python_image(pyimage)) + assign_image(image, numpy_rgb_image(pyimage)); + else + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); + + // Upsampling the image will allow us to detect smaller faces but will cause the + // program to use more RAM and run longer. + unsigned int levels = upsample_num_times; + while (levels > 0) + { + levels--; + pyramid_up(image, pyr); + } + + auto dets = net(image); + + // Scale the detection locations back to the original image size + // if the image was upscaled. + for (auto&& d : dets) { + d.rect = pyr.rect_down(d.rect, upsample_num_times); + rects.push_back(d); + } + + return rects; + } + + std::vector<std::vector<mmod_rect> > detect_mult ( + py::list imgs, + const int upsample_num_times, + const int batch_size = 128 + ) + { + pyramid_down<2> pyr; + std::vector<matrix<rgb_pixel> > dimgs; + dimgs.reserve(len(imgs)); + + for(int i = 0; i < len(imgs); i++) + { + // Copy the data into dlib based objects + matrix<rgb_pixel> image; + py::object tmp = imgs[i].cast<py::object>(); + if (is_gray_python_image(tmp)) + assign_image(image, numpy_gray_image(tmp)); + else if (is_rgb_python_image(tmp)) + assign_image(image, numpy_rgb_image(tmp)); + else + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); + + for(int i = 0; i < upsample_num_times; i++) + { + pyramid_up(image); + } + dimgs.push_back(image); + } + + for(int i = 1; i < dimgs.size(); i++) + { + if + ( + dimgs[i - 1].nc() != dimgs[i].nc() || + dimgs[i - 1].nr() != dimgs[i].nr() + ) + throw dlib::error("Images in list must all have the same dimensions."); + + } + + auto dets = net(dimgs, batch_size); + std::vector<std::vector<mmod_rect> > all_rects; + + for(auto&& im_dets : dets) + { + std::vector<mmod_rect> rects; + rects.reserve(im_dets.size()); + for (auto&& d : im_dets) { + d.rect = pyr.rect_down(d.rect, upsample_num_times); + rects.push_back(d); + } + all_rects.push_back(rects); + } + + return all_rects; + } + +private: + + template <long num_filters, typename SUBNET> using con5d = con<num_filters,5,5,2,2,SUBNET>; + template <long num_filters, typename SUBNET> using con5 = con<num_filters,5,5,1,1,SUBNET>; + + template <typename SUBNET> using downsampler = relu<affine<con5d<32, relu<affine<con5d<32, relu<affine<con5d<16,SUBNET>>>>>>>>>; + template <typename SUBNET> using rcon5 = relu<affine<con5<45,SUBNET>>>; + + using net_type = loss_mmod<con<1,9,9,1,1,rcon5<rcon5<rcon5<downsampler<input_rgb_image_pyramid<pyramid_down<6>>>>>>>>; + + net_type net; +}; + +// ---------------------------------------------------------------------------------------- + +void bind_cnn_face_detection(py::module& m) +{ + { + py::class_<cnn_face_detection_model_v1>(m, "cnn_face_detection_model_v1", "This object detects human faces in an image. The constructor loads the face detection model from a file. You can download a pre-trained model from http://dlib.net/files/mmod_human_face_detector.dat.bz2.") + .def(py::init<std::string>()) + .def( + "__call__", + &cnn_face_detection_model_v1::detect_mult, + py::arg("imgs"), py::arg("upsample_num_times")=0, py::arg("batch_size")=128, + "takes a list of images as input returning a 2d list of mmod rectangles" + ) + .def( + "__call__", + &cnn_face_detection_model_v1::detect, + py::arg("img"), py::arg("upsample_num_times")=0, + "Find faces in an image using a deep learning model.\n\ + - Upsamples the image upsample_num_times before running the face \n\ + detector." + ); + } + + m.def("set_dnn_prefer_smallest_algorithms", &set_dnn_prefer_smallest_algorithms, "Tells cuDNN to use slower algorithms that use less RAM."); + + auto cuda = m.def_submodule("cuda", "Routines for setting CUDA specific properties."); + cuda.def("set_device", &dlib::cuda::set_device, py::arg("device_id"), + "Set the active CUDA device. It is required that 0 <= device_id < get_num_devices()."); + cuda.def("get_device", &dlib::cuda::get_device, "Get the active CUDA device."); + cuda.def("get_num_devices", &dlib::cuda::get_num_devices, "Find out how many CUDA devices are available."); + + { + typedef mmod_rect type; + py::class_<type>(m, "mmod_rectangle", "Wrapper around a rectangle object and a detection confidence score.") + .def_readwrite("rect", &type::rect) + .def_readwrite("confidence", &type::detection_confidence); + } + { + typedef std::vector<mmod_rect> type; + py::bind_vector<type>(m, "mmod_rectangles", "An array of mmod rectangle objects.") + .def("extend", extend_vector_with_python_list<mmod_rect>); + } + { + typedef std::vector<std::vector<mmod_rect> > type; + py::bind_vector<type>(m, "mmod_rectangless", "A 2D array of mmod rectangle objects.") + .def("extend", extend_vector_with_python_list<std::vector<mmod_rect>>); + } +} diff --git a/ml/dlib/tools/python/src/conversion.h b/ml/dlib/tools/python/src/conversion.h new file mode 100644 index 000000000..9ab2360a0 --- /dev/null +++ b/ml/dlib/tools/python/src/conversion.h @@ -0,0 +1,52 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PYTHON_CONVERSION_H__ +#define DLIB_PYTHON_CONVERSION_H__ + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/pixel.h> + +using namespace dlib; +using namespace std; + +namespace py = pybind11; + +template <typename dest_image_type> +void pyimage_to_dlib_image(py::object img, dest_image_type& image) +{ + if (is_gray_python_image(img)) + assign_image(image, numpy_gray_image(img)); + else if (is_rgb_python_image(img)) + assign_image(image, numpy_rgb_image(img)); + else + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); +} + +template <typename image_array, typename param_type> +void images_and_nested_params_to_dlib( + const py::object& pyimages, + const py::object& pyparams, + image_array& images, + std::vector<std::vector<param_type> >& params +) +{ + // Now copy the data into dlib based objects. + py::iterator image_it = pyimages.begin(); + py::iterator params_it = pyparams.begin(); + + for (unsigned long image_idx = 0; + image_it != pyimages.end() + && params_it != pyparams.end(); + ++image_it, ++params_it, ++image_idx) + { + for (py::iterator param_it = params_it->begin(); + param_it != params_it->end(); + ++param_it) + params[image_idx].push_back(param_it->cast<param_type>()); + + pyimage_to_dlib_image(image_it->cast<py::object>(), images[image_idx]); + } +} + +#endif // DLIB_PYTHON_CONVERSION_H__ diff --git a/ml/dlib/tools/python/src/correlation_tracker.cpp b/ml/dlib/tools/python/src/correlation_tracker.cpp new file mode 100644 index 000000000..1b17ba54c --- /dev/null +++ b/ml/dlib/tools/python/src/correlation_tracker.cpp @@ -0,0 +1,167 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/geometry.h> +#include <dlib/image_processing.h> + +using namespace dlib; +using namespace std; + +namespace py = pybind11; + +// ---------------------------------------------------------------------------------------- + +void start_track ( + correlation_tracker& tracker, + py::object img, + const drectangle& bounding_box +) +{ + if (is_gray_python_image(img)) + { + tracker.start_track(numpy_gray_image(img), bounding_box); + } + else if (is_rgb_python_image(img)) + { + tracker.start_track(numpy_rgb_image(img), bounding_box); + } + else + { + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); + } +} + +void start_track_rec ( + correlation_tracker& tracker, + py::object img, + const rectangle& bounding_box +) +{ + drectangle dbounding_box(bounding_box); + start_track(tracker, img, dbounding_box); +} + +double update ( + correlation_tracker& tracker, + py::object img +) +{ + if (is_gray_python_image(img)) + { + return tracker.update(numpy_gray_image(img)); + } + else if (is_rgb_python_image(img)) + { + return tracker.update(numpy_rgb_image(img)); + } + else + { + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); + } +} + +double update_guess ( + correlation_tracker& tracker, + py::object img, + const drectangle& bounding_box +) +{ + if (is_gray_python_image(img)) + { + return tracker.update(numpy_gray_image(img), bounding_box); + } + else if (is_rgb_python_image(img)) + { + return tracker.update(numpy_rgb_image(img), bounding_box); + } + else + { + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); + } +} + +double update_guess_rec ( + correlation_tracker& tracker, + py::object img, + const rectangle& bounding_box +) +{ + drectangle dbounding_box(bounding_box); + return update_guess(tracker, img, dbounding_box); +} + +drectangle get_position (const correlation_tracker& tracker) { return tracker.get_position(); } + +// ---------------------------------------------------------------------------------------- + +void bind_correlation_tracker(py::module &m) +{ + { + typedef correlation_tracker type; + py::class_<type>(m, "correlation_tracker", "This is a tool for tracking moving objects in a video stream. You give it \n\ + the bounding box of an object in the first frame and it attempts to track the \n\ + object in the box from frame to frame. \n\ + This tool is an implementation of the method described in the following paper: \n\ + Danelljan, Martin, et al. 'Accurate scale estimation for robust visual \n\ + tracking.' Proceedings of the British Machine Vision Conference BMVC. 2014.") + .def(py::init()) + .def("start_track", &::start_track, py::arg("image"), py::arg("bounding_box"), "\ + requires \n\ + - image is a numpy ndarray containing either an 8bit grayscale or RGB image. \n\ + - bounding_box.is_empty() == false \n\ + ensures \n\ + - This object will start tracking the thing inside the bounding box in the \n\ + given image. That is, if you call update() with subsequent video frames \n\ + then it will try to keep track of the position of the object inside bounding_box. \n\ + - #get_position() == bounding_box") + .def("start_track", &::start_track_rec, py::arg("image"), py::arg("bounding_box"), "\ + requires \n\ + - image is a numpy ndarray containing either an 8bit grayscale or RGB image. \n\ + - bounding_box.is_empty() == false \n\ + ensures \n\ + - This object will start tracking the thing inside the bounding box in the \n\ + given image. That is, if you call update() with subsequent video frames \n\ + then it will try to keep track of the position of the object inside bounding_box. \n\ + - #get_position() == bounding_box") + .def("update", &::update, py::arg("image"), "\ + requires \n\ + - image is a numpy ndarray containing either an 8bit grayscale or RGB image. \n\ + - get_position().is_empty() == false \n\ + (i.e. you must have started tracking by calling start_track()) \n\ + ensures \n\ + - performs: return update(img, get_position())") + .def("update", &::update_guess, py::arg("image"), py::arg("guess"), "\ + requires \n\ + - image is a numpy ndarray containing either an 8bit grayscale or RGB image. \n\ + - get_position().is_empty() == false \n\ + (i.e. you must have started tracking by calling start_track()) \n\ + ensures \n\ + - When searching for the object in img, we search in the area around the \n\ + provided guess. \n\ + - #get_position() == the new predicted location of the object in img. This \n\ + location will be a copy of guess that has been translated and scaled \n\ + appropriately based on the content of img so that it, hopefully, bounds \n\ + the object in img. \n\ + - Returns the peak to side-lobe ratio. This is a number that measures how \n\ + confident the tracker is that the object is inside #get_position(). \n\ + Larger values indicate higher confidence.") + .def("update", &::update_guess_rec, py::arg("image"), py::arg("guess"), "\ + requires \n\ + - image is a numpy ndarray containing either an 8bit grayscale or RGB image. \n\ + - get_position().is_empty() == false \n\ + (i.e. you must have started tracking by calling start_track()) \n\ + ensures \n\ + - When searching for the object in img, we search in the area around the \n\ + provided guess. \n\ + - #get_position() == the new predicted location of the object in img. This \n\ + location will be a copy of guess that has been translated and scaled \n\ + appropriately based on the content of img so that it, hopefully, bounds \n\ + the object in img. \n\ + - Returns the peak to side-lobe ratio. This is a number that measures how \n\ + confident the tracker is that the object is inside #get_position(). \n\ + Larger values indicate higher confidence.") + .def("get_position", &::get_position, "returns the predicted position of the object under track."); + } +} diff --git a/ml/dlib/tools/python/src/decision_functions.cpp b/ml/dlib/tools/python/src/decision_functions.cpp new file mode 100644 index 000000000..a93fe49b9 --- /dev/null +++ b/ml/dlib/tools/python/src/decision_functions.cpp @@ -0,0 +1,263 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <dlib/python.h> +#include "testing_results.h" +#include <dlib/svm.h> + +using namespace dlib; +using namespace std; + +namespace py = pybind11; + +typedef matrix<double,0,1> sample_type; +typedef std::vector<std::pair<unsigned long,double> > sparse_vect; + +template <typename decision_function> +double predict ( + const decision_function& df, + const typename decision_function::kernel_type::sample_type& samp +) +{ + typedef typename decision_function::kernel_type::sample_type T; + if (df.basis_vectors.size() == 0) + { + return 0; + } + else if (is_matrix<T>::value && df.basis_vectors(0).size() != samp.size()) + { + std::ostringstream sout; + sout << "Input vector should have " << df.basis_vectors(0).size() + << " dimensions, not " << samp.size() << "."; + PyErr_SetString( PyExc_ValueError, sout.str().c_str() ); + throw py::error_already_set(); + } + return df(samp); +} + +template <typename kernel_type> +void add_df ( + py::module& m, + const std::string name +) +{ + typedef decision_function<kernel_type> df_type; + py::class_<df_type>(m, name.c_str()) + .def("__call__", &predict<df_type>) + .def(py::pickle(&getstate<df_type>, &setstate<df_type>)); +} + +template <typename df_type> +typename df_type::sample_type get_weights( + const df_type& df +) +{ + if (df.basis_vectors.size() == 0) + { + PyErr_SetString( PyExc_ValueError, "Decision function is empty." ); + throw py::error_already_set(); + } + df_type temp = simplify_linear_decision_function(df); + return temp.basis_vectors(0); +} + +template <typename df_type> +typename df_type::scalar_type get_bias( + const df_type& df +) +{ + if (df.basis_vectors.size() == 0) + { + PyErr_SetString( PyExc_ValueError, "Decision function is empty." ); + throw py::error_already_set(); + } + return df.b; +} + +template <typename df_type> +void set_bias( + df_type& df, + double b +) +{ + if (df.basis_vectors.size() == 0) + { + PyErr_SetString( PyExc_ValueError, "Decision function is empty." ); + throw py::error_already_set(); + } + df.b = b; +} + +template <typename kernel_type> +void add_linear_df ( + py::module &m, + const std::string name +) +{ + typedef decision_function<kernel_type> df_type; + py::class_<df_type>(m, name.c_str()) + .def("__call__", predict<df_type>) + .def_property_readonly("weights", &get_weights<df_type>) + .def_property("bias", get_bias<df_type>, set_bias<df_type>) + .def(py::pickle(&getstate<df_type>, &setstate<df_type>)); +} + +// ---------------------------------------------------------------------------------------- + +std::string binary_test__str__(const binary_test& item) +{ + std::ostringstream sout; + sout << "class1_accuracy: "<< item.class1_accuracy << " class2_accuracy: "<< item.class2_accuracy; + return sout.str(); +} +std::string binary_test__repr__(const binary_test& item) { return "< " + binary_test__str__(item) + " >";} + +std::string regression_test__str__(const regression_test& item) +{ + std::ostringstream sout; + sout << "mean_squared_error: "<< item.mean_squared_error << " R_squared: "<< item.R_squared; + sout << " mean_average_error: "<< item.mean_average_error << " mean_error_stddev: "<< item.mean_error_stddev; + return sout.str(); +} +std::string regression_test__repr__(const regression_test& item) { return "< " + regression_test__str__(item) + " >";} + +std::string ranking_test__str__(const ranking_test& item) +{ + std::ostringstream sout; + sout << "ranking_accuracy: "<< item.ranking_accuracy << " mean_ap: "<< item.mean_ap; + return sout.str(); +} +std::string ranking_test__repr__(const ranking_test& item) { return "< " + ranking_test__str__(item) + " >";} + +// ---------------------------------------------------------------------------------------- + +template <typename K> +binary_test _test_binary_decision_function ( + const decision_function<K>& dec_funct, + const std::vector<typename K::sample_type>& x_test, + const std::vector<double>& y_test +) { return binary_test(test_binary_decision_function(dec_funct, x_test, y_test)); } + +template <typename K> +regression_test _test_regression_function ( + const decision_function<K>& reg_funct, + const std::vector<typename K::sample_type>& x_test, + const std::vector<double>& y_test +) { return regression_test(test_regression_function(reg_funct, x_test, y_test)); } + +template < typename K > +ranking_test _test_ranking_function1 ( + const decision_function<K>& funct, + const std::vector<ranking_pair<typename K::sample_type> >& samples +) { return ranking_test(test_ranking_function(funct, samples)); } + +template < typename K > +ranking_test _test_ranking_function2 ( + const decision_function<K>& funct, + const ranking_pair<typename K::sample_type>& sample +) { return ranking_test(test_ranking_function(funct, sample)); } + + +void bind_decision_functions(py::module &m) +{ + add_linear_df<linear_kernel<sample_type> >(m, "_decision_function_linear"); + add_linear_df<sparse_linear_kernel<sparse_vect> >(m, "_decision_function_sparse_linear"); + + add_df<histogram_intersection_kernel<sample_type> >(m, "_decision_function_histogram_intersection"); + add_df<sparse_histogram_intersection_kernel<sparse_vect> >(m, "_decision_function_sparse_histogram_intersection"); + + add_df<polynomial_kernel<sample_type> >(m, "_decision_function_polynomial"); + add_df<sparse_polynomial_kernel<sparse_vect> >(m, "_decision_function_sparse_polynomial"); + + add_df<radial_basis_kernel<sample_type> >(m, "_decision_function_radial_basis"); + add_df<sparse_radial_basis_kernel<sparse_vect> >(m, "_decision_function_sparse_radial_basis"); + + add_df<sigmoid_kernel<sample_type> >(m, "_decision_function_sigmoid"); + add_df<sparse_sigmoid_kernel<sparse_vect> >(m, "_decision_function_sparse_sigmoid"); + + + m.def("test_binary_decision_function", _test_binary_decision_function<linear_kernel<sample_type> >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function<sparse_linear_kernel<sparse_vect> >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function<radial_basis_kernel<sample_type> >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function<sparse_radial_basis_kernel<sparse_vect> >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function<polynomial_kernel<sample_type> >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function<sparse_polynomial_kernel<sparse_vect> >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function<histogram_intersection_kernel<sample_type> >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function<sparse_histogram_intersection_kernel<sparse_vect> >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function<sigmoid_kernel<sample_type> >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + m.def("test_binary_decision_function", _test_binary_decision_function<sparse_sigmoid_kernel<sparse_vect> >, + py::arg("function"), py::arg("samples"), py::arg("labels")); + + m.def("test_regression_function", _test_regression_function<linear_kernel<sample_type> >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function<sparse_linear_kernel<sparse_vect> >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function<radial_basis_kernel<sample_type> >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function<sparse_radial_basis_kernel<sparse_vect> >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function<histogram_intersection_kernel<sample_type> >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function<sparse_histogram_intersection_kernel<sparse_vect> >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function<sigmoid_kernel<sample_type> >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function<sparse_sigmoid_kernel<sparse_vect> >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function<polynomial_kernel<sample_type> >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + m.def("test_regression_function", _test_regression_function<sparse_polynomial_kernel<sparse_vect> >, + py::arg("function"), py::arg("samples"), py::arg("targets")); + + m.def("test_ranking_function", _test_ranking_function1<linear_kernel<sample_type> >, + py::arg("function"), py::arg("samples")); + m.def("test_ranking_function", _test_ranking_function1<sparse_linear_kernel<sparse_vect> >, + py::arg("function"), py::arg("samples")); + m.def("test_ranking_function", _test_ranking_function2<linear_kernel<sample_type> >, + py::arg("function"), py::arg("sample")); + m.def("test_ranking_function", _test_ranking_function2<sparse_linear_kernel<sparse_vect> >, + py::arg("function"), py::arg("sample")); + + + py::class_<binary_test>(m, "_binary_test") + .def("__str__", binary_test__str__) + .def("__repr__", binary_test__repr__) + .def_readwrite("class1_accuracy", &binary_test::class1_accuracy, + "A value between 0 and 1, measures accuracy on the +1 class.") + .def_readwrite("class2_accuracy", &binary_test::class2_accuracy, + "A value between 0 and 1, measures accuracy on the -1 class."); + + py::class_<ranking_test>(m, "_ranking_test") + .def("__str__", ranking_test__str__) + .def("__repr__", ranking_test__repr__) + .def_readwrite("ranking_accuracy", &ranking_test::ranking_accuracy, + "A value between 0 and 1, measures the fraction of times a relevant sample was ordered before a non-relevant sample.") + .def_readwrite("mean_ap", &ranking_test::mean_ap, + "A value between 0 and 1, measures the mean average precision of the ranking."); + + py::class_<regression_test>(m, "_regression_test") + .def("__str__", regression_test__str__) + .def("__repr__", regression_test__repr__) + .def_readwrite("mean_average_error", ®ression_test::mean_average_error, + "The mean average error of a regression function on a dataset.") + .def_readwrite("mean_error_stddev", ®ression_test::mean_error_stddev, + "The standard deviation of the absolute value of the error of a regression function on a dataset.") + .def_readwrite("mean_squared_error", ®ression_test::mean_squared_error, + "The mean squared error of a regression function on a dataset.") + .def_readwrite("R_squared", ®ression_test::R_squared, + "A value between 0 and 1, measures the squared correlation between the output of a \n" + "regression function and the target values."); +} + + + diff --git a/ml/dlib/tools/python/src/dlib.cpp b/ml/dlib/tools/python/src/dlib.cpp new file mode 100644 index 000000000..ac6fea0db --- /dev/null +++ b/ml/dlib/tools/python/src/dlib.cpp @@ -0,0 +1,110 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <pybind11/pybind11.h> +#include <dlib/simd.h> +#include <string> + +namespace py = pybind11; + +void bind_matrix(py::module& m); +void bind_vector(py::module& m); +void bind_svm_c_trainer(py::module& m); +void bind_decision_functions(py::module& m); +void bind_basic_types(py::module& m); +void bind_other(py::module& m); +void bind_svm_rank_trainer(py::module& m); +void bind_cca(py::module& m); +void bind_sequence_segmenter(py::module& m); +void bind_svm_struct(py::module& m); +void bind_image_classes(py::module& m); +void bind_rectangles(py::module& m); +void bind_object_detection(py::module& m); +void bind_shape_predictors(py::module& m); +void bind_correlation_tracker(py::module& m); +void bind_face_recognition(py::module& m); +void bind_cnn_face_detection(py::module& m); +void bind_global_optimization(py::module& m); +void bind_numpy_returns(py::module& m); +void bind_image_dataset_metadata(py::module& m); + +#ifndef DLIB_NO_GUI_SUPPORT +void bind_gui(py::module& m); +#endif + +PYBIND11_MODULE(dlib, m) +{ + warn_about_unavailable_but_used_cpu_instructions(); + + +#define DLIB_QUOTE_STRING(x) DLIB_QUOTE_STRING2(x) +#define DLIB_QUOTE_STRING2(x) #x + m.attr("__version__") = DLIB_QUOTE_STRING(DLIB_VERSION); + m.attr("__time_compiled__") = std::string(__DATE__) + " " + std::string(__TIME__); + +#ifdef DLIB_USE_CUDA + m.attr("DLIB_USE_CUDA") = true; +#else + m.attr("DLIB_USE_CUDA") = false; +#endif +#ifdef DLIB_USE_BLAS + m.attr("DLIB_USE_BLAS") = true; +#else + m.attr("DLIB_USE_BLAS") = false; +#endif +#ifdef DLIB_USE_LAPACK + m.attr("DLIB_USE_LAPACK") = true; +#else + m.attr("DLIB_USE_LAPACK") = false; +#endif +#ifdef DLIB_HAVE_AVX + m.attr("USE_AVX_INSTRUCTIONS") = true; +#else + m.attr("USE_AVX_INSTRUCTIONS") = false; +#endif +#ifdef DLIB_HAVE_NEON + m.attr("USE_NEON_INSTRUCTIONS") = true; +#else + m.attr("USE_NEON_INSTRUCTIONS") = false; +#endif + + + + // Note that the order here matters. We need to do the basic types first. If we don't + // then what happens is the documentation created by sphinx will use horrible big + // template names to refer to C++ objects rather than the python names python users + // will expect. For instance, if bind_basic_types() isn't called early then when + // routines take a std::vector<double>, rather than saying dlib.array in the python + // docs it will say "std::vector<double, std::allocator<double> >" which is awful and + // confusing to python users. + // + // So when adding new things always add them to the end of the list. + bind_matrix(m); + bind_vector(m); + bind_basic_types(m); + bind_other(m); + + bind_svm_rank_trainer(m); + bind_decision_functions(m); + bind_cca(m); + bind_sequence_segmenter(m); + bind_svm_struct(m); + bind_image_classes(m); + bind_rectangles(m); + bind_object_detection(m); + bind_shape_predictors(m); + bind_correlation_tracker(m); + bind_face_recognition(m); + bind_cnn_face_detection(m); + bind_global_optimization(m); + bind_numpy_returns(m); + bind_svm_c_trainer(m); +#ifndef DLIB_NO_GUI_SUPPORT + bind_gui(m); +#endif + + bind_image_dataset_metadata(m); + + +} diff --git a/ml/dlib/tools/python/src/face_recognition.cpp b/ml/dlib/tools/python/src/face_recognition.cpp new file mode 100644 index 000000000..8d5dee678 --- /dev/null +++ b/ml/dlib/tools/python/src/face_recognition.cpp @@ -0,0 +1,245 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/matrix.h> +#include <dlib/geometry/vector.h> +#include <dlib/dnn.h> +#include <dlib/image_transforms.h> +#include "indexing.h" +#include <dlib/image_io.h> +#include <dlib/clustering.h> +#include <pybind11/stl_bind.h> + + +using namespace dlib; +using namespace std; + +namespace py = pybind11; + + +typedef matrix<double,0,1> cv; + +class face_recognition_model_v1 +{ + +public: + + face_recognition_model_v1(const std::string& model_filename) + { + deserialize(model_filename) >> net; + } + + matrix<double,0,1> compute_face_descriptor ( + py::object img, + const full_object_detection& face, + const int num_jitters + ) + { + std::vector<full_object_detection> faces(1, face); + return compute_face_descriptors(img, faces, num_jitters)[0]; + } + + std::vector<matrix<double,0,1>> compute_face_descriptors ( + py::object img, + const std::vector<full_object_detection>& faces, + const int num_jitters + ) + { + if (!is_rgb_python_image(img)) + throw dlib::error("Unsupported image type, must be RGB image."); + + for (auto& f : faces) + { + if (f.num_parts() != 68 && f.num_parts() != 5) + throw dlib::error("The full_object_detection must use the iBUG 300W 68 point face landmark style or dlib's 5 point style."); + } + + + std::vector<chip_details> dets; + for (auto& f : faces) + dets.push_back(get_face_chip_details(f, 150, 0.25)); + dlib::array<matrix<rgb_pixel>> face_chips; + extract_image_chips(numpy_rgb_image(img), dets, face_chips); + + std::vector<matrix<double,0,1>> face_descriptors; + face_descriptors.reserve(face_chips.size()); + + if (num_jitters <= 1) + { + // extract descriptors and convert from float vectors to double vectors + for (auto& d : net(face_chips,16)) + face_descriptors.push_back(matrix_cast<double>(d)); + } + else + { + for (auto& fimg : face_chips) + face_descriptors.push_back(matrix_cast<double>(mean(mat(net(jitter_image(fimg,num_jitters),16))))); + } + + return face_descriptors; + } + +private: + + dlib::rand rnd; + + std::vector<matrix<rgb_pixel>> jitter_image( + const matrix<rgb_pixel>& img, + const int num_jitters + ) + { + std::vector<matrix<rgb_pixel>> crops; + for (int i = 0; i < num_jitters; ++i) + crops.push_back(dlib::jitter_image(img,rnd)); + return crops; + } + + + template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET> + using residual = add_prev1<block<N,BN,1,tag1<SUBNET>>>; + + template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET> + using residual_down = add_prev2<avg_pool<2,2,2,2,skip1<tag2<block<N,BN,2,tag1<SUBNET>>>>>>; + + template <int N, template <typename> class BN, int stride, typename SUBNET> + using block = BN<con<N,3,3,1,1,relu<BN<con<N,3,3,stride,stride,SUBNET>>>>>; + + template <int N, typename SUBNET> using ares = relu<residual<block,N,affine,SUBNET>>; + template <int N, typename SUBNET> using ares_down = relu<residual_down<block,N,affine,SUBNET>>; + + template <typename SUBNET> using alevel0 = ares_down<256,SUBNET>; + template <typename SUBNET> using alevel1 = ares<256,ares<256,ares_down<256,SUBNET>>>; + template <typename SUBNET> using alevel2 = ares<128,ares<128,ares_down<128,SUBNET>>>; + template <typename SUBNET> using alevel3 = ares<64,ares<64,ares<64,ares_down<64,SUBNET>>>>; + template <typename SUBNET> using alevel4 = ares<32,ares<32,ares<32,SUBNET>>>; + + using anet_type = loss_metric<fc_no_bias<128,avg_pool_everything< + alevel0< + alevel1< + alevel2< + alevel3< + alevel4< + max_pool<3,3,2,2,relu<affine<con<32,7,7,2,2, + input_rgb_image_sized<150> + >>>>>>>>>>>>; + anet_type net; +}; + +// ---------------------------------------------------------------------------------------- + +py::list chinese_whispers_clustering(py::list descriptors, float threshold) +{ + DLIB_CASSERT(threshold > 0); + py::list clusters; + + size_t num_descriptors = py::len(descriptors); + + // This next bit of code creates a graph of connected objects and then uses the Chinese + // whispers graph clustering algorithm to identify how many objects there are and which + // objects belong to which cluster. + std::vector<sample_pair> edges; + std::vector<unsigned long> labels; + for (size_t i = 0; i < num_descriptors; ++i) + { + for (size_t j = i; j < num_descriptors; ++j) + { + matrix<double,0,1>& first_descriptor = descriptors[i].cast<matrix<double,0,1>&>(); + matrix<double,0,1>& second_descriptor = descriptors[j].cast<matrix<double,0,1>&>(); + + if (length(first_descriptor-second_descriptor) < threshold) + edges.push_back(sample_pair(i,j)); + } + } + chinese_whispers(edges, labels); + for (size_t i = 0; i < labels.size(); ++i) + { + clusters.append(labels[i]); + } + return clusters; +} + +void save_face_chips ( + py::object img, + const std::vector<full_object_detection>& faces, + const std::string& chip_filename, + size_t size = 150, + float padding = 0.25 +) +{ + if (!is_rgb_python_image(img)) + throw dlib::error("Unsupported image type, must be RGB image."); + + int num_faces = faces.size(); + std::vector<chip_details> dets; + for (auto& f : faces) + dets.push_back(get_face_chip_details(f, size, padding)); + dlib::array<matrix<rgb_pixel>> face_chips; + extract_image_chips(numpy_rgb_image(img), dets, face_chips); + int i=0; + for (auto& chip : face_chips) + { + i++; + if(num_faces > 1) + { + const std::string& file_name = chip_filename + "_" + std::to_string(i) + ".jpg"; + save_jpeg(chip, file_name); + } + else + { + const std::string& file_name = chip_filename + ".jpg"; + save_jpeg(chip, file_name); + } + } +} + +void save_face_chip ( + py::object img, + const full_object_detection& face, + const std::string& chip_filename, + size_t size = 150, + float padding = 0.25 +) +{ + std::vector<full_object_detection> faces(1, face); + save_face_chips(img, faces, chip_filename, size, padding); + return; +} + +void bind_face_recognition(py::module &m) +{ + { + py::class_<face_recognition_model_v1>(m, "face_recognition_model_v1", "This object maps human faces into 128D vectors where pictures of the same person are mapped near to each other and pictures of different people are mapped far apart. The constructor loads the face recognition model from a file. The model file is available here: http://dlib.net/files/dlib_face_recognition_resnet_model_v1.dat.bz2") + .def(py::init<std::string>()) + .def("compute_face_descriptor", &face_recognition_model_v1::compute_face_descriptor, py::arg("img"),py::arg("face"),py::arg("num_jitters")=0, + "Takes an image and a full_object_detection that references a face in that image and converts it into a 128D face descriptor. " + "If num_jitters>1 then each face will be randomly jittered slightly num_jitters times, each run through the 128D projection, and the average used as the face descriptor." + ) + .def("compute_face_descriptor", &face_recognition_model_v1::compute_face_descriptors, py::arg("img"),py::arg("faces"),py::arg("num_jitters")=0, + "Takes an image and an array of full_object_detections that reference faces in that image and converts them into 128D face descriptors. " + "If num_jitters>1 then each face will be randomly jittered slightly num_jitters times, each run through the 128D projection, and the average used as the face descriptor." + ); + } + + m.def("save_face_chip", &save_face_chip, + "Takes an image and a full_object_detection that references a face in that image and saves the face with the specified file name prefix. The face will be rotated upright and scaled to 150x150 pixels or with the optional specified size and padding.", + py::arg("img"), py::arg("face"), py::arg("chip_filename"), py::arg("size")=150, py::arg("padding")=0.25 + ); + m.def("save_face_chips", &save_face_chips, + "Takes an image and a full_object_detections object that reference faces in that image and saves the faces with the specified file name prefix. The faces will be rotated upright and scaled to 150x150 pixels or with the optional specified size and padding.", + py::arg("img"), py::arg("faces"), py::arg("chip_filename"), py::arg("size")=150, py::arg("padding")=0.25 + ); + m.def("chinese_whispers_clustering", &chinese_whispers_clustering, py::arg("descriptors"), py::arg("threshold"), + "Takes a list of descriptors and returns a list that contains a label for each descriptor. Clustering is done using dlib::chinese_whispers." + ); + { + typedef std::vector<full_object_detection> type; + py::bind_vector<type>(m, "full_object_detections", "An array of full_object_detection objects.") + .def("clear", &type::clear) + .def("resize", resize<type>) + .def("extend", extend_vector_with_python_list<full_object_detection>) + .def(py::pickle(&getstate<type>, &setstate<type>)); + } +} + diff --git a/ml/dlib/tools/python/src/global_optimization.cpp b/ml/dlib/tools/python/src/global_optimization.cpp new file mode 100644 index 000000000..f27185c51 --- /dev/null +++ b/ml/dlib/tools/python/src/global_optimization.cpp @@ -0,0 +1,442 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/global_optimization.h> +#include <dlib/matrix.h> +#include <pybind11/stl.h> + + +using namespace dlib; +using namespace std; +namespace py = pybind11; + +// ---------------------------------------------------------------------------------------- + +std::vector<bool> list_to_bool_vector( + const py::list& l +) +{ + std::vector<bool> result(len(l)); + for (long i = 0; i < result.size(); ++i) + { + result[i] = l[i].cast<bool>(); + } + return result; +} + +matrix<double,0,1> list_to_mat( + const py::list& l +) +{ + matrix<double,0,1> result(len(l)); + for (long i = 0; i < result.size(); ++i) + result(i) = l[i].cast<double>(); + return result; +} + +py::list mat_to_list ( + const matrix<double,0,1>& m +) +{ + py::list l; + for (long i = 0; i < m.size(); ++i) + l.append(m(i)); + return l; +} + +size_t num_function_arguments(py::object f, size_t expected_num) +{ + const auto code_object = f.attr(hasattr(f,"func_code") ? "func_code" : "__code__"); + const auto num = code_object.attr("co_argcount").cast<std::size_t>(); + if (num < expected_num && (code_object.attr("co_flags").cast<int>() & CO_VARARGS)) + return expected_num; + return num; +} + +double call_func(py::object f, const matrix<double,0,1>& args) +{ + const auto num = num_function_arguments(f, args.size()); + DLIB_CASSERT(num == args.size(), + "The function being optimized takes a number of arguments that doesn't agree with the size of the bounds lists you provided to find_max_global()"); + DLIB_CASSERT(0 < num && num < 15, "Functions being optimized must take between 1 and 15 scalar arguments."); + +#define CALL_WITH_N_ARGS(N) case N: return dlib::gopt_impl::_cwv(f,args,typename make_compile_time_integer_range<N>::type()).cast<double>(); + switch (num) + { + CALL_WITH_N_ARGS(1) + CALL_WITH_N_ARGS(2) + CALL_WITH_N_ARGS(3) + CALL_WITH_N_ARGS(4) + CALL_WITH_N_ARGS(5) + CALL_WITH_N_ARGS(6) + CALL_WITH_N_ARGS(7) + CALL_WITH_N_ARGS(8) + CALL_WITH_N_ARGS(9) + CALL_WITH_N_ARGS(10) + CALL_WITH_N_ARGS(11) + CALL_WITH_N_ARGS(12) + CALL_WITH_N_ARGS(13) + CALL_WITH_N_ARGS(14) + CALL_WITH_N_ARGS(15) + + default: + DLIB_CASSERT(false, "oops"); + break; + } +} + +// ---------------------------------------------------------------------------------------- + +py::tuple py_find_max_global ( + py::object f, + py::list bound1, + py::list bound2, + py::list is_integer_variable, + unsigned long num_function_calls, + double solver_epsilon = 0 +) +{ + DLIB_CASSERT(len(bound1) == len(bound2)); + DLIB_CASSERT(len(bound1) == len(is_integer_variable)); + + auto func = [&](const matrix<double,0,1>& x) + { + return call_func(f, x); + }; + + auto result = find_max_global(func, list_to_mat(bound1), list_to_mat(bound2), + list_to_bool_vector(is_integer_variable), max_function_calls(num_function_calls), + solver_epsilon); + + return py::make_tuple(mat_to_list(result.x),result.y); +} + +py::tuple py_find_max_global2 ( + py::object f, + py::list bound1, + py::list bound2, + unsigned long num_function_calls, + double solver_epsilon = 0 +) +{ + DLIB_CASSERT(len(bound1) == len(bound2)); + + auto func = [&](const matrix<double,0,1>& x) + { + return call_func(f, x); + }; + + auto result = find_max_global(func, list_to_mat(bound1), list_to_mat(bound2), max_function_calls(num_function_calls), solver_epsilon); + + return py::make_tuple(mat_to_list(result.x),result.y); +} + +// ---------------------------------------------------------------------------------------- + +py::tuple py_find_min_global ( + py::object f, + py::list bound1, + py::list bound2, + py::list is_integer_variable, + unsigned long num_function_calls, + double solver_epsilon = 0 +) +{ + DLIB_CASSERT(len(bound1) == len(bound2)); + DLIB_CASSERT(len(bound1) == len(is_integer_variable)); + + auto func = [&](const matrix<double,0,1>& x) + { + return call_func(f, x); + }; + + auto result = find_min_global(func, list_to_mat(bound1), list_to_mat(bound2), + list_to_bool_vector(is_integer_variable), max_function_calls(num_function_calls), + solver_epsilon); + + return py::make_tuple(mat_to_list(result.x),result.y); +} + +py::tuple py_find_min_global2 ( + py::object f, + py::list bound1, + py::list bound2, + unsigned long num_function_calls, + double solver_epsilon = 0 +) +{ + DLIB_CASSERT(len(bound1) == len(bound2)); + + auto func = [&](const matrix<double,0,1>& x) + { + return call_func(f, x); + }; + + auto result = find_min_global(func, list_to_mat(bound1), list_to_mat(bound2), max_function_calls(num_function_calls), solver_epsilon); + + return py::make_tuple(mat_to_list(result.x),result.y); +} + +// ---------------------------------------------------------------------------------------- + +function_spec py_function_spec1 ( + py::list a, + py::list b +) +{ + return function_spec(list_to_mat(a), list_to_mat(b)); +} + +function_spec py_function_spec2 ( + py::list a, + py::list b, + py::list c +) +{ + return function_spec(list_to_mat(a), list_to_mat(b), list_to_bool_vector(c)); +} + +std::shared_ptr<global_function_search> py_global_function_search1 ( + py::list functions +) +{ + std::vector<function_spec> tmp; + for (auto i : functions) + tmp.emplace_back(i.cast<function_spec>()); + + return std::make_shared<global_function_search>(tmp); +} + +std::shared_ptr<global_function_search> py_global_function_search2 ( + py::list functions, + py::list initial_function_evals, + double relative_noise_magnitude +) +{ + std::vector<function_spec> specs; + for (auto i : functions) + specs.emplace_back(i.cast<function_spec>()); + + std::vector<std::vector<function_evaluation>> func_evals; + for (auto i : initial_function_evals) + { + std::vector<function_evaluation> evals; + for (auto j : i) + { + evals.emplace_back(j.cast<function_evaluation>()); + } + func_evals.emplace_back(std::move(evals)); + } + + return std::make_shared<global_function_search>(specs, func_evals, relative_noise_magnitude); +} + +function_evaluation py_function_evaluation( + const py::list& x, + double y +) +{ + return function_evaluation(list_to_mat(x), y); +} + +// ---------------------------------------------------------------------------------------- + +void bind_global_optimization(py::module& m) +{ + /*! + requires + - len(bound1) == len(bound2) == len(is_integer_variable) + - for all valid i: bound1[i] != bound2[i] + - solver_epsilon >= 0 + - f() is a real valued multi-variate function. It must take scalar real + numbers as its arguments and the number of arguments must be len(bound1). + ensures + - This function performs global optimization on the given f() function. + The goal is to maximize the following objective function: + f(x) + subject to the constraints: + min(bound1[i],bound2[i]) <= x[i] <= max(bound1[i],bound2[i]) + if (is_integer_variable[i]) then x[i] is an integer. + - find_max_global() runs until it has called f() num_function_calls times. + Then it returns the best x it has found along with the corresponding output + of f(). That is, it returns (best_x_seen,f(best_x_seen)). Here best_x_seen + is a list containing the best arguments to f() this function has found. + - find_max_global() uses a global optimization method based on a combination of + non-parametric global function modeling and quadratic trust region modeling + to efficiently find a global maximizer. It usually does a good job with a + relatively small number of calls to f(). For more information on how it + works read the documentation for dlib's global_function_search object. + However, one notable element is the solver epsilon, which you can adjust. + + The search procedure will only attempt to find a global maximizer to at most + solver_epsilon accuracy. Once a local maximizer is found to that accuracy + the search will focus entirely on finding other maxima elsewhere rather than + on further improving the current local optima found so far. That is, once a + local maxima is identified to about solver_epsilon accuracy, the algorithm + will spend all its time exploring the function to find other local maxima to + investigate. An epsilon of 0 means it will keep solving until it reaches + full floating point precision. Larger values will cause it to switch to pure + global exploration sooner and therefore might be more effective if your + objective function has many local maxima and you don't care about a super + high precision solution. + - Any variables that satisfy the following conditions are optimized on a log-scale: + - The lower bound on the variable is > 0 + - The ratio of the upper bound to lower bound is > 1000 + - The variable is not an integer variable + We do this because it's common to optimize machine learning models that have + parameters with bounds in a range such as [1e-5 to 1e10] (e.g. the SVM C + parameter) and it's much more appropriate to optimize these kinds of + variables on a log scale. So we transform them by applying log() to + them and then undo the transform via exp() before invoking the function + being optimized. Therefore, this transformation is invisible to the user + supplied functions. In most cases, it improves the efficiency of the + optimizer. + !*/ + { + m.def("find_max_global", &py_find_max_global, +"requires \n\ + - len(bound1) == len(bound2) == len(is_integer_variable) \n\ + - for all valid i: bound1[i] != bound2[i] \n\ + - solver_epsilon >= 0 \n\ + - f() is a real valued multi-variate function. It must take scalar real \n\ + numbers as its arguments and the number of arguments must be len(bound1). \n\ +ensures \n\ + - This function performs global optimization on the given f() function. \n\ + The goal is to maximize the following objective function: \n\ + f(x) \n\ + subject to the constraints: \n\ + min(bound1[i],bound2[i]) <= x[i] <= max(bound1[i],bound2[i]) \n\ + if (is_integer_variable[i]) then x[i] is an integer. \n\ + - find_max_global() runs until it has called f() num_function_calls times. \n\ + Then it returns the best x it has found along with the corresponding output \n\ + of f(). That is, it returns (best_x_seen,f(best_x_seen)). Here best_x_seen \n\ + is a list containing the best arguments to f() this function has found. \n\ + - find_max_global() uses a global optimization method based on a combination of \n\ + non-parametric global function modeling and quadratic trust region modeling \n\ + to efficiently find a global maximizer. It usually does a good job with a \n\ + relatively small number of calls to f(). For more information on how it \n\ + works read the documentation for dlib's global_function_search object. \n\ + However, one notable element is the solver epsilon, which you can adjust. \n\ + \n\ + The search procedure will only attempt to find a global maximizer to at most \n\ + solver_epsilon accuracy. Once a local maximizer is found to that accuracy \n\ + the search will focus entirely on finding other maxima elsewhere rather than \n\ + on further improving the current local optima found so far. That is, once a \n\ + local maxima is identified to about solver_epsilon accuracy, the algorithm \n\ + will spend all its time exploring the function to find other local maxima to \n\ + investigate. An epsilon of 0 means it will keep solving until it reaches \n\ + full floating point precision. Larger values will cause it to switch to pure \n\ + global exploration sooner and therefore might be more effective if your \n\ + objective function has many local maxima and you don't care about a super \n\ + high precision solution. \n\ + - Any variables that satisfy the following conditions are optimized on a log-scale: \n\ + - The lower bound on the variable is > 0 \n\ + - The ratio of the upper bound to lower bound is > 1000 \n\ + - The variable is not an integer variable \n\ + We do this because it's common to optimize machine learning models that have \n\ + parameters with bounds in a range such as [1e-5 to 1e10] (e.g. the SVM C \n\ + parameter) and it's much more appropriate to optimize these kinds of \n\ + variables on a log scale. So we transform them by applying log() to \n\ + them and then undo the transform via exp() before invoking the function \n\ + being optimized. Therefore, this transformation is invisible to the user \n\ + supplied functions. In most cases, it improves the efficiency of the \n\ + optimizer." + , + py::arg("f"), py::arg("bound1"), py::arg("bound2"), py::arg("is_integer_variable"), py::arg("num_function_calls"), py::arg("solver_epsilon")=0 + ); + } + + { + m.def("find_max_global", &py_find_max_global2, + "This function simply calls the other version of find_max_global() with is_integer_variable set to False for all variables.", + py::arg("f"), py::arg("bound1"), py::arg("bound2"), py::arg("num_function_calls"), py::arg("solver_epsilon")=0 + ); + } + + + + { + m.def("find_min_global", &py_find_min_global, + "This function is just like find_max_global(), except it performs minimization rather than maximization." + , + py::arg("f"), py::arg("bound1"), py::arg("bound2"), py::arg("is_integer_variable"), py::arg("num_function_calls"), py::arg("solver_epsilon")=0 + ); + } + + { + m.def("find_min_global", &py_find_min_global2, + "This function simply calls the other version of find_min_global() with is_integer_variable set to False for all variables.", + py::arg("f"), py::arg("bound1"), py::arg("bound2"), py::arg("num_function_calls"), py::arg("solver_epsilon")=0 + ); + } + + // ------------------------------------------------- + // ------------------------------------------------- + + + py::class_<function_evaluation> (m, "function_evaluation", R"RAW( +This object records the output of a real valued function in response to +some input. + +In particular, if you have a function F(x) then the function_evaluation is +simply a struct that records x and the scalar value F(x). )RAW") + .def(py::init<matrix<double,0,1>,double>(), py::arg("x"), py::arg("y")) + .def(py::init<>(&py_function_evaluation), py::arg("x"), py::arg("y")) + .def_readonly("x", &function_evaluation::x) + .def_readonly("y", &function_evaluation::y); + + + py::class_<function_spec> (m, "function_spec", "See: http://dlib.net/dlib/global_optimization/global_function_search_abstract.h.html") + .def(py::init<matrix<double,0,1>,matrix<double,0,1>>(), py::arg("bound1"), py::arg("bound2") ) + .def(py::init<matrix<double,0,1>,matrix<double,0,1>,std::vector<bool>>(), py::arg("bound1"), py::arg("bound2"), py::arg("is_integer") ) + .def(py::init<>(&py_function_spec1), py::arg("bound1"), py::arg("bound2")) + .def(py::init<>(&py_function_spec2), py::arg("bound1"), py::arg("bound2"), py::arg("is_integer")) + .def_readonly("lower", &function_spec::lower) + .def_readonly("upper", &function_spec::upper) + .def_readonly("is_integer_variable", &function_spec::is_integer_variable); + + + py::class_<function_evaluation_request> (m, "function_evaluation_request", "See: http://dlib.net/dlib/global_optimization/global_function_search_abstract.h.html") + .def_property_readonly("function_idx", &function_evaluation_request::function_idx) + .def_property_readonly("x", &function_evaluation_request::x) + .def_property_readonly("has_been_evaluated", &function_evaluation_request::has_been_evaluated) + .def("set", &function_evaluation_request::set); + + py::class_<global_function_search, std::shared_ptr<global_function_search>> (m, "global_function_search", "See: http://dlib.net/dlib/global_optimization/global_function_search_abstract.h.html") + .def(py::init<function_spec>(), py::arg("function")) + .def(py::init<>(&py_global_function_search1), py::arg("functions")) + .def(py::init<>(&py_global_function_search2), py::arg("functions"), py::arg("initial_function_evals"), py::arg("relative_noise_magnitude")) + .def("set_seed", &global_function_search::set_seed, py::arg("seed")) + .def("num_functions", &global_function_search::num_functions) + .def("get_function_evaluations", [](const global_function_search& self) { + std::vector<function_spec> specs; + std::vector<std::vector<function_evaluation>> function_evals; + self.get_function_evaluations(specs,function_evals); + py::list py_specs, py_func_evals; + for (auto& s : specs) + py_specs.append(s); + for (auto& i : function_evals) + { + py::list tmp; + for (auto& j : i) + tmp.append(j); + py_func_evals.append(tmp); + } + return py::make_tuple(py_specs,py_func_evals);}) + .def("get_best_function_eval", [](const global_function_search& self) { + matrix<double,0,1> x; double y; size_t idx; self.get_best_function_eval(x,y,idx); return py::make_tuple(x,y,idx);}) + .def("get_next_x", &global_function_search::get_next_x) + .def("get_pure_random_search_probability", &global_function_search::get_pure_random_search_probability) + .def("set_pure_random_search_probability", &global_function_search::set_pure_random_search_probability, py::arg("prob")) + .def("get_solver_epsilon", &global_function_search::get_solver_epsilon) + .def("set_solver_epsilon", &global_function_search::set_solver_epsilon, py::arg("eps")) + .def("get_relative_noise_magnitude", &global_function_search::get_relative_noise_magnitude) + .def("set_relative_noise_magnitude", &global_function_search::set_relative_noise_magnitude, py::arg("value")) + .def("get_monte_carlo_upper_bound_sample_num", &global_function_search::get_monte_carlo_upper_bound_sample_num) + .def("set_monte_carlo_upper_bound_sample_num", &global_function_search::set_monte_carlo_upper_bound_sample_num, py::arg("num")) + ; + +} + diff --git a/ml/dlib/tools/python/src/gui.cpp b/ml/dlib/tools/python/src/gui.cpp new file mode 100644 index 000000000..418cfaae3 --- /dev/null +++ b/ml/dlib/tools/python/src/gui.cpp @@ -0,0 +1,128 @@ +#ifndef DLIB_NO_GUI_SUPPORT + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/geometry.h> +#include <dlib/image_processing/frontal_face_detector.h> +#include <dlib/image_processing/render_face_detections.h> +#include <dlib/gui_widgets.h> +#include "simple_object_detector_py.h" + +using namespace dlib; +using namespace std; + +namespace py = pybind11; + +// ---------------------------------------------------------------------------------------- + +void image_window_set_image_fhog_detector ( + image_window& win, + const simple_object_detector& det +) +{ + win.set_image(draw_fhog(det)); +} + +void image_window_set_image_simple_detector_py ( + image_window& win, + const simple_object_detector_py& det +) +{ + win.set_image(draw_fhog(det.detector)); +} + +// ---------------------------------------------------------------------------------------- + +void image_window_set_image ( + image_window& win, + py::object img +) +{ + if (is_gray_python_image(img)) + return win.set_image(numpy_gray_image(img)); + else if (is_rgb_python_image(img)) + return win.set_image(numpy_rgb_image(img)); + else + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); +} + +void add_overlay_rect ( + image_window& win, + const rectangle& rect, + const rgb_pixel& color +) +{ + win.add_overlay(rect, color); +} + +void add_overlay_drect ( + image_window& win, + const drectangle& drect, + const rgb_pixel& color +) +{ + rectangle rect(drect.left(), drect.top(), drect.right(), drect.bottom()); + win.add_overlay(rect, color); +} + +void add_overlay_parts ( + image_window& win, + const full_object_detection& detection, + const rgb_pixel& color +) +{ + win.add_overlay(render_face_detections(detection, color)); +} + +std::shared_ptr<image_window> make_image_window_from_image(py::object img) +{ + auto win = std::make_shared<image_window>(); + image_window_set_image(*win, img); + return win; +} + +std::shared_ptr<image_window> make_image_window_from_image_and_title(py::object img, const string& title) +{ + auto win = std::make_shared<image_window>(); + image_window_set_image(*win, img); + win->set_title(title); + return win; +} + +// ---------------------------------------------------------------------------------------- + +void bind_gui(py::module& m) +{ + { + typedef image_window type; + typedef void (image_window::*set_title_funct)(const std::string&); + typedef void (image_window::*add_overlay_funct)(const std::vector<rectangle>& r, rgb_pixel p); + py::class_<type, std::shared_ptr<type>>(m, "image_window", + "This is a GUI window capable of showing images on the screen.") + .def(py::init()) + .def(py::init(&make_image_window_from_image), + "Create an image window that displays the given numpy image.") + .def(py::init(&make_image_window_from_image_and_title), + "Create an image window that displays the given numpy image and also has the given title.") + .def("set_image", image_window_set_image_simple_detector_py, py::arg("detector"), + "Make the image_window display the given HOG detector's filters.") + .def("set_image", image_window_set_image_fhog_detector, py::arg("detector"), + "Make the image_window display the given HOG detector's filters.") + .def("set_image", image_window_set_image, py::arg("image"), + "Make the image_window display the given image.") + .def("set_title", (set_title_funct)&type::set_title, py::arg("title"), + "Set the title of the window to the given value.") + .def("clear_overlay", &type::clear_overlay, "Remove all overlays from the image_window.") + .def("add_overlay", (add_overlay_funct)&type::add_overlay<rgb_pixel>, py::arg("rectangles"), py::arg("color")=rgb_pixel(255, 0, 0), + "Add a list of rectangles to the image_window. They will be displayed as red boxes by default, but the color can be passed.") + .def("add_overlay", add_overlay_rect, py::arg("rectangle"), py::arg("color")=rgb_pixel(255, 0, 0), + "Add a rectangle to the image_window. It will be displayed as a red box by default, but the color can be passed.") + .def("add_overlay", add_overlay_drect, py::arg("rectangle"), py::arg("color")=rgb_pixel(255, 0, 0), + "Add a rectangle to the image_window. It will be displayed as a red box by default, but the color can be passed.") + .def("add_overlay", add_overlay_parts, py::arg("detection"), py::arg("color")=rgb_pixel(0, 0, 255), + "Add full_object_detection parts to the image window. They will be displayed as blue lines by default, but the color can be passed.") + .def("wait_until_closed", &type::wait_until_closed, + "This function blocks until the window is closed."); + } +} +#endif diff --git a/ml/dlib/tools/python/src/image.cpp b/ml/dlib/tools/python/src/image.cpp new file mode 100644 index 000000000..bd43ce5ab --- /dev/null +++ b/ml/dlib/tools/python/src/image.cpp @@ -0,0 +1,40 @@ +#include "opaque_types.h" +#include <dlib/python.h> +#include "dlib/pixel.h" +#include <dlib/image_transforms.h> + +using namespace dlib; +using namespace std; + +namespace py = pybind11; + +// ---------------------------------------------------------------------------------------- + +string print_rgb_pixel_str(const rgb_pixel& p) +{ + std::ostringstream sout; + sout << "red: "<< (int)p.red + << ", green: "<< (int)p.green + << ", blue: "<< (int)p.blue; + return sout.str(); +} + +string print_rgb_pixel_repr(const rgb_pixel& p) +{ + std::ostringstream sout; + sout << "rgb_pixel(" << (int)p.red << "," << (int)p.green << "," << (int)p.blue << ")"; + return sout.str(); +} + +// ---------------------------------------------------------------------------------------- + +void bind_image_classes(py::module& m) +{ + py::class_<rgb_pixel>(m, "rgb_pixel") + .def(py::init<unsigned char,unsigned char,unsigned char>(), py::arg("red"), py::arg("green"), py::arg("blue")) + .def("__str__", &print_rgb_pixel_str) + .def("__repr__", &print_rgb_pixel_repr) + .def_readwrite("red", &rgb_pixel::red) + .def_readwrite("green", &rgb_pixel::green) + .def_readwrite("blue", &rgb_pixel::blue); +} diff --git a/ml/dlib/tools/python/src/image_dataset_metadata.cpp b/ml/dlib/tools/python/src/image_dataset_metadata.cpp new file mode 100644 index 000000000..8f23ddd3f --- /dev/null +++ b/ml/dlib/tools/python/src/image_dataset_metadata.cpp @@ -0,0 +1,279 @@ +// Copyright (C) 2018 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/data_io.h> +#include <dlib/image_processing.h> +#include <pybind11/stl_bind.h> +#include <pybind11/stl.h> +#include <iostream> + +namespace pybind11 +{ + + // a version of bind_map that doesn't force it's own __repr__ on you. +template <typename Map, typename holder_type = std::unique_ptr<Map>, typename... Args> +class_<Map, holder_type> bind_map_no_default_repr(handle scope, const std::string &name, Args&&... args) { + using KeyType = typename Map::key_type; + using MappedType = typename Map::mapped_type; + using Class_ = class_<Map, holder_type>; + + // If either type is a non-module-local bound type then make the map binding non-local as well; + // otherwise (e.g. both types are either module-local or converting) the map will be + // module-local. + auto tinfo = detail::get_type_info(typeid(MappedType)); + bool local = !tinfo || tinfo->module_local; + if (local) { + tinfo = detail::get_type_info(typeid(KeyType)); + local = !tinfo || tinfo->module_local; + } + + Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...); + + cl.def(init<>()); + + + cl.def("__bool__", + [](const Map &m) -> bool { return !m.empty(); }, + "Check whether the map is nonempty" + ); + + cl.def("__iter__", + [](Map &m) { return make_key_iterator(m.begin(), m.end()); }, + keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ + ); + + cl.def("items", + [](Map &m) { return make_iterator(m.begin(), m.end()); }, + keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ + ); + + cl.def("__getitem__", + [](Map &m, const KeyType &k) -> MappedType & { + auto it = m.find(k); + if (it == m.end()) + throw key_error(); + return it->second; + }, + return_value_policy::reference_internal // ref + keepalive + ); + + // Assignment provided only if the type is copyable + detail::map_assignment<Map, Class_>(cl); + + cl.def("__delitem__", + [](Map &m, const KeyType &k) { + auto it = m.find(k); + if (it == m.end()) + throw key_error(); + return m.erase(it); + } + ); + + cl.def("__len__", &Map::size); + + return cl; +} + +} + +using namespace dlib; +using namespace std; +using namespace dlib::image_dataset_metadata; + +namespace py = pybind11; + + +dataset py_load_image_dataset_metadata( + const std::string& filename +) +{ + dataset temp; + load_image_dataset_metadata(temp, filename); + return temp; +} + +std::shared_ptr<std::map<std::string,point>> map_from_object(py::dict obj) +{ + auto ret = std::make_shared<std::map<std::string,point>>(); + for (auto& v : obj) + { + (*ret)[v.first.cast<std::string>()] = v.second.cast<point>(); + } + return ret; +} + +// ---------------------------------------------------------------------------------------- + +image_dataset_metadata::dataset py_make_bounding_box_regression_training_data ( + const image_dataset_metadata::dataset& truth, + const py::object& detections +) +{ + try + { + // if detections is a std::vector then call like this. + return make_bounding_box_regression_training_data(truth, detections.cast<const std::vector<std::vector<rectangle>>&>()); + } + catch (py::cast_error&) + { + // otherwise, detections should be a list of std::vectors. + py::list dets(detections); + std::vector<std::vector<rectangle>> temp; + for (auto& d : dets) + temp.emplace_back(d.cast<const std::vector<rectangle>&>()); + return make_bounding_box_regression_training_data(truth, temp); + } +} + +// ---------------------------------------------------------------------------------------- + +void bind_image_dataset_metadata(py::module &m_) +{ + auto m = m_.def_submodule("image_dataset_metadata", "Routines and objects for working with dlib's image dataset metadata XML files."); + + auto datasetstr = [](const dataset& item) { return "dlib.dataset_dataset_metadata.dataset: images:" + to_string(item.images.size()) + ", " + item.name; }; + auto datasetrepr = [datasetstr](const dataset& item) { return "<"+datasetstr(item)+">"; }; + py::class_<dataset>(m, "dataset", + "This object represents a labeled set of images. In particular, it contains the filename for each image as well as annotated boxes.") + .def("__str__", datasetstr) + .def("__repr__", datasetrepr) + .def_readwrite("images", &dataset::images) + .def_readwrite("comment", &dataset::comment) + .def_readwrite("name", &dataset::name); + + auto imagestr = [](const image& item) { return "dlib.image_dataset_metadata.image: boxes:"+to_string(item.boxes.size())+ ", " + item.filename; }; + auto imagerepr = [imagestr](const image& item) { return "<"+imagestr(item)+">"; }; + py::class_<image>(m, "image", "This object represents an annotated image.") + .def_readwrite("filename", &image::filename) + .def("__str__", imagestr) + .def("__repr__", imagerepr) + .def_readwrite("boxes", &image::boxes); + + + auto partsstr = [](const std::map<std::string,point>& item) { + std::ostringstream sout; + sout << "{"; + for (auto& v : item) + sout << "'" << v.first << "': " << v.second << ", "; + sout << "}"; + return sout.str(); + }; + auto partsrepr = [](const std::map<std::string,point>& item) { + std::ostringstream sout; + sout << "dlib.image_dataset_metadata.parts({\n"; + for (auto& v : item) + sout << "'" << v.first << "': dlib.point" << v.second << ",\n"; + sout << "})"; + return sout.str(); + }; + + py::bind_map_no_default_repr<std::map<std::string,point>, std::shared_ptr<std::map<std::string,point>> >(m, "parts", + "This object is a dictionary mapping string names to object part locations.") + .def(py::init(&map_from_object)) + .def("__str__", partsstr) + .def("__repr__", partsrepr); + + + auto rectstr = [](const rectangle& r) { + std::ostringstream sout; + sout << "dlib.rectangle(" << r.left() << "," << r.top() << "," << r.right() << "," << r.bottom() << ")"; + return sout.str(); + }; + auto boxstr = [rectstr](const box& item) { return "dlib.image_dataset_metadata.box at " + rectstr(item.rect); }; + auto boxrepr = [boxstr](const box& item) { return "<"+boxstr(item)+">"; }; + py::class_<box> pybox(m, "box", + "This object represents an annotated rectangular area of an image. \n" + "It is typically used to mark the location of an object such as a \n" + "person, car, etc.\n" + "\n" + "The main variable of interest is rect. It gives the location of \n" + "the box. All the other variables are optional." ); pybox + .def("__str__", boxstr) + .def("__repr__", boxrepr) + .def_readwrite("rect", &box::rect) + .def_readonly("parts", &box::parts) + .def_readwrite("label", &box::label) + .def_readwrite("difficult", &box::difficult) + .def_readwrite("truncated", &box::truncated) + .def_readwrite("occluded", &box::occluded) + .def_readwrite("ignore", &box::ignore) + .def_readwrite("pose", &box::pose) + .def_readwrite("detection_score", &box::detection_score) + .def_readwrite("angle", &box::angle) + .def_readwrite("gender", &box::gender) + .def_readwrite("age", &box::age); + + py::enum_<gender_t>(pybox,"gender_type") + .value("MALE", gender_t::MALE) + .value("FEMALE", gender_t::FEMALE) + .value("UNKNOWN", gender_t::UNKNOWN) + .export_values(); + + + m.def("save_image_dataset_metadata", &save_image_dataset_metadata, py::arg("data"), py::arg("filename"), + "Writes the contents of the meta object to a file with the given filename. The file will be in an XML format." + ); + + m.def("load_image_dataset_metadata", &py_load_image_dataset_metadata, py::arg("filename"), + "Attempts to interpret filename as a file containing XML formatted data as produced " + "by the save_image_dataset_metadata() function. The data is loaded and returned as a dlib.image_dataset_metadata.dataset object." + ); + + m_.def("make_bounding_box_regression_training_data", &py_make_bounding_box_regression_training_data, + py::arg("truth"), py::arg("detections"), +"requires \n\ + - len(truth.images) == len(detections) \n\ + - detections == A dlib.rectangless object or a list of dlib.rectangles. \n\ +ensures \n\ + - Suppose you have an object detector that can roughly locate objects in an \n\ + image. This means your detector draws boxes around objects, but these are \n\ + *rough* boxes in the sense that they aren't positioned super accurately. For \n\ + instance, HOG based detectors usually have a stride of 8 pixels. So the \n\ + positional accuracy is going to be, at best, +/-8 pixels. \n\ + \n\ + If you want to get better positional accuracy one easy thing to do is train a \n\ + shape_predictor to give you the corners of the object. The \n\ + make_bounding_box_regression_training_data() routine helps you do this by \n\ + creating an appropriate training dataset. It does this by taking the dataset \n\ + you used to train your detector (the truth object), and combining that with \n\ + the output of your detector on each image in the training dataset (the \n\ + detections object). In particular, it will create a new annotated dataset \n\ + where each object box is one of the rectangles from detections and that \n\ + object has 4 part annotations, the corners of the truth rectangle \n\ + corresponding to that detection rectangle. You can then take the returned \n\ + dataset and train a shape_predictor on it. The resulting shape_predictor can \n\ + then be used to do bounding box regression. \n\ + - We assume that detections[i] contains object detections corresponding to \n\ + the image truth.images[i]." + /*! + requires + - len(truth.images) == len(detections) + - detections == A dlib.rectangless object or a list of dlib.rectangles. + ensures + - Suppose you have an object detector that can roughly locate objects in an + image. This means your detector draws boxes around objects, but these are + *rough* boxes in the sense that they aren't positioned super accurately. For + instance, HOG based detectors usually have a stride of 8 pixels. So the + positional accuracy is going to be, at best, +/-8 pixels. + + If you want to get better positional accuracy one easy thing to do is train a + shape_predictor to give you the corners of the object. The + make_bounding_box_regression_training_data() routine helps you do this by + creating an appropriate training dataset. It does this by taking the dataset + you used to train your detector (the truth object), and combining that with + the output of your detector on each image in the training dataset (the + detections object). In particular, it will create a new annotated dataset + where each object box is one of the rectangles from detections and that + object has 4 part annotations, the corners of the truth rectangle + corresponding to that detection rectangle. You can then take the returned + dataset and train a shape_predictor on it. The resulting shape_predictor can + then be used to do bounding box regression. + - We assume that detections[i] contains object detections corresponding to + the image truth.images[i]. + !*/ + ); +} + + diff --git a/ml/dlib/tools/python/src/indexing.h b/ml/dlib/tools/python/src/indexing.h new file mode 100644 index 000000000..3aa398f02 --- /dev/null +++ b/ml/dlib/tools/python/src/indexing.h @@ -0,0 +1,11 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PYTHON_INDEXING_H__ +#define DLIB_PYTHON_INDEXING_H__ + +namespace dlib +{ + template <typename T> + void resize(T& v, unsigned long n) { v.resize(n); } +} +#endif // DLIB_PYTHON_INDEXING_H__ diff --git a/ml/dlib/tools/python/src/matrix.cpp b/ml/dlib/tools/python/src/matrix.cpp new file mode 100644 index 000000000..a93544820 --- /dev/null +++ b/ml/dlib/tools/python/src/matrix.cpp @@ -0,0 +1,209 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/matrix.h> +#include <dlib/string.h> +#include <pybind11/pybind11.h> + +using namespace dlib; +namespace py = pybind11; +using std::string; +using std::ostringstream; + + +void matrix_set_size(matrix<double>& m, long nr, long nc) +{ + m.set_size(nr,nc); + m = 0; +} + +string matrix_double__repr__(matrix<double>& c) +{ + ostringstream sout; + sout << "< dlib.matrix containing: \n"; + sout << c; + return trim(sout.str()) + " >"; +} + +string matrix_double__str__(matrix<double>& c) +{ + ostringstream sout; + sout << c; + return trim(sout.str()); +} + +std::shared_ptr<matrix<double> > make_matrix_from_size(long nr, long nc) +{ + if (nr < 0 || nc < 0) + { + PyErr_SetString( PyExc_IndexError, "Input dimensions can't be negative." + ); + throw py::error_already_set(); + } + auto temp = std::make_shared<matrix<double>>(nr,nc); + *temp = 0; + return temp; +} + + +std::shared_ptr<matrix<double> > from_object(py::object obj) +{ + py::tuple s = obj.attr("shape").cast<py::tuple>(); + if (len(s) != 2) + { + PyErr_SetString( PyExc_IndexError, "Input must be a matrix or some kind of 2D array." + ); + throw py::error_already_set(); + } + + const long nr = s[0].cast<long>(); + const long nc = s[1].cast<long>(); + auto temp = std::make_shared<matrix<double>>(nr,nc); + for ( long r = 0; r < nr; ++r) + { + for (long c = 0; c < nc; ++c) + { + (*temp)(r,c) = obj[py::make_tuple(r,c)].cast<double>(); + } + } + return temp; +} + +std::shared_ptr<matrix<double> > from_list(py::list l) +{ + const long nr = py::len(l); + if (py::isinstance<py::list>(l[0])) + { + const long nc = py::len(l[0]); + // make sure all the other rows have the same length + for (long r = 1; r < nr; ++r) + pyassert(py::len(l[r]) == nc, "All rows of a matrix must have the same number of columns."); + + auto temp = std::make_shared<matrix<double>>(nr,nc); + for ( long r = 0; r < nr; ++r) + { + for (long c = 0; c < nc; ++c) + { + (*temp)(r,c) = l[r].cast<py::list>()[c].cast<double>(); + } + } + return temp; + } + else + { + // In this case we treat it like a column vector + auto temp = std::make_shared<matrix<double>>(nr,1); + for ( long r = 0; r < nr; ++r) + { + (*temp)(r) = l[r].cast<double>(); + } + return temp; + } +} + +long matrix_double__len__(matrix<double>& c) +{ + return c.nr(); +} + +struct mat_row +{ + mat_row() : data(0),size(0) {} + mat_row(double* data_, long size_) : data(data_),size(size_) {} + double* data; + long size; +}; + +void mat_row__setitem__(mat_row& c, long p, double val) +{ + if (p < 0) { + p = c.size + p; // negative index + } + if (p > c.size-1) { + PyErr_SetString( PyExc_IndexError, "3 index out of range" + ); + throw py::error_already_set(); + } + c.data[p] = val; +} + + +string mat_row__str__(mat_row& c) +{ + ostringstream sout; + sout << mat(c.data,1, c.size); + return sout.str(); +} + +string mat_row__repr__(mat_row& c) +{ + ostringstream sout; + sout << "< matrix row: " << mat(c.data,1, c.size); + return trim(sout.str()) + " >"; +} + +long mat_row__len__(mat_row& m) +{ + return m.size; +} + +double mat_row__getitem__(mat_row& m, long r) +{ + if (r < 0) { + r = m.size + r; // negative index + } + if (r > m.size-1 || r < 0) { + PyErr_SetString( PyExc_IndexError, "1 index out of range" + ); + throw py::error_already_set(); + } + return m.data[r]; +} + +mat_row matrix_double__getitem__(matrix<double>& m, long r) +{ + if (r < 0) { + r = m.nr() + r; // negative index + } + if (r > m.nr()-1 || r < 0) { + PyErr_SetString( PyExc_IndexError, (string("2 index out of range, got ") + cast_to_string(r)).c_str() + ); + throw py::error_already_set(); + } + return mat_row(&m(r,0),m.nc()); +} + + +py::tuple get_matrix_size(matrix<double>& m) +{ + return py::make_tuple(m.nr(), m.nc()); +} + +void bind_matrix(py::module& m) +{ + py::class_<mat_row>(m, "_row") + .def("__len__", &mat_row__len__) + .def("__repr__", &mat_row__repr__) + .def("__str__", &mat_row__str__) + .def("__setitem__", &mat_row__setitem__) + .def("__getitem__", &mat_row__getitem__); + + py::class_<matrix<double>, std::shared_ptr<matrix<double>>>(m, "matrix", + "This object represents a dense 2D matrix of floating point numbers." + "Moreover, it binds directly to the C++ type dlib::matrix<double>.") + .def(py::init<>()) + .def(py::init(&from_list)) + .def(py::init(&from_object)) + .def(py::init(&make_matrix_from_size)) + .def("set_size", &matrix_set_size, py::arg("rows"), py::arg("cols"), "Set the size of the matrix to the given number of rows and columns.") + .def("__repr__", &matrix_double__repr__) + .def("__str__", &matrix_double__str__) + .def("nr", &matrix<double>::nr, "Return the number of rows in the matrix.") + .def("nc", &matrix<double>::nc, "Return the number of columns in the matrix.") + .def("__len__", &matrix_double__len__) + .def("__getitem__", &matrix_double__getitem__, py::keep_alive<0,1>()) + .def_property_readonly("shape", &get_matrix_size) + .def(py::pickle(&getstate<matrix<double>>, &setstate<matrix<double>>)); +} diff --git a/ml/dlib/tools/python/src/numpy_returns.cpp b/ml/dlib/tools/python/src/numpy_returns.cpp new file mode 100644 index 000000000..235816a78 --- /dev/null +++ b/ml/dlib/tools/python/src/numpy_returns.cpp @@ -0,0 +1,158 @@ +#include "opaque_types.h" +#include <dlib/python.h> +#include "dlib/pixel.h" +#include <dlib/image_transforms.h> + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#include <numpy/ndarrayobject.h> + + +using namespace dlib; +using namespace std; + +namespace py = pybind11; + +// ---------------------------------------------------------------------------------------- + +py::list get_jitter_images(py::object img, size_t num_jitters = 1, bool disturb_colors = false) +{ + static dlib::rand rnd_jitter; + if (!is_rgb_python_image(img)) + throw dlib::error("Unsupported image type, must be RGB image."); + + // Convert the image to matrix<rgb_pixel> for processing + matrix<rgb_pixel> img_mat; + assign_image(img_mat, numpy_rgb_image(img)); + + // The top level list (containing 1 or more images) to return to python + py::list jitter_list; + + size_t rows = num_rows(img_mat); + size_t cols = num_columns(img_mat); + + // Size of the numpy array + npy_intp dims[3] = { num_rows(img_mat), num_columns(img_mat), 3}; + + for (int i = 0; i < num_jitters; ++i) { + // Get a jittered crop + matrix<rgb_pixel> crop = dlib::jitter_image(img_mat, rnd_jitter); + // If required disturb colors of the image + if(disturb_colors) + dlib::disturb_colors(crop, rnd_jitter); + + PyObject *arr = PyArray_SimpleNew(3, dims, NPY_UINT8); + npy_uint8 *outdata = (npy_uint8 *) PyArray_DATA((PyArrayObject*) arr); + memcpy(outdata, image_data(crop), rows * width_step(crop)); + + py::handle handle(arr); + // Append image to jittered image list + jitter_list.append(handle); + } + + return jitter_list; +} + +// ---------------------------------------------------------------------------------------- + +py::list get_face_chips ( + py::object img, + const std::vector<full_object_detection>& faces, + size_t size = 150, + float padding = 0.25 +) +{ + if (!is_rgb_python_image(img)) + throw dlib::error("Unsupported image type, must be RGB image."); + + if (faces.size() < 1) { + throw dlib::error("No face were specified in the faces array."); + } + + py::list chips_list; + + std::vector<chip_details> dets; + for (auto& f : faces) + dets.push_back(get_face_chip_details(f, size, padding)); + dlib::array<matrix<rgb_pixel>> face_chips; + extract_image_chips(numpy_rgb_image(img), dets, face_chips); + + npy_intp rows = size; + npy_intp cols = size; + + // Size of the numpy array + npy_intp dims[3] = { rows, cols, 3}; + + for (auto& chip : face_chips) + { + PyObject *arr = PyArray_SimpleNew(3, dims, NPY_UINT8); + npy_uint8 *outdata = (npy_uint8 *) PyArray_DATA((PyArrayObject*) arr); + memcpy(outdata, image_data(chip), rows * width_step(chip)); + py::handle handle(arr); + + // Append image to chips list + chips_list.append(handle); + } + return chips_list; +} + +py::object get_face_chip ( + py::object img, + const full_object_detection& face, + size_t size = 150, + float padding = 0.25 +) +{ + if (!is_rgb_python_image(img)) + throw dlib::error("Unsupported image type, must be RGB image."); + + matrix<rgb_pixel> chip; + extract_image_chip(numpy_rgb_image(img), get_face_chip_details(face, size, padding), chip); + + // Size of the numpy array + npy_intp dims[3] = { num_rows(chip), num_columns(chip), 3}; + + PyObject *arr = PyArray_SimpleNew(3, dims, NPY_UINT8); + npy_uint8 *outdata = (npy_uint8 *) PyArray_DATA((PyArrayObject *) arr); + memcpy(outdata, image_data(chip), num_rows(chip) * width_step(chip)); + py::handle handle(arr); + return handle.cast<py::object>(); +} + +// ---------------------------------------------------------------------------------------- + +// we need this wonky stuff because different versions of numpy's import_array macro +// contain differently typed return statements inside import_array(). +#if PY_VERSION_HEX >= 0x03000000 +#define DLIB_NUMPY_IMPORT_ARRAY_RETURN_TYPE void* +#define DLIB_NUMPY_IMPORT_RETURN return 0 +#else +#define DLIB_NUMPY_IMPORT_ARRAY_RETURN_TYPE void +#define DLIB_NUMPY_IMPORT_RETURN return +#endif +DLIB_NUMPY_IMPORT_ARRAY_RETURN_TYPE import_numpy_stuff() +{ + import_array(); + DLIB_NUMPY_IMPORT_RETURN; +} + +void bind_numpy_returns(py::module &m) +{ + import_numpy_stuff(); + + m.def("jitter_image", &get_jitter_images, + "Takes an image and returns a list of jittered images." + "The returned list contains num_jitters images (default is 1)." + "If disturb_colors is set to True, the colors of the image are disturbed (default is False)", + py::arg("img"), py::arg("num_jitters")=1, py::arg("disturb_colors")=false + ); + + m.def("get_face_chip", &get_face_chip, + "Takes an image and a full_object_detection that references a face in that image and returns the face as a Numpy array representing the image. The face will be rotated upright and scaled to 150x150 pixels or with the optional specified size and padding.", + py::arg("img"), py::arg("face"), py::arg("size")=150, py::arg("padding")=0.25 + ); + + m.def("get_face_chips", &get_face_chips, + "Takes an image and a full_object_detections object that reference faces in that image and returns the faces as a list of Numpy arrays representing the image. The faces will be rotated upright and scaled to 150x150 pixels or with the optional specified size and padding.", + py::arg("img"), py::arg("faces"), py::arg("size")=150, py::arg("padding")=0.25 + ); +} diff --git a/ml/dlib/tools/python/src/numpy_returns_stub.cpp b/ml/dlib/tools/python/src/numpy_returns_stub.cpp new file mode 100644 index 000000000..07d38ceac --- /dev/null +++ b/ml/dlib/tools/python/src/numpy_returns_stub.cpp @@ -0,0 +1,59 @@ +#include "opaque_types.h" +#include <dlib/python.h> +#include "dlib/pixel.h" +#include <dlib/image_transforms.h> + +using namespace dlib; +using namespace std; +namespace py = pybind11; + +// ---------------------------------------------------------------------------------------- + +py::list get_jitter_images(py::object img, size_t num_jitters = 1, bool disturb_colors = false) +{ + throw dlib::error("jitter_image is only supported if you compiled dlib with numpy installed!"); +} + +// ---------------------------------------------------------------------------------------- + +py::list get_face_chips ( + py::object img, + const std::vector<full_object_detection>& faces, + size_t size = 150, + float padding = 0.25 +) +{ + throw dlib::error("get_face_chips is only supported if you compiled dlib with numpy installed!"); +} + +py::object get_face_chip ( + py::object img, + const full_object_detection& face, + size_t size = 150, + float padding = 0.25 +) +{ + throw dlib::error("get_face_chip is only supported if you compiled dlib with numpy installed!"); +} + +// ---------------------------------------------------------------------------------------- + +void bind_numpy_returns(py::module &m) +{ + m.def("jitter_image", &get_jitter_images, + "Takes an image and returns a list of jittered images." + "The returned list contains num_jitters images (default is 1)." + "If disturb_colors is set to True, the colors of the image are disturbed (default is False)", + py::arg("img"), py::arg("num_jitters")=1, py::arg("disturb_colors")=false + ); + + m.def("get_face_chip", &get_face_chip, + "Takes an image and a full_object_detection that references a face in that image and returns the face as a Numpy array representing the image. The face will be rotated upright and scaled to 150x150 pixels or with the optional specified size and padding.", + py::arg("img"), py::arg("face"), py::arg("size")=150, py::arg("padding")=0.25 + ); + + m.def("get_face_chips", &get_face_chips, + "Takes an image and a full_object_detections object that reference faces in that image and returns the faces as a list of Numpy arrays representing the image. The faces will be rotated upright and scaled to 150x150 pixels or with the optional specified size and padding.", + py::arg("img"), py::arg("faces"), py::arg("size")=150, py::arg("padding")=0.25 + ); +} diff --git a/ml/dlib/tools/python/src/object_detection.cpp b/ml/dlib/tools/python/src/object_detection.cpp new file mode 100644 index 000000000..bda570d7d --- /dev/null +++ b/ml/dlib/tools/python/src/object_detection.cpp @@ -0,0 +1,376 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/matrix.h> +#include <dlib/geometry.h> +#include <dlib/image_processing/frontal_face_detector.h> +#include "simple_object_detector.h" +#include "simple_object_detector_py.h" +#include "conversion.h" + +using namespace dlib; +using namespace std; + +namespace py = pybind11; + +// ---------------------------------------------------------------------------------------- + +string print_simple_test_results(const simple_test_results& r) +{ + std::ostringstream sout; + sout << "precision: "<<r.precision << ", recall: "<< r.recall << ", average precision: " << r.average_precision; + return sout.str(); +} + +// ---------------------------------------------------------------------------------------- + +inline simple_object_detector_py train_simple_object_detector_on_images_py ( + const py::list& pyimages, + const py::list& pyboxes, + const simple_object_detector_training_options& options +) +{ + const unsigned long num_images = py::len(pyimages); + if (num_images != py::len(pyboxes)) + throw dlib::error("The length of the boxes list must match the length of the images list."); + + // We never have any ignore boxes for this version of the API. + std::vector<std::vector<rectangle> > ignore(num_images), boxes(num_images); + dlib::array<array2d<rgb_pixel> > images(num_images); + images_and_nested_params_to_dlib(pyimages, pyboxes, images, boxes); + + return train_simple_object_detector_on_images("", images, boxes, ignore, options); +} + +inline simple_test_results test_simple_object_detector_with_images_py ( + const py::list& pyimages, + const py::list& pyboxes, + simple_object_detector& detector, + const unsigned int upsampling_amount +) +{ + const unsigned long num_images = py::len(pyimages); + if (num_images != py::len(pyboxes)) + throw dlib::error("The length of the boxes list must match the length of the images list."); + + // We never have any ignore boxes for this version of the API. + std::vector<std::vector<rectangle> > ignore(num_images), boxes(num_images); + dlib::array<array2d<rgb_pixel> > images(num_images); + images_and_nested_params_to_dlib(pyimages, pyboxes, images, boxes); + + return test_simple_object_detector_with_images(images, upsampling_amount, boxes, ignore, detector); +} + +// ---------------------------------------------------------------------------------------- + +inline simple_test_results test_simple_object_detector_py_with_images_py ( + const py::list& pyimages, + const py::list& pyboxes, + simple_object_detector_py& detector, + const int upsampling_amount +) +{ + // Allow users to pass an upsampling amount ELSE use the one cached on the object + // Anything less than 0 is ignored and the cached value is used. + unsigned int final_upsampling_amount = 0; + if (upsampling_amount >= 0) + final_upsampling_amount = upsampling_amount; + else + final_upsampling_amount = detector.upsampling_amount; + + return test_simple_object_detector_with_images_py(pyimages, pyboxes, detector.detector, final_upsampling_amount); +} + +// ---------------------------------------------------------------------------------------- + +inline void find_candidate_object_locations_py ( + py::object pyimage, + py::list& pyboxes, + py::tuple pykvals, + unsigned long min_size, + unsigned long max_merging_iterations +) +{ + // Copy the data into dlib based objects + array2d<rgb_pixel> image; + if (is_gray_python_image(pyimage)) + assign_image(image, numpy_gray_image(pyimage)); + else if (is_rgb_python_image(pyimage)) + assign_image(image, numpy_rgb_image(pyimage)); + else + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); + + if (py::len(pykvals) != 3) + throw dlib::error("kvals must be a tuple with three elements for start, end, num."); + + double start = pykvals[0].cast<double>(); + double end = pykvals[1].cast<double>(); + long num = pykvals[2].cast<long>(); + matrix_range_exp<double> kvals = linspace(start, end, num); + + std::vector<rectangle> rects; + const long count = py::len(pyboxes); + // Copy any rectangles in the input pyboxes into rects so that any rectangles will be + // properly deduped in the resulting output. + for (long i = 0; i < count; ++i) + rects.push_back(pyboxes[i].cast<rectangle>()); + // Find candidate objects + find_candidate_object_locations(image, rects, kvals, min_size, max_merging_iterations); + + // Collect boxes containing candidate objects + std::vector<rectangle>::iterator iter; + for (iter = rects.begin(); iter != rects.end(); ++iter) + pyboxes.append(*iter); +} + +// ---------------------------------------------------------------------------------------- + +void bind_object_detection(py::module& m) +{ + { + typedef simple_object_detector_training_options type; + py::class_<type>(m, "simple_object_detector_training_options", + "This object is a container for the options to the train_simple_object_detector() routine.") + .def(py::init()) + .def_readwrite("be_verbose", &type::be_verbose, +"If true, train_simple_object_detector() will print out a lot of information to the screen while training.") + .def_readwrite("add_left_right_image_flips", &type::add_left_right_image_flips, +"if true, train_simple_object_detector() will assume the objects are \n\ +left/right symmetric and add in left right flips of the training \n\ +images. This doubles the size of the training dataset.") + .def_readwrite("detection_window_size", &type::detection_window_size, + "The sliding window used will have about this many pixels inside it.") + .def_readwrite("C", &type::C, +"C is the usual SVM C regularization parameter. So it is passed to \n\ +structural_object_detection_trainer::set_c(). Larger values of C \n\ +will encourage the trainer to fit the data better but might lead to \n\ +overfitting. Therefore, you must determine the proper setting of \n\ +this parameter experimentally.") + .def_readwrite("epsilon", &type::epsilon, +"epsilon is the stopping epsilon. Smaller values make the trainer's \n\ +solver more accurate but might take longer to train.") + .def_readwrite("num_threads", &type::num_threads, +"train_simple_object_detector() will use this many threads of \n\ +execution. Set this to the number of CPU cores on your machine to \n\ +obtain the fastest training speed.") + .def_readwrite("upsample_limit", &type::upsample_limit, +"train_simple_object_detector() will upsample images if needed \n\ +no more than upsample_limit times. Value 0 will forbid trainer to \n\ +upsample any images. If trainer is unable to fit all boxes with \n\ +required upsample_limit, exception will be thrown. Higher values \n\ +of upsample_limit exponentially increases memory requiremens. \n\ +Values higher than 2 (default) are not recommended."); + } + { + typedef simple_test_results type; + py::class_<type>(m, "simple_test_results") + .def_readwrite("precision", &type::precision) + .def_readwrite("recall", &type::recall) + .def_readwrite("average_precision", &type::average_precision) + .def("__str__", &::print_simple_test_results); + } + + // Here, kvals is actually the result of linspace(start, end, num) and it is different from kvals used + // in find_candidate_object_locations(). See dlib/image_transforms/segment_image_abstract.h for more details. + m.def("find_candidate_object_locations", find_candidate_object_locations_py, py::arg("image"), py::arg("rects"), py::arg("kvals")=py::make_tuple(50, 200, 3), py::arg("min_size")=20, py::arg("max_merging_iterations")=50, +"Returns found candidate objects\n\ +requires\n\ + - image == an image object which is a numpy ndarray\n\ + - len(kvals) == 3\n\ + - kvals should be a tuple that specifies the range of k values to use. In\n\ + particular, it should take the form (start, end, num) where num > 0. \n\ +ensures\n\ + - This function takes an input image and generates a set of candidate\n\ + rectangles which are expected to bound any objects in the image. It does\n\ + this by running a version of the segment_image() routine on the image and\n\ + then reports rectangles containing each of the segments as well as rectangles\n\ + containing unions of adjacent segments. The basic idea is described in the\n\ + paper: \n\ + Segmentation as Selective Search for Object Recognition by Koen E. A. van de Sande, et al.\n\ + Note that this function deviates from what is described in the paper slightly. \n\ + See the code for details.\n\ + - The basic segmentation is performed kvals[2] times, each time with the k parameter\n\ + (see segment_image() and the Felzenszwalb paper for details on k) set to a different\n\ + value from the range of numbers linearly spaced between kvals[0] to kvals[1].\n\ + - When doing the basic segmentations prior to any box merging, we discard all\n\ + rectangles that have an area < min_size. Therefore, all outputs and\n\ + subsequent merged rectangles are built out of rectangles that contain at\n\ + least min_size pixels. Note that setting min_size to a smaller value than\n\ + you might otherwise be interested in using can be useful since it allows a\n\ + larger number of possible merged boxes to be created.\n\ + - There are max_merging_iterations rounds of neighboring blob merging.\n\ + Therefore, this parameter has some effect on the number of output rectangles\n\ + you get, with larger values of the parameter giving more output rectangles.\n\ + - This function appends the output rectangles into #rects. This means that any\n\ + rectangles in rects before this function was called will still be in there\n\ + after it terminates. Note further that #rects will not contain any duplicate\n\ + rectangles. That is, for all valid i and j where i != j it will be true\n\ + that:\n\ + - #rects[i] != rects[j]"); + + m.def("get_frontal_face_detector", get_frontal_face_detector, + "Returns the default face detector"); + + m.def("train_simple_object_detector", train_simple_object_detector, + py::arg("dataset_filename"), py::arg("detector_output_filename"), py::arg("options"), +"requires \n\ + - options.C > 0 \n\ +ensures \n\ + - Uses the structural_object_detection_trainer to train a \n\ + simple_object_detector based on the labeled images in the XML file \n\ + dataset_filename. This function assumes the file dataset_filename is in the \n\ + XML format produced by dlib's save_image_dataset_metadata() routine. \n\ + - This function will apply a reasonable set of default parameters and \n\ + preprocessing techniques to the training procedure for simple_object_detector \n\ + objects. So the point of this function is to provide you with a very easy \n\ + way to train a basic object detector. \n\ + - The trained object detector is serialized to the file detector_output_filename."); + + m.def("train_simple_object_detector", train_simple_object_detector_on_images_py, + py::arg("images"), py::arg("boxes"), py::arg("options"), +"requires \n\ + - options.C > 0 \n\ + - len(images) == len(boxes) \n\ + - images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\ + - boxes should be a list of lists of dlib.rectangle object. \n\ +ensures \n\ + - Uses the structural_object_detection_trainer to train a \n\ + simple_object_detector based on the labeled images and bounding boxes. \n\ + - This function will apply a reasonable set of default parameters and \n\ + preprocessing techniques to the training procedure for simple_object_detector \n\ + objects. So the point of this function is to provide you with a very easy \n\ + way to train a basic object detector. \n\ + - The trained object detector is returned."); + + m.def("test_simple_object_detector", test_simple_object_detector, + // Please see test_simple_object_detector for the reason upsampling_amount is -1 + py::arg("dataset_filename"), py::arg("detector_filename"), py::arg("upsampling_amount")=-1, + "requires \n\ + - Optionally, take the number of times to upsample the testing images (upsampling_amount >= 0). \n\ + ensures \n\ + - Loads an image dataset from dataset_filename. We assume dataset_filename is \n\ + a file using the XML format written by save_image_dataset_metadata(). \n\ + - Loads a simple_object_detector from the file detector_filename. This means \n\ + detector_filename should be a file produced by the train_simple_object_detector() \n\ + routine. \n\ + - This function tests the detector against the dataset and returns the \n\ + precision, recall, and average precision of the detector. In fact, The \n\ + return value of this function is identical to that of dlib's \n\ + test_object_detection_function() routine. Therefore, see the documentation \n\ + for test_object_detection_function() for a detailed definition of these \n\ + metrics. " + ); + + m.def("test_simple_object_detector", test_simple_object_detector_with_images_py, + py::arg("images"), py::arg("boxes"), py::arg("detector"), py::arg("upsampling_amount")=0, + "requires \n\ + - len(images) == len(boxes) \n\ + - images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\ + - boxes should be a list of lists of dlib.rectangle object. \n\ + - Optionally, take the number of times to upsample the testing images (upsampling_amount >= 0). \n\ + ensures \n\ + - Loads a simple_object_detector from the file detector_filename. This means \n\ + detector_filename should be a file produced by the train_simple_object_detector() \n\ + routine. \n\ + - This function tests the detector against the dataset and returns the \n\ + precision, recall, and average precision of the detector. In fact, The \n\ + return value of this function is identical to that of dlib's \n\ + test_object_detection_function() routine. Therefore, see the documentation \n\ + for test_object_detection_function() for a detailed definition of these \n\ + metrics. " + ); + + m.def("test_simple_object_detector", test_simple_object_detector_py_with_images_py, + // Please see test_simple_object_detector_py_with_images_py for the reason upsampling_amount is -1 + py::arg("images"), py::arg("boxes"), py::arg("detector"), py::arg("upsampling_amount")=-1, + "requires \n\ + - len(images) == len(boxes) \n\ + - images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\ + - boxes should be a list of lists of dlib.rectangle object. \n\ + ensures \n\ + - Loads a simple_object_detector from the file detector_filename. This means \n\ + detector_filename should be a file produced by the train_simple_object_detector() \n\ + routine. \n\ + - This function tests the detector against the dataset and returns the \n\ + precision, recall, and average precision of the detector. In fact, The \n\ + return value of this function is identical to that of dlib's \n\ + test_object_detection_function() routine. Therefore, see the documentation \n\ + for test_object_detection_function() for a detailed definition of these \n\ + metrics. " + ); + { + typedef simple_object_detector type; + py::class_<type, std::shared_ptr<type>>(m, "fhog_object_detector", + "This object represents a sliding window histogram-of-oriented-gradients based object detector.") + .def(py::init(&load_object_from_file<type>), +"Loads an object detector from a file that contains the output of the \n\ +train_simple_object_detector() routine or a serialized C++ object of type\n\ +object_detector<scan_fhog_pyramid<pyramid_down<6>>>.") + .def("__call__", run_detector_with_upscale2, py::arg("image"), py::arg("upsample_num_times")=0, +"requires \n\ + - image is a numpy ndarray containing either an 8bit grayscale or RGB \n\ + image. \n\ + - upsample_num_times >= 0 \n\ +ensures \n\ + - This function runs the object detector on the input image and returns \n\ + a list of detections. \n\ + - Upsamples the image upsample_num_times before running the basic \n\ + detector.") + .def("run", run_rect_detector, py::arg("image"), py::arg("upsample_num_times")=0, py::arg("adjust_threshold")=0.0, +"requires \n\ + - image is a numpy ndarray containing either an 8bit grayscale or RGB \n\ + image. \n\ + - upsample_num_times >= 0 \n\ +ensures \n\ + - This function runs the object detector on the input image and returns \n\ + a tuple of (list of detections, list of scores, list of weight_indices). \n\ + - Upsamples the image upsample_num_times before running the basic \n\ + detector.") + .def_static("run_multiple", run_multiple_rect_detectors, py::arg("detectors"), py::arg("image"), py::arg("upsample_num_times")=0, py::arg("adjust_threshold")=0.0, +"requires \n\ + - detectors is a list of detectors. \n\ + - image is a numpy ndarray containing either an 8bit grayscale or RGB \n\ + image. \n\ + - upsample_num_times >= 0 \n\ +ensures \n\ + - This function runs the list of object detectors at once on the input image and returns \n\ + a tuple of (list of detections, list of scores, list of weight_indices). \n\ + - Upsamples the image upsample_num_times before running the basic \n\ + detector.") + .def("save", save_simple_object_detector, py::arg("detector_output_filename"), "Save a simple_object_detector to the provided path.") + .def(py::pickle(&getstate<type>, &setstate<type>)); + } + { + typedef simple_object_detector_py type; + py::class_<type, std::shared_ptr<type>>(m, "simple_object_detector", + "This object represents a sliding window histogram-of-oriented-gradients based object detector.") + .def(py::init(&load_object_from_file<type>), +"Loads a simple_object_detector from a file that contains the output of the \n\ +train_simple_object_detector() routine.") + .def("__call__", &type::run_detector1, py::arg("image"), py::arg("upsample_num_times"), +"requires \n\ + - image is a numpy ndarray containing either an 8bit grayscale or RGB \n\ + image. \n\ + - upsample_num_times >= 0 \n\ +ensures \n\ + - This function runs the object detector on the input image and returns \n\ + a list of detections. \n\ + - Upsamples the image upsample_num_times before running the basic \n\ + detector. If you don't know how many times you want to upsample then \n\ + don't provide a value for upsample_num_times and an appropriate \n\ + default will be used.") + .def("__call__", &type::run_detector2, py::arg("image"), +"requires \n\ + - image is a numpy ndarray containing either an 8bit grayscale or RGB \n\ + image. \n\ +ensures \n\ + - This function runs the object detector on the input image and returns \n\ + a list of detections.") + .def("save", save_simple_object_detector_py, py::arg("detector_output_filename"), "Save a simple_object_detector to the provided path.") + .def(py::pickle(&getstate<type>, &setstate<type>)); + } +} + +// ---------------------------------------------------------------------------------------- diff --git a/ml/dlib/tools/python/src/opaque_types.h b/ml/dlib/tools/python/src/opaque_types.h new file mode 100644 index 000000000..1a31c08df --- /dev/null +++ b/ml/dlib/tools/python/src/opaque_types.h @@ -0,0 +1,55 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PyTHON_OPAQUE_TYPES_H_ +#define DLIB_PyTHON_OPAQUE_TYPES_H_ + +#include <dlib/python.h> +#include <dlib/geometry.h> +#include <pybind11/stl_bind.h> +#include <vector> +#include <dlib/matrix.h> +#include <dlib/image_processing/full_object_detection.h> +#include <map> +#include <dlib/svm/ranking_tools.h> + +// All uses of PYBIND11_MAKE_OPAQUE need to be in this common header to avoid ODR +// violations. +PYBIND11_MAKE_OPAQUE(std::vector<dlib::rectangle>); +PYBIND11_MAKE_OPAQUE(std::vector<std::vector<dlib::rectangle>>); + +PYBIND11_MAKE_OPAQUE(std::vector<double>); + + +typedef std::vector<dlib::matrix<double,0,1>> column_vectors; +PYBIND11_MAKE_OPAQUE(column_vectors); +PYBIND11_MAKE_OPAQUE(std::vector<column_vectors>); + +typedef std::pair<unsigned long,unsigned long> ulong_pair; +PYBIND11_MAKE_OPAQUE(ulong_pair); +PYBIND11_MAKE_OPAQUE(std::vector<ulong_pair>); +PYBIND11_MAKE_OPAQUE(std::vector<std::vector<ulong_pair>>); + +typedef std::pair<unsigned long,double> ulong_double_pair; +PYBIND11_MAKE_OPAQUE(ulong_double_pair); +PYBIND11_MAKE_OPAQUE(std::vector<ulong_double_pair>); +PYBIND11_MAKE_OPAQUE(std::vector<std::vector<ulong_double_pair>>); +PYBIND11_MAKE_OPAQUE(std::vector<std::vector<std::vector<ulong_double_pair> > >); + +PYBIND11_MAKE_OPAQUE(std::vector<dlib::mmod_rect>); +PYBIND11_MAKE_OPAQUE(std::vector<std::vector<dlib::mmod_rect> >); +PYBIND11_MAKE_OPAQUE(std::vector<dlib::full_object_detection>); + +typedef std::map<std::string,dlib::point> parts_list_type; +PYBIND11_MAKE_OPAQUE(parts_list_type); + +typedef std::vector<dlib::ranking_pair<dlib::matrix<double,0,1>>> ranking_pairs; +typedef std::vector<std::pair<unsigned long,double> > sparse_vect; +typedef std::vector<dlib::ranking_pair<sparse_vect> > sparse_ranking_pairs; +PYBIND11_MAKE_OPAQUE(ranking_pairs); +PYBIND11_MAKE_OPAQUE(sparse_ranking_pairs); + + +PYBIND11_MAKE_OPAQUE(std::vector<dlib::point>); + +#endif // DLIB_PyTHON_OPAQUE_TYPES_H_ + diff --git a/ml/dlib/tools/python/src/other.cpp b/ml/dlib/tools/python/src/other.cpp new file mode 100644 index 000000000..3e0149022 --- /dev/null +++ b/ml/dlib/tools/python/src/other.cpp @@ -0,0 +1,268 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/matrix.h> +#include <dlib/data_io.h> +#include <dlib/sparse_vector.h> +#include <dlib/optimization.h> +#include <dlib/statistics/running_gradient.h> + +using namespace dlib; +using namespace std; +namespace py = pybind11; + +typedef std::vector<std::pair<unsigned long,double> > sparse_vect; + + +void _make_sparse_vector ( + sparse_vect& v +) +{ + make_sparse_vector_inplace(v); +} + +void _make_sparse_vector2 ( + std::vector<sparse_vect>& v +) +{ + for (unsigned long i = 0; i < v.size(); ++i) + make_sparse_vector_inplace(v[i]); +} + +py::tuple _load_libsvm_formatted_data( + const std::string& file_name +) +{ + std::vector<sparse_vect> samples; + std::vector<double> labels; + load_libsvm_formatted_data(file_name, samples, labels); + return py::make_tuple(samples, labels); +} + +void _save_libsvm_formatted_data ( + const std::string& file_name, + const std::vector<sparse_vect>& samples, + const std::vector<double>& labels +) +{ + pyassert(samples.size() == labels.size(), "Invalid inputs"); + save_libsvm_formatted_data(file_name, samples, labels); +} + +// ---------------------------------------------------------------------------------------- + +py::list _max_cost_assignment ( + const matrix<double>& cost +) +{ + if (cost.nr() != cost.nc()) + throw dlib::error("The input matrix must be square."); + + // max_cost_assignment() only works with integer matrices, so convert from + // double to integer. + const double scale = (std::numeric_limits<dlib::int64>::max()/1000)/max(abs(cost)); + matrix<dlib::int64> int_cost = matrix_cast<dlib::int64>(round(cost*scale)); + return vector_to_python_list(max_cost_assignment(int_cost)); +} + +double _assignment_cost ( + const matrix<double>& cost, + const py::list& assignment +) +{ + return assignment_cost(cost, python_list_to_vector<long>(assignment)); +} + +// ---------------------------------------------------------------------------------------- + +size_t py_count_steps_without_decrease ( + py::object arr, + double probability_of_decrease +) +{ + DLIB_CASSERT(0.5 < probability_of_decrease && probability_of_decrease < 1); + return count_steps_without_decrease(python_list_to_vector<double>(arr), probability_of_decrease); +} + +// ---------------------------------------------------------------------------------------- + +size_t py_count_steps_without_decrease_robust ( + py::object arr, + double probability_of_decrease, + double quantile_discard +) +{ + DLIB_CASSERT(0.5 < probability_of_decrease && probability_of_decrease < 1); + DLIB_CASSERT(0 <= quantile_discard && quantile_discard <= 1); + return count_steps_without_decrease_robust(python_list_to_vector<double>(arr), probability_of_decrease, quantile_discard); +} + +// ---------------------------------------------------------------------------------------- + +double probability_that_sequence_is_increasing ( + py::object arr +) +{ + DLIB_CASSERT(len(arr) > 2); + return probability_gradient_greater_than(python_list_to_vector<double>(arr), 0); +} + +// ---------------------------------------------------------------------------------------- + +void hit_enter_to_continue() +{ + std::cout << "Hit enter to continue"; + std::cin.get(); +} + +// ---------------------------------------------------------------------------------------- + +void bind_other(py::module &m) +{ + m.def("max_cost_assignment", _max_cost_assignment, py::arg("cost"), +"requires \n\ + - cost.nr() == cost.nc() \n\ + (i.e. the input must be a square matrix) \n\ +ensures \n\ + - Finds and returns the solution to the following optimization problem: \n\ + \n\ + Maximize: f(A) == assignment_cost(cost, A) \n\ + Subject to the following constraints: \n\ + - The elements of A are unique. That is, there aren't any \n\ + elements of A which are equal. \n\ + - len(A) == cost.nr() \n\ + \n\ + - Note that this function converts the input cost matrix into a 64bit fixed \n\ + point representation. Therefore, you should make sure that the values in \n\ + your cost matrix can be accurately represented by 64bit fixed point values. \n\ + If this is not the case then the solution my become inaccurate due to \n\ + rounding error. In general, this function will work properly when the ratio \n\ + of the largest to the smallest value in cost is no more than about 1e16. " + ); + + m.def("assignment_cost", _assignment_cost, py::arg("cost"),py::arg("assignment"), +"requires \n\ + - cost.nr() == cost.nc() \n\ + (i.e. the input must be a square matrix) \n\ + - for all valid i: \n\ + - 0 <= assignment[i] < cost.nr() \n\ +ensures \n\ + - Interprets cost as a cost assignment matrix. That is, cost[i][j] \n\ + represents the cost of assigning i to j. \n\ + - Interprets assignment as a particular set of assignments. That is, \n\ + i is assigned to assignment[i]. \n\ + - returns the cost of the given assignment. That is, returns \n\ + a number which is: \n\ + sum over i: cost[i][assignment[i]] " + ); + + m.def("make_sparse_vector", _make_sparse_vector , +"This function modifies its argument so that it is a properly sorted sparse vector. \n\ +This means that the elements of the sparse vector will be ordered so that pairs \n\ +with smaller indices come first. Additionally, there won't be any pairs with \n\ +identical indices. If such pairs were present in the input sparse vector then \n\ +their values will be added together and only one pair with their index will be \n\ +present in the output. " + ); + m.def("make_sparse_vector", _make_sparse_vector2 , + "This function modifies a sparse_vectors object so that all elements it contains are properly sorted sparse vectors."); + + m.def("load_libsvm_formatted_data",_load_libsvm_formatted_data, py::arg("file_name"), +"ensures \n\ + - Attempts to read a file of the given name that should contain libsvm \n\ + formatted data. The data is returned as a tuple where the first tuple \n\ + element is an array of sparse vectors and the second element is an array of \n\ + labels. " + ); + + m.def("save_libsvm_formatted_data",_save_libsvm_formatted_data, py::arg("file_name"), py::arg("samples"), py::arg("labels"), +"requires \n\ + - len(samples) == len(labels) \n\ +ensures \n\ + - saves the data to the given file in libsvm format " + ); + + m.def("hit_enter_to_continue", hit_enter_to_continue, + "Asks the user to hit enter to continue and pauses until they do so."); + + + + + m.def("count_steps_without_decrease",py_count_steps_without_decrease, py::arg("time_series"), py::arg("probability_of_decrease")=0.51, +"requires \n\ + - time_series must be a one dimensional array of real numbers. \n\ + - 0.5 < probability_of_decrease < 1 \n\ +ensures \n\ + - If you think of the contents of time_series as a potentially noisy time \n\ + series, then this function returns a count of how long the time series has \n\ + gone without noticeably decreasing in value. It does this by scanning along \n\ + the elements, starting from the end (i.e. time_series[-1]) to the beginning, \n\ + and checking how many elements you need to examine before you are confident \n\ + that the series has been decreasing in value. Here, \"confident of decrease\" \n\ + means the probability of decrease is >= probability_of_decrease. \n\ + - Setting probability_of_decrease to 0.51 means we count until we see even a \n\ + small hint of decrease, whereas a larger value of 0.99 would return a larger \n\ + count since it keeps going until it is nearly certain the time series is \n\ + decreasing. \n\ + - The max possible output from this function is len(time_series). \n\ + - The implementation of this function is done using the dlib::running_gradient \n\ + object, which is a tool that finds the least squares fit of a line to the \n\ + time series and the confidence interval around the slope of that line. That \n\ + can then be used in a simple statistical test to determine if the slope is \n\ + positive or negative." + /*! + requires + - time_series must be a one dimensional array of real numbers. + - 0.5 < probability_of_decrease < 1 + ensures + - If you think of the contents of time_series as a potentially noisy time + series, then this function returns a count of how long the time series has + gone without noticeably decreasing in value. It does this by scanning along + the elements, starting from the end (i.e. time_series[-1]) to the beginning, + and checking how many elements you need to examine before you are confident + that the series has been decreasing in value. Here, "confident of decrease" + means the probability of decrease is >= probability_of_decrease. + - Setting probability_of_decrease to 0.51 means we count until we see even a + small hint of decrease, whereas a larger value of 0.99 would return a larger + count since it keeps going until it is nearly certain the time series is + decreasing. + - The max possible output from this function is len(time_series). + - The implementation of this function is done using the dlib::running_gradient + object, which is a tool that finds the least squares fit of a line to the + time series and the confidence interval around the slope of that line. That + can then be used in a simple statistical test to determine if the slope is + positive or negative. + !*/ + ); + + m.def("count_steps_without_decrease_robust",py_count_steps_without_decrease_robust, py::arg("time_series"), py::arg("probability_of_decrease")=0.51, py::arg("quantile_discard")=0.1, +"requires \n\ + - time_series must be a one dimensional array of real numbers. \n\ + - 0.5 < probability_of_decrease < 1 \n\ + - 0 <= quantile_discard <= 1 \n\ +ensures \n\ + - This function behaves just like \n\ + count_steps_without_decrease(time_series,probability_of_decrease) except that \n\ + it ignores values in the time series that are in the upper quantile_discard \n\ + quantile. So for example, if the quantile discard is 0.1 then the 10% \n\ + largest values in the time series are ignored." + /*! + requires + - time_series must be a one dimensional array of real numbers. + - 0.5 < probability_of_decrease < 1 + - 0 <= quantile_discard <= 1 + ensures + - This function behaves just like + count_steps_without_decrease(time_series,probability_of_decrease) except that + it ignores values in the time series that are in the upper quantile_discard + quantile. So for example, if the quantile discard is 0.1 then the 10% + largest values in the time series are ignored. + !*/ + ); + + m.def("probability_that_sequence_is_increasing",probability_that_sequence_is_increasing, py::arg("time_series"), + "returns the probability that the given sequence of real numbers is increasing in value over time."); +} + diff --git a/ml/dlib/tools/python/src/rectangles.cpp b/ml/dlib/tools/python/src/rectangles.cpp new file mode 100644 index 000000000..d06ec591b --- /dev/null +++ b/ml/dlib/tools/python/src/rectangles.cpp @@ -0,0 +1,268 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include <dlib/python.h> +#include <dlib/geometry.h> +#include <pybind11/stl_bind.h> +#include "indexing.h" +#include "opaque_types.h" +#include <dlib/filtering.h> + +using namespace dlib; +using namespace std; + +namespace py = pybind11; + + +// ---------------------------------------------------------------------------------------- + +long left(const rectangle& r) { return r.left(); } +long top(const rectangle& r) { return r.top(); } +long right(const rectangle& r) { return r.right(); } +long bottom(const rectangle& r) { return r.bottom(); } +long width(const rectangle& r) { return r.width(); } +long height(const rectangle& r) { return r.height(); } +unsigned long area(const rectangle& r) { return r.area(); } + +double dleft(const drectangle& r) { return r.left(); } +double dtop(const drectangle& r) { return r.top(); } +double dright(const drectangle& r) { return r.right(); } +double dbottom(const drectangle& r) { return r.bottom(); } +double dwidth(const drectangle& r) { return r.width(); } +double dheight(const drectangle& r) { return r.height(); } +double darea(const drectangle& r) { return r.area(); } + +template <typename rect_type> +bool is_empty(const rect_type& r) { return r.is_empty(); } + +template <typename rect_type> +point center(const rect_type& r) { return center(r); } + +template <typename rect_type> +point dcenter(const rect_type& r) { return dcenter(r); } + +template <typename rect_type> +bool contains(const rect_type& r, const point& p) { return r.contains(p); } + +template <typename rect_type> +bool contains_xy(const rect_type& r, const long x, const long y) { return r.contains(point(x, y)); } + +template <typename rect_type> +bool contains_rec(const rect_type& r, const rect_type& r2) { return r.contains(r2); } + +template <typename rect_type> +rect_type intersect(const rect_type& r, const rect_type& r2) { return r.intersect(r2); } + +template <typename rect_type> +string print_rectangle_str(const rect_type& r) +{ + std::ostringstream sout; + sout << r; + return sout.str(); +} + +string print_rectangle_repr(const rectangle& r) +{ + std::ostringstream sout; + sout << "rectangle(" << r.left() << "," << r.top() << "," << r.right() << "," << r.bottom() << ")"; + return sout.str(); +} + +string print_drectangle_repr(const drectangle& r) +{ + std::ostringstream sout; + sout << "drectangle(" << r.left() << "," << r.top() << "," << r.right() << "," << r.bottom() << ")"; + return sout.str(); +} + +string print_rect_filter(const rect_filter& r) +{ + std::ostringstream sout; + sout << "rect_filter("; + sout << "measurement_noise="<<r.get_left().get_measurement_noise(); + sout << ", typical_acceleration="<<r.get_left().get_typical_acceleration(); + sout << ", max_measurement_deviation="<<r.get_left().get_max_measurement_deviation(); + sout << ")"; + return sout.str(); +} + + +rectangle add_point_to_rect(const rectangle& r, const point& p) +{ + return r + p; +} + +rectangle add_rect_to_rect(const rectangle& r, const rectangle& p) +{ + return r + p; +} + +rectangle& iadd_point_to_rect(rectangle& r, const point& p) +{ + r += p; + return r; +} + +rectangle& iadd_rect_to_rect(rectangle& r, const rectangle& p) +{ + r += p; + return r; +} + + + +// ---------------------------------------------------------------------------------------- + +void bind_rectangles(py::module& m) +{ + { + typedef rectangle type; + py::class_<type>(m, "rectangle", "This object represents a rectangular area of an image.") + .def(py::init<long,long,long,long>(), py::arg("left"),py::arg("top"),py::arg("right"),py::arg("bottom")) + .def(py::init()) + .def("area", &::area) + .def("left", &::left) + .def("top", &::top) + .def("right", &::right) + .def("bottom", &::bottom) + .def("width", &::width) + .def("height", &::height) + .def("is_empty", &::is_empty<type>) + .def("center", &::center<type>) + .def("dcenter", &::dcenter<type>) + .def("contains", &::contains<type>, py::arg("point")) + .def("contains", &::contains_xy<type>, py::arg("x"), py::arg("y")) + .def("contains", &::contains_rec<type>, py::arg("rectangle")) + .def("intersect", &::intersect<type>, py::arg("rectangle")) + .def("__str__", &::print_rectangle_str<type>) + .def("__repr__", &::print_rectangle_repr) + .def("__add__", &::add_point_to_rect) + .def("__add__", &::add_rect_to_rect) + .def("__iadd__", &::iadd_point_to_rect) + .def("__iadd__", &::iadd_rect_to_rect) + .def(py::self == py::self) + .def(py::self != py::self) + .def(py::pickle(&getstate<type>, &setstate<type>)); + } + { + typedef drectangle type; + py::class_<type>(m, "drectangle", "This object represents a rectangular area of an image with floating point coordinates.") + .def(py::init<double,double,double,double>(), py::arg("left"), py::arg("top"), py::arg("right"), py::arg("bottom")) + .def("area", &::darea) + .def("left", &::dleft) + .def("top", &::dtop) + .def("right", &::dright) + .def("bottom", &::dbottom) + .def("width", &::dwidth) + .def("height", &::dheight) + .def("is_empty", &::is_empty<type>) + .def("center", &::center<type>) + .def("dcenter", &::dcenter<type>) + .def("contains", &::contains<type>, py::arg("point")) + .def("contains", &::contains_xy<type>, py::arg("x"), py::arg("y")) + .def("contains", &::contains_rec<type>, py::arg("rectangle")) + .def("intersect", &::intersect<type>, py::arg("rectangle")) + .def("__str__", &::print_rectangle_str<type>) + .def("__repr__", &::print_drectangle_repr) + .def(py::self == py::self) + .def(py::self != py::self) + .def(py::pickle(&getstate<type>, &setstate<type>)); + } + + { + typedef rect_filter type; + py::class_<type>(m, "rect_filter", + R"asdf( + This object is a simple tool for filtering a rectangle that + measures the location of a moving object that has some non-trivial + momentum. Importantly, the measurements are noisy and the object can + experience sudden unpredictable accelerations. To accomplish this + filtering we use a simple Kalman filter with a state transition model of: + + position_{i+1} = position_{i} + velocity_{i} + velocity_{i+1} = velocity_{i} + some_unpredictable_acceleration + + and a measurement model of: + + measured_position_{i} = position_{i} + measurement_noise + + Where some_unpredictable_acceleration and measurement_noise are 0 mean Gaussian + noise sources with standard deviations of typical_acceleration and + measurement_noise respectively. + + To allow for really sudden and large but infrequent accelerations, at each + step we check if the current measured position deviates from the predicted + filtered position by more than max_measurement_deviation*measurement_noise + and if so we adjust the filter's state to keep it within these bounds. + This allows the moving object to undergo large unmodeled accelerations, far + in excess of what would be suggested by typical_acceleration, without + then experiencing a long lag time where the Kalman filter has to "catches + up" to the new position. )asdf" + ) + .def(py::init<double,double,double>(), py::arg("measurement_noise"), py::arg("typical_acceleration"), py::arg("max_measurement_deviation")) + .def("measurement_noise", [](const rect_filter& a){return a.get_left().get_measurement_noise();}) + .def("typical_acceleration", [](const rect_filter& a){return a.get_left().get_typical_acceleration();}) + .def("max_measurement_deviation", [](const rect_filter& a){return a.get_left().get_max_measurement_deviation();}) + .def("__call__", [](rect_filter& f, const dlib::rectangle& r){return rectangle(f(r)); }, py::arg("rect")) + .def("__repr__", print_rect_filter) + .def(py::pickle(&getstate<type>, &setstate<type>)); + } + + m.def("find_optimal_rect_filter", + [](const std::vector<rectangle>& rects, const double smoothness ) { return find_optimal_rect_filter(rects, smoothness); }, + py::arg("rects"), + py::arg("smoothness")=1, +"requires \n\ + - rects.size() > 4 \n\ + - smoothness >= 0 \n\ +ensures \n\ + - This function finds the \"optimal\" settings of a rect_filter based on recorded \n\ + measurement data stored in rects. Here we assume that rects is a complete \n\ + track history of some object's measured positions. Essentially, what we do \n\ + is find the rect_filter that minimizes the following objective function: \n\ + sum of abs(predicted_location[i] - measured_location[i]) + smoothness*abs(filtered_location[i]-filtered_location[i-1]) \n\ + Where i is a time index. \n\ + The sum runs over all the data in rects. So what we do is find the \n\ + filter settings that produce smooth filtered trajectories but also produce \n\ + filtered outputs that are as close to the measured positions as possible. \n\ + The larger the value of smoothness the less jittery the filter outputs will \n\ + be, but they might become biased or laggy if smoothness is set really high. " + /*! + requires + - rects.size() > 4 + - smoothness >= 0 + ensures + - This function finds the "optimal" settings of a rect_filter based on recorded + measurement data stored in rects. Here we assume that rects is a complete + track history of some object's measured positions. Essentially, what we do + is find the rect_filter that minimizes the following objective function: + sum of abs(predicted_location[i] - measured_location[i]) + smoothness*abs(filtered_location[i]-filtered_location[i-1]) + Where i is a time index. + The sum runs over all the data in rects. So what we do is find the + filter settings that produce smooth filtered trajectories but also produce + filtered outputs that are as close to the measured positions as possible. + The larger the value of smoothness the less jittery the filter outputs will + be, but they might become biased or laggy if smoothness is set really high. + !*/ + ); + + { + typedef std::vector<rectangle> type; + py::bind_vector<type>(m, "rectangles", "An array of rectangle objects.") + .def("clear", &type::clear) + .def("resize", resize<type>) + .def("extend", extend_vector_with_python_list<rectangle>) + .def(py::pickle(&getstate<type>, &setstate<type>)); + } + + { + typedef std::vector<std::vector<rectangle>> type; + py::bind_vector<type>(m, "rectangless", "An array of arrays of rectangle objects.") + .def("clear", &type::clear) + .def("resize", resize<type>) + .def("extend", extend_vector_with_python_list<rectangle>) + .def(py::pickle(&getstate<type>, &setstate<type>)); + } +} + +// ---------------------------------------------------------------------------------------- diff --git a/ml/dlib/tools/python/src/sequence_segmenter.cpp b/ml/dlib/tools/python/src/sequence_segmenter.cpp new file mode 100644 index 000000000..9fde1e771 --- /dev/null +++ b/ml/dlib/tools/python/src/sequence_segmenter.cpp @@ -0,0 +1,827 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/matrix.h> +#include <dlib/svm_threaded.h> + +using namespace dlib; +using namespace std; +namespace py = pybind11; + +typedef matrix<double,0,1> dense_vect; +typedef std::vector<std::pair<unsigned long,double> > sparse_vect; +typedef std::vector<std::pair<unsigned long, unsigned long> > ranges; + +// ---------------------------------------------------------------------------------------- + +template <typename samp_type, bool BIO, bool high_order, bool nonnegative> +class segmenter_feature_extractor +{ + +public: + typedef std::vector<samp_type> sequence_type; + const static bool use_BIO_model = BIO; + const static bool use_high_order_features = high_order; + const static bool allow_negative_weights = nonnegative; + + + unsigned long _num_features; + unsigned long _window_size; + + segmenter_feature_extractor( + ) : _num_features(1), _window_size(1) {} + + segmenter_feature_extractor( + unsigned long _num_features_, + unsigned long _window_size_ + ) : _num_features(_num_features_), _window_size(_window_size_) {} + + unsigned long num_features( + ) const { return _num_features; } + + unsigned long window_size( + ) const {return _window_size; } + + template <typename feature_setter> + void get_features ( + feature_setter& set_feature, + const std::vector<dense_vect>& x, + unsigned long position + ) const + { + for (long i = 0; i < x[position].size(); ++i) + { + set_feature(i, x[position](i)); + } + } + + template <typename feature_setter> + void get_features ( + feature_setter& set_feature, + const std::vector<sparse_vect>& x, + unsigned long position + ) const + { + for (unsigned long i = 0; i < x[position].size(); ++i) + { + set_feature(x[position][i].first, x[position][i].second); + } + } + + friend void serialize(const segmenter_feature_extractor& item, std::ostream& out) + { + dlib::serialize(item._num_features, out); + dlib::serialize(item._window_size, out); + } + friend void deserialize(segmenter_feature_extractor& item, std::istream& in) + { + dlib::deserialize(item._num_features, in); + dlib::deserialize(item._window_size, in); + } +}; + +// ---------------------------------------------------------------------------------------- + +struct segmenter_type +{ + /*! + WHAT THIS OBJECT REPRESENTS + This the object that python will use directly to represent a + sequence_segmenter. All it does is contain all the possible template + instantiations of a sequence_segmenter and invoke the right one depending on + the mode variable. + !*/ + + segmenter_type() : mode(-1) + { } + + ranges segment_sequence_dense ( + const std::vector<dense_vect>& x + ) const + { + switch (mode) + { + case 0: return segmenter0(x); + case 1: return segmenter1(x); + case 2: return segmenter2(x); + case 3: return segmenter3(x); + case 4: return segmenter4(x); + case 5: return segmenter5(x); + case 6: return segmenter6(x); + case 7: return segmenter7(x); + default: throw dlib::error("Invalid mode"); + } + } + + ranges segment_sequence_sparse ( + const std::vector<sparse_vect>& x + ) const + { + switch (mode) + { + case 8: return segmenter8(x); + case 9: return segmenter9(x); + case 10: return segmenter10(x); + case 11: return segmenter11(x); + case 12: return segmenter12(x); + case 13: return segmenter13(x); + case 14: return segmenter14(x); + case 15: return segmenter15(x); + default: throw dlib::error("Invalid mode"); + } + } + + const matrix<double,0,1> get_weights() + { + switch(mode) + { + case 0: return segmenter0.get_weights(); + case 1: return segmenter1.get_weights(); + case 2: return segmenter2.get_weights(); + case 3: return segmenter3.get_weights(); + case 4: return segmenter4.get_weights(); + case 5: return segmenter5.get_weights(); + case 6: return segmenter6.get_weights(); + case 7: return segmenter7.get_weights(); + + case 8: return segmenter8.get_weights(); + case 9: return segmenter9.get_weights(); + case 10: return segmenter10.get_weights(); + case 11: return segmenter11.get_weights(); + case 12: return segmenter12.get_weights(); + case 13: return segmenter13.get_weights(); + case 14: return segmenter14.get_weights(); + case 15: return segmenter15.get_weights(); + + default: throw dlib::error("Invalid mode"); + } + } + + friend void serialize (const segmenter_type& item, std::ostream& out) + { + serialize(item.mode, out); + switch(item.mode) + { + case 0: serialize(item.segmenter0, out); break; + case 1: serialize(item.segmenter1, out); break; + case 2: serialize(item.segmenter2, out); break; + case 3: serialize(item.segmenter3, out); break; + case 4: serialize(item.segmenter4, out); break; + case 5: serialize(item.segmenter5, out); break; + case 6: serialize(item.segmenter6, out); break; + case 7: serialize(item.segmenter7, out); break; + + case 8: serialize(item.segmenter8, out); break; + case 9: serialize(item.segmenter9, out); break; + case 10: serialize(item.segmenter10, out); break; + case 11: serialize(item.segmenter11, out); break; + case 12: serialize(item.segmenter12, out); break; + case 13: serialize(item.segmenter13, out); break; + case 14: serialize(item.segmenter14, out); break; + case 15: serialize(item.segmenter15, out); break; + default: throw dlib::error("Invalid mode"); + } + } + friend void deserialize (segmenter_type& item, std::istream& in) + { + deserialize(item.mode, in); + switch(item.mode) + { + case 0: deserialize(item.segmenter0, in); break; + case 1: deserialize(item.segmenter1, in); break; + case 2: deserialize(item.segmenter2, in); break; + case 3: deserialize(item.segmenter3, in); break; + case 4: deserialize(item.segmenter4, in); break; + case 5: deserialize(item.segmenter5, in); break; + case 6: deserialize(item.segmenter6, in); break; + case 7: deserialize(item.segmenter7, in); break; + + case 8: deserialize(item.segmenter8, in); break; + case 9: deserialize(item.segmenter9, in); break; + case 10: deserialize(item.segmenter10, in); break; + case 11: deserialize(item.segmenter11, in); break; + case 12: deserialize(item.segmenter12, in); break; + case 13: deserialize(item.segmenter13, in); break; + case 14: deserialize(item.segmenter14, in); break; + case 15: deserialize(item.segmenter15, in); break; + default: throw dlib::error("Invalid mode"); + } + } + + int mode; + + typedef segmenter_feature_extractor<dense_vect, false,false,false> fe0; + typedef segmenter_feature_extractor<dense_vect, false,false,true> fe1; + typedef segmenter_feature_extractor<dense_vect, false,true, false> fe2; + typedef segmenter_feature_extractor<dense_vect, false,true, true> fe3; + typedef segmenter_feature_extractor<dense_vect, true, false,false> fe4; + typedef segmenter_feature_extractor<dense_vect, true, false,true> fe5; + typedef segmenter_feature_extractor<dense_vect, true, true, false> fe6; + typedef segmenter_feature_extractor<dense_vect, true, true, true> fe7; + sequence_segmenter<fe0> segmenter0; + sequence_segmenter<fe1> segmenter1; + sequence_segmenter<fe2> segmenter2; + sequence_segmenter<fe3> segmenter3; + sequence_segmenter<fe4> segmenter4; + sequence_segmenter<fe5> segmenter5; + sequence_segmenter<fe6> segmenter6; + sequence_segmenter<fe7> segmenter7; + + typedef segmenter_feature_extractor<sparse_vect, false,false,false> fe8; + typedef segmenter_feature_extractor<sparse_vect, false,false,true> fe9; + typedef segmenter_feature_extractor<sparse_vect, false,true, false> fe10; + typedef segmenter_feature_extractor<sparse_vect, false,true, true> fe11; + typedef segmenter_feature_extractor<sparse_vect, true, false,false> fe12; + typedef segmenter_feature_extractor<sparse_vect, true, false,true> fe13; + typedef segmenter_feature_extractor<sparse_vect, true, true, false> fe14; + typedef segmenter_feature_extractor<sparse_vect, true, true, true> fe15; + sequence_segmenter<fe8> segmenter8; + sequence_segmenter<fe9> segmenter9; + sequence_segmenter<fe10> segmenter10; + sequence_segmenter<fe11> segmenter11; + sequence_segmenter<fe12> segmenter12; + sequence_segmenter<fe13> segmenter13; + sequence_segmenter<fe14> segmenter14; + sequence_segmenter<fe15> segmenter15; +}; + + +// ---------------------------------------------------------------------------------------- + +struct segmenter_params +{ + segmenter_params() + { + use_BIO_model = true; + use_high_order_features = true; + allow_negative_weights = true; + window_size = 5; + num_threads = 4; + epsilon = 0.1; + max_cache_size = 40; + be_verbose = false; + C = 100; + } + + bool use_BIO_model; + bool use_high_order_features; + bool allow_negative_weights; + unsigned long window_size; + unsigned long num_threads; + double epsilon; + unsigned long max_cache_size; + bool be_verbose; + double C; +}; + + +string segmenter_params__str__(const segmenter_params& p) +{ + ostringstream sout; + if (p.use_BIO_model) + sout << "BIO,"; + else + sout << "BILOU,"; + + if (p.use_high_order_features) + sout << "highFeats,"; + else + sout << "lowFeats,"; + + if (p.allow_negative_weights) + sout << "signed,"; + else + sout << "non-negative,"; + + sout << "win="<<p.window_size << ","; + sout << "threads="<<p.num_threads << ","; + sout << "eps="<<p.epsilon << ","; + sout << "cache="<<p.max_cache_size << ","; + if (p.be_verbose) + sout << "verbose,"; + else + sout << "non-verbose,"; + sout << "C="<<p.C; + return trim(sout.str()); +} + +string segmenter_params__repr__(const segmenter_params& p) +{ + ostringstream sout; + sout << "<"; + sout << segmenter_params__str__(p); + sout << ">"; + return sout.str(); +} + +void serialize ( const segmenter_params& item, std::ostream& out) +{ + serialize(item.use_BIO_model, out); + serialize(item.use_high_order_features, out); + serialize(item.allow_negative_weights, out); + serialize(item.window_size, out); + serialize(item.num_threads, out); + serialize(item.epsilon, out); + serialize(item.max_cache_size, out); + serialize(item.be_verbose, out); + serialize(item.C, out); +} + +void deserialize (segmenter_params& item, std::istream& in) +{ + deserialize(item.use_BIO_model, in); + deserialize(item.use_high_order_features, in); + deserialize(item.allow_negative_weights, in); + deserialize(item.window_size, in); + deserialize(item.num_threads, in); + deserialize(item.epsilon, in); + deserialize(item.max_cache_size, in); + deserialize(item.be_verbose, in); + deserialize(item.C, in); +} + +// ---------------------------------------------------------------------------------------- + +template <typename T> +void configure_trainer ( + const std::vector<std::vector<dense_vect> >& samples, + structural_sequence_segmentation_trainer<T>& trainer, + const segmenter_params& params +) +{ + pyassert(samples.size() != 0, "Invalid arguments. You must give some training sequences."); + pyassert(samples[0].size() != 0, "Invalid arguments. You can't have zero length training sequences."); + pyassert(params.window_size != 0, "Invalid window_size parameter, it must be > 0."); + pyassert(params.epsilon > 0, "Invalid epsilon parameter, it must be > 0."); + pyassert(params.C > 0, "Invalid C parameter, it must be > 0."); + const long dims = samples[0][0].size(); + + trainer = structural_sequence_segmentation_trainer<T>(T(dims, params.window_size)); + trainer.set_num_threads(params.num_threads); + trainer.set_epsilon(params.epsilon); + trainer.set_max_cache_size(params.max_cache_size); + trainer.set_c(params.C); + if (params.be_verbose) + trainer.be_verbose(); +} + +// ---------------------------------------------------------------------------------------- + +template <typename T> +void configure_trainer ( + const std::vector<std::vector<sparse_vect> >& samples, + structural_sequence_segmentation_trainer<T>& trainer, + const segmenter_params& params +) +{ + pyassert(samples.size() != 0, "Invalid arguments. You must give some training sequences."); + pyassert(samples[0].size() != 0, "Invalid arguments. You can't have zero length training sequences."); + + unsigned long dims = 0; + for (unsigned long i = 0; i < samples.size(); ++i) + { + dims = std::max(dims, max_index_plus_one(samples[i])); + } + + trainer = structural_sequence_segmentation_trainer<T>(T(dims, params.window_size)); + trainer.set_num_threads(params.num_threads); + trainer.set_epsilon(params.epsilon); + trainer.set_max_cache_size(params.max_cache_size); + trainer.set_c(params.C); + if (params.be_verbose) + trainer.be_verbose(); +} + +// ---------------------------------------------------------------------------------------- + +segmenter_type train_dense ( + const std::vector<std::vector<dense_vect> >& samples, + const std::vector<ranges>& segments, + segmenter_params params +) +{ + pyassert(is_sequence_segmentation_problem(samples, segments), "Invalid inputs"); + + int mode = 0; + if (params.use_BIO_model) + mode = mode*2 + 1; + else + mode = mode*2; + if (params.use_high_order_features) + mode = mode*2 + 1; + else + mode = mode*2; + if (params.allow_negative_weights) + mode = mode*2 + 1; + else + mode = mode*2; + + + segmenter_type res; + res.mode = mode; + switch(mode) + { + case 0: { structural_sequence_segmentation_trainer<segmenter_type::fe0> trainer; + configure_trainer(samples, trainer, params); + res.segmenter0 = trainer.train(samples, segments); + } break; + case 1: { structural_sequence_segmentation_trainer<segmenter_type::fe1> trainer; + configure_trainer(samples, trainer, params); + res.segmenter1 = trainer.train(samples, segments); + } break; + case 2: { structural_sequence_segmentation_trainer<segmenter_type::fe2> trainer; + configure_trainer(samples, trainer, params); + res.segmenter2 = trainer.train(samples, segments); + } break; + case 3: { structural_sequence_segmentation_trainer<segmenter_type::fe3> trainer; + configure_trainer(samples, trainer, params); + res.segmenter3 = trainer.train(samples, segments); + } break; + case 4: { structural_sequence_segmentation_trainer<segmenter_type::fe4> trainer; + configure_trainer(samples, trainer, params); + res.segmenter4 = trainer.train(samples, segments); + } break; + case 5: { structural_sequence_segmentation_trainer<segmenter_type::fe5> trainer; + configure_trainer(samples, trainer, params); + res.segmenter5 = trainer.train(samples, segments); + } break; + case 6: { structural_sequence_segmentation_trainer<segmenter_type::fe6> trainer; + configure_trainer(samples, trainer, params); + res.segmenter6 = trainer.train(samples, segments); + } break; + case 7: { structural_sequence_segmentation_trainer<segmenter_type::fe7> trainer; + configure_trainer(samples, trainer, params); + res.segmenter7 = trainer.train(samples, segments); + } break; + default: throw dlib::error("Invalid mode"); + } + + + return res; +} + +// ---------------------------------------------------------------------------------------- + +segmenter_type train_sparse ( + const std::vector<std::vector<sparse_vect> >& samples, + const std::vector<ranges>& segments, + segmenter_params params +) +{ + pyassert(is_sequence_segmentation_problem(samples, segments), "Invalid inputs"); + + int mode = 0; + if (params.use_BIO_model) + mode = mode*2 + 1; + else + mode = mode*2; + if (params.use_high_order_features) + mode = mode*2 + 1; + else + mode = mode*2; + if (params.allow_negative_weights) + mode = mode*2 + 1; + else + mode = mode*2; + + mode += 8; + + segmenter_type res; + res.mode = mode; + switch(mode) + { + case 8: { structural_sequence_segmentation_trainer<segmenter_type::fe8> trainer; + configure_trainer(samples, trainer, params); + res.segmenter8 = trainer.train(samples, segments); + } break; + case 9: { structural_sequence_segmentation_trainer<segmenter_type::fe9> trainer; + configure_trainer(samples, trainer, params); + res.segmenter9 = trainer.train(samples, segments); + } break; + case 10: { structural_sequence_segmentation_trainer<segmenter_type::fe10> trainer; + configure_trainer(samples, trainer, params); + res.segmenter10 = trainer.train(samples, segments); + } break; + case 11: { structural_sequence_segmentation_trainer<segmenter_type::fe11> trainer; + configure_trainer(samples, trainer, params); + res.segmenter11 = trainer.train(samples, segments); + } break; + case 12: { structural_sequence_segmentation_trainer<segmenter_type::fe12> trainer; + configure_trainer(samples, trainer, params); + res.segmenter12 = trainer.train(samples, segments); + } break; + case 13: { structural_sequence_segmentation_trainer<segmenter_type::fe13> trainer; + configure_trainer(samples, trainer, params); + res.segmenter13 = trainer.train(samples, segments); + } break; + case 14: { structural_sequence_segmentation_trainer<segmenter_type::fe14> trainer; + configure_trainer(samples, trainer, params); + res.segmenter14 = trainer.train(samples, segments); + } break; + case 15: { structural_sequence_segmentation_trainer<segmenter_type::fe15> trainer; + configure_trainer(samples, trainer, params); + res.segmenter15 = trainer.train(samples, segments); + } break; + default: throw dlib::error("Invalid mode"); + } + + + return res; +} + +// ---------------------------------------------------------------------------------------- + + +struct segmenter_test +{ + double precision; + double recall; + double f1; +}; + +void serialize(const segmenter_test& item, std::ostream& out) +{ + serialize(item.precision, out); + serialize(item.recall, out); + serialize(item.f1, out); +} + +void deserialize(segmenter_test& item, std::istream& in) +{ + deserialize(item.precision, in); + deserialize(item.recall, in); + deserialize(item.f1, in); +} + +std::string segmenter_test__str__(const segmenter_test& item) +{ + std::ostringstream sout; + sout << "precision: "<< item.precision << " recall: "<< item.recall << " f1-score: " << item.f1; + return sout.str(); +} +std::string segmenter_test__repr__(const segmenter_test& item) { return "< " + segmenter_test__str__(item) + " >";} + +// ---------------------------------------------------------------------------------------- + +const segmenter_test test_sequence_segmenter1 ( + const segmenter_type& segmenter, + const std::vector<std::vector<dense_vect> >& samples, + const std::vector<ranges>& segments +) +{ + pyassert(is_sequence_segmentation_problem(samples, segments), "Invalid inputs"); + matrix<double,1,3> res; + + switch(segmenter.mode) + { + case 0: res = test_sequence_segmenter(segmenter.segmenter0, samples, segments); break; + case 1: res = test_sequence_segmenter(segmenter.segmenter1, samples, segments); break; + case 2: res = test_sequence_segmenter(segmenter.segmenter2, samples, segments); break; + case 3: res = test_sequence_segmenter(segmenter.segmenter3, samples, segments); break; + case 4: res = test_sequence_segmenter(segmenter.segmenter4, samples, segments); break; + case 5: res = test_sequence_segmenter(segmenter.segmenter5, samples, segments); break; + case 6: res = test_sequence_segmenter(segmenter.segmenter6, samples, segments); break; + case 7: res = test_sequence_segmenter(segmenter.segmenter7, samples, segments); break; + default: throw dlib::error("Invalid mode"); + } + + + segmenter_test temp; + temp.precision = res(0); + temp.recall = res(1); + temp.f1 = res(2); + return temp; +} + +const segmenter_test test_sequence_segmenter2 ( + const segmenter_type& segmenter, + const std::vector<std::vector<sparse_vect> >& samples, + const std::vector<ranges>& segments +) +{ + pyassert(is_sequence_segmentation_problem(samples, segments), "Invalid inputs"); + matrix<double,1,3> res; + + switch(segmenter.mode) + { + case 8: res = test_sequence_segmenter(segmenter.segmenter8, samples, segments); break; + case 9: res = test_sequence_segmenter(segmenter.segmenter9, samples, segments); break; + case 10: res = test_sequence_segmenter(segmenter.segmenter10, samples, segments); break; + case 11: res = test_sequence_segmenter(segmenter.segmenter11, samples, segments); break; + case 12: res = test_sequence_segmenter(segmenter.segmenter12, samples, segments); break; + case 13: res = test_sequence_segmenter(segmenter.segmenter13, samples, segments); break; + case 14: res = test_sequence_segmenter(segmenter.segmenter14, samples, segments); break; + case 15: res = test_sequence_segmenter(segmenter.segmenter15, samples, segments); break; + default: throw dlib::error("Invalid mode"); + } + + + segmenter_test temp; + temp.precision = res(0); + temp.recall = res(1); + temp.f1 = res(2); + return temp; +} + +// ---------------------------------------------------------------------------------------- + +const segmenter_test cross_validate_sequence_segmenter1 ( + const std::vector<std::vector<dense_vect> >& samples, + const std::vector<ranges>& segments, + long folds, + segmenter_params params +) +{ + pyassert(is_sequence_segmentation_problem(samples, segments), "Invalid inputs"); + pyassert(1 < folds && folds <= static_cast<long>(samples.size()), "folds argument is outside the valid range."); + + matrix<double,1,3> res; + + int mode = 0; + if (params.use_BIO_model) + mode = mode*2 + 1; + else + mode = mode*2; + if (params.use_high_order_features) + mode = mode*2 + 1; + else + mode = mode*2; + if (params.allow_negative_weights) + mode = mode*2 + 1; + else + mode = mode*2; + + + switch(mode) + { + case 0: { structural_sequence_segmentation_trainer<segmenter_type::fe0> trainer; + configure_trainer(samples, trainer, params); + res = cross_validate_sequence_segmenter(trainer, samples, segments, folds); + } break; + case 1: { structural_sequence_segmentation_trainer<segmenter_type::fe1> trainer; + configure_trainer(samples, trainer, params); + res = cross_validate_sequence_segmenter(trainer, samples, segments, folds); + } break; + case 2: { structural_sequence_segmentation_trainer<segmenter_type::fe2> trainer; + configure_trainer(samples, trainer, params); + res = cross_validate_sequence_segmenter(trainer, samples, segments, folds); + } break; + case 3: { structural_sequence_segmentation_trainer<segmenter_type::fe3> trainer; + configure_trainer(samples, trainer, params); + res = cross_validate_sequence_segmenter(trainer, samples, segments, folds); + } break; + case 4: { structural_sequence_segmentation_trainer<segmenter_type::fe4> trainer; + configure_trainer(samples, trainer, params); + res = cross_validate_sequence_segmenter(trainer, samples, segments, folds); + } break; + case 5: { structural_sequence_segmentation_trainer<segmenter_type::fe5> trainer; + configure_trainer(samples, trainer, params); + res = cross_validate_sequence_segmenter(trainer, samples, segments, folds); + } break; + case 6: { structural_sequence_segmentation_trainer<segmenter_type::fe6> trainer; + configure_trainer(samples, trainer, params); + res = cross_validate_sequence_segmenter(trainer, samples, segments, folds); + } break; + case 7: { structural_sequence_segmentation_trainer<segmenter_type::fe7> trainer; + configure_trainer(samples, trainer, params); + res = cross_validate_sequence_segmenter(trainer, samples, segments, folds); + } break; + default: throw dlib::error("Invalid mode"); + } + + + segmenter_test temp; + temp.precision = res(0); + temp.recall = res(1); + temp.f1 = res(2); + return temp; +} + +const segmenter_test cross_validate_sequence_segmenter2 ( + const std::vector<std::vector<sparse_vect> >& samples, + const std::vector<ranges>& segments, + long folds, + segmenter_params params +) +{ + pyassert(is_sequence_segmentation_problem(samples, segments), "Invalid inputs"); + pyassert(1 < folds && folds <= static_cast<long>(samples.size()), "folds argument is outside the valid range."); + + matrix<double,1,3> res; + + int mode = 0; + if (params.use_BIO_model) + mode = mode*2 + 1; + else + mode = mode*2; + if (params.use_high_order_features) + mode = mode*2 + 1; + else + mode = mode*2; + if (params.allow_negative_weights) + mode = mode*2 + 1; + else + mode = mode*2; + + mode += 8; + + switch(mode) + { + case 8: { structural_sequence_segmentation_trainer<segmenter_type::fe8> trainer; + configure_trainer(samples, trainer, params); + res = cross_validate_sequence_segmenter(trainer, samples, segments, folds); + } break; + case 9: { structural_sequence_segmentation_trainer<segmenter_type::fe9> trainer; + configure_trainer(samples, trainer, params); + res = cross_validate_sequence_segmenter(trainer, samples, segments, folds); + } break; + case 10: { structural_sequence_segmentation_trainer<segmenter_type::fe10> trainer; + configure_trainer(samples, trainer, params); + res = cross_validate_sequence_segmenter(trainer, samples, segments, folds); + } break; + case 11: { structural_sequence_segmentation_trainer<segmenter_type::fe11> trainer; + configure_trainer(samples, trainer, params); + res = cross_validate_sequence_segmenter(trainer, samples, segments, folds); + } break; + case 12: { structural_sequence_segmentation_trainer<segmenter_type::fe12> trainer; + configure_trainer(samples, trainer, params); + res = cross_validate_sequence_segmenter(trainer, samples, segments, folds); + } break; + case 13: { structural_sequence_segmentation_trainer<segmenter_type::fe13> trainer; + configure_trainer(samples, trainer, params); + res = cross_validate_sequence_segmenter(trainer, samples, segments, folds); + } break; + case 14: { structural_sequence_segmentation_trainer<segmenter_type::fe14> trainer; + configure_trainer(samples, trainer, params); + res = cross_validate_sequence_segmenter(trainer, samples, segments, folds); + } break; + case 15: { structural_sequence_segmentation_trainer<segmenter_type::fe15> trainer; + configure_trainer(samples, trainer, params); + res = cross_validate_sequence_segmenter(trainer, samples, segments, folds); + } break; + default: throw dlib::error("Invalid mode"); + } + + + segmenter_test temp; + temp.precision = res(0); + temp.recall = res(1); + temp.f1 = res(2); + return temp; +} + +// ---------------------------------------------------------------------------------------- + +void bind_sequence_segmenter(py::module& m) +{ + py::class_<segmenter_params>(m, "segmenter_params", +"This class is used to define all the optional parameters to the \n\ +train_sequence_segmenter() and cross_validate_sequence_segmenter() routines. ") + .def(py::init<>()) + .def_readwrite("use_BIO_model", &segmenter_params::use_BIO_model) + .def_readwrite("use_high_order_features", &segmenter_params::use_high_order_features) + .def_readwrite("allow_negative_weights", &segmenter_params::allow_negative_weights) + .def_readwrite("window_size", &segmenter_params::window_size) + .def_readwrite("num_threads", &segmenter_params::num_threads) + .def_readwrite("epsilon", &segmenter_params::epsilon) + .def_readwrite("max_cache_size", &segmenter_params::max_cache_size) + .def_readwrite("C", &segmenter_params::C, "SVM C parameter") + .def_readwrite("be_verbose", &segmenter_params::be_verbose) + .def("__repr__",&segmenter_params__repr__) + .def("__str__",&segmenter_params__str__) + .def(py::pickle(&getstate<segmenter_params>, &setstate<segmenter_params>)); + + py::class_<segmenter_type> (m, "segmenter_type", "This object represents a sequence segmenter and is the type of object " + "returned by the dlib.train_sequence_segmenter() routine.") + .def("__call__", &segmenter_type::segment_sequence_dense) + .def("__call__", &segmenter_type::segment_sequence_sparse) + .def_property_readonly("weights", &segmenter_type::get_weights) + .def(py::pickle(&getstate<segmenter_type>, &setstate<segmenter_type>)); + + py::class_<segmenter_test> (m, "segmenter_test", "This object is the output of the dlib.test_sequence_segmenter() and " + "dlib.cross_validate_sequence_segmenter() routines.") + .def_readwrite("precision", &segmenter_test::precision) + .def_readwrite("recall", &segmenter_test::recall) + .def_readwrite("f1", &segmenter_test::f1) + .def("__repr__",&segmenter_test__repr__) + .def("__str__",&segmenter_test__str__) + .def(py::pickle(&getstate<segmenter_test>, &setstate<segmenter_test>)); + + m.def("train_sequence_segmenter", train_dense, py::arg("samples"), py::arg("segments"), py::arg("params")=segmenter_params()); + m.def("train_sequence_segmenter", train_sparse, py::arg("samples"), py::arg("segments"), py::arg("params")=segmenter_params()); + + + m.def("test_sequence_segmenter", test_sequence_segmenter1); + m.def("test_sequence_segmenter", test_sequence_segmenter2); + + m.def("cross_validate_sequence_segmenter", cross_validate_sequence_segmenter1, + py::arg("samples"), py::arg("segments"), py::arg("folds"), py::arg("params")=segmenter_params()); + m.def("cross_validate_sequence_segmenter", cross_validate_sequence_segmenter2, + py::arg("samples"), py::arg("segments"), py::arg("folds"), py::arg("params")=segmenter_params()); +} + + + + diff --git a/ml/dlib/tools/python/src/serialize_object_detector.h b/ml/dlib/tools/python/src/serialize_object_detector.h new file mode 100644 index 000000000..e53401c81 --- /dev/null +++ b/ml/dlib/tools/python/src/serialize_object_detector.h @@ -0,0 +1,49 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SERIALIZE_OBJECT_DETECTOR_H__ +#define DLIB_SERIALIZE_OBJECT_DETECTOR_H__ + +#include "simple_object_detector_py.h" + +namespace dlib +{ + inline void serialize (const dlib::simple_object_detector_py& item, std::ostream& out) + { + int version = 1; + serialize(item.detector, out); + serialize(version, out); + serialize(item.upsampling_amount, out); + } + + inline void deserialize (dlib::simple_object_detector_py& item, std::istream& in) + { + int version = 0; + deserialize(item.detector, in); + deserialize(version, in); + if (version != 1) + throw dlib::serialization_error("Unexpected version found while deserializing a simple_object_detector."); + deserialize(item.upsampling_amount, in); + } + + inline void save_simple_object_detector_py(const simple_object_detector_py& detector, const std::string& detector_output_filename) + { + std::ofstream fout(detector_output_filename.c_str(), std::ios::binary); + int version = 1; + serialize(detector.detector, fout); + serialize(version, fout); + serialize(detector.upsampling_amount, fout); + } + +// ---------------------------------------------------------------------------------------- + + inline void save_simple_object_detector(const simple_object_detector& detector, const std::string& detector_output_filename) + { + std::ofstream fout(detector_output_filename.c_str(), std::ios::binary); + serialize(detector, fout); + // Don't need to save version of upsampling amount because want to write out the + // object detector just like the C++ code that serializes an object_detector would. + // We also don't know the upsampling amount in this case anyway. + } +} + +#endif // DLIB_SERIALIZE_OBJECT_DETECTOR_H__ diff --git a/ml/dlib/tools/python/src/shape_predictor.cpp b/ml/dlib/tools/python/src/shape_predictor.cpp new file mode 100644 index 000000000..76f21750a --- /dev/null +++ b/ml/dlib/tools/python/src/shape_predictor.cpp @@ -0,0 +1,319 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/geometry.h> +#include <dlib/image_processing.h> +#include "shape_predictor.h" +#include "conversion.h" + +using namespace dlib; +using namespace std; + +namespace py = pybind11; + +// ---------------------------------------------------------------------------------------- + +full_object_detection run_predictor ( + shape_predictor& predictor, + py::object img, + py::object rect +) +{ + rectangle box = rect.cast<rectangle>(); + if (is_gray_python_image(img)) + { + return predictor(numpy_gray_image(img), box); + } + else if (is_rgb_python_image(img)) + { + return predictor(numpy_rgb_image(img), box); + } + else + { + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); + } +} + +void save_shape_predictor(const shape_predictor& predictor, const std::string& predictor_output_filename) +{ + std::ofstream fout(predictor_output_filename.c_str(), std::ios::binary); + serialize(predictor, fout); +} + +// ---------------------------------------------------------------------------------------- + +rectangle full_obj_det_get_rect (const full_object_detection& detection) +{ return detection.get_rect(); } + +unsigned long full_obj_det_num_parts (const full_object_detection& detection) +{ return detection.num_parts(); } + +point full_obj_det_part (const full_object_detection& detection, const unsigned long idx) +{ + if (idx >= detection.num_parts()) + { + PyErr_SetString(PyExc_IndexError, "Index out of range"); + throw py::error_already_set(); + } + return detection.part(idx); +} + +std::vector<point> full_obj_det_parts (const full_object_detection& detection) +{ + const unsigned long num_parts = detection.num_parts(); + std::vector<point> parts(num_parts); + for (unsigned long j = 0; j < num_parts; ++j) + parts[j] = detection.part(j); + return parts; +} + +std::shared_ptr<full_object_detection> full_obj_det_init(py::object& pyrect, py::object& pyparts) +{ + const unsigned long num_parts = py::len(pyparts); + std::vector<point> parts(num_parts); + rectangle rect = pyrect.cast<rectangle>(); + py::iterator parts_it = pyparts.begin(); + + for (unsigned long j = 0; + parts_it != pyparts.end(); + ++j, ++parts_it) + parts[j] = parts_it->cast<point>(); + + return std::make_shared<full_object_detection>(rect, parts); +} + +// ---------------------------------------------------------------------------------------- + +inline shape_predictor train_shape_predictor_on_images_py ( + const py::list& pyimages, + const py::list& pydetections, + const shape_predictor_training_options& options +) +{ + const unsigned long num_images = py::len(pyimages); + if (num_images != py::len(pydetections)) + throw dlib::error("The length of the detections list must match the length of the images list."); + + std::vector<std::vector<full_object_detection> > detections(num_images); + dlib::array<array2d<unsigned char> > images(num_images); + images_and_nested_params_to_dlib(pyimages, pydetections, images, detections); + + return train_shape_predictor_on_images(images, detections, options); +} + + +inline double test_shape_predictor_with_images_py ( + const py::list& pyimages, + const py::list& pydetections, + const py::list& pyscales, + const shape_predictor& predictor +) +{ + const unsigned long num_images = py::len(pyimages); + const unsigned long num_scales = py::len(pyscales); + if (num_images != py::len(pydetections)) + throw dlib::error("The length of the detections list must match the length of the images list."); + + if (num_scales > 0 && num_scales != num_images) + throw dlib::error("The length of the scales list must match the length of the detections list."); + + std::vector<std::vector<full_object_detection> > detections(num_images); + std::vector<std::vector<double> > scales; + if (num_scales > 0) + scales.resize(num_scales); + dlib::array<array2d<unsigned char> > images(num_images); + + // Now copy the data into dlib based objects so we can call the testing routine. + for (unsigned long i = 0; i < num_images; ++i) + { + const unsigned long num_boxes = py::len(pydetections[i]); + for (py::iterator det_it = pydetections[i].begin(); + det_it != pydetections[i].end(); + ++det_it) + detections[i].push_back(det_it->cast<full_object_detection>()); + + pyimage_to_dlib_image(pyimages[i], images[i]); + if (num_scales > 0) + { + if (num_boxes != py::len(pyscales[i])) + throw dlib::error("The length of the scales list must match the length of the detections list."); + for (py::iterator scale_it = pyscales[i].begin(); + scale_it != pyscales[i].end(); + ++scale_it) + scales[i].push_back(scale_it->cast<double>()); + } + } + + return test_shape_predictor_with_images(images, detections, scales, predictor); +} + +inline double test_shape_predictor_with_images_no_scales_py ( + const py::list& pyimages, + const py::list& pydetections, + const shape_predictor& predictor +) +{ + py::list pyscales; + return test_shape_predictor_with_images_py(pyimages, pydetections, pyscales, predictor); +} + +// ---------------------------------------------------------------------------------------- + +void bind_shape_predictors(py::module &m) +{ + { + typedef full_object_detection type; + py::class_<type, std::shared_ptr<type>>(m, "full_object_detection", + "This object represents the location of an object in an image along with the \ + positions of each of its constituent parts.") + .def(py::init(&full_obj_det_init), +"requires \n\ + - rect: dlib rectangle \n\ + - parts: list of dlib points") + .def_property_readonly("rect", &full_obj_det_get_rect, "Bounding box from the underlying detector. Parts can be outside box if appropriate.") + .def_property_readonly("num_parts", &full_obj_det_num_parts, "The number of parts of the object.") + .def("part", &full_obj_det_part, py::arg("idx"), "A single part of the object as a dlib point.") + .def("parts", &full_obj_det_parts, "A vector of dlib points representing all of the parts.") + .def(py::pickle(&getstate<type>, &setstate<type>)); + } + { + typedef shape_predictor_training_options type; + py::class_<type>(m, "shape_predictor_training_options", + "This object is a container for the options to the train_shape_predictor() routine.") + .def(py::init()) + .def_readwrite("be_verbose", &type::be_verbose, + "If true, train_shape_predictor() will print out a lot of information to stdout while training.") + .def_readwrite("cascade_depth", &type::cascade_depth, + "The number of cascades created to train the model with.") + .def_readwrite("tree_depth", &type::tree_depth, + "The depth of the trees used in each cascade. There are pow(2, get_tree_depth()) leaves in each tree") + .def_readwrite("num_trees_per_cascade_level", &type::num_trees_per_cascade_level, + "The number of trees created for each cascade.") + .def_readwrite("nu", &type::nu, + "The regularization parameter. Larger values of this parameter \ + will cause the algorithm to fit the training data better but may also \ + cause overfitting. The value must be in the range (0, 1].") + .def_readwrite("oversampling_amount", &type::oversampling_amount, + "The number of randomly selected initial starting points sampled for each training example") + .def_readwrite("feature_pool_size", &type::feature_pool_size, + "Number of pixels used to generate features for the random trees.") + .def_readwrite("lambda_param", &type::lambda_param, + "Controls how tight the feature sampling should be. Lower values enforce closer features.") + .def_readwrite("num_test_splits", &type::num_test_splits, + "Number of split features at each node to sample. The one that gives the best split is chosen.") + .def_readwrite("feature_pool_region_padding", &type::feature_pool_region_padding, + "Size of region within which to sample features for the feature pool, \ + e.g a padding of 0.5 would cause the algorithm to sample pixels from a box that was 2x2 pixels") + .def_readwrite("random_seed", &type::random_seed, + "The random seed used by the internal random number generator") + .def_readwrite("num_threads", &type::num_threads, + "Use this many threads/CPU cores for training.") + .def("__str__", &::print_shape_predictor_training_options) + .def(py::pickle(&getstate<type>, &setstate<type>)); + } + { + typedef shape_predictor type; + py::class_<type, std::shared_ptr<type>>(m, "shape_predictor", +"This object is a tool that takes in an image region containing some object and \ +outputs a set of point locations that define the pose of the object. The classic \ +example of this is human face pose prediction, where you take an image of a human \ +face as input and are expected to identify the locations of important facial \ +landmarks such as the corners of the mouth and eyes, tip of the nose, and so forth.") + .def(py::init()) + .def(py::init(&load_object_from_file<type>), +"Loads a shape_predictor from a file that contains the output of the \n\ +train_shape_predictor() routine.") + .def("__call__", &run_predictor, py::arg("image"), py::arg("box"), +"requires \n\ + - image is a numpy ndarray containing either an 8bit grayscale or RGB \n\ + image. \n\ + - box is the bounding box to begin the shape prediction inside. \n\ +ensures \n\ + - This function runs the shape predictor on the input image and returns \n\ + a single full_object_detection.") + .def("save", save_shape_predictor, py::arg("predictor_output_filename"), "Save a shape_predictor to the provided path.") + .def(py::pickle(&getstate<type>, &setstate<type>)); + } + { + m.def("train_shape_predictor", train_shape_predictor_on_images_py, + py::arg("images"), py::arg("object_detections"), py::arg("options"), +"requires \n\ + - options.lambda_param > 0 \n\ + - 0 < options.nu <= 1 \n\ + - options.feature_pool_region_padding >= 0 \n\ + - len(images) == len(object_detections) \n\ + - images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\ + - object_detections should be a list of lists of dlib.full_object_detection objects. \ + Each dlib.full_object_detection contains the bounding box and the lists of points that make up the object parts.\n\ +ensures \n\ + - Uses dlib's shape_predictor_trainer object to train a \n\ + shape_predictor based on the provided labeled images, full_object_detections, and options.\n\ + - The trained shape_predictor is returned"); + + m.def("train_shape_predictor", train_shape_predictor, + py::arg("dataset_filename"), py::arg("predictor_output_filename"), py::arg("options"), +"requires \n\ + - options.lambda_param > 0 \n\ + - 0 < options.nu <= 1 \n\ + - options.feature_pool_region_padding >= 0 \n\ +ensures \n\ + - Uses dlib's shape_predictor_trainer to train a \n\ + shape_predictor based on the labeled images in the XML file \n\ + dataset_filename and the provided options. This function assumes the file dataset_filename is in the \n\ + XML format produced by dlib's save_image_dataset_metadata() routine. \n\ + - The trained shape predictor is serialized to the file predictor_output_filename."); + + m.def("test_shape_predictor", test_shape_predictor_py, + py::arg("dataset_filename"), py::arg("predictor_filename"), +"ensures \n\ + - Loads an image dataset from dataset_filename. We assume dataset_filename is \n\ + a file using the XML format written by save_image_dataset_metadata(). \n\ + - Loads a shape_predictor from the file predictor_filename. This means \n\ + predictor_filename should be a file produced by the train_shape_predictor() \n\ + routine. \n\ + - This function tests the predictor against the dataset and returns the \n\ + mean average error of the detector. In fact, The \n\ + return value of this function is identical to that of dlib's \n\ + shape_predictor_trainer() routine. Therefore, see the documentation \n\ + for shape_predictor_trainer() for a detailed definition of the mean average error."); + + m.def("test_shape_predictor", test_shape_predictor_with_images_no_scales_py, + py::arg("images"), py::arg("detections"), py::arg("shape_predictor"), +"requires \n\ + - len(images) == len(object_detections) \n\ + - images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\ + - object_detections should be a list of lists of dlib.full_object_detection objects. \ + Each dlib.full_object_detection contains the bounding box and the lists of points that make up the object parts.\n\ + ensures \n\ + - shape_predictor should be a file produced by the train_shape_predictor() \n\ + routine. \n\ + - This function tests the predictor against the dataset and returns the \n\ + mean average error of the detector. In fact, The \n\ + return value of this function is identical to that of dlib's \n\ + shape_predictor_trainer() routine. Therefore, see the documentation \n\ + for shape_predictor_trainer() for a detailed definition of the mean average error."); + + + m.def("test_shape_predictor", test_shape_predictor_with_images_py, + py::arg("images"), py::arg("detections"), py::arg("scales"), py::arg("shape_predictor"), +"requires \n\ + - len(images) == len(object_detections) \n\ + - len(object_detections) == len(scales) \n\ + - for every sublist in object_detections: len(object_detections[i]) == len(scales[i]) \n\ + - scales is a list of floating point scales that each predicted part location \ + should be divided by. Useful for normalization. \n\ + - images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\ + - object_detections should be a list of lists of dlib.full_object_detection objects. \ + Each dlib.full_object_detection contains the bounding box and the lists of points that make up the object parts.\n\ + ensures \n\ + - shape_predictor should be a file produced by the train_shape_predictor() \n\ + routine. \n\ + - This function tests the predictor against the dataset and returns the \n\ + mean average error of the detector. In fact, The \n\ + return value of this function is identical to that of dlib's \n\ + shape_predictor_trainer() routine. Therefore, see the documentation \n\ + for shape_predictor_trainer() for a detailed definition of the mean average error."); + } +} diff --git a/ml/dlib/tools/python/src/shape_predictor.h b/ml/dlib/tools/python/src/shape_predictor.h new file mode 100644 index 000000000..f7a071a75 --- /dev/null +++ b/ml/dlib/tools/python/src/shape_predictor.h @@ -0,0 +1,259 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SHAPE_PREDICTOR_H__ +#define DLIB_SHAPE_PREDICTOR_H__ + +#include "dlib/string.h" +#include "dlib/geometry.h" +#include "dlib/data_io/load_image_dataset.h" +#include "dlib/image_processing.h" + +using namespace std; + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct shape_predictor_training_options + { + shape_predictor_training_options() + { + be_verbose = false; + cascade_depth = 10; + tree_depth = 4; + num_trees_per_cascade_level = 500; + nu = 0.1; + oversampling_amount = 20; + feature_pool_size = 400; + lambda_param = 0.1; + num_test_splits = 20; + feature_pool_region_padding = 0; + random_seed = ""; + num_threads = 0; + } + + bool be_verbose; + unsigned long cascade_depth; + unsigned long tree_depth; + unsigned long num_trees_per_cascade_level; + double nu; + unsigned long oversampling_amount; + unsigned long feature_pool_size; + double lambda_param; + unsigned long num_test_splits; + double feature_pool_region_padding; + std::string random_seed; + + // not serialized + unsigned long num_threads; + }; + + inline void serialize ( + const shape_predictor_training_options& item, + std::ostream& out + ) + { + try + { + serialize(item.be_verbose,out); + serialize(item.cascade_depth,out); + serialize(item.tree_depth,out); + serialize(item.num_trees_per_cascade_level,out); + serialize(item.nu,out); + serialize(item.oversampling_amount,out); + serialize(item.feature_pool_size,out); + serialize(item.lambda_param,out); + serialize(item.num_test_splits,out); + serialize(item.feature_pool_region_padding,out); + serialize(item.random_seed,out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing an object of type shape_predictor_training_options"); + } + } + + inline void deserialize ( + shape_predictor_training_options& item, + std::istream& in + ) + { + try + { + deserialize(item.be_verbose,in); + deserialize(item.cascade_depth,in); + deserialize(item.tree_depth,in); + deserialize(item.num_trees_per_cascade_level,in); + deserialize(item.nu,in); + deserialize(item.oversampling_amount,in); + deserialize(item.feature_pool_size,in); + deserialize(item.lambda_param,in); + deserialize(item.num_test_splits,in); + deserialize(item.feature_pool_region_padding,in); + deserialize(item.random_seed,in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing an object of type shape_predictor_training_options"); + } + } + + string print_shape_predictor_training_options(const shape_predictor_training_options& o) + { + std::ostringstream sout; + sout << "shape_predictor_training_options(" + << "be_verbose=" << o.be_verbose << "," + << "cascade_depth=" << o.cascade_depth << "," + << "tree_depth=" << o.tree_depth << "," + << "num_trees_per_cascade_level=" << o.num_trees_per_cascade_level << "," + << "nu=" << o.nu << "," + << "oversampling_amount=" << o.oversampling_amount << "," + << "feature_pool_size=" << o.feature_pool_size << "," + << "lambda_param=" << o.lambda_param << "," + << "num_test_splits=" << o.num_test_splits << "," + << "feature_pool_region_padding=" << o.feature_pool_region_padding << "," + << "random_seed=" << o.random_seed << "," + << "num_threads=" << o.num_threads + << ")"; + return sout.str(); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline bool contains_any_detections ( + const std::vector<std::vector<full_object_detection> >& detections + ) + { + for (unsigned long i = 0; i < detections.size(); ++i) + { + if (detections[i].size() != 0) + return true; + } + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + template <typename image_array> + inline shape_predictor train_shape_predictor_on_images ( + image_array& images, + std::vector<std::vector<full_object_detection> >& detections, + const shape_predictor_training_options& options + ) + { + if (options.lambda_param <= 0) + throw error("Invalid lambda_param value given to train_shape_predictor(), lambda_param must be > 0."); + if (!(0 < options.nu && options.nu <= 1)) + throw error("Invalid nu value given to train_shape_predictor(). It is required that 0 < nu <= 1."); + if (options.feature_pool_region_padding <= -0.5) + throw error("Invalid feature_pool_region_padding value given to train_shape_predictor(), feature_pool_region_padding must be > -0.5."); + + if (images.size() != detections.size()) + throw error("The list of images must have the same length as the list of detections."); + + if (!impl::contains_any_detections(detections)) + throw error("Error, the training dataset does not have any labeled object detections in it."); + + shape_predictor_trainer trainer; + + trainer.set_cascade_depth(options.cascade_depth); + trainer.set_tree_depth(options.tree_depth); + trainer.set_num_trees_per_cascade_level(options.num_trees_per_cascade_level); + trainer.set_nu(options.nu); + trainer.set_random_seed(options.random_seed); + trainer.set_oversampling_amount(options.oversampling_amount); + trainer.set_feature_pool_size(options.feature_pool_size); + trainer.set_feature_pool_region_padding(options.feature_pool_region_padding); + trainer.set_lambda(options.lambda_param); + trainer.set_num_test_splits(options.num_test_splits); + trainer.set_num_threads(options.num_threads); + + if (options.be_verbose) + { + std::cout << "Training with cascade depth: " << options.cascade_depth << std::endl; + std::cout << "Training with tree depth: " << options.tree_depth << std::endl; + std::cout << "Training with " << options.num_trees_per_cascade_level << " trees per cascade level."<< std::endl; + std::cout << "Training with nu: " << options.nu << std::endl; + std::cout << "Training with random seed: " << options.random_seed << std::endl; + std::cout << "Training with oversampling amount: " << options.oversampling_amount << std::endl; + std::cout << "Training with feature pool size: " << options.feature_pool_size << std::endl; + std::cout << "Training with feature pool region padding: " << options.feature_pool_region_padding << std::endl; + std::cout << "Training with " << options.num_threads << " threads." << std::endl; + std::cout << "Training with lambda_param: " << options.lambda_param << std::endl; + std::cout << "Training with " << options.num_test_splits << " split tests."<< std::endl; + trainer.be_verbose(); + } + + shape_predictor predictor = trainer.train(images, detections); + + return predictor; + } + + inline void train_shape_predictor ( + const std::string& dataset_filename, + const std::string& predictor_output_filename, + const shape_predictor_training_options& options + ) + { + dlib::array<array2d<unsigned char> > images; + std::vector<std::vector<full_object_detection> > objects; + load_image_dataset(images, objects, dataset_filename); + + shape_predictor predictor = train_shape_predictor_on_images(images, objects, options); + + serialize(predictor_output_filename) << predictor; + + if (options.be_verbose) + std::cout << "Training complete, saved predictor to file " << predictor_output_filename << std::endl; + } + +// ---------------------------------------------------------------------------------------- + + template <typename image_array> + inline double test_shape_predictor_with_images ( + image_array& images, + std::vector<std::vector<full_object_detection> >& detections, + std::vector<std::vector<double> >& scales, + const shape_predictor& predictor + ) + { + if (images.size() != detections.size()) + throw error("The list of images must have the same length as the list of detections."); + if (scales.size() > 0 && scales.size() != images.size()) + throw error("The list of scales must have the same length as the list of detections."); + + if (scales.size() > 0) + return test_shape_predictor(predictor, images, detections, scales); + else + return test_shape_predictor(predictor, images, detections); + } + + inline double test_shape_predictor_py ( + const std::string& dataset_filename, + const std::string& predictor_filename + ) + { + // Load the images, no scales can be provided + dlib::array<array2d<unsigned char> > images; + // This interface cannot take the scales parameter. + std::vector<std::vector<double> > scales; + std::vector<std::vector<full_object_detection> > objects; + load_image_dataset(images, objects, dataset_filename); + + // Load the shape predictor + shape_predictor predictor; + deserialize(predictor_filename) >> predictor; + + return test_shape_predictor_with_images(images, objects, scales, predictor); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SHAPE_PREDICTOR_H__ + diff --git a/ml/dlib/tools/python/src/simple_object_detector.h b/ml/dlib/tools/python/src/simple_object_detector.h new file mode 100644 index 000000000..4fceab429 --- /dev/null +++ b/ml/dlib/tools/python/src/simple_object_detector.h @@ -0,0 +1,318 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SIMPLE_ObJECT_DETECTOR_H__ +#define DLIB_SIMPLE_ObJECT_DETECTOR_H__ + +#include "dlib/image_processing/object_detector.h" +#include "dlib/string.h" +#include "dlib/image_processing/scan_fhog_pyramid.h" +#include "dlib/svm/structural_object_detection_trainer.h" +#include "dlib/geometry.h" +#include "dlib/data_io/load_image_dataset.h" +#include "dlib/image_processing/remove_unobtainable_rectangles.h" +#include "serialize_object_detector.h" +#include "dlib/svm.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + typedef object_detector<scan_fhog_pyramid<pyramid_down<6> > > simple_object_detector; + +// ---------------------------------------------------------------------------------------- + + struct simple_object_detector_training_options + { + simple_object_detector_training_options() + { + be_verbose = false; + add_left_right_image_flips = false; + num_threads = 4; + detection_window_size = 80*80; + C = 1; + epsilon = 0.01; + upsample_limit = 2; + } + + bool be_verbose; + bool add_left_right_image_flips; + unsigned long num_threads; + unsigned long detection_window_size; + double C; + double epsilon; + unsigned long upsample_limit; + }; + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline void pick_best_window_size ( + const std::vector<std::vector<rectangle> >& boxes, + unsigned long& width, + unsigned long& height, + const unsigned long target_size + ) + { + // find the average width and height + running_stats<double> avg_width, avg_height; + for (unsigned long i = 0; i < boxes.size(); ++i) + { + for (unsigned long j = 0; j < boxes[i].size(); ++j) + { + avg_width.add(boxes[i][j].width()); + avg_height.add(boxes[i][j].height()); + } + } + + // now adjust the box size so that it is about target_pixels pixels in size + double size = avg_width.mean()*avg_height.mean(); + double scale = std::sqrt(target_size/size); + + width = (unsigned long)(avg_width.mean()*scale+0.5); + height = (unsigned long)(avg_height.mean()*scale+0.5); + // make sure the width and height never round to zero. + if (width == 0) + width = 1; + if (height == 0) + height = 1; + } + + inline bool contains_any_boxes ( + const std::vector<std::vector<rectangle> >& boxes + ) + { + for (unsigned long i = 0; i < boxes.size(); ++i) + { + if (boxes[i].size() != 0) + return true; + } + return false; + } + + inline void throw_invalid_box_error_message ( + const std::string& dataset_filename, + const std::vector<std::vector<rectangle> >& removed, + const simple_object_detector_training_options& options + ) + { + + std::ostringstream sout; + // Note that the 1/16 factor is here because we will try to upsample the image + // 2 times to accommodate small boxes. We also take the max because we want to + // lower bound the size of the smallest recommended box. This is because the + // 8x8 HOG cells can't really deal with really small object boxes. + sout << "Error! An impossible set of object boxes was given for training. "; + sout << "All the boxes need to have a similar aspect ratio and also not be "; + sout << "smaller than about " << std::max<long>(20*20,options.detection_window_size/16) << " pixels in area. "; + + std::ostringstream sout2; + if (dataset_filename.size() != 0) + { + sout << "The following images contain invalid boxes:\n"; + image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, dataset_filename); + for (unsigned long i = 0; i < removed.size(); ++i) + { + if (removed[i].size() != 0) + { + const std::string imgname = data.images[i].filename; + sout2 << " " << imgname << "\n"; + } + } + } + throw error("\n"+wrap_string(sout.str()) + "\n" + sout2.str()); + } + } + +// ---------------------------------------------------------------------------------------- + + template <typename image_array> + inline simple_object_detector_py train_simple_object_detector_on_images ( + const std::string& dataset_filename, // can be "" if it's not applicable + image_array& images, + std::vector<std::vector<rectangle> >& boxes, + std::vector<std::vector<rectangle> >& ignore, + const simple_object_detector_training_options& options + ) + { + if (options.C <= 0) + throw error("Invalid C value given to train_simple_object_detector(), C must be > 0."); + if (options.epsilon <= 0) + throw error("Invalid epsilon value given to train_simple_object_detector(), epsilon must be > 0."); + + if (images.size() != boxes.size()) + throw error("The list of images must have the same length as the list of boxes."); + if (images.size() != ignore.size()) + throw error("The list of images must have the same length as the list of ignore boxes."); + + if (impl::contains_any_boxes(boxes) == false) + throw error("Error, the training dataset does not have any labeled object boxes in it."); + + typedef scan_fhog_pyramid<pyramid_down<6> > image_scanner_type; + image_scanner_type scanner; + unsigned long width, height; + impl::pick_best_window_size(boxes, width, height, options.detection_window_size); + scanner.set_detection_window_size(width, height); + structural_object_detection_trainer<image_scanner_type> trainer(scanner); + trainer.set_num_threads(options.num_threads); + trainer.set_c(options.C); + trainer.set_epsilon(options.epsilon); + if (options.be_verbose) + { + std::cout << "Training with C: " << options.C << std::endl; + std::cout << "Training with epsilon: " << options.epsilon << std::endl; + std::cout << "Training using " << options.num_threads << " threads."<< std::endl; + std::cout << "Training with sliding window " << width << " pixels wide by " << height << " pixels tall." << std::endl; + if (options.add_left_right_image_flips) + std::cout << "Training on both left and right flipped versions of images." << std::endl; + trainer.be_verbose(); + } + + unsigned long upsampling_amount = 0; + + // now make sure all the boxes are obtainable by the scanner. We will try and + // upsample the images at most two times to help make the boxes obtainable. + std::vector<std::vector<rectangle> > temp(boxes), removed; + removed = remove_unobtainable_rectangles(trainer, images, temp); + while (impl::contains_any_boxes(removed) && upsampling_amount < options.upsample_limit) + { + ++upsampling_amount; + if (options.be_verbose) + std::cout << "Upsample images..." << std::endl; + upsample_image_dataset<pyramid_down<2> >(images, boxes, ignore); + temp = boxes; + removed = remove_unobtainable_rectangles(trainer, images, temp); + } + // if we weren't able to get all the boxes to match then throw an error + if (impl::contains_any_boxes(removed)) + impl::throw_invalid_box_error_message(dataset_filename, removed, options); + + if (options.add_left_right_image_flips) + add_image_left_right_flips(images, boxes, ignore); + + simple_object_detector detector = trainer.train(images, boxes, ignore); + + if (options.be_verbose) + { + std::cout << "Training complete." << std::endl; + std::cout << "Trained with C: " << options.C << std::endl; + std::cout << "Training with epsilon: " << options.epsilon << std::endl; + std::cout << "Trained using " << options.num_threads << " threads."<< std::endl; + std::cout << "Trained with sliding window " << width << " pixels wide by " << height << " pixels tall." << std::endl; + if (upsampling_amount != 0) + { + // Unsampled images # time(s) to allow detection of small boxes + std::cout << "Upsampled images " << upsampling_amount; + std::cout << ((upsampling_amount > 1) ? " times" : " time"); + std::cout << " to allow detection of small boxes." << std::endl; + } + if (options.add_left_right_image_flips) + std::cout << "Trained on both left and right flipped versions of images." << std::endl; + } + + return simple_object_detector_py(detector, upsampling_amount); + } + +// ---------------------------------------------------------------------------------------- + + inline void train_simple_object_detector ( + const std::string& dataset_filename, + const std::string& detector_output_filename, + const simple_object_detector_training_options& options + ) + { + dlib::array<array2d<rgb_pixel> > images; + std::vector<std::vector<rectangle> > boxes, ignore; + ignore = load_image_dataset(images, boxes, dataset_filename); + + simple_object_detector_py detector = train_simple_object_detector_on_images(dataset_filename, images, boxes, ignore, options); + + save_simple_object_detector_py(detector, detector_output_filename); + + if (options.be_verbose) + std::cout << "Saved detector to file " << detector_output_filename << std::endl; + } + +// ---------------------------------------------------------------------------------------- + + struct simple_test_results + { + double precision; + double recall; + double average_precision; + }; + + template <typename image_array> + inline const simple_test_results test_simple_object_detector_with_images ( + image_array& images, + const unsigned int upsample_amount, + std::vector<std::vector<rectangle> >& boxes, + std::vector<std::vector<rectangle> >& ignore, + simple_object_detector& detector + ) + { + for (unsigned int i = 0; i < upsample_amount; ++i) + upsample_image_dataset<pyramid_down<2> >(images, boxes); + + matrix<double,1,3> res = test_object_detection_function(detector, images, boxes, ignore); + simple_test_results ret; + ret.precision = res(0); + ret.recall = res(1); + ret.average_precision = res(2); + return ret; + } + + inline const simple_test_results test_simple_object_detector ( + const std::string& dataset_filename, + const std::string& detector_filename, + const int upsample_amount + ) + { + // Load all the testing images + dlib::array<array2d<rgb_pixel> > images; + std::vector<std::vector<rectangle> > boxes, ignore; + ignore = load_image_dataset(images, boxes, dataset_filename); + + // Load the detector off disk (We have to use the explicit serialization here + // so that we have an open file stream) + simple_object_detector detector; + std::ifstream fin(detector_filename.c_str(), std::ios::binary); + if (!fin) + throw error("Unable to open file " + detector_filename); + deserialize(detector, fin); + + + /* Here we need a little hack to deal with whether we are going to be loading a + * simple_object_detector (possibly trained outside of Python) or a + * simple_object_detector_py (definitely trained from Python). In order to do this + * we peek into the filestream to see if there is more data after the object + * detector. If there is, it will be the version and upsampling amount. Therefore, + * by default we set the upsampling amount to -1 so that we can catch when no + * upsampling amount has been passed (numbers less than 0). If -1 is passed, we + * assume no upsampling and use 0. If a number > 0 is passed, we use that, else we + * use the upsampling amount saved in the detector file (if it exists). + */ + unsigned int final_upsampling_amount = 0; + if (fin.peek() != EOF) + { + int version = 0; + deserialize(version, fin); + if (version != 1) + throw error("Unknown simple_object_detector format."); + deserialize(final_upsampling_amount, fin); + } + if (upsample_amount >= 0) + final_upsampling_amount = upsample_amount; + + return test_simple_object_detector_with_images(images, final_upsampling_amount, boxes, ignore, detector); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SIMPLE_ObJECT_DETECTOR_H__ + diff --git a/ml/dlib/tools/python/src/simple_object_detector_py.h b/ml/dlib/tools/python/src/simple_object_detector_py.h new file mode 100644 index 000000000..0f950273d --- /dev/null +++ b/ml/dlib/tools/python/src/simple_object_detector_py.h @@ -0,0 +1,290 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SIMPLE_OBJECT_DETECTOR_PY_H__ +#define DLIB_SIMPLE_OBJECT_DETECTOR_PY_H__ + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/matrix.h> +#include <dlib/geometry.h> +#include <dlib/image_processing/frontal_face_detector.h> + +namespace py = pybind11; + +namespace dlib +{ + typedef object_detector<scan_fhog_pyramid<pyramid_down<6> > > simple_object_detector; + + inline void split_rect_detections ( + std::vector<rect_detection>& rect_detections, + std::vector<rectangle>& rectangles, + std::vector<double>& detection_confidences, + std::vector<unsigned long>& weight_indices + ) + { + rectangles.clear(); + detection_confidences.clear(); + weight_indices.clear(); + + for (unsigned long i = 0; i < rect_detections.size(); ++i) + { + rectangles.push_back(rect_detections[i].rect); + detection_confidences.push_back(rect_detections[i].detection_confidence); + weight_indices.push_back(rect_detections[i].weight_index); + } + } + + + inline std::vector<dlib::rectangle> run_detector_with_upscale1 ( + dlib::simple_object_detector& detector, + py::object img, + const unsigned int upsampling_amount, + const double adjust_threshold, + std::vector<double>& detection_confidences, + std::vector<unsigned long>& weight_indices + ) + { + pyramid_down<2> pyr; + + std::vector<rectangle> rectangles; + std::vector<rect_detection> rect_detections; + + if (is_gray_python_image(img)) + { + array2d<unsigned char> temp; + if (upsampling_amount == 0) + { + detector(numpy_gray_image(img), rect_detections, adjust_threshold); + split_rect_detections(rect_detections, rectangles, + detection_confidences, weight_indices); + return rectangles; + } + else + { + pyramid_up(numpy_gray_image(img), temp, pyr); + unsigned int levels = upsampling_amount-1; + while (levels > 0) + { + levels--; + pyramid_up(temp); + } + + detector(temp, rect_detections, adjust_threshold); + for (unsigned long i = 0; i < rect_detections.size(); ++i) + rect_detections[i].rect = pyr.rect_down(rect_detections[i].rect, + upsampling_amount); + split_rect_detections(rect_detections, rectangles, + detection_confidences, weight_indices); + + return rectangles; + } + } + else if (is_rgb_python_image(img)) + { + array2d<rgb_pixel> temp; + if (upsampling_amount == 0) + { + detector(numpy_rgb_image(img), rect_detections, adjust_threshold); + split_rect_detections(rect_detections, rectangles, + detection_confidences, weight_indices); + return rectangles; + } + else + { + pyramid_up(numpy_rgb_image(img), temp, pyr); + unsigned int levels = upsampling_amount-1; + while (levels > 0) + { + levels--; + pyramid_up(temp); + } + + detector(temp, rect_detections, adjust_threshold); + for (unsigned long i = 0; i < rect_detections.size(); ++i) + rect_detections[i].rect = pyr.rect_down(rect_detections[i].rect, + upsampling_amount); + split_rect_detections(rect_detections, rectangles, + detection_confidences, weight_indices); + + return rectangles; + } + } + else + { + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); + } + } + + inline std::vector<dlib::rectangle> run_detectors_with_upscale1 ( + std::vector<simple_object_detector >& detectors, + py::object img, + const unsigned int upsampling_amount, + const double adjust_threshold, + std::vector<double>& detection_confidences, + std::vector<unsigned long>& weight_indices + ) + { + pyramid_down<2> pyr; + + std::vector<rectangle> rectangles; + std::vector<rect_detection> rect_detections; + + if (is_gray_python_image(img)) + { + array2d<unsigned char> temp; + if (upsampling_amount == 0) + { + evaluate_detectors(detectors, numpy_gray_image(img), rect_detections, adjust_threshold); + split_rect_detections(rect_detections, rectangles, + detection_confidences, weight_indices); + return rectangles; + } + else + { + pyramid_up(numpy_gray_image(img), temp, pyr); + unsigned int levels = upsampling_amount-1; + while (levels > 0) + { + levels--; + pyramid_up(temp); + } + + evaluate_detectors(detectors, temp, rect_detections, adjust_threshold); + for (unsigned long i = 0; i < rect_detections.size(); ++i) + rect_detections[i].rect = pyr.rect_down(rect_detections[i].rect, + upsampling_amount); + split_rect_detections(rect_detections, rectangles, + detection_confidences, weight_indices); + + return rectangles; + } + } + else if (is_rgb_python_image(img)) + { + array2d<rgb_pixel> temp; + if (upsampling_amount == 0) + { + evaluate_detectors(detectors, numpy_rgb_image(img), rect_detections, adjust_threshold); + split_rect_detections(rect_detections, rectangles, + detection_confidences, weight_indices); + return rectangles; + } + else + { + pyramid_up(numpy_rgb_image(img), temp, pyr); + unsigned int levels = upsampling_amount-1; + while (levels > 0) + { + levels--; + pyramid_up(temp); + } + + evaluate_detectors(detectors, temp, rect_detections, adjust_threshold); + for (unsigned long i = 0; i < rect_detections.size(); ++i) + rect_detections[i].rect = pyr.rect_down(rect_detections[i].rect, + upsampling_amount); + split_rect_detections(rect_detections, rectangles, + detection_confidences, weight_indices); + + return rectangles; + } + } + else + { + throw dlib::error("Unsupported image type, must be 8bit gray or RGB image."); + } + } + + inline std::vector<dlib::rectangle> run_detector_with_upscale2 ( + dlib::simple_object_detector& detector, + py::object img, + const unsigned int upsampling_amount + + ) + { + std::vector<double> detection_confidences; + std::vector<unsigned long> weight_indices; + const double adjust_threshold = 0.0; + + return run_detector_with_upscale1(detector, img, upsampling_amount, + adjust_threshold, + detection_confidences, weight_indices); + } + + inline py::tuple run_rect_detector ( + dlib::simple_object_detector& detector, + py::object img, + const unsigned int upsampling_amount, + const double adjust_threshold) + { + py::tuple t; + + std::vector<double> detection_confidences; + std::vector<unsigned long> weight_indices; + std::vector<rectangle> rectangles; + + rectangles = run_detector_with_upscale1(detector, img, upsampling_amount, + adjust_threshold, + detection_confidences, weight_indices); + + return py::make_tuple(rectangles, + vector_to_python_list(detection_confidences), + vector_to_python_list(weight_indices)); + } + + inline py::tuple run_multiple_rect_detectors ( + py::list& detectors, + py::object img, + const unsigned int upsampling_amount, + const double adjust_threshold) + { + py::tuple t; + + std::vector<simple_object_detector > vector_detectors; + const unsigned long num_detectors = len(detectors); + // Now copy the data into dlib based objects. + for (unsigned long i = 0; i < num_detectors; ++i) + { + vector_detectors.push_back(detectors[i].cast<simple_object_detector >()); + } + + std::vector<double> detection_confidences; + std::vector<unsigned long> weight_indices; + std::vector<rectangle> rectangles; + + rectangles = run_detectors_with_upscale1(vector_detectors, img, upsampling_amount, + adjust_threshold, + detection_confidences, weight_indices); + + return py::make_tuple(rectangles, + vector_to_python_list(detection_confidences), + vector_to_python_list(weight_indices)); + } + + + + struct simple_object_detector_py + { + simple_object_detector detector; + unsigned int upsampling_amount; + + simple_object_detector_py() {} + simple_object_detector_py(simple_object_detector& _detector, unsigned int _upsampling_amount) : + detector(_detector), upsampling_amount(_upsampling_amount) {} + + std::vector<dlib::rectangle> run_detector1 (py::object img, + const unsigned int upsampling_amount_) + { + return run_detector_with_upscale2(detector, img, upsampling_amount_); + } + + std::vector<dlib::rectangle> run_detector2 (py::object img) + { + return run_detector_with_upscale2(detector, img, upsampling_amount); + } + + + }; +} + +#endif // DLIB_SIMPLE_OBJECT_DETECTOR_PY_H__ diff --git a/ml/dlib/tools/python/src/svm_c_trainer.cpp b/ml/dlib/tools/python/src/svm_c_trainer.cpp new file mode 100644 index 000000000..7b592abe7 --- /dev/null +++ b/ml/dlib/tools/python/src/svm_c_trainer.cpp @@ -0,0 +1,311 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <dlib/python.h> +#include "testing_results.h" +#include <dlib/matrix.h> +#include <dlib/svm_threaded.h> + +using namespace dlib; +using namespace std; + +typedef matrix<double,0,1> sample_type; +typedef std::vector<std::pair<unsigned long,double> > sparse_vect; + +template <typename trainer_type> +typename trainer_type::trained_function_type train ( + const trainer_type& trainer, + const std::vector<typename trainer_type::sample_type>& samples, + const std::vector<double>& labels +) +{ + pyassert(is_binary_classification_problem(samples,labels), "Invalid inputs"); + return trainer.train(samples, labels); +} + +template <typename trainer_type> +void set_epsilon ( trainer_type& trainer, double eps) +{ + pyassert(eps > 0, "epsilon must be > 0"); + trainer.set_epsilon(eps); +} + +template <typename trainer_type> +double get_epsilon ( const trainer_type& trainer) { return trainer.get_epsilon(); } + + +template <typename trainer_type> +void set_cache_size ( trainer_type& trainer, long cache_size) +{ + pyassert(cache_size > 0, "cache size must be > 0"); + trainer.set_cache_size(cache_size); +} + +template <typename trainer_type> +long get_cache_size ( const trainer_type& trainer) { return trainer.get_cache_size(); } + + +template <typename trainer_type> +void set_c ( trainer_type& trainer, double C) +{ + pyassert(C > 0, "C must be > 0"); + trainer.set_c(C); +} + +template <typename trainer_type> +void set_c_class1 ( trainer_type& trainer, double C) +{ + pyassert(C > 0, "C must be > 0"); + trainer.set_c_class1(C); +} + +template <typename trainer_type> +void set_c_class2 ( trainer_type& trainer, double C) +{ + pyassert(C > 0, "C must be > 0"); + trainer.set_c_class2(C); +} + +template <typename trainer_type> +double get_c_class1 ( const trainer_type& trainer) { return trainer.get_c_class1(); } +template <typename trainer_type> +double get_c_class2 ( const trainer_type& trainer) { return trainer.get_c_class2(); } + +template <typename trainer_type> +py::class_<trainer_type> setup_trainer_eps ( + py::module& m, + const std::string& name +) +{ + return py::class_<trainer_type>(m, name.c_str()) + .def("train", train<trainer_type>) + .def_property("epsilon", get_epsilon<trainer_type>, set_epsilon<trainer_type>); +} + +template <typename trainer_type> +py::class_<trainer_type> setup_trainer_eps_c ( + py::module& m, + const std::string& name +) +{ + return setup_trainer_eps<trainer_type>(m, name) + .def("set_c", set_c<trainer_type>) + .def_property("c_class1", get_c_class1<trainer_type>, set_c_class1<trainer_type>) + .def_property("c_class2", get_c_class2<trainer_type>, set_c_class2<trainer_type>); +} + +template <typename trainer_type> +py::class_<trainer_type> setup_trainer_eps_c_cache ( + py::module& m, + const std::string& name +) +{ + return setup_trainer_eps_c<trainer_type>(m, name) + .def_property("cache_size", get_cache_size<trainer_type>, set_cache_size<trainer_type>); +} + +template <typename trainer_type> +void set_gamma ( + trainer_type& trainer, + double gamma +) +{ + pyassert(gamma > 0, "gamma must be > 0"); + trainer.set_kernel(typename trainer_type::kernel_type(gamma)); +} + +template <typename trainer_type> +double get_gamma ( + const trainer_type& trainer +) +{ + return trainer.get_kernel().gamma; +} + +// ---------------------------------------------------------------------------------------- + +template < + typename trainer_type + > +const binary_test _cross_validate_trainer ( + const trainer_type& trainer, + const std::vector<typename trainer_type::sample_type>& x, + const std::vector<double>& y, + const unsigned long folds +) +{ + pyassert(is_binary_classification_problem(x,y), "Training data does not make a valid training set."); + pyassert(1 < folds && folds <= x.size(), "Invalid number of folds given."); + return cross_validate_trainer(trainer, x, y, folds); +} + +template < + typename trainer_type + > +const binary_test _cross_validate_trainer_t ( + const trainer_type& trainer, + const std::vector<typename trainer_type::sample_type>& x, + const std::vector<double>& y, + const unsigned long folds, + const unsigned long num_threads +) +{ + pyassert(is_binary_classification_problem(x,y), "Training data does not make a valid training set."); + pyassert(1 < folds && folds <= x.size(), "Invalid number of folds given."); + pyassert(1 < num_threads, "The number of threads specified must not be zero."); + return cross_validate_trainer_threaded(trainer, x, y, folds, num_threads); +} + +// ---------------------------------------------------------------------------------------- + +void bind_svm_c_trainer(py::module& m) +{ + namespace py = pybind11; + + // svm_c + { + typedef svm_c_trainer<radial_basis_kernel<sample_type> > T; + setup_trainer_eps_c_cache<T>(m, "svm_c_trainer_radial_basis") + .def(py::init()) + .def_property("gamma", get_gamma<T>, set_gamma<T>); + m.def("cross_validate_trainer", _cross_validate_trainer<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds")); + m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads")); + } + + { + typedef svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> > T; + setup_trainer_eps_c_cache<T>(m, "svm_c_trainer_sparse_radial_basis") + .def(py::init()) + .def_property("gamma", get_gamma<T>, set_gamma<T>); + m.def("cross_validate_trainer", _cross_validate_trainer<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds")); + m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads")); + } + + { + typedef svm_c_trainer<histogram_intersection_kernel<sample_type> > T; + setup_trainer_eps_c_cache<T>(m, "svm_c_trainer_histogram_intersection") + .def(py::init()); + m.def("cross_validate_trainer", _cross_validate_trainer<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds")); + m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads")); + } + + { + typedef svm_c_trainer<sparse_histogram_intersection_kernel<sparse_vect> > T; + setup_trainer_eps_c_cache<T>(m, "svm_c_trainer_sparse_histogram_intersection") + .def(py::init()); + m.def("cross_validate_trainer", _cross_validate_trainer<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds")); + m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads")); + } + + // svm_c_linear + { + typedef svm_c_linear_trainer<linear_kernel<sample_type> > T; + setup_trainer_eps_c<T>(m, "svm_c_trainer_linear") + .def(py::init()) + .def_property("max_iterations", &T::get_max_iterations, &T::set_max_iterations) + .def_property("force_last_weight_to_1", &T::forces_last_weight_to_1, &T::force_last_weight_to_1) + .def_property("learns_nonnegative_weights", &T::learns_nonnegative_weights, &T::set_learns_nonnegative_weights) + .def_property_readonly("has_prior", &T::has_prior) + .def("set_prior", &T::set_prior) + .def("be_verbose", &T::be_verbose) + .def("be_quiet", &T::be_quiet); + + m.def("cross_validate_trainer", _cross_validate_trainer<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds")); + m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads")); + } + + { + typedef svm_c_linear_trainer<sparse_linear_kernel<sparse_vect> > T; + setup_trainer_eps_c<T>(m, "svm_c_trainer_sparse_linear") + .def(py::init()) + .def_property("max_iterations", &T::get_max_iterations, &T::set_max_iterations) + .def_property("force_last_weight_to_1", &T::forces_last_weight_to_1, &T::force_last_weight_to_1) + .def_property("learns_nonnegative_weights", &T::learns_nonnegative_weights, &T::set_learns_nonnegative_weights) + .def_property_readonly("has_prior", &T::has_prior) + .def("set_prior", &T::set_prior) + .def("be_verbose", &T::be_verbose) + .def("be_quiet", &T::be_quiet); + + m.def("cross_validate_trainer", _cross_validate_trainer<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds")); + m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads")); + } + + // rvm + { + typedef rvm_trainer<radial_basis_kernel<sample_type> > T; + setup_trainer_eps<T>(m, "rvm_trainer_radial_basis") + .def(py::init()) + .def_property("gamma", get_gamma<T>, set_gamma<T>); + m.def("cross_validate_trainer", _cross_validate_trainer<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds")); + m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads")); + } + + { + typedef rvm_trainer<sparse_radial_basis_kernel<sparse_vect> > T; + setup_trainer_eps<T>(m, "rvm_trainer_sparse_radial_basis") + .def(py::init()) + .def_property("gamma", get_gamma<T>, set_gamma<T>); + m.def("cross_validate_trainer", _cross_validate_trainer<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds")); + m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads")); + } + + { + typedef rvm_trainer<histogram_intersection_kernel<sample_type> > T; + setup_trainer_eps<T>(m, "rvm_trainer_histogram_intersection") + .def(py::init()); + m.def("cross_validate_trainer", _cross_validate_trainer<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds")); + m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads")); + } + + { + typedef rvm_trainer<sparse_histogram_intersection_kernel<sparse_vect> > T; + setup_trainer_eps<T>(m, "rvm_trainer_sparse_histogram_intersection") + .def(py::init()); + m.def("cross_validate_trainer", _cross_validate_trainer<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds")); + m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads")); + } + + // rvm linear + { + typedef rvm_trainer<linear_kernel<sample_type> > T; + setup_trainer_eps<T>(m, "rvm_trainer_linear") + .def(py::init()); + m.def("cross_validate_trainer", _cross_validate_trainer<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds")); + m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads")); + } + + { + typedef rvm_trainer<sparse_linear_kernel<sparse_vect> > T; + setup_trainer_eps<T>(m, "rvm_trainer_sparse_linear") + .def(py::init()); + m.def("cross_validate_trainer", _cross_validate_trainer<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds")); + m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, + py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads")); + } +} + + diff --git a/ml/dlib/tools/python/src/svm_rank_trainer.cpp b/ml/dlib/tools/python/src/svm_rank_trainer.cpp new file mode 100644 index 000000000..26cf3111a --- /dev/null +++ b/ml/dlib/tools/python/src/svm_rank_trainer.cpp @@ -0,0 +1,161 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/matrix.h> +#include <dlib/svm.h> +#include "testing_results.h" +#include <pybind11/stl_bind.h> + +using namespace dlib; +using namespace std; +namespace py = pybind11; + +typedef matrix<double,0,1> sample_type; + + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + template <typename T> + bool operator== ( + const ranking_pair<T>&, + const ranking_pair<T>& + ) + { + pyassert(false, "It is illegal to compare ranking pair objects for equality."); + return false; + } +} + +template <typename T> +void resize(T& v, unsigned long n) { v.resize(n); } + +// ---------------------------------------------------------------------------------------- + +template <typename trainer_type> +typename trainer_type::trained_function_type train1 ( + const trainer_type& trainer, + const ranking_pair<typename trainer_type::sample_type>& sample +) +{ + typedef ranking_pair<typename trainer_type::sample_type> st; + pyassert(is_ranking_problem(std::vector<st>(1, sample)), "Invalid inputs"); + return trainer.train(sample); +} + +template <typename trainer_type> +typename trainer_type::trained_function_type train2 ( + const trainer_type& trainer, + const std::vector<ranking_pair<typename trainer_type::sample_type> >& samples +) +{ + pyassert(is_ranking_problem(samples), "Invalid inputs"); + return trainer.train(samples); +} + +template <typename trainer_type> +void set_epsilon ( trainer_type& trainer, double eps) +{ + pyassert(eps > 0, "epsilon must be > 0"); + trainer.set_epsilon(eps); +} + +template <typename trainer_type> +double get_epsilon ( const trainer_type& trainer) { return trainer.get_epsilon(); } + +template <typename trainer_type> +void set_c ( trainer_type& trainer, double C) +{ + pyassert(C > 0, "C must be > 0"); + trainer.set_c(C); +} + +template <typename trainer_type> +double get_c (const trainer_type& trainer) +{ + return trainer.get_c(); +} + + +template <typename trainer> +void add_ranker ( + py::module& m, + const char* name +) +{ + py::class_<trainer>(m, name) + .def(py::init()) + .def_property("epsilon", get_epsilon<trainer>, set_epsilon<trainer>) + .def_property("c", get_c<trainer>, set_c<trainer>) + .def_property("max_iterations", &trainer::get_max_iterations, &trainer::set_max_iterations) + .def_property("force_last_weight_to_1", &trainer::forces_last_weight_to_1, &trainer::force_last_weight_to_1) + .def_property("learns_nonnegative_weights", &trainer::learns_nonnegative_weights, &trainer::set_learns_nonnegative_weights) + .def_property_readonly("has_prior", &trainer::has_prior) + .def("train", train1<trainer>) + .def("train", train2<trainer>) + .def("set_prior", &trainer::set_prior) + .def("be_verbose", &trainer::be_verbose) + .def("be_quiet", &trainer::be_quiet); +} + +// ---------------------------------------------------------------------------------------- + +template < + typename trainer_type, + typename T + > +const ranking_test _cross_ranking_validate_trainer ( + const trainer_type& trainer, + const std::vector<ranking_pair<T> >& samples, + const unsigned long folds +) +{ + pyassert(is_ranking_problem(samples), "Training data does not make a valid training set."); + pyassert(1 < folds && folds <= samples.size(), "Invalid number of folds given."); + return cross_validate_ranking_trainer(trainer, samples, folds); +} + +// ---------------------------------------------------------------------------------------- + +void bind_svm_rank_trainer(py::module& m) +{ + py::class_<ranking_pair<sample_type> >(m, "ranking_pair") + .def(py::init()) + .def_readwrite("relevant", &ranking_pair<sample_type>::relevant) + .def_readwrite("nonrelevant", &ranking_pair<sample_type>::nonrelevant) + .def(py::pickle(&getstate<ranking_pair<sample_type>>, &setstate<ranking_pair<sample_type>>)); + + py::class_<ranking_pair<sparse_vect> >(m, "sparse_ranking_pair") + .def(py::init()) + .def_readwrite("relevant", &ranking_pair<sparse_vect>::relevant) + .def_readwrite("nonrelevant", &ranking_pair<sparse_vect>::nonrelevant) + .def(py::pickle(&getstate<ranking_pair<sparse_vect>>, &setstate<ranking_pair<sparse_vect>>)); + + py::bind_vector<ranking_pairs>(m, "ranking_pairs") + .def("clear", &ranking_pairs::clear) + .def("resize", resize<ranking_pairs>) + .def("extend", extend_vector_with_python_list<ranking_pair<sample_type>>) + .def(py::pickle(&getstate<ranking_pairs>, &setstate<ranking_pairs>)); + + py::bind_vector<sparse_ranking_pairs>(m, "sparse_ranking_pairs") + .def("clear", &sparse_ranking_pairs::clear) + .def("resize", resize<sparse_ranking_pairs>) + .def("extend", extend_vector_with_python_list<ranking_pair<sparse_vect>>) + .def(py::pickle(&getstate<sparse_ranking_pairs>, &setstate<sparse_ranking_pairs>)); + + add_ranker<svm_rank_trainer<linear_kernel<sample_type> > >(m, "svm_rank_trainer"); + add_ranker<svm_rank_trainer<sparse_linear_kernel<sparse_vect> > >(m, "svm_rank_trainer_sparse"); + + m.def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer< + svm_rank_trainer<linear_kernel<sample_type> >,sample_type>, + py::arg("trainer"), py::arg("samples"), py::arg("folds") ); + m.def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer< + svm_rank_trainer<sparse_linear_kernel<sparse_vect> > ,sparse_vect>, + py::arg("trainer"), py::arg("samples"), py::arg("folds") ); +} + + + diff --git a/ml/dlib/tools/python/src/svm_struct.cpp b/ml/dlib/tools/python/src/svm_struct.cpp new file mode 100644 index 000000000..d8ebad957 --- /dev/null +++ b/ml/dlib/tools/python/src/svm_struct.cpp @@ -0,0 +1,151 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/matrix.h> +#include <dlib/svm.h> + +using namespace dlib; +using namespace std; +namespace py = pybind11; + +template <typename psi_type> +class svm_struct_prob : public structural_svm_problem<matrix<double,0,1>, psi_type> +{ + typedef structural_svm_problem<matrix<double,0,1>, psi_type> base; + typedef typename base::feature_vector_type feature_vector_type; + typedef typename base::matrix_type matrix_type; + typedef typename base::scalar_type scalar_type; +public: + svm_struct_prob ( + py::object& problem_, + long num_dimensions_, + long num_samples_ + ) : + num_dimensions(num_dimensions_), + num_samples(num_samples_), + problem(problem_) + {} + + virtual long get_num_dimensions ( + ) const { return num_dimensions; } + + virtual long get_num_samples ( + ) const { return num_samples; } + + virtual void get_truth_joint_feature_vector ( + long idx, + feature_vector_type& psi + ) const + { + psi = problem.attr("get_truth_joint_feature_vector")(idx).template cast<feature_vector_type&>(); + } + + virtual void separation_oracle ( + const long idx, + const matrix_type& current_solution, + scalar_type& loss, + feature_vector_type& psi + ) const + { + py::object res = problem.attr("separation_oracle")(idx,std::ref(current_solution)); + pyassert(len(res) == 2, "separation_oracle() must return two objects, the loss and the psi vector"); + py::tuple t = res.cast<py::tuple>(); + // let the user supply the output arguments in any order. + try { + loss = t[0].cast<scalar_type>(); + psi = t[1].cast<feature_vector_type&>(); + } catch(py::cast_error &e) { + psi = t[0].cast<feature_vector_type&>(); + loss = t[1].cast<scalar_type>(); + } + } + +private: + + const long num_dimensions; + const long num_samples; + py::object& problem; +}; + +// ---------------------------------------------------------------------------------------- + +template <typename psi_type> +matrix<double,0,1> solve_structural_svm_problem_impl( + py::object problem +) +{ + const double C = problem.attr("C").cast<double>(); + const bool be_verbose = py::hasattr(problem,"be_verbose") && problem.attr("be_verbose").cast<bool>(); + const bool use_sparse_feature_vectors = py::hasattr(problem,"use_sparse_feature_vectors") && + problem.attr("use_sparse_feature_vectors").cast<bool>(); + const bool learns_nonnegative_weights = py::hasattr(problem,"learns_nonnegative_weights") && + problem.attr("learns_nonnegative_weights").cast<bool>(); + + double eps = 0.001; + unsigned long max_cache_size = 10; + if (py::hasattr(problem, "epsilon")) + eps = problem.attr("epsilon").cast<double>(); + if (py::hasattr(problem, "max_cache_size")) + max_cache_size = problem.attr("max_cache_size").cast<double>(); + + const long num_samples = problem.attr("num_samples").cast<long>(); + const long num_dimensions = problem.attr("num_dimensions").cast<long>(); + + pyassert(num_samples > 0, "You can't train a Structural-SVM if you don't have any training samples."); + + if (be_verbose) + { + cout << "C: " << C << endl; + cout << "epsilon: " << eps << endl; + cout << "max_cache_size: " << max_cache_size << endl; + cout << "num_samples: " << num_samples << endl; + cout << "num_dimensions: " << num_dimensions << endl; + cout << "use_sparse_feature_vectors: " << std::boolalpha << use_sparse_feature_vectors << endl; + cout << "learns_nonnegative_weights: " << std::boolalpha << learns_nonnegative_weights << endl; + cout << endl; + } + + svm_struct_prob<psi_type> prob(problem, num_dimensions, num_samples); + prob.set_c(C); + prob.set_epsilon(eps); + prob.set_max_cache_size(max_cache_size); + if (be_verbose) + prob.be_verbose(); + + oca solver; + matrix<double,0,1> w; + if (learns_nonnegative_weights) + solver(prob, w, prob.get_num_dimensions()); + else + solver(prob, w); + return w; +} + +// ---------------------------------------------------------------------------------------- + +matrix<double,0,1> solve_structural_svm_problem( + py::object problem +) +{ + // Check if the python code is using sparse or dense vectors to represent PSI() + if (py::isinstance<matrix<double,0,1>>(problem.attr("get_truth_joint_feature_vector")(0))) + return solve_structural_svm_problem_impl<matrix<double,0,1> >(problem); + else + return solve_structural_svm_problem_impl<std::vector<std::pair<unsigned long,double> > >(problem); +} + +// ---------------------------------------------------------------------------------------- + +void bind_svm_struct(py::module& m) +{ + m.def("solve_structural_svm_problem",solve_structural_svm_problem, py::arg("problem"), +"This function solves a structural SVM problem and returns the weight vector \n\ +that defines the solution. See the example program python_examples/svm_struct.py \n\ +for documentation about how to create a proper problem object. " + ); +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/tools/python/src/testing_results.h b/ml/dlib/tools/python/src/testing_results.h new file mode 100644 index 000000000..746e2934a --- /dev/null +++ b/ml/dlib/tools/python/src/testing_results.h @@ -0,0 +1,50 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TESTING_ReSULTS_H__ +#define DLIB_TESTING_ReSULTS_H__ + +#include <dlib/matrix.h> + +struct binary_test +{ + binary_test() : class1_accuracy(0), class2_accuracy(0) {} + binary_test( + const dlib::matrix<double,1,2>& m + ) : class1_accuracy(m(0)), + class2_accuracy(m(1)) {} + + double class1_accuracy; + double class2_accuracy; +}; + +struct regression_test +{ + regression_test() = default; + regression_test( + const dlib::matrix<double,1,4>& m + ) : mean_squared_error(m(0)), + R_squared(m(1)), + mean_average_error(m(2)), + mean_error_stddev(m(3)) + {} + + double mean_squared_error = 0; + double R_squared = 0; + double mean_average_error = 0; + double mean_error_stddev = 0; +}; + +struct ranking_test +{ + ranking_test() : ranking_accuracy(0), mean_ap(0) {} + ranking_test( + const dlib::matrix<double,1,2>& m + ) : ranking_accuracy(m(0)), + mean_ap(m(1)) {} + + double ranking_accuracy; + double mean_ap; +}; + +#endif // DLIB_TESTING_ReSULTS_H__ + diff --git a/ml/dlib/tools/python/src/vector.cpp b/ml/dlib/tools/python/src/vector.cpp new file mode 100644 index 000000000..a9f81c65e --- /dev/null +++ b/ml/dlib/tools/python/src/vector.cpp @@ -0,0 +1,182 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "opaque_types.h" +#include <dlib/python.h> +#include <dlib/matrix.h> +#include <dlib/geometry/vector.h> +#include <pybind11/stl_bind.h> +#include "indexing.h" + +using namespace dlib; +using namespace std; + +typedef matrix<double,0,1> cv; + + +void cv_set_size(cv& m, long s) +{ + m.set_size(s); + m = 0; +} + +double dotprod ( const cv& a, const cv& b) +{ + return dot(a,b); +} + +string cv__str__(const cv& v) +{ + ostringstream sout; + for (long i = 0; i < v.size(); ++i) + { + sout << v(i); + if (i+1 < v.size()) + sout << "\n"; + } + return sout.str(); +} + +string cv__repr__ (const cv& v) +{ + std::ostringstream sout; + sout << "dlib.vector(["; + for (long i = 0; i < v.size(); ++i) + { + sout << v(i); + if (i+1 < v.size()) + sout << ", "; + } + sout << "])"; + return sout.str(); +} + +std::shared_ptr<cv> cv_from_object(py::object obj) +{ + try { + long nr = obj.cast<long>(); + auto temp = std::make_shared<cv>(nr); + *temp = 0; + return temp; + } catch(py::cast_error &e) { + py::list li = obj.cast<py::list>(); + const long nr = len(obj); + auto temp = std::make_shared<cv>(nr); + for ( long r = 0; r < nr; ++r) + { + (*temp)(r) = li[r].cast<double>(); + } + return temp; + } +} + +long cv__len__(cv& c) +{ + return c.size(); +} + + +void cv__setitem__(cv& c, long p, double val) +{ + if (p < 0) { + p = c.size() + p; // negative index + } + if (p > c.size()-1) { + PyErr_SetString( PyExc_IndexError, "index out of range" + ); + throw py::error_already_set(); + } + c(p) = val; +} + +double cv__getitem__(cv& m, long r) +{ + if (r < 0) { + r = m.size() + r; // negative index + } + if (r > m.size()-1 || r < 0) { + PyErr_SetString( PyExc_IndexError, "index out of range" + ); + throw py::error_already_set(); + } + return m(r); +} + + +cv cv__getitem2__(cv& m, py::slice r) +{ + size_t start, stop, step, slicelength; + if (!r.compute(m.size(), &start, &stop, &step, &slicelength)) + throw py::error_already_set(); + + cv temp(slicelength); + + for (size_t i = 0; i < slicelength; ++i) { + temp(i) = m(start); start += step; + } + return temp; +} + +py::tuple cv_get_matrix_size(cv& m) +{ + return py::make_tuple(m.nr(), m.nc()); +} + +// ---------------------------------------------------------------------------------------- + +string point__repr__ (const point& p) +{ + std::ostringstream sout; + sout << "point(" << p.x() << ", " << p.y() << ")"; + return sout.str(); +} + +string point__str__(const point& p) +{ + std::ostringstream sout; + sout << "(" << p.x() << ", " << p.y() << ")"; + return sout.str(); +} + +long point_x(const point& p) { return p.x(); } +long point_y(const point& p) { return p.y(); } + +// ---------------------------------------------------------------------------------------- +void bind_vector(py::module& m) +{ + { + py::class_<cv, std::shared_ptr<cv>>(m, "vector", "This object represents the mathematical idea of a column vector.") + .def(py::init()) + .def("set_size", &cv_set_size) + .def("resize", &cv_set_size) + .def(py::init(&cv_from_object)) + .def("__repr__", &cv__repr__) + .def("__str__", &cv__str__) + .def("__len__", &cv__len__) + .def("__getitem__", &cv__getitem__) + .def("__getitem__", &cv__getitem2__) + .def("__setitem__", &cv__setitem__) + .def_property_readonly("shape", &cv_get_matrix_size) + .def(py::pickle(&getstate<cv>, &setstate<cv>)); + + m.def("dot", &dotprod, "Compute the dot product between two dense column vectors."); + } + { + typedef point type; + py::class_<type>(m, "point", "This object represents a single point of integer coordinates that maps directly to a dlib::point.") + .def(py::init<long,long>(), py::arg("x"), py::arg("y")) + .def("__repr__", &point__repr__) + .def("__str__", &point__str__) + .def_property("x", &point_x, [](point& p, long x){p.x()=x;}, "The x-coordinate of the point.") + .def_property("y", &point_y, [](point& p, long y){p.x()=y;}, "The y-coordinate of the point.") + .def(py::pickle(&getstate<type>, &setstate<type>)); + } + { + typedef std::vector<point> type; + py::bind_vector<type>(m, "points", "An array of point objects.") + .def("clear", &type::clear) + .def("resize", resize<type>) + .def("extend", extend_vector_with_python_list<point>) + .def(py::pickle(&getstate<type>, &setstate<type>)); + } +} diff --git a/ml/dlib/tools/python/test/.gitignore b/ml/dlib/tools/python/test/.gitignore new file mode 100644 index 000000000..bee8a64b7 --- /dev/null +++ b/ml/dlib/tools/python/test/.gitignore @@ -0,0 +1 @@ +__pycache__ diff --git a/ml/dlib/tools/python/test/test_array.py b/ml/dlib/tools/python/test/test_array.py new file mode 100644 index 000000000..479997ac3 --- /dev/null +++ b/ml/dlib/tools/python/test/test_array.py @@ -0,0 +1,107 @@ +from dlib import array +try: + import cPickle as pickle # Use cPickle on Python 2.7 +except ImportError: + import pickle + +try: + from types import FloatType +except ImportError: + FloatType = float + +from pytest import raises + + +def test_array_init_with_number(): + a = array(5) + assert len(a) == 5 + for i in range(5): + assert a[i] == 0 + assert type(a[i]) == FloatType + + +def test_array_init_with_negative_number(): + with raises(Exception): + array(-5) + + +def test_array_init_with_zero(): + a = array(0) + assert len(a) == 0 + + +def test_array_init_with_list(): + a = array([0, 1, 2, 3, 4]) + assert len(a) == 5 + for idx, val in enumerate(a): + assert idx == val + assert type(val) == FloatType + + +def test_array_init_with_empty_list(): + a = array([]) + assert len(a) == 0 + + +def test_array_init_without_argument(): + a = array() + assert len(a) == 0 + + +def test_array_init_with_tuple(): + a = array((0, 1, 2, 3, 4)) + for idx, val in enumerate(a): + assert idx == val + assert type(val) == FloatType + + +def test_array_serialization_empty(): + a = array() + # cPickle with protocol 2 required for Python 2.7 + # see http://pybind11.readthedocs.io/en/stable/advanced/classes.html#custom-constructors + ser = pickle.dumps(a, 2) + deser = pickle.loads(ser) + assert a == deser + + +def test_array_serialization(): + a = array([0, 1, 2, 3, 4]) + ser = pickle.dumps(a, 2) + deser = pickle.loads(ser) + assert a == deser + + +def test_array_extend(): + a = array() + a.extend([0, 1, 2, 3, 4]) + assert len(a) == 5 + for idx, val in enumerate(a): + assert idx == val + assert type(val) == FloatType + + +def test_array_string_representations_empty(): + a = array() + assert str(a) == "" + assert repr(a) == "array[]" + + +def test_array_string_representations(): + a = array([1, 2, 3]) + assert str(a) == "1\n2\n3" + assert repr(a) == "array[1, 2, 3]" + + +def test_array_clear(): + a = array(10) + a.clear() + assert len(a) == 0 + + +def test_array_resize(): + a = array(10) + a.resize(100) + assert len(a) == 100 + + for i in range(100): + assert a[i] == 0 diff --git a/ml/dlib/tools/python/test/test_global_optimization.py b/ml/dlib/tools/python/test/test_global_optimization.py new file mode 100644 index 000000000..ec320909f --- /dev/null +++ b/ml/dlib/tools/python/test/test_global_optimization.py @@ -0,0 +1,69 @@ +from dlib import find_max_global, find_min_global +import dlib +from pytest import raises +from math import sin,cos,pi,exp,sqrt,pow + + +def test_global_optimization_nargs(): + w0 = find_max_global(lambda *args: sum(args), [0, 0, 0], [1, 1, 1], 10) + w1 = find_min_global(lambda *args: sum(args), [0, 0, 0], [1, 1, 1], 10) + assert w0 == ([1, 1, 1], 3) + assert w1 == ([0, 0, 0], 0) + + w2 = find_max_global(lambda a, b, c, *args: a + b + c - sum(args), [0, 0, 0], [1, 1, 1], 10) + w3 = find_min_global(lambda a, b, c, *args: a + b + c - sum(args), [0, 0, 0], [1, 1, 1], 10) + assert w2 == ([1, 1, 1], 3) + assert w3 == ([0, 0, 0], 0) + + with raises(Exception): + find_max_global(lambda a, b: 0, [0, 0, 0], [1, 1, 1], 10) + with raises(Exception): + find_min_global(lambda a, b: 0, [0, 0, 0], [1, 1, 1], 10) + with raises(Exception): + find_max_global(lambda a, b, c, d, *args: 0, [0, 0, 0], [1, 1, 1], 10) + with raises(Exception): + find_min_global(lambda a, b, c, d, *args: 0, [0, 0, 0], [1, 1, 1], 10) + + +def F(a,b): + return -pow(a-2,2.0) - pow(b-4,2.0); +def G(x): + return 2-pow(x-5,2.0); + +def test_global_function_search(): + spec_F = dlib.function_spec([-10,-10], [10,10]) + spec_G = dlib.function_spec([-2], [6]) + + opt = dlib.global_function_search([spec_F, spec_G]) + + for i in range(15): + next = opt.get_next_x() + #print("next x is for function {} and has coordinates {}".format(next.function_idx, next.x)) + + if (next.function_idx == 0): + a = next.x[0] + b = next.x[1] + next.set(F(a,b)) + else: + x = next.x[0] + next.set(G(x)) + + [x,y,function_idx] = opt.get_best_function_eval() + + #print("\nbest function was {}, with y of {}, and x of {}".format(function_idx,y,x)) + + assert(abs(y-2) < 1e-7) + assert(abs(x[0]-5) < 1e-7) + assert(function_idx==1) + + + +def holder_table(x0,x1): + return -abs(sin(x0)*cos(x1)*exp(abs(1-sqrt(x0*x0+x1*x1)/pi))) + +def test_on_holder_table(): + x,y = find_min_global(holder_table, + [-10,-10], + [10,10], + 200) + assert (y - -19.2085025679) < 1e-7 diff --git a/ml/dlib/tools/python/test/test_matrix.py b/ml/dlib/tools/python/test/test_matrix.py new file mode 100644 index 000000000..cdd9bed13 --- /dev/null +++ b/ml/dlib/tools/python/test/test_matrix.py @@ -0,0 +1,100 @@ +from dlib import matrix +try: + import cPickle as pickle # Use cPickle on Python 2.7 +except ImportError: + import pickle +from pytest import raises + +try: + import numpy + have_numpy = True +except ImportError: + have_numpy = False + + +def test_matrix_empty_init(): + m = matrix() + assert m.nr() == 0 + assert m.nc() == 0 + assert m.shape == (0, 0) + assert len(m) == 0 + assert repr(m) == "< dlib.matrix containing: >" + assert str(m) == "" + + +def test_matrix_from_list(): + m = matrix([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + assert m.nr() == 3 + assert m.nc() == 3 + assert m.shape == (3, 3) + assert len(m) == 3 + assert repr(m) == "< dlib.matrix containing: \n0 1 2 \n3 4 5 \n6 7 8 >" + assert str(m) == "0 1 2 \n3 4 5 \n6 7 8" + + deser = pickle.loads(pickle.dumps(m, 2)) + + for row in range(3): + for col in range(3): + assert m[row][col] == deser[row][col] + + +def test_matrix_from_list_with_invalid_rows(): + with raises(ValueError): + matrix([[0, 1, 2], + [3, 4], + [5, 6, 7]]) + + +def test_matrix_from_list_as_column_vector(): + m = matrix([0, 1, 2]) + assert m.nr() == 3 + assert m.nc() == 1 + assert m.shape == (3, 1) + assert len(m) == 3 + assert repr(m) == "< dlib.matrix containing: \n0 \n1 \n2 >" + assert str(m) == "0 \n1 \n2" + + +if have_numpy: + def test_matrix_from_object_with_2d_shape(): + m1 = numpy.array([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + m = matrix(m1) + assert m.nr() == 3 + assert m.nc() == 3 + assert m.shape == (3, 3) + assert len(m) == 3 + assert repr(m) == "< dlib.matrix containing: \n0 1 2 \n3 4 5 \n6 7 8 >" + assert str(m) == "0 1 2 \n3 4 5 \n6 7 8" + + + def test_matrix_from_object_without_2d_shape(): + with raises(IndexError): + m1 = numpy.array([0, 1, 2]) + matrix(m1) + + +def test_matrix_from_object_without_shape(): + with raises(AttributeError): + matrix("invalid") + + +def test_matrix_set_size(): + m = matrix() + m.set_size(5, 5) + + assert m.nr() == 5 + assert m.nc() == 5 + assert m.shape == (5, 5) + assert len(m) == 5 + assert repr(m) == "< dlib.matrix containing: \n0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0 >" + assert str(m) == "0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0" + + deser = pickle.loads(pickle.dumps(m, 2)) + + for row in range(5): + for col in range(5): + assert m[row][col] == deser[row][col] diff --git a/ml/dlib/tools/python/test/test_point.py b/ml/dlib/tools/python/test/test_point.py new file mode 100644 index 000000000..75b8c191f --- /dev/null +++ b/ml/dlib/tools/python/test/test_point.py @@ -0,0 +1,48 @@ +from dlib import point, points +try: + import cPickle as pickle # Use cPickle on Python 2.7 +except ImportError: + import pickle + + +def test_point(): + p = point(27, 42) + assert repr(p) == "point(27, 42)" + assert str(p) == "(27, 42)" + assert p.x == 27 + assert p.y == 42 + ser = pickle.dumps(p, 2) + deser = pickle.loads(ser) + assert deser.x == p.x + assert deser.y == p.y + + +def test_point_init_kwargs(): + p = point(y=27, x=42) + assert repr(p) == "point(42, 27)" + assert str(p) == "(42, 27)" + assert p.x == 42 + assert p.y == 27 + + +def test_points(): + ps = points() + + ps.resize(5) + assert len(ps) == 5 + for i in range(5): + assert ps[i].x == 0 + assert ps[i].y == 0 + + ps.clear() + assert len(ps) == 0 + + ps.extend([point(1, 2), point(3, 4)]) + assert len(ps) == 2 + + ser = pickle.dumps(ps, 2) + deser = pickle.loads(ser) + assert deser[0].x == 1 + assert deser[0].y == 2 + assert deser[1].x == 3 + assert deser[1].y == 4 diff --git a/ml/dlib/tools/python/test/test_range.py b/ml/dlib/tools/python/test/test_range.py new file mode 100644 index 000000000..c881da369 --- /dev/null +++ b/ml/dlib/tools/python/test/test_range.py @@ -0,0 +1,97 @@ +from dlib import range, ranges, rangess +try: + import cPickle as pickle # Use cPickle on Python 2.7 +except ImportError: + import pickle +from pytest import raises + + +def test_range(): + r = range(0, 10) + assert r.begin == 0 + assert r.end == 10 + assert str(r) == "0, 10" + assert repr(r) == "dlib.range(0, 10)" + assert len(r) == 10 + + ser = pickle.dumps(r, 2) + deser = pickle.loads(ser) + + for a, b in zip(r, deser): + assert a == b + + +# TODO: make this init parameterization an exception? +def test_range_wrong_order(): + r = range(5, 0) + assert r.begin == 5 + assert r.end == 0 + assert str(r) == "5, 0" + assert repr(r) == "dlib.range(5, 0)" + assert len(r) == 0 + + +def test_range_with_negative_elements(): + with raises(TypeError): + range(-1, 1) + with raises(TypeError): + range(1, -1) + + +def test_ranges(): + rs = ranges() + assert len(rs) == 0 + + rs.resize(5) + assert len(rs) == 5 + for r in rs: + assert r.begin == 0 + assert r.end == 0 + + rs.clear() + assert len(rs) == 0 + + rs.extend([range(1, 2), range(3, 4)]) + assert rs[0].begin == 1 + assert rs[0].end == 2 + assert rs[1].begin == 3 + assert rs[1].end == 4 + + ser = pickle.dumps(rs, 2) + deser = pickle.loads(ser) + assert rs == deser + + +def test_rangess(): + rss = rangess() + assert len(rss) == 0 + + rss.resize(5) + assert len(rss) == 5 + for rs in rss: + assert len(rs) == 0 + + rss.clear() + assert len(rss) == 0 + + rs1 = ranges() + rs1.append(range(1, 2)) + rs1.append(range(3, 4)) + + rs2 = ranges() + rs2.append(range(5, 6)) + rs2.append(range(7, 8)) + + rss.extend([rs1, rs2]) + assert rss[0][0].begin == 1 + assert rss[0][1].begin == 3 + assert rss[1][0].begin == 5 + assert rss[1][1].begin == 7 + assert rss[0][0].end == 2 + assert rss[0][1].end == 4 + assert rss[1][0].end == 6 + assert rss[1][1].end == 8 + + ser = pickle.dumps(rss, 2) + deser = pickle.loads(ser) + assert rss == deser diff --git a/ml/dlib/tools/python/test/test_rgb_pixel.py b/ml/dlib/tools/python/test/test_rgb_pixel.py new file mode 100644 index 000000000..0b3aaf5e9 --- /dev/null +++ b/ml/dlib/tools/python/test/test_rgb_pixel.py @@ -0,0 +1,26 @@ +from dlib import rgb_pixel + + +def test_rgb_pixel(): + p = rgb_pixel(0, 50, 100) + assert p.red == 0 + assert p.green == 50 + assert p.blue == 100 + assert str(p) == "red: 0, green: 50, blue: 100" + assert repr(p) == "rgb_pixel(0,50,100)" + + p = rgb_pixel(blue=0, red=50, green=100) + assert p.red == 50 + assert p.green == 100 + assert p.blue == 0 + assert str(p) == "red: 50, green: 100, blue: 0" + assert repr(p) == "rgb_pixel(50,100,0)" + + p.red = 100 + p.green = 0 + p.blue = 50 + assert p.red == 100 + assert p.green == 0 + assert p.blue == 50 + assert str(p) == "red: 100, green: 0, blue: 50" + assert repr(p) == "rgb_pixel(100,0,50)" diff --git a/ml/dlib/tools/python/test/test_sparse_vector.py b/ml/dlib/tools/python/test/test_sparse_vector.py new file mode 100644 index 000000000..124e68d5d --- /dev/null +++ b/ml/dlib/tools/python/test/test_sparse_vector.py @@ -0,0 +1,101 @@ +from dlib import pair, make_sparse_vector, sparse_vector, sparse_vectors, sparse_vectorss +try: + import cPickle as pickle # Use cPickle on Python 2.7 +except ImportError: + import pickle +from pytest import approx + + +def test_pair(): + p = pair(4, .9) + assert p.first == 4 + assert p.second == .9 + + p.first = 3 + p.second = .4 + + assert p.first == 3 + assert p.second == .4 + + assert str(p) == "3: 0.4" + assert repr(p) == "dlib.pair(3, 0.4)" + + deser = pickle.loads(pickle.dumps(p, 2)) + assert deser.first == p.first + assert deser.second == p.second + + +def test_sparse_vector(): + sv = sparse_vector() + sv.append(pair(3, .1)) + sv.append(pair(3, .2)) + sv.append(pair(2, .3)) + sv.append(pair(1, .4)) + + assert len(sv) == 4 + make_sparse_vector(sv) + + assert len(sv) == 3 + assert sv[0].first == 1 + assert sv[0].second == .4 + assert sv[1].first == 2 + assert sv[1].second == .3 + assert sv[2].first == 3 + assert sv[2].second == approx(.3) + + assert str(sv) == "1: 0.4\n2: 0.3\n3: 0.3" + assert repr(sv) == "< dlib.sparse_vector containing: \n1: 0.4\n2: 0.3\n3: 0.3 >" + + +def test_sparse_vectors(): + svs = sparse_vectors() + assert len(svs) == 0 + + svs.resize(5) + for sv in svs: + assert len(sv) == 0 + + svs.clear() + assert len(svs) == 0 + + svs.extend([sparse_vector([pair(1, 2), pair(3, 4)]), sparse_vector([pair(5, 6), pair(7, 8)])]) + + assert len(svs) == 2 + assert svs[0][0].first == 1 + assert svs[0][0].second == 2 + assert svs[0][1].first == 3 + assert svs[0][1].second == 4 + assert svs[1][0].first == 5 + assert svs[1][0].second == 6 + assert svs[1][1].first == 7 + assert svs[1][1].second == 8 + + deser = pickle.loads(pickle.dumps(svs, 2)) + assert deser == svs + + +def test_sparse_vectorss(): + svss = sparse_vectorss() + assert len(svss) == 0 + + svss.resize(5) + for svs in svss: + assert len(svs) == 0 + + svss.clear() + assert len(svss) == 0 + + svss.extend([sparse_vectors([sparse_vector([pair(1, 2), pair(3, 4)]), sparse_vector([pair(5, 6), pair(7, 8)])])]) + + assert len(svss) == 1 + assert svss[0][0][0].first == 1 + assert svss[0][0][0].second == 2 + assert svss[0][0][1].first == 3 + assert svss[0][0][1].second == 4 + assert svss[0][1][0].first == 5 + assert svss[0][1][0].second == 6 + assert svss[0][1][1].first == 7 + assert svss[0][1][1].second == 8 + + deser = pickle.loads(pickle.dumps(svss, 2)) + assert deser == svss diff --git a/ml/dlib/tools/python/test/test_svm_c_trainer.py b/ml/dlib/tools/python/test/test_svm_c_trainer.py new file mode 100644 index 000000000..ba9392e08 --- /dev/null +++ b/ml/dlib/tools/python/test/test_svm_c_trainer.py @@ -0,0 +1,65 @@ +from __future__ import division + +import pytest +from random import Random +from dlib import (vectors, vector, sparse_vectors, sparse_vector, pair, array, + cross_validate_trainer, + svm_c_trainer_radial_basis, + svm_c_trainer_sparse_radial_basis, + svm_c_trainer_histogram_intersection, + svm_c_trainer_sparse_histogram_intersection, + svm_c_trainer_linear, + svm_c_trainer_sparse_linear, + rvm_trainer_radial_basis, + rvm_trainer_sparse_radial_basis, + rvm_trainer_histogram_intersection, + rvm_trainer_sparse_histogram_intersection, + rvm_trainer_linear, + rvm_trainer_sparse_linear) + + +@pytest.fixture +def training_data(): + r = Random(0) + predictors = vectors() + sparse_predictors = sparse_vectors() + response = array() + for i in range(30): + for c in [-1, 1]: + response.append(c) + values = [r.random() + c * 0.5 for _ in range(3)] + predictors.append(vector(values)) + sp = sparse_vector() + for i, v in enumerate(values): + sp.append(pair(i, v)) + sparse_predictors.append(sp) + return predictors, sparse_predictors, response + + +@pytest.mark.parametrize('trainer, class1_accuracy, class2_accuracy', [ + (svm_c_trainer_radial_basis, 1.0, 1.0), + (svm_c_trainer_sparse_radial_basis, 1.0, 1.0), + (svm_c_trainer_histogram_intersection, 1.0, 1.0), + (svm_c_trainer_sparse_histogram_intersection, 1.0, 1.0), + (svm_c_trainer_linear, 1.0, 23 / 30), + (svm_c_trainer_sparse_linear, 1.0, 23 / 30), + (rvm_trainer_radial_basis, 1.0, 1.0), + (rvm_trainer_sparse_radial_basis, 1.0, 1.0), + (rvm_trainer_histogram_intersection, 1.0, 1.0), + (rvm_trainer_sparse_histogram_intersection, 1.0, 1.0), + (rvm_trainer_linear, 1.0, 0.6), + (rvm_trainer_sparse_linear, 1.0, 0.6) +]) +def test_trainers(training_data, trainer, class1_accuracy, class2_accuracy): + predictors, sparse_predictors, response = training_data + if 'sparse' in trainer.__name__: + predictors = sparse_predictors + cv = cross_validate_trainer(trainer(), predictors, response, folds=10) + assert cv.class1_accuracy == pytest.approx(class1_accuracy) + assert cv.class2_accuracy == pytest.approx(class2_accuracy) + + decision_function = trainer().train(predictors, response) + assert decision_function(predictors[2]) < 0 + assert decision_function(predictors[3]) > 0 + if 'linear' in trainer.__name__: + assert len(decision_function.weights) == 3 diff --git a/ml/dlib/tools/python/test/test_vector.py b/ml/dlib/tools/python/test/test_vector.py new file mode 100644 index 000000000..ff79ab339 --- /dev/null +++ b/ml/dlib/tools/python/test/test_vector.py @@ -0,0 +1,170 @@ +from dlib import vector, vectors, vectorss, dot +try: + import cPickle as pickle # Use cPickle on Python 2.7 +except ImportError: + import pickle +from pytest import raises + + +def test_vector_empty_init(): + v = vector() + assert len(v) == 0 + assert v.shape == (0, 1) + assert str(v) == "" + assert repr(v) == "dlib.vector([])" + + +def test_vector_init_with_number(): + v = vector(3) + assert len(v) == 3 + assert v.shape == (3, 1) + assert str(v) == "0\n0\n0" + assert repr(v) == "dlib.vector([0, 0, 0])" + + +def test_vector_set_size(): + v = vector(3) + + v.set_size(0) + assert len(v) == 0 + assert v.shape == (0, 1) + + v.resize(10) + assert len(v) == 10 + assert v.shape == (10, 1) + for i in range(10): + assert v[i] == 0 + + +def test_vector_init_with_list(): + v = vector([1, 2, 3]) + assert len(v) == 3 + assert v.shape == (3, 1) + assert str(v) == "1\n2\n3" + assert repr(v) == "dlib.vector([1, 2, 3])" + + +def test_vector_getitem(): + v = vector([1, 2, 3]) + assert v[0] == 1 + assert v[-1] == 3 + assert v[1] == v[-2] + + +def test_vector_slice(): + v = vector([1, 2, 3, 4, 5]) + v_slice = v[1:4] + assert len(v_slice) == 3 + for idx, val in enumerate([2, 3, 4]): + assert v_slice[idx] == val + + v_slice = v[-3:-1] + assert len(v_slice) == 2 + for idx, val in enumerate([3, 4]): + assert v_slice[idx] == val + + v_slice = v[1:-2] + assert len(v_slice) == 2 + for idx, val in enumerate([2, 3]): + assert v_slice[idx] == val + + +def test_vector_invalid_getitem(): + v = vector([1, 2, 3]) + with raises(IndexError): + v[-4] + with raises(IndexError): + v[3] + + +def test_vector_init_with_negative_number(): + with raises(Exception): + vector(-3) + + +def test_dot(): + v1 = vector([1, 0]) + v2 = vector([0, 1]) + v3 = vector([-1, 0]) + assert dot(v1, v1) == 1 + assert dot(v1, v2) == 0 + assert dot(v1, v3) == -1 + + +def test_vector_serialization(): + v = vector([1, 2, 3]) + ser = pickle.dumps(v, 2) + deser = pickle.loads(ser) + assert str(v) == str(deser) + + +def generate_test_vectors(): + vs = vectors() + vs.append(vector([0, 1, 2])) + vs.append(vector([3, 4, 5])) + vs.append(vector([6, 7, 8])) + assert len(vs) == 3 + return vs + + +def generate_test_vectorss(): + vss = vectorss() + vss.append(generate_test_vectors()) + vss.append(generate_test_vectors()) + vss.append(generate_test_vectors()) + assert len(vss) == 3 + return vss + + +def test_vectors_serialization(): + vs = generate_test_vectors() + ser = pickle.dumps(vs, 2) + deser = pickle.loads(ser) + assert vs == deser + + +def test_vectors_clear(): + vs = generate_test_vectors() + vs.clear() + assert len(vs) == 0 + + +def test_vectors_resize(): + vs = vectors() + vs.resize(100) + assert len(vs) == 100 + for i in range(100): + assert len(vs[i]) == 0 + + +def test_vectors_extend(): + vs = vectors() + vs.extend([vector([1, 2, 3]), vector([4, 5, 6])]) + assert len(vs) == 2 + + +def test_vectorss_serialization(): + vss = generate_test_vectorss() + ser = pickle.dumps(vss, 2) + deser = pickle.loads(ser) + assert vss == deser + + +def test_vectorss_clear(): + vss = generate_test_vectorss() + vss.clear() + assert len(vss) == 0 + + +def test_vectorss_resize(): + vss = vectorss() + vss.resize(100) + assert len(vss) == 100 + for i in range(100): + assert len(vss[i]) == 0 + + +def test_vectorss_extend(): + vss = vectorss() + vss.extend([generate_test_vectors(), generate_test_vectors()]) + assert len(vss) == 2 diff --git a/ml/dlib/tools/visual_studio_natvis/README.txt b/ml/dlib/tools/visual_studio_natvis/README.txt new file mode 100644 index 000000000..f766ac0c1 --- /dev/null +++ b/ml/dlib/tools/visual_studio_natvis/README.txt @@ -0,0 +1,12 @@ +Hi Davis, +thanks for your work on dlib! + +I have created a natvis file to have nicer debugger visualization of dlib matrices in Visual Studio (2012 - …) and I just wanted to share it with you. + +To test it, copy the file into you folder %USERPROFILE%\My Documents\Visual Studio 2015\Visualizers or %VSINSTALLDIR%\Common7\Packages\Debugger\Visualizers as described here https://msdn.microsoft.com/en-us/library/jj620914.aspx + +It’s certainly extendable, especially to include it into image watch, but currently it may help users to debug much faster. + +Feel free to share it. +Best, + Johannes Huber diff --git a/ml/dlib/tools/visual_studio_natvis/dlib.natvis b/ml/dlib/tools/visual_studio_natvis/dlib.natvis new file mode 100644 index 000000000..0c1c52060 --- /dev/null +++ b/ml/dlib/tools/visual_studio_natvis/dlib.natvis @@ -0,0 +1,51 @@ +<?xml version="1.0" encoding="utf-8"?>
+<AutoVisualizer xmlns="http://schemas.microsoft.com/vstudio/debugger/natvis/2010">
+ <!-- dlib matrix debugger visualization in Visual Studio-->
+ <!-- Johannes Huber, SAFEmine Part of Hexagon -->
+ <!-- no warranty -->
+
+ <!-- general dlib::matrix fixed size-->
+ <Type Name="dlib::matrix<*,*,*,*>">
+ <DisplayString>{{ size= <{$T2}> x <{$T3}> }}</DisplayString>
+ <Expand>
+ <ArrayItems>
+ <Size>$T2 * $T3</Size>
+ <ValuePointer>($T1*)data.data</ValuePointer>
+ </ArrayItems>
+ </Expand>
+ </Type>
+
+ <!-- general dlib::matrix fixed rows-->
+ <Type Name="dlib::matrix<*,0,*,*>">
+ <DisplayString>{{ size={data.nr_} x <{$T2}> }}</DisplayString>
+ <Expand>
+ <ArrayItems Condition="data.data != 0">
+ <Size>data.nr_ * $T2</Size>
+ <ValuePointer>($T1*)data.data</ValuePointer>
+ </ArrayItems>
+ </Expand>
+ </Type>
+
+ <!-- general dlib::matrix fixed cols-->
+ <Type Name="dlib::matrix<*,*,0,*>">
+ <DisplayString>{{ size= <{$T2}> x {data.nc_} }}</DisplayString>
+ <Expand>
+ <ArrayItems Condition="data.data != 0">
+ <Size>$T2 * data.nc_</Size>
+ <ValuePointer>($T1*)data.data</ValuePointer>
+ </ArrayItems>
+ </Expand>
+ </Type>
+
+ <!-- general dlib::matrix dynamic size-->
+ <Type Name="dlib::matrix<*,0,0,*>">
+ <DisplayString>{{ size= {data.nc_} x {data.nc_} }}</DisplayString>
+ <Expand>
+ <ArrayItems Condition="data.data != 0">
+ <Size>data.nr_*data.nc_</Size>
+ <ValuePointer>($T1*)data.data</ValuePointer>
+ </ArrayItems>
+ </Expand>
+ </Type>
+
+</AutoVisualizer>
\ No newline at end of file |