diff options
Diffstat (limited to 'ml/dlib/examples/dnn_semantic_segmentation_ex.cpp')
-rw-r--r-- | ml/dlib/examples/dnn_semantic_segmentation_ex.cpp | 172 |
1 files changed, 172 insertions, 0 deletions
diff --git a/ml/dlib/examples/dnn_semantic_segmentation_ex.cpp b/ml/dlib/examples/dnn_semantic_segmentation_ex.cpp new file mode 100644 index 00000000..fa49c5a9 --- /dev/null +++ b/ml/dlib/examples/dnn_semantic_segmentation_ex.cpp @@ -0,0 +1,172 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + This example shows how to do semantic segmentation on an image using net pretrained + on the PASCAL VOC2012 dataset. For an introduction to what segmentation is, see the + accompanying header file dnn_semantic_segmentation_ex.h. + + Instructions how to run the example: + 1. Download the PASCAL VOC2012 data, and untar it somewhere. + http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar + 2. Build the dnn_semantic_segmentation_train_ex example program. + 3. Run: + ./dnn_semantic_segmentation_train_ex /path/to/VOC2012 + 4. Wait while the network is being trained. + 5. Build the dnn_semantic_segmentation_ex example program. + 6. Run: + ./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images + + An alternative to steps 2-4 above is to download a pre-trained network + from here: http://dlib.net/files/semantic_segmentation_voc2012net.dnn + + It would be a good idea to become familiar with dlib's DNN tooling before reading this + example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp + before reading this example program. +*/ + +#include "dnn_semantic_segmentation_ex.h" + +#include <iostream> +#include <dlib/data_io.h> +#include <dlib/gui_widgets.h> + +using namespace std; +using namespace dlib; + +// ---------------------------------------------------------------------------------------- + +// The PASCAL VOC2012 dataset contains 20 ground-truth classes + background. Each class +// is represented using an RGB color value. We associate each class also to an index in the +// range [0, 20], used internally by the network. To generate nice RGB representations of +// inference results, we need to be able to convert the index values to the corresponding +// RGB values. + +// Given an index in the range [0, 20], find the corresponding PASCAL VOC2012 class +// (e.g., 'dog'). +const Voc2012class& find_voc2012_class(const uint16_t& index_label) +{ + return find_voc2012_class( + [&index_label](const Voc2012class& voc2012class) + { + return index_label == voc2012class.index; + } + ); +} + +// Convert an index in the range [0, 20] to a corresponding RGB class label. +inline rgb_pixel index_label_to_rgb_label(uint16_t index_label) +{ + return find_voc2012_class(index_label).rgb_label; +} + +// Convert an image containing indexes in the range [0, 20] to a corresponding +// image containing RGB class labels. +void index_label_image_to_rgb_label_image( + const matrix<uint16_t>& index_label_image, + matrix<rgb_pixel>& rgb_label_image +) +{ + const long nr = index_label_image.nr(); + const long nc = index_label_image.nc(); + + rgb_label_image.set_size(nr, nc); + + for (long r = 0; r < nr; ++r) + { + for (long c = 0; c < nc; ++c) + { + rgb_label_image(r, c) = index_label_to_rgb_label(index_label_image(r, c)); + } + } +} + +// Find the most prominent class label from amongst the per-pixel predictions. +std::string get_most_prominent_non_background_classlabel(const matrix<uint16_t>& index_label_image) +{ + const long nr = index_label_image.nr(); + const long nc = index_label_image.nc(); + + std::vector<unsigned int> counters(class_count); + + for (long r = 0; r < nr; ++r) + { + for (long c = 0; c < nc; ++c) + { + const uint16_t label = index_label_image(r, c); + ++counters[label]; + } + } + + const auto max_element = std::max_element(counters.begin() + 1, counters.end()); + const uint16_t most_prominent_index_label = max_element - counters.begin(); + + return find_voc2012_class(most_prominent_index_label).classlabel; +} + +// ---------------------------------------------------------------------------------------- + +int main(int argc, char** argv) try +{ + if (argc != 2) + { + cout << "You call this program like this: " << endl; + cout << "./dnn_semantic_segmentation_train_ex /path/to/images" << endl; + cout << endl; + cout << "You will also need a trained 'semantic_segmentation_voc2012net.dnn' file." << endl; + cout << "You can either train it yourself (see example program" << endl; + cout << "dnn_semantic_segmentation_train_ex), or download a" << endl; + cout << "copy from here: http://dlib.net/files/semantic_segmentation_voc2012net.dnn" << endl; + return 1; + } + + // Read the file containing the trained network from the working directory. + anet_type net; + deserialize("semantic_segmentation_voc2012net.dnn") >> net; + + // Show inference results in a window. + image_window win; + + matrix<rgb_pixel> input_image; + matrix<uint16_t> index_label_image; + matrix<rgb_pixel> rgb_label_image; + + // Find supported image files. + const std::vector<file> files = dlib::get_files_in_directory_tree(argv[1], + dlib::match_endings(".jpeg .jpg .png")); + + cout << "Found " << files.size() << " images, processing..." << endl; + + for (const file& file : files) + { + // Load the input image. + load_image(input_image, file.full_name()); + + // Create predictions for each pixel. At this point, the type of each prediction + // is an index (a value between 0 and 20). Note that the net may return an image + // that is not exactly the same size as the input. + const matrix<uint16_t> temp = net(input_image); + + // Crop the returned image to be exactly the same size as the input. + const chip_details chip_details( + centered_rect(temp.nc() / 2, temp.nr() / 2, input_image.nc(), input_image.nr()), + chip_dims(input_image.nr(), input_image.nc()) + ); + extract_image_chip(temp, chip_details, index_label_image, interpolate_nearest_neighbor()); + + // Convert the indexes to RGB values. + index_label_image_to_rgb_label_image(index_label_image, rgb_label_image); + + // Show the input image on the left, and the predicted RGB labels on the right. + win.set_image(join_rows(input_image, rgb_label_image)); + + // Find the most prominent class label from amongst the per-pixel predictions. + const std::string classlabel = get_most_prominent_non_background_classlabel(index_label_image); + + cout << file.name() << " : " << classlabel << " - hit enter to process the next image"; + cin.get(); + } +} +catch(std::exception& e) +{ + cout << e.what() << endl; +} + |