diff options
Diffstat (limited to '')
-rw-r--r-- | ml/dlib/examples/train_object_detector.cpp | 422 |
1 files changed, 422 insertions, 0 deletions
diff --git a/ml/dlib/examples/train_object_detector.cpp b/ml/dlib/examples/train_object_detector.cpp new file mode 100644 index 000000000..9bc0977c0 --- /dev/null +++ b/ml/dlib/examples/train_object_detector.cpp @@ -0,0 +1,422 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + + This is an example showing how you might use dlib to create a reasonably + functional command line tool for object detection. This example assumes + you are familiar with the contents of at least the following example + programs: + - fhog_object_detector_ex.cpp + - compress_stream_ex.cpp + + + + + This program is a command line tool for learning to detect objects in images. + Therefore, to create an object detector it requires a set of annotated training + images. To create this annotated data you will need to use the imglab tool + included with dlib. It is located in the tools/imglab folder and can be compiled + using the following commands. + cd tools/imglab + mkdir build + cd build + cmake .. + cmake --build . --config Release + Note that you may need to install CMake (www.cmake.org) for this to work. + + Next, let's assume you have a folder of images called /tmp/images. These images + should contain examples of the objects you want to learn to detect. You will + use the imglab tool to label these objects. Do this by typing the following + ./imglab -c mydataset.xml /tmp/images + This will create a file called mydataset.xml which simply lists the images in + /tmp/images. To annotate them run + ./imglab mydataset.xml + A window will appear showing all the images. You can use the up and down arrow + keys to cycle though the images and the mouse to label objects. In particular, + holding the shift key, left clicking, and dragging the mouse will allow you to + draw boxes around the objects you wish to detect. So next, label all the objects + with boxes. Note that it is important to label all the objects since any object + not labeled is implicitly assumed to be not an object we should detect. If there + are objects you are not sure about you should draw a box around them, then double + click the box and press i. This will cross out the box and mark it as "ignore". + The training code in dlib will then simply ignore detections matching that box. + + + Once you finish labeling objects go to the file menu, click save, and then close + the program. This will save the object boxes back to mydataset.xml. You can verify + this by opening the tool again with + ./imglab mydataset.xml + and observing that the boxes are present. + + Returning to the present example program, we can compile it using cmake just as we + did with the imglab tool. Once compiled, we can issue the command + ./train_object_detector -tv mydataset.xml + which will train an object detection model based on our labeled data. The model + will be saved to the file object_detector.svm. Once this has finished we can use + the object detector to locate objects in new images with a command like + ./train_object_detector some_image.png + This command will display some_image.png in a window and any detected objects will + be indicated by a red box. + + Finally, to make running this example easy dlib includes some training data in the + examples/faces folder. Therefore, you can test this program out by running the + following sequence of commands: + ./train_object_detector -tv examples/faces/training.xml -u1 --flip + ./train_object_detector --test examples/faces/testing.xml -u1 + ./train_object_detector examples/faces/*.jpg -u1 + That will make a face detector that performs perfectly on the test images listed in + testing.xml and then it will show you the detections on all the images. +*/ + + +#include <dlib/svm_threaded.h> +#include <dlib/string.h> +#include <dlib/gui_widgets.h> +#include <dlib/image_processing.h> +#include <dlib/data_io.h> +#include <dlib/cmd_line_parser.h> + + +#include <iostream> +#include <fstream> + + +using namespace std; +using namespace dlib; + +// ---------------------------------------------------------------------------------------- + +void pick_best_window_size ( + const std::vector<std::vector<rectangle> >& boxes, + unsigned long& width, + unsigned long& height, + const unsigned long target_size +) +/*! + ensures + - Finds the average aspect ratio of the elements of boxes and outputs a width + and height such that the aspect ratio is equal to the average and also the + area is equal to target_size. That is, the following will be approximately true: + - #width*#height == target_size + - #width/#height == the average aspect ratio of the elements of boxes. +!*/ +{ + // find the average width and height + running_stats<double> avg_width, avg_height; + for (unsigned long i = 0; i < boxes.size(); ++i) + { + for (unsigned long j = 0; j < boxes[i].size(); ++j) + { + avg_width.add(boxes[i][j].width()); + avg_height.add(boxes[i][j].height()); + } + } + + // now adjust the box size so that it is about target_pixels pixels in size + double size = avg_width.mean()*avg_height.mean(); + double scale = std::sqrt(target_size/size); + + width = (unsigned long)(avg_width.mean()*scale+0.5); + height = (unsigned long)(avg_height.mean()*scale+0.5); + // make sure the width and height never round to zero. + if (width == 0) + width = 1; + if (height == 0) + height = 1; +} + +// ---------------------------------------------------------------------------------------- + +bool contains_any_boxes ( + const std::vector<std::vector<rectangle> >& boxes +) +{ + for (unsigned long i = 0; i < boxes.size(); ++i) + { + if (boxes[i].size() != 0) + return true; + } + return false; +} + +// ---------------------------------------------------------------------------------------- + +void throw_invalid_box_error_message ( + const std::string& dataset_filename, + const std::vector<std::vector<rectangle> >& removed, + const unsigned long target_size +) +{ + image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, dataset_filename); + + std::ostringstream sout; + sout << "Error! An impossible set of object boxes was given for training. "; + sout << "All the boxes need to have a similar aspect ratio and also not be "; + sout << "smaller than about " << target_size << " pixels in area. "; + sout << "The following images contain invalid boxes:\n"; + std::ostringstream sout2; + for (unsigned long i = 0; i < removed.size(); ++i) + { + if (removed[i].size() != 0) + { + const std::string imgname = data.images[i].filename; + sout2 << " " << imgname << "\n"; + } + } + throw error("\n"+wrap_string(sout.str()) + "\n" + sout2.str()); +} + +// ---------------------------------------------------------------------------------------- + +int main(int argc, char** argv) +{ + try + { + command_line_parser parser; + parser.add_option("h","Display this help message."); + parser.add_option("t","Train an object detector and save the detector to disk."); + parser.add_option("cross-validate", + "Perform cross-validation on an image dataset and print the results."); + parser.add_option("test", "Test a trained detector on an image dataset and print the results."); + parser.add_option("u", "Upsample each input image <arg> times. Each upsampling quadruples the number of pixels in the image (default: 0).", 1); + + parser.set_group_name("training/cross-validation sub-options"); + parser.add_option("v","Be verbose."); + parser.add_option("folds","When doing cross-validation, do <arg> folds (default: 3).",1); + parser.add_option("c","Set the SVM C parameter to <arg> (default: 1.0).",1); + parser.add_option("threads", "Use <arg> threads for training (default: 4).",1); + parser.add_option("eps", "Set training epsilon to <arg> (default: 0.01).", 1); + parser.add_option("target-size", "Set size of the sliding window to about <arg> pixels in area (default: 80*80).", 1); + parser.add_option("flip", "Add left/right flipped copies of the images into the training dataset. Useful when the objects " + "you want to detect are left/right symmetric."); + + + parser.parse(argc, argv); + + // Now we do a little command line validation. Each of the following functions + // checks something and throws an exception if the test fails. + const char* one_time_opts[] = {"h", "v", "t", "cross-validate", "c", "threads", "target-size", + "folds", "test", "eps", "u", "flip"}; + parser.check_one_time_options(one_time_opts); // Can't give an option more than once + // Make sure the arguments to these options are within valid ranges if they are supplied by the user. + parser.check_option_arg_range("c", 1e-12, 1e12); + parser.check_option_arg_range("eps", 1e-5, 1e4); + parser.check_option_arg_range("threads", 1, 1000); + parser.check_option_arg_range("folds", 2, 100); + parser.check_option_arg_range("u", 0, 8); + parser.check_option_arg_range("target-size", 4*4, 10000*10000); + const char* incompatible[] = {"t", "cross-validate", "test"}; + parser.check_incompatible_options(incompatible); + // You are only allowed to give these training_sub_ops if you also give either -t or --cross-validate. + const char* training_ops[] = {"t", "cross-validate"}; + const char* training_sub_ops[] = {"v", "c", "threads", "target-size", "eps", "flip"}; + parser.check_sub_options(training_ops, training_sub_ops); + parser.check_sub_option("cross-validate", "folds"); + + + if (parser.option("h")) + { + cout << "Usage: train_object_detector [options] <image dataset file|image file>\n"; + parser.print_options(); + + return EXIT_SUCCESS; + } + + + typedef scan_fhog_pyramid<pyramid_down<6> > image_scanner_type; + // Get the upsample option from the user but use 0 if it wasn't given. + const unsigned long upsample_amount = get_option(parser, "u", 0); + + if (parser.option("t") || parser.option("cross-validate")) + { + if (parser.number_of_arguments() != 1) + { + cout << "You must give an image dataset metadata XML file produced by the imglab tool." << endl; + cout << "\nTry the -h option for more information." << endl; + return EXIT_FAILURE; + } + + dlib::array<array2d<unsigned char> > images; + std::vector<std::vector<rectangle> > object_locations, ignore; + + cout << "Loading image dataset from metadata file " << parser[0] << endl; + ignore = load_image_dataset(images, object_locations, parser[0]); + cout << "Number of images loaded: " << images.size() << endl; + + // Get the options from the user, but use default values if they are not + // supplied. + const int threads = get_option(parser, "threads", 4); + const double C = get_option(parser, "c", 1.0); + const double eps = get_option(parser, "eps", 0.01); + unsigned int num_folds = get_option(parser, "folds", 3); + const unsigned long target_size = get_option(parser, "target-size", 80*80); + // You can't do more folds than there are images. + if (num_folds > images.size()) + num_folds = images.size(); + + // Upsample images if the user asked us to do that. + for (unsigned long i = 0; i < upsample_amount; ++i) + upsample_image_dataset<pyramid_down<2> >(images, object_locations, ignore); + + + image_scanner_type scanner; + unsigned long width, height; + pick_best_window_size(object_locations, width, height, target_size); + scanner.set_detection_window_size(width, height); + + structural_object_detection_trainer<image_scanner_type> trainer(scanner); + trainer.set_num_threads(threads); + if (parser.option("v")) + trainer.be_verbose(); + trainer.set_c(C); + trainer.set_epsilon(eps); + + // Now make sure all the boxes are obtainable by the scanner. + std::vector<std::vector<rectangle> > removed; + removed = remove_unobtainable_rectangles(trainer, images, object_locations); + // if we weren't able to get all the boxes to match then throw an error + if (contains_any_boxes(removed)) + { + unsigned long scale = upsample_amount+1; + scale = scale*scale; + throw_invalid_box_error_message(parser[0], removed, target_size/scale); + } + + if (parser.option("flip")) + add_image_left_right_flips(images, object_locations, ignore); + + if (parser.option("t")) + { + // Do the actual training and save the results into the detector object. + object_detector<image_scanner_type> detector = trainer.train(images, object_locations, ignore); + + cout << "Saving trained detector to object_detector.svm" << endl; + serialize("object_detector.svm") << detector; + + cout << "Testing detector on training data..." << endl; + cout << "Test detector (precision,recall,AP): " << test_object_detection_function(detector, images, object_locations, ignore) << endl; + } + else + { + // shuffle the order of the training images + randomize_samples(images, object_locations); + + cout << num_folds << "-fold cross validation (precision,recall,AP): " + << cross_validate_object_detection_trainer(trainer, images, object_locations, ignore, num_folds) << endl; + } + + cout << "Parameters used: " << endl; + cout << " threads: "<< threads << endl; + cout << " C: "<< C << endl; + cout << " eps: "<< eps << endl; + cout << " target-size: "<< target_size << endl; + cout << " detection window width: "<< width << endl; + cout << " detection window height: "<< height << endl; + cout << " upsample this many times : "<< upsample_amount << endl; + if (parser.option("flip")) + cout << " trained using left/right flips." << endl; + if (parser.option("cross-validate")) + cout << " num_folds: "<< num_folds << endl; + cout << endl; + + return EXIT_SUCCESS; + } + + + + + + + + // The rest of the code is devoted to testing an already trained object detector. + + if (parser.number_of_arguments() == 0) + { + cout << "You must give an image or an image dataset metadata XML file produced by the imglab tool." << endl; + cout << "\nTry the -h option for more information." << endl; + return EXIT_FAILURE; + } + + // load a previously trained object detector and try it out on some data + ifstream fin("object_detector.svm", ios::binary); + if (!fin) + { + cout << "Can't find a trained object detector file object_detector.svm. " << endl; + cout << "You need to train one using the -t option." << endl; + cout << "\nTry the -h option for more information." << endl; + return EXIT_FAILURE; + + } + object_detector<image_scanner_type> detector; + deserialize(detector, fin); + + dlib::array<array2d<unsigned char> > images; + // Check if the command line argument is an XML file + if (tolower(right_substr(parser[0],".")) == "xml") + { + std::vector<std::vector<rectangle> > object_locations, ignore; + cout << "Loading image dataset from metadata file " << parser[0] << endl; + ignore = load_image_dataset(images, object_locations, parser[0]); + cout << "Number of images loaded: " << images.size() << endl; + + // Upsample images if the user asked us to do that. + for (unsigned long i = 0; i < upsample_amount; ++i) + upsample_image_dataset<pyramid_down<2> >(images, object_locations, ignore); + + if (parser.option("test")) + { + cout << "Testing detector on data..." << endl; + cout << "Results (precision,recall,AP): " << test_object_detection_function(detector, images, object_locations, ignore) << endl; + return EXIT_SUCCESS; + } + } + else + { + // In this case, the user should have given some image files. So just + // load them. + images.resize(parser.number_of_arguments()); + for (unsigned long i = 0; i < images.size(); ++i) + load_image(images[i], parser[i]); + + // Upsample images if the user asked us to do that. + for (unsigned long i = 0; i < upsample_amount; ++i) + { + for (unsigned long j = 0; j < images.size(); ++j) + pyramid_up(images[j]); + } + } + + + // Test the detector on the images we loaded and display the results + // in a window. + image_window win; + for (unsigned long i = 0; i < images.size(); ++i) + { + // Run the detector on images[i] + const std::vector<rectangle> rects = detector(images[i]); + cout << "Number of detections: "<< rects.size() << endl; + + // Put the image and detections into the window. + win.clear_overlay(); + win.set_image(images[i]); + win.add_overlay(rects, rgb_pixel(255,0,0)); + + cout << "Hit enter to see the next image."; + cin.get(); + } + + + } + catch (exception& e) + { + cout << "\nexception thrown!" << endl; + cout << e.what() << endl; + cout << "\nTry the -h option for more information." << endl; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} + +// ---------------------------------------------------------------------------------------- + |