summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/python
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/python')
-rw-r--r--ml/dlib/dlib/python/numpy.h214
-rw-r--r--ml/dlib/dlib/python/numpy_image.h129
-rw-r--r--ml/dlib/dlib/python/pyassert.h17
-rw-r--r--ml/dlib/dlib/python/pybind_utils.h82
-rw-r--r--ml/dlib/dlib/python/serialize_pickle.h66
5 files changed, 508 insertions, 0 deletions
diff --git a/ml/dlib/dlib/python/numpy.h b/ml/dlib/dlib/python/numpy.h
new file mode 100644
index 000000000..9b2c1a01c
--- /dev/null
+++ b/ml/dlib/dlib/python/numpy.h
@@ -0,0 +1,214 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PYTHON_NuMPY_Hh_
+#define DLIB_PYTHON_NuMPY_Hh_
+
+#include <pybind11/pybind11.h>
+#include <dlib/error.h>
+#include <dlib/algs.h>
+#include <dlib/string.h>
+#include <dlib/array.h>
+#include <dlib/pixel.h>
+
+namespace py = pybind11;
+
+// ----------------------------------------------------------------------------------------
+
+template <typename TT>
+void validate_numpy_array_type (
+ const py::object& obj
+)
+{
+ const char ch = obj.attr("dtype").attr("char").cast<char>();
+
+ using T = typename dlib::pixel_traits<TT>::basic_pixel_type;
+
+ if (dlib::is_same_type<T,double>::value)
+ {
+ if (ch != 'd')
+ throw dlib::error("Expected numpy.ndarray of float64");
+ }
+ else if (dlib::is_same_type<T,float>::value)
+ {
+ if (ch != 'f')
+ throw dlib::error("Expected numpy.ndarray of float32");
+ }
+ else if (dlib::is_same_type<T,dlib::int16>::value)
+ {
+ if (ch != 'h')
+ throw dlib::error("Expected numpy.ndarray of int16");
+ }
+ else if (dlib::is_same_type<T,dlib::uint16>::value)
+ {
+ if (ch != 'H')
+ throw dlib::error("Expected numpy.ndarray of uint16");
+ }
+ else if (dlib::is_same_type<T,dlib::int32>::value)
+ {
+ if (ch != 'i')
+ throw dlib::error("Expected numpy.ndarray of int32");
+ }
+ else if (dlib::is_same_type<T,dlib::uint32>::value)
+ {
+ if (ch != 'I')
+ throw dlib::error("Expected numpy.ndarray of uint32");
+ }
+ else if (dlib::is_same_type<T,unsigned char>::value)
+ {
+ if (ch != 'B')
+ throw dlib::error("Expected numpy.ndarray of uint8");
+ }
+ else if (dlib::is_same_type<T,signed char>::value)
+ {
+ if (ch != 'b')
+ throw dlib::error("Expected numpy.ndarray of int8");
+ }
+ else
+ {
+ throw dlib::error("validate_numpy_array_type() called with unsupported type.");
+ }
+}
+
+// ----------------------------------------------------------------------------------------
+
+template <int dims>
+void get_numpy_ndarray_shape (
+ const py::object& obj,
+ long (&shape)[dims]
+)
+/*!
+ ensures
+ - stores the shape of the array into #shape.
+ - the dimension of the given numpy array is not greater than #dims.
+!*/
+{
+ Py_buffer pybuf;
+ if (PyObject_GetBuffer(obj.ptr(), &pybuf, PyBUF_STRIDES ))
+ throw dlib::error("Expected numpy.ndarray with shape set.");
+
+ try
+ {
+
+ if (pybuf.ndim > dims)
+ throw dlib::error("Expected array with " + dlib::cast_to_string(dims) + " dimensions.");
+
+ for (int i = 0; i < dims; ++i)
+ {
+ if (i < pybuf.ndim)
+ shape[i] = pybuf.shape[i];
+ else
+ shape[i] = 1;
+ }
+ }
+ catch(...)
+ {
+ PyBuffer_Release(&pybuf);
+ throw;
+ }
+ PyBuffer_Release(&pybuf);
+}
+
+// ----------------------------------------------------------------------------------------
+
+template <typename T, int dims>
+void get_numpy_ndarray_parts (
+ py::object& obj,
+ T*& data,
+ dlib::array<T>& contig_buf,
+ long (&shape)[dims]
+)
+/*!
+ ensures
+ - extracts the pointer to the data from the given numpy ndarray. Stores the shape
+ of the array into #shape.
+ - the dimension of the given numpy array is not greater than #dims.
+ - #shape[#dims-1] == pixel_traits<T>::num when #dims is greater than 2
+!*/
+{
+ Py_buffer pybuf;
+ if (PyObject_GetBuffer(obj.ptr(), &pybuf, PyBUF_STRIDES | PyBUF_WRITABLE ))
+ throw dlib::error("Expected writable numpy.ndarray with shape set.");
+
+ try
+ {
+ validate_numpy_array_type<T>(obj);
+
+ if (pybuf.ndim > dims)
+ throw dlib::error("Expected array with " + dlib::cast_to_string(dims) + " dimensions.");
+ get_numpy_ndarray_shape(obj, shape);
+
+ if (dlib::pixel_traits<T>::num > 1 && dlib::pixel_traits<T>::num != shape[dims-1])
+ throw dlib::error("Expected numpy.ndarray with " + dlib::cast_to_string(dlib::pixel_traits<T>::num) + " channels.");
+
+ if (PyBuffer_IsContiguous(&pybuf, 'C'))
+ data = (T*)pybuf.buf;
+ else
+ {
+ contig_buf.resize(pybuf.len);
+ if (PyBuffer_ToContiguous(&contig_buf[0], &pybuf, pybuf.len, 'C'))
+ throw dlib::error("Can't copy numpy.ndarray to a contiguous buffer.");
+ data = &contig_buf[0];
+ }
+ }
+ catch(...)
+ {
+ PyBuffer_Release(&pybuf);
+ throw;
+ }
+ PyBuffer_Release(&pybuf);
+}
+
+// ----------------------------------------------------------------------------------------
+
+template <typename T, int dims>
+void get_numpy_ndarray_parts (
+ const py::object& obj,
+ const T*& data,
+ dlib::array<T>& contig_buf,
+ long (&shape)[dims]
+)
+/*!
+ ensures
+ - extracts the pointer to the data from the given numpy ndarray. Stores the shape
+ of the array into #shape.
+ - the dimension of the given numpy array is not greater than #dims.
+ - #shape[#dims-1] == pixel_traits<T>::num when #dims is greater than 2
+!*/
+{
+ Py_buffer pybuf;
+ if (PyObject_GetBuffer(obj.ptr(), &pybuf, PyBUF_STRIDES ))
+ throw dlib::error("Expected numpy.ndarray with shape set.");
+
+ try
+ {
+ validate_numpy_array_type<T>(obj);
+
+ if (pybuf.ndim > dims)
+ throw dlib::error("Expected array with " + dlib::cast_to_string(dims) + " dimensions.");
+ get_numpy_ndarray_shape(obj, shape);
+
+ if (dlib::pixel_traits<T>::num > 1 && dlib::pixel_traits<T>::num != shape[dims-1])
+ throw dlib::error("Expected numpy.ndarray with " + dlib::cast_to_string(dlib::pixel_traits<T>::num) + " channels.");
+
+ if (PyBuffer_IsContiguous(&pybuf, 'C'))
+ data = (const T*)pybuf.buf;
+ else
+ {
+ contig_buf.resize(pybuf.len);
+ if (PyBuffer_ToContiguous(&contig_buf[0], &pybuf, pybuf.len, 'C'))
+ throw dlib::error("Can't copy numpy.ndarray to a contiguous buffer.");
+ data = &contig_buf[0];
+ }
+ }
+ catch(...)
+ {
+ PyBuffer_Release(&pybuf);
+ throw;
+ }
+ PyBuffer_Release(&pybuf);
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_PYTHON_NuMPY_Hh_
+
diff --git a/ml/dlib/dlib/python/numpy_image.h b/ml/dlib/dlib/python/numpy_image.h
new file mode 100644
index 000000000..49ea80317
--- /dev/null
+++ b/ml/dlib/dlib/python/numpy_image.h
@@ -0,0 +1,129 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PYTHON_NuMPY_IMAGE_Hh_
+#define DLIB_PYTHON_NuMPY_IMAGE_Hh_
+
+#include "numpy.h"
+#include <dlib/pixel.h>
+#include <dlib/matrix.h>
+#include <dlib/array.h>
+
+
+// ----------------------------------------------------------------------------------------
+
+class numpy_gray_image
+{
+public:
+
+ numpy_gray_image() : _data(0), _nr(0), _nc(0) {}
+ numpy_gray_image (py::object& img)
+ {
+ long shape[2];
+ get_numpy_ndarray_parts(img, _data, _contig_buf, shape);
+ _nr = shape[0];
+ _nc = shape[1];
+ }
+
+ friend inline long num_rows(const numpy_gray_image& img) { return img._nr; }
+ friend inline long num_columns(const numpy_gray_image& img) { return img._nc; }
+ friend inline void* image_data(numpy_gray_image& img) { return img._data; }
+ friend inline const void* image_data(const numpy_gray_image& img) { return img._data; }
+ friend inline long width_step(const numpy_gray_image& img) { return img._nc*sizeof(unsigned char); }
+
+private:
+
+ unsigned char* _data;
+ dlib::array<unsigned char> _contig_buf;
+ long _nr;
+ long _nc;
+};
+
+namespace dlib
+{
+ template <>
+ struct image_traits<numpy_gray_image >
+ {
+ typedef unsigned char pixel_type;
+ };
+}
+
+// ----------------------------------------------------------------------------------------
+
+inline bool is_gray_python_image (py::object& img)
+{
+ try
+ {
+ long shape[2];
+ get_numpy_ndarray_shape(img, shape);
+ return true;
+ }
+ catch (dlib::error&)
+ {
+ return false;
+ }
+}
+
+// ----------------------------------------------------------------------------------------
+
+class numpy_rgb_image
+{
+public:
+
+ numpy_rgb_image() : _data(0), _nr(0), _nc(0) {}
+ numpy_rgb_image (py::object& img)
+ {
+ long shape[3];
+ get_numpy_ndarray_parts(img, _data, _contig_buf, shape);
+ _nr = shape[0];
+ _nc = shape[1];
+ if (shape[2] != 3)
+ throw dlib::error("Error, python object is not a three band image and therefore can't be a RGB image.");
+ }
+
+ friend inline long num_rows(const numpy_rgb_image& img) { return img._nr; }
+ friend inline long num_columns(const numpy_rgb_image& img) { return img._nc; }
+ friend inline void* image_data(numpy_rgb_image& img) { return img._data; }
+ friend inline const void* image_data(const numpy_rgb_image& img) { return img._data; }
+ friend inline long width_step(const numpy_rgb_image& img) { return img._nc*sizeof(dlib::rgb_pixel); }
+
+
+private:
+
+ dlib::rgb_pixel* _data;
+ dlib::array<dlib::rgb_pixel> _contig_buf;
+ long _nr;
+ long _nc;
+};
+
+namespace dlib
+{
+ template <>
+ struct image_traits<numpy_rgb_image >
+ {
+ typedef rgb_pixel pixel_type;
+ };
+}
+
+// ----------------------------------------------------------------------------------------
+
+
+inline bool is_rgb_python_image (py::object& img)
+{
+ try
+ {
+ long shape[3];
+ get_numpy_ndarray_shape(img, shape);
+ if (shape[2] == 3)
+ return true;
+ return false;
+ }
+ catch (dlib::error&)
+ {
+ return false;
+ }
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_PYTHON_NuMPY_IMAGE_Hh_
+
diff --git a/ml/dlib/dlib/python/pyassert.h b/ml/dlib/dlib/python/pyassert.h
new file mode 100644
index 000000000..80939f501
--- /dev/null
+++ b/ml/dlib/dlib/python/pyassert.h
@@ -0,0 +1,17 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PYaSSERT_Hh_
+#define DLIB_PYaSSERT_Hh_
+
+#include <pybind11/pybind11.h>
+
+#define pyassert(_exp,_message) \
+ {if ( !(_exp) ) \
+ { \
+ namespace py = pybind11; \
+ PyErr_SetString( PyExc_ValueError, _message ); \
+ throw py::error_already_set(); \
+ }}
+
+#endif // DLIB_PYaSSERT_Hh_
+
diff --git a/ml/dlib/dlib/python/pybind_utils.h b/ml/dlib/dlib/python/pybind_utils.h
new file mode 100644
index 000000000..7f94cf32d
--- /dev/null
+++ b/ml/dlib/dlib/python/pybind_utils.h
@@ -0,0 +1,82 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PYBIND_UtILS_Hh_
+#define DLIB_PYBIND_UtILS_Hh_
+
+#include <pybind11/pybind11.h>
+#include <vector>
+#include <string>
+#include <dlib/serialize.h>
+
+namespace py = pybind11;
+
+template <typename T>
+std::vector<T> python_list_to_vector (
+ const py::list& obj
+)
+/*!
+ ensures
+ - converts a python object into a std::vector<T> and returns it.
+!*/
+{
+ std::vector<T> vect(len(obj));
+ for (unsigned long i = 0; i < vect.size(); ++i)
+ {
+ vect[i] = obj[i].cast<T>();
+ }
+ return vect;
+}
+
+template <typename T>
+py::list vector_to_python_list (
+ const std::vector<T>& vect
+)
+/*!
+ ensures
+ - converts a std::vector<T> into a python list object.
+!*/
+{
+ py::list obj;
+ for (unsigned long i = 0; i < vect.size(); ++i)
+ obj.append(vect[i]);
+ return obj;
+}
+
+template <typename T>
+void extend_vector_with_python_list (
+ std::vector<T> &v,
+ const py::list &l
+)
+/*!
+ ensures
+ - appends items from a python list to the end of std::vector<T>.
+!*/
+{
+ for (const auto &item : l)
+ v.push_back(item.cast<T>());
+}
+
+// ----------------------------------------------------------------------------------------
+
+template <typename T>
+std::shared_ptr<T> load_object_from_file (
+ const std::string& filename
+)
+/*!
+ ensures
+ - deserializes an object of type T from the given file and returns it.
+!*/
+{
+ std::ifstream fin(filename.c_str(), std::ios::binary);
+ if (!fin)
+ throw dlib::error("Unable to open " + filename);
+ auto obj = std::make_shared<T>();
+ deserialize(*obj, fin);
+ return obj;
+}
+
+// ----------------------------------------------------------------------------------------
+
+
+#endif // DLIB_PYBIND_UtILS_Hh_
+
diff --git a/ml/dlib/dlib/python/serialize_pickle.h b/ml/dlib/dlib/python/serialize_pickle.h
new file mode 100644
index 000000000..2dc44c322
--- /dev/null
+++ b/ml/dlib/dlib/python/serialize_pickle.h
@@ -0,0 +1,66 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SERIALIZE_PiCKLE_Hh_
+#define DLIB_SERIALIZE_PiCKLE_Hh_
+
+#include <dlib/serialize.h>
+#include <pybind11/pybind11.h>
+#include <sstream>
+#include <dlib/vectorstream.h>
+
+template<typename T>
+py::tuple getstate(const T& item)
+{
+ using namespace dlib;
+ std::vector<char> buf;
+ buf.reserve(5000);
+ vectorstream sout(buf);
+ serialize(item, sout);
+ return py::make_tuple(py::handle(
+ PyBytes_FromStringAndSize(buf.size()?&buf[0]:0, buf.size())));
+}
+
+template<typename T>
+T setstate(py::tuple state)
+{
+ using namespace dlib;
+ if (len(state) != 1)
+ {
+ PyErr_SetObject(PyExc_ValueError,
+ py::str("expected 1-item tuple in call to __setstate__; got {}").format(state).ptr()
+ );
+ throw py::error_already_set();
+ }
+
+ // We used to serialize by converting to a str but the boost.python routines for
+ // doing this don't work in Python 3. You end up getting an error about invalid
+ // UTF-8 encodings. So instead we access the python C interface directly and use
+ // bytes objects. However, we keep the deserialization code that worked with str
+ // for backwards compatibility with previously pickled files.
+ T item;
+ py::object obj = state[0];
+ if (py::isinstance<py::str>(obj))
+ {
+ py::str data = state[0].cast<py::str>();
+ std::string temp = data;
+ std::istringstream sin(temp);
+ deserialize(item, sin);
+ }
+ else if(PyBytes_Check(py::object(state[0]).ptr()))
+ {
+ py::object obj = state[0];
+ char* data = PyBytes_AsString(obj.ptr());
+ unsigned long num = PyBytes_Size(obj.ptr());
+ std::istringstream sin(std::string(data, num));
+ deserialize(item, sin);
+ }
+ else
+ {
+ throw error("Unable to unpickle, error in input file.");
+ }
+
+ return item;
+}
+
+#endif // DLIB_SERIALIZE_PiCKLE_Hh_
+