diff options
Diffstat (limited to 'ml/dlib/dlib/data_io')
-rw-r--r-- | ml/dlib/dlib/data_io/image_dataset_metadata.cpp | 411 | ||||
-rw-r--r-- | ml/dlib/dlib/data_io/image_dataset_metadata.h | 174 | ||||
-rw-r--r-- | ml/dlib/dlib/data_io/libsvm_io.h | 276 | ||||
-rw-r--r-- | ml/dlib/dlib/data_io/libsvm_io_abstract.h | 125 | ||||
-rw-r--r-- | ml/dlib/dlib/data_io/load_image_dataset.h | 510 | ||||
-rw-r--r-- | ml/dlib/dlib/data_io/load_image_dataset_abstract.h | 358 | ||||
-rw-r--r-- | ml/dlib/dlib/data_io/mnist.cpp | 133 | ||||
-rw-r--r-- | ml/dlib/dlib/data_io/mnist.h | 32 | ||||
-rw-r--r-- | ml/dlib/dlib/data_io/mnist_abstract.h | 46 |
9 files changed, 2065 insertions, 0 deletions
diff --git a/ml/dlib/dlib/data_io/image_dataset_metadata.cpp b/ml/dlib/dlib/data_io/image_dataset_metadata.cpp new file mode 100644 index 000000000..390ef6a0a --- /dev/null +++ b/ml/dlib/dlib/data_io/image_dataset_metadata.cpp @@ -0,0 +1,411 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMAGE_DAtASET_METADATA_CPPh_ +#define DLIB_IMAGE_DAtASET_METADATA_CPPh_ + +#include "image_dataset_metadata.h" + +#include <fstream> +#include <sstream> +#include "../compress_stream.h" +#include "../base64.h" +#include "../xml_parser.h" +#include "../string.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + namespace image_dataset_metadata + { + + // ------------------------------------------------------------------------------------ + + const std::string get_decoded_string(); + void create_image_metadata_stylesheet_file(const std::string& main_filename) + { + std::string path; + std::string::size_type pos = main_filename.find_last_of("/\\"); + if (pos != std::string::npos) + path = main_filename.substr(0,pos+1); + + std::ofstream fout((path + "image_metadata_stylesheet.xsl").c_str()); + if (!fout) + throw dlib::error("ERROR: Unable to open image_metadata_stylesheet.xsl for writing."); + + fout << get_decoded_string(); + + if (!fout) + throw dlib::error("ERROR: Unable to write to image_metadata_stylesheet.xsl."); + } + + void save_image_dataset_metadata ( + const dataset& meta, + const std::string& filename + ) + { + create_image_metadata_stylesheet_file(filename); + + const std::vector<image>& images = meta.images; + + std::ofstream fout(filename.c_str()); + if (!fout) + throw dlib::error("ERROR: Unable to open " + filename + " for writing."); + + fout << "<?xml version='1.0' encoding='ISO-8859-1'?>\n"; + fout << "<?xml-stylesheet type='text/xsl' href='image_metadata_stylesheet.xsl'?>\n"; + fout << "<dataset>\n"; + fout << "<name>" << meta.name << "</name>\n"; + fout << "<comment>" << meta.comment << "</comment>\n"; + fout << "<images>\n"; + for (unsigned long i = 0; i < images.size(); ++i) + { + fout << " <image file='" << images[i].filename << "'>\n"; + + // save all the boxes + for (unsigned long j = 0; j < images[i].boxes.size(); ++j) + { + const box& b = images[i].boxes[j]; + fout << " <box top='" << b.rect.top() << "' " + << "left='" << b.rect.left() << "' " + << "width='" << b.rect.width() << "' " + << "height='" << b.rect.height() << "'"; + if (b.difficult) + fout << " difficult='" << b.difficult << "'"; + if (b.truncated) + fout << " truncated='" << b.truncated << "'"; + if (b.occluded) + fout << " occluded='" << b.occluded << "'"; + if (b.ignore) + fout << " ignore='" << b.ignore << "'"; + if (b.angle != 0) + fout << " angle='" << b.angle << "'"; + if (b.age != 0) + fout << " age='" << b.age << "'"; + if (b.gender == FEMALE) + fout << " gender='female'"; + else if (b.gender == MALE) + fout << " gender='male'"; + if (b.pose != 0) + fout << " pose='" << b.pose << "'"; + if (b.detection_score != 0) + fout << " detection_score='" << b.detection_score << "'"; + + if (b.has_label() || b.parts.size() != 0) + { + fout << ">\n"; + + if (b.has_label()) + fout << " <label>" << b.label << "</label>\n"; + + // save all the parts + std::map<std::string,point>::const_iterator itr; + for (itr = b.parts.begin(); itr != b.parts.end(); ++itr) + { + fout << " <part name='"<< itr->first << "' x='"<< itr->second.x() <<"' y='"<< itr->second.y() <<"'/>\n"; + } + + fout << " </box>\n"; + } + else + { + fout << "/>\n"; + } + } + + + + fout << " </image>\n"; + + if (!fout) + throw dlib::error("ERROR: Unable to write to " + filename + "."); + } + fout << "</images>\n"; + fout << "</dataset>"; + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + class doc_handler : public document_handler + { + std::vector<std::string> ts; + image temp_image; + box temp_box; + + dataset& meta; + + public: + + doc_handler( + dataset& metadata_ + ): + meta(metadata_) + {} + + + virtual void start_document ( + ) + { + meta = dataset(); + ts.clear(); + temp_image = image(); + temp_box = box(); + } + + virtual void end_document ( + ) + { + } + + virtual void start_element ( + const unsigned long line_number, + const std::string& name, + const dlib::attribute_list& atts + ) + { + try + { + if (ts.size() == 0) + { + if (name != "dataset") + { + std::ostringstream sout; + sout << "Invalid XML document. Root tag must be <dataset>. Found <" << name << "> instead."; + throw dlib::error(sout.str()); + } + else + { + ts.push_back(name); + return; + } + } + + + if (name == "box") + { + if (atts.is_in_list("top")) temp_box.rect.top() = sa = atts["top"]; + else throw dlib::error("<box> missing required attribute 'top'"); + + if (atts.is_in_list("left")) temp_box.rect.left() = sa = atts["left"]; + else throw dlib::error("<box> missing required attribute 'left'"); + + if (atts.is_in_list("width")) temp_box.rect.right() = sa = atts["width"]; + else throw dlib::error("<box> missing required attribute 'width'"); + + if (atts.is_in_list("height")) temp_box.rect.bottom() = sa = atts["height"]; + else throw dlib::error("<box> missing required attribute 'height'"); + + if (atts.is_in_list("difficult")) temp_box.difficult = sa = atts["difficult"]; + if (atts.is_in_list("truncated")) temp_box.truncated = sa = atts["truncated"]; + if (atts.is_in_list("occluded")) temp_box.occluded = sa = atts["occluded"]; + if (atts.is_in_list("ignore")) temp_box.ignore = sa = atts["ignore"]; + if (atts.is_in_list("angle")) temp_box.angle = sa = atts["angle"]; + if (atts.is_in_list("age")) temp_box.age = sa = atts["age"]; + if (atts.is_in_list("gender")) + { + if (atts["gender"] == "male") + temp_box.gender = MALE; + else if (atts["gender"] == "female") + temp_box.gender = FEMALE; + else if (atts["gender"] == "unknown") + temp_box.gender = UNKNOWN; + else + throw dlib::error("Invalid gender string in box attribute."); + } + if (atts.is_in_list("pose")) temp_box.pose = sa = atts["pose"]; + if (atts.is_in_list("detection_score")) temp_box.detection_score = sa = atts["detection_score"]; + + temp_box.rect.bottom() += temp_box.rect.top()-1; + temp_box.rect.right() += temp_box.rect.left()-1; + } + else if (name == "part" && ts.back() == "box") + { + point temp; + if (atts.is_in_list("x")) temp.x() = sa = atts["x"]; + else throw dlib::error("<part> missing required attribute 'x'"); + + if (atts.is_in_list("y")) temp.y() = sa = atts["y"]; + else throw dlib::error("<part> missing required attribute 'y'"); + + if (atts.is_in_list("name")) + { + if (temp_box.parts.count(atts["name"])==0) + { + temp_box.parts[atts["name"]] = temp; + } + else + { + throw dlib::error("<part> with name '" + atts["name"] + "' is defined more than one time in a single box."); + } + } + else + { + throw dlib::error("<part> missing required attribute 'name'"); + } + } + else if (name == "image") + { + temp_image.boxes.clear(); + + if (atts.is_in_list("file")) temp_image.filename = atts["file"]; + else throw dlib::error("<image> missing required attribute 'file'"); + } + + ts.push_back(name); + } + catch (error& e) + { + throw dlib::error("Error on line " + cast_to_string(line_number) + ": " + e.what()); + } + } + + virtual void end_element ( + const unsigned long , + const std::string& name + ) + { + ts.pop_back(); + if (ts.size() == 0) + return; + + if (name == "box" && ts.back() == "image") + { + temp_image.boxes.push_back(temp_box); + temp_box = box(); + } + else if (name == "image" && ts.back() == "images") + { + meta.images.push_back(temp_image); + temp_image = image(); + } + } + + virtual void characters ( + const std::string& data + ) + { + if (ts.size() == 2 && ts[1] == "name") + { + meta.name = trim(data); + } + else if (ts.size() == 2 && ts[1] == "comment") + { + meta.comment = trim(data); + } + else if (ts.size() >= 2 && ts[ts.size()-1] == "label" && + ts[ts.size()-2] == "box") + { + temp_box.label = trim(data); + } + } + + virtual void processing_instruction ( + const unsigned long , + const std::string& , + const std::string& + ) + { + } + }; + + // ---------------------------------------------------------------------------------------- + + class xml_error_handler : public error_handler + { + public: + virtual void error ( + const unsigned long + ) { } + + virtual void fatal_error ( + const unsigned long line_number + ) + { + std::ostringstream sout; + sout << "There is a fatal error on line " << line_number << " so parsing will now halt."; + throw dlib::error(sout.str()); + } + }; + + // ------------------------------------------------------------------------------------ + + void load_image_dataset_metadata ( + dataset& meta, + const std::string& filename + ) + { + xml_error_handler eh; + doc_handler dh(meta); + + std::ifstream fin(filename.c_str()); + if (!fin) + throw dlib::error("ERROR: unable to open " + filename + " for reading."); + + xml_parser parser; + parser.add_document_handler(dh); + parser.add_error_handler(eh); + parser.parse(fin); + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + // This function returns the contents of the file 'images.xsl' + const std::string get_decoded_string() + { + dlib::base64 base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file 'image_metadata_stylesheet.xsl' we want to decode and return. + sout << "PFWfgmWfCHr1DkV63lbjjeY2dCc2FbHDOVh0Kd7dkvaOfRYrOG24f0x77/5iMVq8FtE3UBxtGwSd"; + sout << "1ZHOHRSHgieNoeBv8ssJQ75RRxYtFKRY3OTPX5eKQoCN9jUaUnHnR4QZtEHgmKqXSs50Yrdd+2Ah"; + sout << "gNyarPZCiR6nvqNvCjtP2MP5FxleqNf8Fylatm2KdsXmrv5K87LYVN7i7JMkmZ++cTXYSOxDmxZi"; + sout << "OiCH8funXUdF9apDW547gCjz9HOQUI6dkz5dYUeFjfp6dFugpnaJyyprFLKq048Qk7+QiL4CNF/G"; + sout << "7e0VpBw8dMpiyRNi2fSQGSZGfIAUQKKT6+rPwQoRH2spdjsdXVWj4XQAqBX87nmqMnqjMhn/Vd1s"; + sout << "W5aoC0drwRGu3Xe3gn9vBL8hBkRXcJvEy6q/lb9bYnsLemhE5Zp/+nTmTBjfT9UFYLcsmgsjC+4n"; + sout << "Bq6h9QlpuyMYqJ8RvW8pp3mFlvXc3Yg+18t5F0hSMQfaIFYAuDPU2lVzPpY+ba0B39iu9IrPCLsS"; + sout << "+tUtSNSmQ74CtzZgKKjkTMA3nwYP2SDmZE3firq42pihT7hdU5vYkes69K8AQl8WZyLPpMww+r0z"; + sout << "+veEHPlAuxF7kL3ZvVjdB+xABwwqDe0kSRHRZINYdUfJwJdfYLyDnYoMjj6afqIJZ7QOBPZ42tV5"; + sout << "3hYOQTFwTNovOastzJJXQe1kxPg1AQ8ynmfjjJZqD0xKedlyeJybP919mVAA23UryHsq9TVlabou"; + sout << "qNl3xZW/mKKktvVsd/nuH62HIv/kgomyhaEUY5HgupupBUbQFZfyljZ5bl3g3V3Y1400Z1xTM/LL"; + sout << "LJpeLdlqoGzIe/19vAN1zUUVId9F/OLNUl3Zoar63yZERSJHcsuq/Pasisp0HIGi7rfI9EIQF7C/"; + sout << "IhLKLZsJ+LOycreQGOJALZIEZHOqxYLSXG0qaPM5bQL/MQJ2OZfwEhQgYOrjaM7oPOHHEfTq5kcO"; + sout << "daMwzefKfxrF2GXbUs0bYsEXsIGwENIUKMliFaAI4qKLxxb94oc+O3BRjWueZjZty2zKawQyTHNd"; + sout << "ltFJBUzfffdZN9Wq4zbPzntkM3U6Ys4LRztx5M15dtbhFeKx5rAf2tPXT6wU01hx7EJxBJzpvoDE"; + sout << "YwEoYVDSYulRKpgk82cHFzzUDgWXbl4paFSe1L1w8r9KHr67SYJDTUG86Lrm6LJ0rw73Xp0NAFcU"; + sout << "MKpiG9g1cHW74HYbUb/yAbtVWt40eB7M637umdo2jWz/r/vP5WnfSMXEbkyWebsa1fFceg/TLWy6"; + sout << "E8OTc4XKB48h1oFIlGagOiprxho3+F3TIcxDSwA="; + + + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin,sout); + + // Return the decoded and decompressed data + return sout.str(); + } + + + } +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_IMAGE_DAtASET_METADATA_CPPh_ + + diff --git a/ml/dlib/dlib/data_io/image_dataset_metadata.h b/ml/dlib/dlib/data_io/image_dataset_metadata.h new file mode 100644 index 000000000..3dac29ba6 --- /dev/null +++ b/ml/dlib/dlib/data_io/image_dataset_metadata.h @@ -0,0 +1,174 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMAGE_DAtASET_METADATA_Hh_ +#define DLIB_IMAGE_DAtASET_METADATA_Hh_ + +#include <string> +#include <vector> +#include "../geometry.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + namespace image_dataset_metadata + { + + // ------------------------------------------------------------------------------------ + + enum gender_t + { + UNKNOWN, + MALE, + FEMALE + }; + + // ------------------------------------------------------------------------------------ + + struct box + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents an annotated rectangular area of an image. + It is typically used to mark the location of an object such as a + person, car, etc. + + The main variable of interest is rect. It gives the location of + the box. All the other variables are optional. + !*/ + + box( + ) : + difficult(false), + truncated(false), + occluded(false), + ignore(false), + pose(0), + detection_score(0), + angle(0), + gender(UNKNOWN), + age(0) + {} + + box ( + const rectangle& rect_ + ) : + rect(rect_), + difficult(false), + truncated(false), + occluded(false), + ignore(false), + pose(0), + detection_score(0), + angle(0), + gender(UNKNOWN), + age(0) + {} + + rectangle rect; + + std::map<std::string,point> parts; + + // optional fields + std::string label; + bool difficult; + bool truncated; + bool occluded; + bool ignore; + double pose; + double detection_score; + + // The angle of the object in radians. Positive values indicate that the + // object at the center of the box is rotated clockwise by angle radians. A + // value of 0 would indicate that the object is in its "standard" upright pose. + // Therefore, to make the object appear upright we would have to rotate the + // image counter-clockwise by angle radians. + double angle; + + gender_t gender; + double age; + + bool has_label() const { return label.size() != 0; } + /*! + ensures + - returns true if label metadata is present and false otherwise. + !*/ + }; + + // ------------------------------------------------------------------------------------ + + struct image + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents an annotated image. + !*/ + + image() {} + image(const std::string& f) : filename(f) {} + + std::string filename; + std::vector<box> boxes; + }; + + // ------------------------------------------------------------------------------------ + + struct dataset + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a labeled set of images. In particular, it + contains the filename for each image as well as annotated boxes. + !*/ + + std::vector<image> images; + std::string comment; + std::string name; + }; + + // ------------------------------------------------------------------------------------ + + void save_image_dataset_metadata ( + const dataset& meta, + const std::string& filename + ); + /*! + ensures + - Writes the contents of the meta object to a file with the given + filename. The file will be in an XML format. + throws + - dlib::error + This exception is thrown if there is an error which prevents + this function from succeeding. + !*/ + + // ------------------------------------------------------------------------------------ + + void load_image_dataset_metadata ( + dataset& meta, + const std::string& filename + ); + /*! + ensures + - Attempts to interpret filename as a file containing XML formatted data + as produced by the save_image_dataset_metadata() function. Then + meta is loaded with the contents of the file. + throws + - dlib::error + This exception is thrown if there is an error which prevents + this function from succeeding. + !*/ + + // ------------------------------------------------------------------------------------ + + } +} + +// ---------------------------------------------------------------------------------------- + +#ifdef NO_MAKEFILE +#include "image_dataset_metadata.cpp" +#endif + +#endif // DLIB_IMAGE_DAtASET_METADATA_Hh_ + diff --git a/ml/dlib/dlib/data_io/libsvm_io.h b/ml/dlib/dlib/data_io/libsvm_io.h new file mode 100644 index 000000000..f365e82d7 --- /dev/null +++ b/ml/dlib/dlib/data_io/libsvm_io.h @@ -0,0 +1,276 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LIBSVM_iO_Hh_ +#define DLIB_LIBSVM_iO_Hh_ + +#include "libsvm_io_abstract.h" + +#include <fstream> +#include <string> +#include <utility> +#include "../algs.h" +#include "../matrix.h" +#include "../string.h" +#include "../svm/sparse_vector.h" +#include <vector> + +namespace dlib +{ + struct sample_data_io_error : public error + { + sample_data_io_error(const std::string& message): error(message) {} + }; + +// ---------------------------------------------------------------------------------------- + + template <typename sample_type, typename label_type, typename alloc1, typename alloc2> + void load_libsvm_formatted_data ( + const std::string& file_name, + std::vector<sample_type, alloc1>& samples, + std::vector<label_type, alloc2>& labels + ) + { + using namespace std; + typedef typename sample_type::value_type pair_type; + typedef typename basic_type<typename pair_type::first_type>::type key_type; + typedef typename pair_type::second_type value_type; + + // You must use unsigned integral key types in your sparse vectors + COMPILE_TIME_ASSERT(is_unsigned_type<key_type>::value); + + samples.clear(); + labels.clear(); + + ifstream fin(file_name.c_str()); + + if (!fin) + throw sample_data_io_error("Unable to open file " + file_name); + + string line; + istringstream sin; + key_type key; + value_type value; + label_type label; + sample_type sample; + long line_num = 0; + while (fin.peek() != EOF) + { + ++line_num; + getline(fin, line); + + string::size_type pos = line.find_first_not_of(" \t\r\n"); + + // ignore empty lines or comment lines + if (pos == string::npos || line[pos] == '#') + continue; + + sin.clear(); + sin.str(line); + sample.clear(); + + sin >> label; + + if (!sin) + throw sample_data_io_error("On line: " + cast_to_string(line_num) + ", error while reading file " + file_name ); + + // eat whitespace + sin >> ws; + + while (sin.peek() != EOF && sin.peek() != '#') + { + + sin >> key >> ws; + + // ignore what should be a : character + if (sin.get() != ':') + throw sample_data_io_error("On line: " + cast_to_string(line_num) + ", error while reading file " + file_name); + + sin >> value; + + if (sin && value != 0) + { + sample.insert(sample.end(), make_pair(key, value)); + } + + sin >> ws; + } + + samples.push_back(sample); + labels.push_back(label); + } + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template <typename sample_type, typename alloc> + typename enable_if<is_const_type<typename sample_type::value_type::first_type> >::type + fix_nonzero_indexing ( + std::vector<sample_type,alloc>& samples + ) + { + typedef typename sample_type::value_type pair_type; + typedef typename basic_type<typename pair_type::first_type>::type key_type; + + if (samples.size() == 0) + return; + + // figure out the min index value + key_type min_idx = samples[0].begin()->first; + for (unsigned long i = 0; i < samples.size(); ++i) + min_idx = std::min(min_idx, samples[i].begin()->first); + + // Now adjust all the samples so that their min index value is zero. + if (min_idx != 0) + { + sample_type temp; + for (unsigned long i = 0; i < samples.size(); ++i) + { + // copy samples[i] into temp but make sure it has a min index of zero. + temp.clear(); + typename sample_type::iterator j; + for (j = samples[i].begin(); j != samples[i].end(); ++j) + { + temp.insert(temp.end(), std::make_pair(j->first-min_idx, j->second)); + } + + // replace the current sample with temp. + samples[i].swap(temp); + } + } + } + +// ---------------------------------------------------------------------------------------- + +// If the "first" values in the std::pair objects are not const then we can modify them +// directly and that is what this version of fix_nonzero_indexing() does. + template <typename sample_type, typename alloc> + typename disable_if<is_const_type<typename sample_type::value_type::first_type> >::type + fix_nonzero_indexing ( + std::vector<sample_type,alloc>& samples + ) + { + typedef typename sample_type::value_type pair_type; + typedef typename basic_type<typename pair_type::first_type>::type key_type; + + if (samples.size() == 0) + return; + + // figure out the min index value + key_type min_idx = samples[0].begin()->first; + for (unsigned long i = 0; i < samples.size(); ++i) + min_idx = std::min(min_idx, samples[i].begin()->first); + + // Now adjust all the samples so that their min index value is zero. + if (min_idx != 0) + { + for (unsigned long i = 0; i < samples.size(); ++i) + { + typename sample_type::iterator j; + for (j = samples[i].begin(); j != samples[i].end(); ++j) + { + j->first -= min_idx; + } + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +// This is an overload for sparse vectors + template <typename sample_type, typename label_type, typename alloc1, typename alloc2> + typename disable_if<is_matrix<sample_type>,void>::type save_libsvm_formatted_data ( + const std::string& file_name, + const std::vector<sample_type, alloc1>& samples, + const std::vector<label_type, alloc2>& labels + ) + { + typedef typename sample_type::value_type pair_type; + typedef typename basic_type<typename pair_type::first_type>::type key_type; + + // You must use unsigned integral key types in your sparse vectors + COMPILE_TIME_ASSERT(is_unsigned_type<key_type>::value); + + // make sure requires clause is not broken + DLIB_ASSERT(samples.size() == labels.size(), + "\t void save_libsvm_formatted_data()" + << "\n\t You have to have labels for each sample and vice versa" + << "\n\t samples.size(): " << samples.size() + << "\n\t labels.size(): " << labels.size() + ); + + + using namespace std; + ofstream fout(file_name.c_str()); + fout.precision(14); + + if (!fout) + throw sample_data_io_error("Unable to open file " + file_name); + + for (unsigned long i = 0; i < samples.size(); ++i) + { + fout << labels[i]; + + for (typename sample_type::const_iterator j = samples[i].begin(); j != samples[i].end(); ++j) + { + if (j->second != 0) + fout << " " << j->first << ":" << j->second; + } + fout << "\n"; + + if (!fout) + throw sample_data_io_error("Error while writing to file " + file_name); + } + + } + +// ---------------------------------------------------------------------------------------- + +// This is an overload for dense vectors + template <typename sample_type, typename label_type, typename alloc1, typename alloc2> + typename enable_if<is_matrix<sample_type>,void>::type save_libsvm_formatted_data ( + const std::string& file_name, + const std::vector<sample_type, alloc1>& samples, + const std::vector<label_type, alloc2>& labels + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(samples.size() == labels.size(), + "\t void save_libsvm_formatted_data()" + << "\n\t You have to have labels for each sample and vice versa" + << "\n\t samples.size(): " << samples.size() + << "\n\t labels.size(): " << labels.size() + ); + + using namespace std; + ofstream fout(file_name.c_str()); + fout.precision(14); + + if (!fout) + throw sample_data_io_error("Unable to open file " + file_name); + + for (unsigned long i = 0; i < samples.size(); ++i) + { + fout << labels[i]; + + for (long j = 0; j < samples[i].size(); ++j) + { + if (samples[i](j) != 0) + fout << " " << j << ":" << samples[i](j); + } + fout << "\n"; + + if (!fout) + throw sample_data_io_error("Error while writing to file " + file_name); + } + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LIBSVM_iO_Hh_ + diff --git a/ml/dlib/dlib/data_io/libsvm_io_abstract.h b/ml/dlib/dlib/data_io/libsvm_io_abstract.h new file mode 100644 index 000000000..88d934fdb --- /dev/null +++ b/ml/dlib/dlib/data_io/libsvm_io_abstract.h @@ -0,0 +1,125 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LIBSVM_iO_ABSTRACT_Hh_ +#ifdef DLIB_LIBSVM_iO_ABSTRACT_Hh_ + +#include <fstream> +#include <string> +#include <utility> +#include "../algs.h" +#include "../matrix.h" +#include <vector> + +namespace dlib +{ + struct sample_data_io_error : public error + { + /*! + This is the exception class used by the file IO functions defined below. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type, + typename label_type, + typename alloc1, + typename alloc2 + > + void load_libsvm_formatted_data ( + const std::string& file_name, + std::vector<sample_type, alloc1>& samples, + std::vector<label_type, alloc2>& labels + ); + /*! + requires + - sample_type must be an STL container + - sample_type::value_type == std::pair<T,U> where T is some kind of + unsigned integral type + ensures + - attempts to read a file of the given name that should contain libsvm + formatted data. We turn the data into sparse vectors and store it + in samples + - #labels.size() == #samples.size() + - for all valid i: #labels[i] is the label for #samples[i] + throws + - sample_data_io_error + This exception is thrown if there is any problem loading data from file + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type, + typename label_type, + typename alloc1, + typename alloc2 + > + void save_libsvm_formatted_data ( + const std::string& file_name, + const std::vector<sample_type, alloc1>& samples, + const std::vector<label_type, alloc2>& labels + ); + /*! + requires + - sample_type must be an STL container + - sample_type::value_type == std::pair<T,U> where T is some kind of + unsigned integral type + - samples.size() == labels.size() + ensures + - saves the data to the given file in libsvm format + throws + - sample_data_io_error + This exception is thrown if there is any problem saving data to file + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type, + typename label_type, + typename alloc1, + typename alloc2 + > + void save_libsvm_formatted_data ( + const std::string& file_name, + const std::vector<sample_type, alloc1>& samples, + const std::vector<label_type, alloc2>& labels + ); + /*! + requires + - sample_type == a dense matrix (i.e. dlib::matrix) + - for all valid i: is_vector(samples[i]) == true + - samples.size() == labels.size() + ensures + - saves the data to the given file in libsvm format + throws + - sample_data_io_error + This exception is thrown if there is any problem saving data to file + !*/ + +// ---------------------------------------------------------------------------------------- + + template <typename sample_type, typename alloc> + void fix_nonzero_indexing ( + std::vector<sample_type,alloc>& samples + ); + /*! + requires + - samples must only contain valid sparse vectors. The definition of + a sparse vector can be found at the top of dlib/svm/sparse_vector_abstract.h + ensures + - Adjusts the sparse vectors in samples so that they are zero-indexed. + Or in other words, assume the smallest used index value in any of the sparse + vectors is N. Then this function subtracts N from all the index values in + samples. This is useful, for example, if you load a libsvm formatted datafile + with features indexed from 1 rather than 0 and you would like to fix this. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LIBSVM_iO_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/data_io/load_image_dataset.h b/ml/dlib/dlib/data_io/load_image_dataset.h new file mode 100644 index 000000000..5664d96b2 --- /dev/null +++ b/ml/dlib/dlib/data_io/load_image_dataset.h @@ -0,0 +1,510 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LOAD_IMAGE_DaTASET_Hh_ +#define DLIB_LOAD_IMAGE_DaTASET_Hh_ + +#include "load_image_dataset_abstract.h" +#include "../misc_api.h" +#include "../dir_nav.h" +#include "../image_io.h" +#include "../array.h" +#include <vector> +#include "../geometry.h" +#include "image_dataset_metadata.h" +#include <string> +#include <set> +#include "../image_processing/full_object_detection.h" +#include <utility> +#include <limits> +#include "../image_transforms/image_pyramid.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class image_dataset_file + { + public: + image_dataset_file(const std::string& filename) + { + _skip_empty_images = false; + _have_parts = false; + _filename = filename; + _box_area_thresh = std::numeric_limits<double>::infinity(); + } + + image_dataset_file boxes_match_label( + const std::string& label + ) const + { + image_dataset_file temp(*this); + temp._labels.insert(label); + return temp; + } + + image_dataset_file skip_empty_images( + ) const + { + image_dataset_file temp(*this); + temp._skip_empty_images = true; + return temp; + } + + image_dataset_file boxes_have_parts( + ) const + { + image_dataset_file temp(*this); + temp._have_parts = true; + return temp; + } + + image_dataset_file shrink_big_images( + double new_box_area_thresh = 150*150 + ) const + { + image_dataset_file temp(*this); + temp._box_area_thresh = new_box_area_thresh; + return temp; + } + + bool should_load_box ( + const image_dataset_metadata::box& box + ) const + { + if (_have_parts && box.parts.size() == 0) + return false; + if (_labels.size() == 0) + return true; + if (_labels.count(box.label) != 0) + return true; + return false; + } + + const std::string& get_filename() const { return _filename; } + bool should_skip_empty_images() const { return _skip_empty_images; } + bool should_boxes_have_parts() const { return _have_parts; } + double box_area_thresh() const { return _box_area_thresh; } + const std::set<std::string>& get_selected_box_labels() const { return _labels; } + + private: + std::string _filename; + std::set<std::string> _labels; + bool _skip_empty_images; + bool _have_parts; + double _box_area_thresh; + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector<std::vector<rectangle> > load_image_dataset ( + array_type& images, + std::vector<std::vector<rectangle> >& object_locations, + const image_dataset_file& source + ) + { + images.clear(); + object_locations.clear(); + + std::vector<std::vector<rectangle> > ignored_rects; + + using namespace dlib::image_dataset_metadata; + dataset data; + load_image_dataset_metadata(data, source.get_filename()); + + // Set the current directory to be the one that contains the + // metadata file. We do this because the file might contain + // file paths which are relative to this folder. + locally_change_current_dir chdir(get_parent_directory(file(source.get_filename()))); + + + typedef typename array_type::value_type image_type; + + + image_type img; + std::vector<rectangle> rects, ignored; + for (unsigned long i = 0; i < data.images.size(); ++i) + { + double min_rect_size = std::numeric_limits<double>::infinity(); + rects.clear(); + ignored.clear(); + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (source.should_load_box(data.images[i].boxes[j])) + { + if (data.images[i].boxes[j].ignore) + { + ignored.push_back(data.images[i].boxes[j].rect); + } + else + { + rects.push_back(data.images[i].boxes[j].rect); + min_rect_size = std::min<double>(min_rect_size, rects.back().area()); + } + } + } + + if (!source.should_skip_empty_images() || rects.size() != 0) + { + load_image(img, data.images[i].filename); + if (rects.size() != 0) + { + // if shrinking the image would still result in the smallest box being + // bigger than the box area threshold then shrink the image. + while(min_rect_size/2/2 > source.box_area_thresh()) + { + pyramid_down<2> pyr; + pyr(img); + min_rect_size *= (1.0/2.0)*(1.0/2.0); + for (auto&& r : rects) + r = pyr.rect_down(r); + for (auto&& r : ignored) + r = pyr.rect_down(r); + } + while(min_rect_size*(2.0/3.0)*(2.0/3.0) > source.box_area_thresh()) + { + pyramid_down<3> pyr; + pyr(img); + min_rect_size *= (2.0/3.0)*(2.0/3.0); + for (auto&& r : rects) + r = pyr.rect_down(r); + for (auto&& r : ignored) + r = pyr.rect_down(r); + } + } + images.push_back(img); + object_locations.push_back(rects); + ignored_rects.push_back(ignored); + } + } + + return ignored_rects; + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline size_t num_non_ignored_boxes (const std::vector<mmod_rect>& rects) + { + size_t cnt = 0; + for (auto& b : rects) + { + if (!b.ignore) + cnt++; + } + return cnt; + } + } + + template < + typename array_type + > + void load_image_dataset ( + array_type& images, + std::vector<std::vector<mmod_rect> >& object_locations, + const image_dataset_file& source + ) + { + images.clear(); + object_locations.clear(); + + using namespace dlib::image_dataset_metadata; + dataset data; + load_image_dataset_metadata(data, source.get_filename()); + + // Set the current directory to be the one that contains the + // metadata file. We do this because the file might contain + // file paths which are relative to this folder. + locally_change_current_dir chdir(get_parent_directory(file(source.get_filename()))); + + typedef typename array_type::value_type image_type; + + image_type img; + std::vector<mmod_rect> rects; + for (unsigned long i = 0; i < data.images.size(); ++i) + { + double min_rect_size = std::numeric_limits<double>::infinity(); + rects.clear(); + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (source.should_load_box(data.images[i].boxes[j])) + { + if (data.images[i].boxes[j].ignore) + { + rects.push_back(ignored_mmod_rect(data.images[i].boxes[j].rect)); + } + else + { + rects.push_back(mmod_rect(data.images[i].boxes[j].rect)); + min_rect_size = std::min<double>(min_rect_size, rects.back().rect.area()); + } + rects.back().label = data.images[i].boxes[j].label; + + } + } + + if (!source.should_skip_empty_images() || impl::num_non_ignored_boxes(rects) != 0) + { + load_image(img, data.images[i].filename); + if (rects.size() != 0) + { + // if shrinking the image would still result in the smallest box being + // bigger than the box area threshold then shrink the image. + while(min_rect_size/2/2 > source.box_area_thresh()) + { + pyramid_down<2> pyr; + pyr(img); + min_rect_size *= (1.0/2.0)*(1.0/2.0); + for (auto&& r : rects) + r.rect = pyr.rect_down(r.rect); + } + while(min_rect_size*(2.0/3.0)*(2.0/3.0) > source.box_area_thresh()) + { + pyramid_down<3> pyr; + pyr(img); + min_rect_size *= (2.0/3.0)*(2.0/3.0); + for (auto&& r : rects) + r.rect = pyr.rect_down(r.rect); + } + } + images.push_back(std::move(img)); + object_locations.push_back(std::move(rects)); + } + } + } + +// ---------------------------------------------------------------------------------------- + +// ******* THIS FUNCTION IS DEPRECATED, you should use another version of load_image_dataset() ******* + template < + typename image_type, + typename MM + > + std::vector<std::vector<rectangle> > load_image_dataset ( + array<image_type,MM>& images, + std::vector<std::vector<rectangle> >& object_locations, + const std::string& filename, + const std::string& label, + bool skip_empty_images = false + ) + { + image_dataset_file f(filename); + if (label.size() != 0) + f = f.boxes_match_label(label); + if (skip_empty_images) + f = f.skip_empty_images(); + return load_image_dataset(images, object_locations, f); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector<std::vector<rectangle> > load_image_dataset ( + array_type& images, + std::vector<std::vector<rectangle> >& object_locations, + const std::string& filename + ) + { + return load_image_dataset(images, object_locations, image_dataset_file(filename)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + void load_image_dataset ( + array_type& images, + std::vector<std::vector<mmod_rect>>& object_locations, + const std::string& filename + ) + { + load_image_dataset(images, object_locations, image_dataset_file(filename)); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector<std::vector<rectangle> > load_image_dataset ( + array_type& images, + std::vector<std::vector<full_object_detection> >& object_locations, + const image_dataset_file& source, + std::vector<std::string>& parts_list + ) + { + typedef typename array_type::value_type image_type; + parts_list.clear(); + images.clear(); + object_locations.clear(); + + using namespace dlib::image_dataset_metadata; + dataset data; + load_image_dataset_metadata(data, source.get_filename()); + + // Set the current directory to be the one that contains the + // metadata file. We do this because the file might contain + // file paths which are relative to this folder. + locally_change_current_dir chdir(get_parent_directory(file(source.get_filename()))); + + + std::set<std::string> all_parts; + + // find out what parts are being used in the dataset. Store results in all_parts. + for (unsigned long i = 0; i < data.images.size(); ++i) + { + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (source.should_load_box(data.images[i].boxes[j])) + { + const std::map<std::string,point>& parts = data.images[i].boxes[j].parts; + std::map<std::string,point>::const_iterator itr; + + for (itr = parts.begin(); itr != parts.end(); ++itr) + { + all_parts.insert(itr->first); + } + } + } + } + + // make a mapping between part names and the integers [0, all_parts.size()) + std::map<std::string,int> parts_idx; + for (std::set<std::string>::iterator i = all_parts.begin(); i != all_parts.end(); ++i) + { + parts_idx[*i] = parts_list.size(); + parts_list.push_back(*i); + } + + std::vector<std::vector<rectangle> > ignored_rects; + std::vector<rectangle> ignored; + image_type img; + std::vector<full_object_detection> object_dets; + for (unsigned long i = 0; i < data.images.size(); ++i) + { + double min_rect_size = std::numeric_limits<double>::infinity(); + object_dets.clear(); + ignored.clear(); + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (source.should_load_box(data.images[i].boxes[j])) + { + if (data.images[i].boxes[j].ignore) + { + ignored.push_back(data.images[i].boxes[j].rect); + } + else + { + std::vector<point> partlist(parts_idx.size(), OBJECT_PART_NOT_PRESENT); + + // populate partlist with all the parts present in this box. + const std::map<std::string,point>& parts = data.images[i].boxes[j].parts; + std::map<std::string,point>::const_iterator itr; + for (itr = parts.begin(); itr != parts.end(); ++itr) + { + partlist[parts_idx[itr->first]] = itr->second; + } + + object_dets.push_back(full_object_detection(data.images[i].boxes[j].rect, partlist)); + min_rect_size = std::min<double>(min_rect_size, object_dets.back().get_rect().area()); + } + } + } + + if (!source.should_skip_empty_images() || object_dets.size() != 0) + { + load_image(img, data.images[i].filename); + if (object_dets.size() != 0) + { + // if shrinking the image would still result in the smallest box being + // bigger than the box area threshold then shrink the image. + while(min_rect_size/2/2 > source.box_area_thresh()) + { + pyramid_down<2> pyr; + pyr(img); + min_rect_size *= (1.0/2.0)*(1.0/2.0); + for (auto&& r : object_dets) + { + r.get_rect() = pyr.rect_down(r.get_rect()); + for (unsigned long k = 0; k < r.num_parts(); ++k) + r.part(k) = pyr.point_down(r.part(k)); + } + for (auto&& r : ignored) + { + r = pyr.rect_down(r); + } + } + while(min_rect_size*(2.0/3.0)*(2.0/3.0) > source.box_area_thresh()) + { + pyramid_down<3> pyr; + pyr(img); + min_rect_size *= (2.0/3.0)*(2.0/3.0); + for (auto&& r : object_dets) + { + r.get_rect() = pyr.rect_down(r.get_rect()); + for (unsigned long k = 0; k < r.num_parts(); ++k) + r.part(k) = pyr.point_down(r.part(k)); + } + for (auto&& r : ignored) + { + r = pyr.rect_down(r); + } + } + } + images.push_back(img); + object_locations.push_back(object_dets); + ignored_rects.push_back(ignored); + } + } + + + return ignored_rects; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector<std::vector<rectangle> > load_image_dataset ( + array_type& images, + std::vector<std::vector<full_object_detection> >& object_locations, + const image_dataset_file& source + ) + { + std::vector<std::string> parts_list; + return load_image_dataset(images, object_locations, source, parts_list); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector<std::vector<rectangle> > load_image_dataset ( + array_type& images, + std::vector<std::vector<full_object_detection> >& object_locations, + const std::string& filename + ) + { + std::vector<std::string> parts_list; + return load_image_dataset(images, object_locations, image_dataset_file(filename), parts_list); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LOAD_IMAGE_DaTASET_Hh_ + diff --git a/ml/dlib/dlib/data_io/load_image_dataset_abstract.h b/ml/dlib/dlib/data_io/load_image_dataset_abstract.h new file mode 100644 index 000000000..b06252098 --- /dev/null +++ b/ml/dlib/dlib/data_io/load_image_dataset_abstract.h @@ -0,0 +1,358 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LOAD_IMAGE_DaTASET_ABSTRACT_Hh_ +#ifdef DLIB_LOAD_IMAGE_DaTASET_ABSTRACT_Hh_ + +#include "image_dataset_metadata.h" +#include "../array/array_kernel_abstract.h" +#include <string> +#include <vector> +#include "../image_processing/full_object_detection_abstract.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class image_dataset_file + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool used to tell the load_image_dataset() functions which + boxes and images to load from an XML based image dataset file. By default, + this object tells load_image_dataset() to load all images and object boxes. + !*/ + + public: + image_dataset_file( + const std::string& filename + ); + /*! + ensures + - #get_filename() == filename + - #should_skip_empty_images() == false + - #get_selected_box_labels().size() == 0 + This means that, initially, all boxes will be loaded. Therefore, for all + possible boxes B we have: + - #should_load_box(B) == true + - #box_area_thresh() == infinity + !*/ + + const std::string& get_filename( + ) const; + /*! + ensures + - returns the name of the XML image dataset metadata file given to this + object's constructor. + !*/ + + bool should_skip_empty_images( + ) const; + /*! + ensures + - returns true if we are supposed to skip images that don't have any + non-ignored boxes to load when loading an image dataset using + load_image_dataset(). + !*/ + + image_dataset_file boxes_match_label( + const std::string& label + ) const; + /*! + ensures + - returns a copy of *this that is identical in all respects to *this except + that label will be included in the labels set (i.e. the set returned by + get_selected_box_labels()). + !*/ + + const std::set<std::string>& get_selected_box_labels( + ) const; + /*! + ensures + - returns the set of box labels currently selected by the should_load_box() + method. Note that if the set is empty then we select all boxes. + !*/ + + image_dataset_file skip_empty_images( + ) const; + /*! + ensures + - returns a copy of *this that is identical in all respects to *this except + that #should_skip_empty_images() == true. + !*/ + + bool should_boxes_have_parts( + ) const; + /*! + ensures + - returns true if boxes must have some parts defined for them to be loaded. + !*/ + + image_dataset_file boxes_have_parts( + ) const; + /*! + ensures + - returns a copy of *this that is identical in all respects to *this except + that #should_boxes_have_parts() == true. + !*/ + + bool should_load_box ( + const image_dataset_metadata::box& box + ) const; + /*! + ensures + - returns true if we are supposed to load the given box from an image + dataset XML file. In particular, if should_load_box() returns false then + the load_image_dataset() routines will not return the box at all, neither + in the ignore rectangles list or in the primary object_locations vector. + The behavior of this function is defined as follows: + - if (should_boxes_have_parts() && boxes.parts.size() == 0) then + - returns false + - else if (get_selected_box_labels().size() == 0) then + - returns true + - else if (get_selected_box_labels().count(box.label) != 0) then + - returns true + - else + - returns false + !*/ + + image_dataset_file shrink_big_images( + double new_box_area_thresh = 150*150 + ) const; + /*! + ensures + - returns a copy of *this that is identical in all respects to *this except + that #box_area_thresh() == new_box_area_thresh + !*/ + + double box_area_thresh( + ) const; + /*! + ensures + - If the smallest non-ignored rectangle in an image has an area greater + than box_area_thresh() then we will shrink the image until the area of + the box is about equal to box_area_thresh(). This is useful if you have + a dataset containing very high resolution images and you don't want to + load it in its native high resolution. Setting the box_area_thresh() + allows you to control the resolution of the loaded images. + !*/ + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector<std::vector<rectangle> > load_image_dataset ( + array_type& images, + std::vector<std::vector<rectangle> >& object_locations, + const image_dataset_file& source + ); + /*! + requires + - array_type == An array of images. This is anything with an interface that + looks like std::vector<some generic image type> where a "generic image" is + anything that implements the generic image interface defined in + dlib/image_processing/generic_image.h. + ensures + - This routine loads the images and their associated object boxes from the + image metadata file indicated by source.get_filename(). This metadata file + should be in the XML format used by the save_image_dataset_metadata() routine. + - #images.size() == The number of images loaded from the metadata file. This + is all the images listed in the file unless source.should_skip_empty_images() + is set to true. + - #images.size() == #object_locations.size() + - This routine is capable of loading any image format which can be read by the + load_image() routine. + - let IGNORED_RECTS denote the vector returned from this function. + - IGNORED_RECTS.size() == #object_locations.size() + - IGNORED_RECTS == a list of the rectangles which have the "ignore" flag set to + true in the input XML file. + - for all valid i: + - #images[i] == a copy of the i-th image from the dataset. + - #object_locations[i] == a vector of all the rectangles associated with + #images[i]. These are the rectangles for which source.should_load_box() + returns true and are also not marked as "ignore" in the XML file. + - IGNORED_RECTS[i] == A vector of all the rectangles associated with #images[i] + that are marked as "ignore" but not discarded by source.should_load_box(). + - if (source.should_skip_empty_images() == true) then + - #object_locations[i].size() != 0 + (i.e. we won't load images that don't end up having any object locations) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector<std::vector<rectangle> > load_image_dataset ( + array_type& images, + std::vector<std::vector<rectangle> >& object_locations, + const std::string& filename + ); + /*! + requires + - array_type == An array of images. This is anything with an interface that + looks like std::vector<some generic image type> where a "generic image" is + anything that implements the generic image interface defined in + dlib/image_processing/generic_image.h. + ensures + - performs: return load_image_dataset(images, object_locations, image_dataset_file(filename)); + (i.e. it ignores box labels and therefore loads all the boxes in the dataset) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + void load_image_dataset ( + array_type& images, + std::vector<std::vector<mmod_rect> >& object_locations, + const image_dataset_file& source + ); + /*! + requires + - array_type == An array of images. This is anything with an interface that + looks like std::vector<some generic image type> where a "generic image" is + anything that implements the generic image interface defined in + dlib/image_processing/generic_image.h. + ensures + - This function has essentially the same behavior as the above + load_image_dataset() routines, except here we output to a vector of + mmod_rects instead of rectangles. In this case, both ignore and non-ignore + rectangles go into object_locations since mmod_rect has an ignore boolean + field that records the ignored/non-ignored state of each rectangle. We also store + a each box's string label into the mmod_rect::label field as well. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + void load_image_dataset ( + array_type& images, + std::vector<std::vector<mmod_rect> >& object_locations, + const std::string& filename + ); + /*! + requires + - array_type == An array of images. This is anything with an interface that + looks like std::vector<some generic image type> where a "generic image" is + anything that implements the generic image interface defined in + dlib/image_processing/generic_image.h. + ensures + - performs: load_image_dataset(images, object_locations, image_dataset_file(filename)); + (i.e. it ignores box labels and therefore loads all the boxes in the dataset) + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector<std::vector<rectangle> > load_image_dataset ( + array_type& images, + std::vector<std::vector<full_object_detection> >& object_locations, + const image_dataset_file& source, + std::vector<std::string>& parts_list + ); + /*! + requires + - array_type == An array of images. This is anything with an interface that + looks like std::vector<some generic image type> where a "generic image" is + anything that implements the generic image interface defined in + dlib/image_processing/generic_image.h. + ensures + - This routine loads the images and their associated object locations from the + image metadata file indicated by source.get_filename(). This metadata file + should be in the XML format used by the save_image_dataset_metadata() routine. + - The difference between this function and the version of load_image_dataset() + defined above is that this version will also load object part information and + thus fully populates the full_object_detection objects. + - #images.size() == The number of images loaded from the metadata file. This + is all the images listed in the file unless source.should_skip_empty_images() + is set to true. + - #images.size() == #object_locations.size() + - This routine is capable of loading any image format which can be read + by the load_image() routine. + - #parts_list == a vector that contains the list of object parts found in the + input file and loaded into object_locations. + - #parts_list is in lexicographic sorted order. + - let IGNORED_RECTS denote the vector returned from this function. + - IGNORED_RECTS.size() == #object_locations.size() + - IGNORED_RECTS == a list of the rectangles which have the "ignore" flag set to + true in the input XML file. + - for all valid i: + - #images[i] == a copy of the i-th image from the dataset. + - #object_locations[i] == a vector of all the rectangles associated with + #images[i]. These are the rectangles for which source.should_load_box() + returns true and are also not marked as "ignore" in the XML file. + - IGNORED_RECTS[i] == A vector of all the rectangles associated with #images[i] + that are marked as "ignore" but not discarded by source.should_load_box(). + - if (source.should_skip_empty_images() == true) then + - #object_locations[i].size() != 0 + (i.e. we won't load images that don't end up having any object locations) + - for all valid j: + - #object_locations[i][j].num_parts() == #parts_list.size() + - for all valid k: + - #object_locations[i][j].part(k) == the location of the part + with name #parts_list[k] or OBJECT_PART_NOT_PRESENT if the + part was not indicated for object #object_locations[i][j]. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector<std::vector<rectangle> > load_image_dataset ( + array_type& images, + std::vector<std::vector<full_object_detection> >& object_locations, + const image_dataset_file& source + ); + /*! + requires + - array_type == An array of images. This is anything with an interface that + looks like std::vector<some generic image type> where a "generic image" is + anything that implements the generic image interface defined in + dlib/image_processing/generic_image.h. + ensures + - performs: return load_image_dataset(images, object_locations, source, parts_list); + (i.e. this function simply calls the above function and discards the output + parts_list. So it is just a convenience function you can call if you don't + care about getting the parts list.) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector<std::vector<rectangle> > load_image_dataset ( + array_type& images, + std::vector<std::vector<full_object_detection> >& object_locations, + const std::string& filename + ); + /*! + requires + - array_type == An array of images. This is anything with an interface that + looks like std::vector<some generic image type> where a "generic image" is + anything that implements the generic image interface defined in + dlib/image_processing/generic_image.h. + ensures + - performs: return load_image_dataset(images, object_locations, image_dataset_file(filename)); + (i.e. it ignores box labels and therefore loads all the boxes in the dataset) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LOAD_IMAGE_DaTASET_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/data_io/mnist.cpp b/ml/dlib/dlib/data_io/mnist.cpp new file mode 100644 index 000000000..d6a62fb67 --- /dev/null +++ b/ml/dlib/dlib/data_io/mnist.cpp @@ -0,0 +1,133 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MNIST_CPp_ +#define DLIB_MNIST_CPp_ + +#include "mnist.h" +#include <fstream> +#include "../byte_orderer.h" +#include "../uintn.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + void load_mnist_dataset ( + const std::string& folder_name, + std::vector<matrix<unsigned char> >& training_images, + std::vector<unsigned long>& training_labels, + std::vector<matrix<unsigned char> >& testing_images, + std::vector<unsigned long>& testing_labels + ) + { + using namespace std; + ifstream fin1((folder_name+"/train-images-idx3-ubyte").c_str(), ios::binary); + if (!fin1) + { + fin1.open((folder_name + "/train-images.idx3-ubyte").c_str(), ios::binary); + } + + ifstream fin2((folder_name+"/train-labels-idx1-ubyte").c_str(), ios::binary); + if (!fin2) + { + fin2.open((folder_name + "/train-labels.idx1-ubyte").c_str(), ios::binary); + } + + ifstream fin3((folder_name+"/t10k-images-idx3-ubyte").c_str(), ios::binary); + if (!fin3) + { + fin3.open((folder_name + "/t10k-images.idx3-ubyte").c_str(), ios::binary); + } + + ifstream fin4((folder_name+"/t10k-labels-idx1-ubyte").c_str(), ios::binary); + if (!fin4) + { + fin4.open((folder_name + "/t10k-labels.idx1-ubyte").c_str(), ios::binary); + } + + if (!fin1) throw error("Unable to open file train-images-idx3-ubyte or train-images.idx3-ubyte"); + if (!fin2) throw error("Unable to open file train-labels-idx1-ubyte or train-labels.idx1-ubyte"); + if (!fin3) throw error("Unable to open file t10k-images-idx3-ubyte or t10k-images.idx3-ubyte"); + if (!fin4) throw error("Unable to open file t10k-labels-idx1-ubyte or t10k-labels.idx1-ubyte"); + + byte_orderer bo; + + // make sure the files have the contents we expect. + uint32 magic, num, nr, nc, num2, num3, num4; + fin1.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic); + fin1.read((char*)&num, sizeof(num)); bo.big_to_host(num); + fin1.read((char*)&nr, sizeof(nr)); bo.big_to_host(nr); + fin1.read((char*)&nc, sizeof(nc)); bo.big_to_host(nc); + if (magic != 2051 || num != 60000 || nr != 28 || nc != 28) + throw error("mndist dat files are corrupted."); + + fin2.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic); + fin2.read((char*)&num2, sizeof(num2)); bo.big_to_host(num2); + if (magic != 2049 || num2 != 60000) + throw error("mndist dat files are corrupted."); + + fin3.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic); + fin3.read((char*)&num3, sizeof(num3)); bo.big_to_host(num3); + fin3.read((char*)&nr, sizeof(nr)); bo.big_to_host(nr); + fin3.read((char*)&nc, sizeof(nc)); bo.big_to_host(nc); + if (magic != 2051 || num3 != 10000 || nr != 28 || nc != 28) + throw error("mndist dat files are corrupted."); + + fin4.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic); + fin4.read((char*)&num4, sizeof(num4)); bo.big_to_host(num4); + if (magic != 2049 || num4 != 10000) + throw error("mndist dat files are corrupted."); + + if (!fin1) throw error("Unable to read train-images-idx3-ubyte"); + if (!fin2) throw error("Unable to read train-labels-idx1-ubyte"); + if (!fin3) throw error("Unable to read t10k-images-idx3-ubyte"); + if (!fin4) throw error("Unable to read t10k-labels-idx1-ubyte"); + + + training_images.resize(60000); + training_labels.resize(60000); + testing_images.resize(10000); + testing_labels.resize(10000); + + for (size_t i = 0; i < training_images.size(); ++i) + { + training_images[i].set_size(nr,nc); + fin1.read((char*)&training_images[i](0,0), nr*nc); + } + for (size_t i = 0; i < training_labels.size(); ++i) + { + char l; + fin2.read(&l, 1); + training_labels[i] = l; + } + + for (size_t i = 0; i < testing_images.size(); ++i) + { + testing_images[i].set_size(nr,nc); + fin3.read((char*)&testing_images[i](0,0), nr*nc); + } + for (size_t i = 0; i < testing_labels.size(); ++i) + { + char l; + fin4.read(&l, 1); + testing_labels[i] = l; + } + + if (!fin1) throw error("Unable to read train-images-idx3-ubyte"); + if (!fin2) throw error("Unable to read train-labels-idx1-ubyte"); + if (!fin3) throw error("Unable to read t10k-images-idx3-ubyte"); + if (!fin4) throw error("Unable to read t10k-labels-idx1-ubyte"); + + if (fin1.get() != EOF) throw error("Unexpected bytes at end of train-images-idx3-ubyte"); + if (fin2.get() != EOF) throw error("Unexpected bytes at end of train-labels-idx1-ubyte"); + if (fin3.get() != EOF) throw error("Unexpected bytes at end of t10k-images-idx3-ubyte"); + if (fin4.get() != EOF) throw error("Unexpected bytes at end of t10k-labels-idx1-ubyte"); + } +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_MNIST_CPp_ + + + diff --git a/ml/dlib/dlib/data_io/mnist.h b/ml/dlib/dlib/data_io/mnist.h new file mode 100644 index 000000000..e71be6f2b --- /dev/null +++ b/ml/dlib/dlib/data_io/mnist.h @@ -0,0 +1,32 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MNIST_Hh_ +#define DLIB_MNIST_Hh_ + +#include "mnist_abstract.h" +#include <string> +#include <vector> +#include "../matrix.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + void load_mnist_dataset ( + const std::string& folder_name, + std::vector<matrix<unsigned char> >& training_images, + std::vector<unsigned long>& training_labels, + std::vector<matrix<unsigned char> >& testing_images, + std::vector<unsigned long>& testing_labels + ); +} + +// ---------------------------------------------------------------------------------------- + +#ifdef NO_MAKEFILE +#include "mnist.cpp" +#endif + +#endif // DLIB_MNIST_Hh_ + + diff --git a/ml/dlib/dlib/data_io/mnist_abstract.h b/ml/dlib/dlib/data_io/mnist_abstract.h new file mode 100644 index 000000000..09121633e --- /dev/null +++ b/ml/dlib/dlib/data_io/mnist_abstract.h @@ -0,0 +1,46 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MNIST_ABSTRACT_Hh_ +#ifdef DLIB_MNIST_ABSTRACT_Hh_ + +#include <string> +#include <vector> +#include "../matrix.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + void load_mnist_dataset ( + const std::string& folder_name, + std::vector<matrix<unsigned char> >& training_images, + std::vector<unsigned long>& training_labels, + std::vector<matrix<unsigned char> >& testing_images, + std::vector<unsigned long>& testing_labels + ); + /*! + ensures + - Attempts to load the MNIST dataset from the hard drive. This is the dataset + of handwritten digits available from http://yann.lecun.com/exdb/mnist/. In + particular, the 4 files comprising the MNIST dataset should be present in the + folder indicated by folder_name. These four files are: + - train-images-idx3-ubyte + - train-labels-idx1-ubyte + - t10k-images-idx3-ubyte + - t10k-labels-idx1-ubyte + - #training_images == The 60,000 training images from the dataset. + - #training_labels == The labels for the contents of #training_images. + I.e. #training_labels[i] is the label of #training_images[i]. + - #testing_images == The 10,000 testing images from the dataset. + - #testing_labels == The labels for the contents of #testing_images. + I.e. #testing_labels[i] is the label of #testing_images[i]. + throws + - dlib::error if some problem prevents us from loading the data or the files + can't be found. + !*/ +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_MNIST_ABSTRACT_Hh_ + |