summaryrefslogtreecommitdiffstats
path: root/ml/dlib/tools/python/src/cnn_face_detector.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/tools/python/src/cnn_face_detector.cpp')
-rw-r--r--ml/dlib/tools/python/src/cnn_face_detector.cpp183
1 files changed, 183 insertions, 0 deletions
diff --git a/ml/dlib/tools/python/src/cnn_face_detector.cpp b/ml/dlib/tools/python/src/cnn_face_detector.cpp
new file mode 100644
index 000000000..f18d99d95
--- /dev/null
+++ b/ml/dlib/tools/python/src/cnn_face_detector.cpp
@@ -0,0 +1,183 @@
+// Copyright (C) 2017 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "opaque_types.h"
+#include <dlib/python.h>
+#include <dlib/matrix.h>
+#include <dlib/dnn.h>
+#include <dlib/image_transforms.h>
+#include "indexing.h"
+#include <pybind11/stl_bind.h>
+
+using namespace dlib;
+using namespace std;
+
+namespace py = pybind11;
+
+
+class cnn_face_detection_model_v1
+{
+
+public:
+
+ cnn_face_detection_model_v1(const std::string& model_filename)
+ {
+ deserialize(model_filename) >> net;
+ }
+
+ std::vector<mmod_rect> detect (
+ py::object pyimage,
+ const int upsample_num_times
+ )
+ {
+ pyramid_down<2> pyr;
+ std::vector<mmod_rect> rects;
+
+ // Copy the data into dlib based objects
+ matrix<rgb_pixel> image;
+ if (is_gray_python_image(pyimage))
+ assign_image(image, numpy_gray_image(pyimage));
+ else if (is_rgb_python_image(pyimage))
+ assign_image(image, numpy_rgb_image(pyimage));
+ else
+ throw dlib::error("Unsupported image type, must be 8bit gray or RGB image.");
+
+ // Upsampling the image will allow us to detect smaller faces but will cause the
+ // program to use more RAM and run longer.
+ unsigned int levels = upsample_num_times;
+ while (levels > 0)
+ {
+ levels--;
+ pyramid_up(image, pyr);
+ }
+
+ auto dets = net(image);
+
+ // Scale the detection locations back to the original image size
+ // if the image was upscaled.
+ for (auto&& d : dets) {
+ d.rect = pyr.rect_down(d.rect, upsample_num_times);
+ rects.push_back(d);
+ }
+
+ return rects;
+ }
+
+ std::vector<std::vector<mmod_rect> > detect_mult (
+ py::list imgs,
+ const int upsample_num_times,
+ const int batch_size = 128
+ )
+ {
+ pyramid_down<2> pyr;
+ std::vector<matrix<rgb_pixel> > dimgs;
+ dimgs.reserve(len(imgs));
+
+ for(int i = 0; i < len(imgs); i++)
+ {
+ // Copy the data into dlib based objects
+ matrix<rgb_pixel> image;
+ py::object tmp = imgs[i].cast<py::object>();
+ if (is_gray_python_image(tmp))
+ assign_image(image, numpy_gray_image(tmp));
+ else if (is_rgb_python_image(tmp))
+ assign_image(image, numpy_rgb_image(tmp));
+ else
+ throw dlib::error("Unsupported image type, must be 8bit gray or RGB image.");
+
+ for(int i = 0; i < upsample_num_times; i++)
+ {
+ pyramid_up(image);
+ }
+ dimgs.push_back(image);
+ }
+
+ for(int i = 1; i < dimgs.size(); i++)
+ {
+ if
+ (
+ dimgs[i - 1].nc() != dimgs[i].nc() ||
+ dimgs[i - 1].nr() != dimgs[i].nr()
+ )
+ throw dlib::error("Images in list must all have the same dimensions.");
+
+ }
+
+ auto dets = net(dimgs, batch_size);
+ std::vector<std::vector<mmod_rect> > all_rects;
+
+ for(auto&& im_dets : dets)
+ {
+ std::vector<mmod_rect> rects;
+ rects.reserve(im_dets.size());
+ for (auto&& d : im_dets) {
+ d.rect = pyr.rect_down(d.rect, upsample_num_times);
+ rects.push_back(d);
+ }
+ all_rects.push_back(rects);
+ }
+
+ return all_rects;
+ }
+
+private:
+
+ template <long num_filters, typename SUBNET> using con5d = con<num_filters,5,5,2,2,SUBNET>;
+ template <long num_filters, typename SUBNET> using con5 = con<num_filters,5,5,1,1,SUBNET>;
+
+ template <typename SUBNET> using downsampler = relu<affine<con5d<32, relu<affine<con5d<32, relu<affine<con5d<16,SUBNET>>>>>>>>>;
+ template <typename SUBNET> using rcon5 = relu<affine<con5<45,SUBNET>>>;
+
+ using net_type = loss_mmod<con<1,9,9,1,1,rcon5<rcon5<rcon5<downsampler<input_rgb_image_pyramid<pyramid_down<6>>>>>>>>;
+
+ net_type net;
+};
+
+// ----------------------------------------------------------------------------------------
+
+void bind_cnn_face_detection(py::module& m)
+{
+ {
+ py::class_<cnn_face_detection_model_v1>(m, "cnn_face_detection_model_v1", "This object detects human faces in an image. The constructor loads the face detection model from a file. You can download a pre-trained model from http://dlib.net/files/mmod_human_face_detector.dat.bz2.")
+ .def(py::init<std::string>())
+ .def(
+ "__call__",
+ &cnn_face_detection_model_v1::detect_mult,
+ py::arg("imgs"), py::arg("upsample_num_times")=0, py::arg("batch_size")=128,
+ "takes a list of images as input returning a 2d list of mmod rectangles"
+ )
+ .def(
+ "__call__",
+ &cnn_face_detection_model_v1::detect,
+ py::arg("img"), py::arg("upsample_num_times")=0,
+ "Find faces in an image using a deep learning model.\n\
+ - Upsamples the image upsample_num_times before running the face \n\
+ detector."
+ );
+ }
+
+ m.def("set_dnn_prefer_smallest_algorithms", &set_dnn_prefer_smallest_algorithms, "Tells cuDNN to use slower algorithms that use less RAM.");
+
+ auto cuda = m.def_submodule("cuda", "Routines for setting CUDA specific properties.");
+ cuda.def("set_device", &dlib::cuda::set_device, py::arg("device_id"),
+ "Set the active CUDA device. It is required that 0 <= device_id < get_num_devices().");
+ cuda.def("get_device", &dlib::cuda::get_device, "Get the active CUDA device.");
+ cuda.def("get_num_devices", &dlib::cuda::get_num_devices, "Find out how many CUDA devices are available.");
+
+ {
+ typedef mmod_rect type;
+ py::class_<type>(m, "mmod_rectangle", "Wrapper around a rectangle object and a detection confidence score.")
+ .def_readwrite("rect", &type::rect)
+ .def_readwrite("confidence", &type::detection_confidence);
+ }
+ {
+ typedef std::vector<mmod_rect> type;
+ py::bind_vector<type>(m, "mmod_rectangles", "An array of mmod rectangle objects.")
+ .def("extend", extend_vector_with_python_list<mmod_rect>);
+ }
+ {
+ typedef std::vector<std::vector<mmod_rect> > type;
+ py::bind_vector<type>(m, "mmod_rectangless", "A 2D array of mmod rectangle objects.")
+ .def("extend", extend_vector_with_python_list<std::vector<mmod_rect>>);
+ }
+}