diff options
Diffstat (limited to 'ml/dlib/tools/python/src/simple_object_detector.h')
-rw-r--r-- | ml/dlib/tools/python/src/simple_object_detector.h | 318 |
1 files changed, 318 insertions, 0 deletions
diff --git a/ml/dlib/tools/python/src/simple_object_detector.h b/ml/dlib/tools/python/src/simple_object_detector.h new file mode 100644 index 000000000..4fceab429 --- /dev/null +++ b/ml/dlib/tools/python/src/simple_object_detector.h @@ -0,0 +1,318 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SIMPLE_ObJECT_DETECTOR_H__ +#define DLIB_SIMPLE_ObJECT_DETECTOR_H__ + +#include "dlib/image_processing/object_detector.h" +#include "dlib/string.h" +#include "dlib/image_processing/scan_fhog_pyramid.h" +#include "dlib/svm/structural_object_detection_trainer.h" +#include "dlib/geometry.h" +#include "dlib/data_io/load_image_dataset.h" +#include "dlib/image_processing/remove_unobtainable_rectangles.h" +#include "serialize_object_detector.h" +#include "dlib/svm.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + typedef object_detector<scan_fhog_pyramid<pyramid_down<6> > > simple_object_detector; + +// ---------------------------------------------------------------------------------------- + + struct simple_object_detector_training_options + { + simple_object_detector_training_options() + { + be_verbose = false; + add_left_right_image_flips = false; + num_threads = 4; + detection_window_size = 80*80; + C = 1; + epsilon = 0.01; + upsample_limit = 2; + } + + bool be_verbose; + bool add_left_right_image_flips; + unsigned long num_threads; + unsigned long detection_window_size; + double C; + double epsilon; + unsigned long upsample_limit; + }; + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline void pick_best_window_size ( + const std::vector<std::vector<rectangle> >& boxes, + unsigned long& width, + unsigned long& height, + const unsigned long target_size + ) + { + // find the average width and height + running_stats<double> avg_width, avg_height; + for (unsigned long i = 0; i < boxes.size(); ++i) + { + for (unsigned long j = 0; j < boxes[i].size(); ++j) + { + avg_width.add(boxes[i][j].width()); + avg_height.add(boxes[i][j].height()); + } + } + + // now adjust the box size so that it is about target_pixels pixels in size + double size = avg_width.mean()*avg_height.mean(); + double scale = std::sqrt(target_size/size); + + width = (unsigned long)(avg_width.mean()*scale+0.5); + height = (unsigned long)(avg_height.mean()*scale+0.5); + // make sure the width and height never round to zero. + if (width == 0) + width = 1; + if (height == 0) + height = 1; + } + + inline bool contains_any_boxes ( + const std::vector<std::vector<rectangle> >& boxes + ) + { + for (unsigned long i = 0; i < boxes.size(); ++i) + { + if (boxes[i].size() != 0) + return true; + } + return false; + } + + inline void throw_invalid_box_error_message ( + const std::string& dataset_filename, + const std::vector<std::vector<rectangle> >& removed, + const simple_object_detector_training_options& options + ) + { + + std::ostringstream sout; + // Note that the 1/16 factor is here because we will try to upsample the image + // 2 times to accommodate small boxes. We also take the max because we want to + // lower bound the size of the smallest recommended box. This is because the + // 8x8 HOG cells can't really deal with really small object boxes. + sout << "Error! An impossible set of object boxes was given for training. "; + sout << "All the boxes need to have a similar aspect ratio and also not be "; + sout << "smaller than about " << std::max<long>(20*20,options.detection_window_size/16) << " pixels in area. "; + + std::ostringstream sout2; + if (dataset_filename.size() != 0) + { + sout << "The following images contain invalid boxes:\n"; + image_dataset_metadata::dataset data; + load_image_dataset_metadata(data, dataset_filename); + for (unsigned long i = 0; i < removed.size(); ++i) + { + if (removed[i].size() != 0) + { + const std::string imgname = data.images[i].filename; + sout2 << " " << imgname << "\n"; + } + } + } + throw error("\n"+wrap_string(sout.str()) + "\n" + sout2.str()); + } + } + +// ---------------------------------------------------------------------------------------- + + template <typename image_array> + inline simple_object_detector_py train_simple_object_detector_on_images ( + const std::string& dataset_filename, // can be "" if it's not applicable + image_array& images, + std::vector<std::vector<rectangle> >& boxes, + std::vector<std::vector<rectangle> >& ignore, + const simple_object_detector_training_options& options + ) + { + if (options.C <= 0) + throw error("Invalid C value given to train_simple_object_detector(), C must be > 0."); + if (options.epsilon <= 0) + throw error("Invalid epsilon value given to train_simple_object_detector(), epsilon must be > 0."); + + if (images.size() != boxes.size()) + throw error("The list of images must have the same length as the list of boxes."); + if (images.size() != ignore.size()) + throw error("The list of images must have the same length as the list of ignore boxes."); + + if (impl::contains_any_boxes(boxes) == false) + throw error("Error, the training dataset does not have any labeled object boxes in it."); + + typedef scan_fhog_pyramid<pyramid_down<6> > image_scanner_type; + image_scanner_type scanner; + unsigned long width, height; + impl::pick_best_window_size(boxes, width, height, options.detection_window_size); + scanner.set_detection_window_size(width, height); + structural_object_detection_trainer<image_scanner_type> trainer(scanner); + trainer.set_num_threads(options.num_threads); + trainer.set_c(options.C); + trainer.set_epsilon(options.epsilon); + if (options.be_verbose) + { + std::cout << "Training with C: " << options.C << std::endl; + std::cout << "Training with epsilon: " << options.epsilon << std::endl; + std::cout << "Training using " << options.num_threads << " threads."<< std::endl; + std::cout << "Training with sliding window " << width << " pixels wide by " << height << " pixels tall." << std::endl; + if (options.add_left_right_image_flips) + std::cout << "Training on both left and right flipped versions of images." << std::endl; + trainer.be_verbose(); + } + + unsigned long upsampling_amount = 0; + + // now make sure all the boxes are obtainable by the scanner. We will try and + // upsample the images at most two times to help make the boxes obtainable. + std::vector<std::vector<rectangle> > temp(boxes), removed; + removed = remove_unobtainable_rectangles(trainer, images, temp); + while (impl::contains_any_boxes(removed) && upsampling_amount < options.upsample_limit) + { + ++upsampling_amount; + if (options.be_verbose) + std::cout << "Upsample images..." << std::endl; + upsample_image_dataset<pyramid_down<2> >(images, boxes, ignore); + temp = boxes; + removed = remove_unobtainable_rectangles(trainer, images, temp); + } + // if we weren't able to get all the boxes to match then throw an error + if (impl::contains_any_boxes(removed)) + impl::throw_invalid_box_error_message(dataset_filename, removed, options); + + if (options.add_left_right_image_flips) + add_image_left_right_flips(images, boxes, ignore); + + simple_object_detector detector = trainer.train(images, boxes, ignore); + + if (options.be_verbose) + { + std::cout << "Training complete." << std::endl; + std::cout << "Trained with C: " << options.C << std::endl; + std::cout << "Training with epsilon: " << options.epsilon << std::endl; + std::cout << "Trained using " << options.num_threads << " threads."<< std::endl; + std::cout << "Trained with sliding window " << width << " pixels wide by " << height << " pixels tall." << std::endl; + if (upsampling_amount != 0) + { + // Unsampled images # time(s) to allow detection of small boxes + std::cout << "Upsampled images " << upsampling_amount; + std::cout << ((upsampling_amount > 1) ? " times" : " time"); + std::cout << " to allow detection of small boxes." << std::endl; + } + if (options.add_left_right_image_flips) + std::cout << "Trained on both left and right flipped versions of images." << std::endl; + } + + return simple_object_detector_py(detector, upsampling_amount); + } + +// ---------------------------------------------------------------------------------------- + + inline void train_simple_object_detector ( + const std::string& dataset_filename, + const std::string& detector_output_filename, + const simple_object_detector_training_options& options + ) + { + dlib::array<array2d<rgb_pixel> > images; + std::vector<std::vector<rectangle> > boxes, ignore; + ignore = load_image_dataset(images, boxes, dataset_filename); + + simple_object_detector_py detector = train_simple_object_detector_on_images(dataset_filename, images, boxes, ignore, options); + + save_simple_object_detector_py(detector, detector_output_filename); + + if (options.be_verbose) + std::cout << "Saved detector to file " << detector_output_filename << std::endl; + } + +// ---------------------------------------------------------------------------------------- + + struct simple_test_results + { + double precision; + double recall; + double average_precision; + }; + + template <typename image_array> + inline const simple_test_results test_simple_object_detector_with_images ( + image_array& images, + const unsigned int upsample_amount, + std::vector<std::vector<rectangle> >& boxes, + std::vector<std::vector<rectangle> >& ignore, + simple_object_detector& detector + ) + { + for (unsigned int i = 0; i < upsample_amount; ++i) + upsample_image_dataset<pyramid_down<2> >(images, boxes); + + matrix<double,1,3> res = test_object_detection_function(detector, images, boxes, ignore); + simple_test_results ret; + ret.precision = res(0); + ret.recall = res(1); + ret.average_precision = res(2); + return ret; + } + + inline const simple_test_results test_simple_object_detector ( + const std::string& dataset_filename, + const std::string& detector_filename, + const int upsample_amount + ) + { + // Load all the testing images + dlib::array<array2d<rgb_pixel> > images; + std::vector<std::vector<rectangle> > boxes, ignore; + ignore = load_image_dataset(images, boxes, dataset_filename); + + // Load the detector off disk (We have to use the explicit serialization here + // so that we have an open file stream) + simple_object_detector detector; + std::ifstream fin(detector_filename.c_str(), std::ios::binary); + if (!fin) + throw error("Unable to open file " + detector_filename); + deserialize(detector, fin); + + + /* Here we need a little hack to deal with whether we are going to be loading a + * simple_object_detector (possibly trained outside of Python) or a + * simple_object_detector_py (definitely trained from Python). In order to do this + * we peek into the filestream to see if there is more data after the object + * detector. If there is, it will be the version and upsampling amount. Therefore, + * by default we set the upsampling amount to -1 so that we can catch when no + * upsampling amount has been passed (numbers less than 0). If -1 is passed, we + * assume no upsampling and use 0. If a number > 0 is passed, we use that, else we + * use the upsampling amount saved in the detector file (if it exists). + */ + unsigned int final_upsampling_amount = 0; + if (fin.peek() != EOF) + { + int version = 0; + deserialize(version, fin); + if (version != 1) + throw error("Unknown simple_object_detector format."); + deserialize(final_upsampling_amount, fin); + } + if (upsample_amount >= 0) + final_upsampling_amount = upsample_amount; + + return test_simple_object_detector_with_images(images, final_upsampling_amount, boxes, ignore, detector); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SIMPLE_ObJECT_DETECTOR_H__ + |