diff options
Diffstat (limited to 'ml/dlib/dlib/python')
-rw-r--r-- | ml/dlib/dlib/python/numpy.h | 214 | ||||
-rw-r--r-- | ml/dlib/dlib/python/numpy_image.h | 129 | ||||
-rw-r--r-- | ml/dlib/dlib/python/pyassert.h | 17 | ||||
-rw-r--r-- | ml/dlib/dlib/python/pybind_utils.h | 82 | ||||
-rw-r--r-- | ml/dlib/dlib/python/serialize_pickle.h | 66 |
5 files changed, 508 insertions, 0 deletions
diff --git a/ml/dlib/dlib/python/numpy.h b/ml/dlib/dlib/python/numpy.h new file mode 100644 index 000000000..9b2c1a01c --- /dev/null +++ b/ml/dlib/dlib/python/numpy.h @@ -0,0 +1,214 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PYTHON_NuMPY_Hh_ +#define DLIB_PYTHON_NuMPY_Hh_ + +#include <pybind11/pybind11.h> +#include <dlib/error.h> +#include <dlib/algs.h> +#include <dlib/string.h> +#include <dlib/array.h> +#include <dlib/pixel.h> + +namespace py = pybind11; + +// ---------------------------------------------------------------------------------------- + +template <typename TT> +void validate_numpy_array_type ( + const py::object& obj +) +{ + const char ch = obj.attr("dtype").attr("char").cast<char>(); + + using T = typename dlib::pixel_traits<TT>::basic_pixel_type; + + if (dlib::is_same_type<T,double>::value) + { + if (ch != 'd') + throw dlib::error("Expected numpy.ndarray of float64"); + } + else if (dlib::is_same_type<T,float>::value) + { + if (ch != 'f') + throw dlib::error("Expected numpy.ndarray of float32"); + } + else if (dlib::is_same_type<T,dlib::int16>::value) + { + if (ch != 'h') + throw dlib::error("Expected numpy.ndarray of int16"); + } + else if (dlib::is_same_type<T,dlib::uint16>::value) + { + if (ch != 'H') + throw dlib::error("Expected numpy.ndarray of uint16"); + } + else if (dlib::is_same_type<T,dlib::int32>::value) + { + if (ch != 'i') + throw dlib::error("Expected numpy.ndarray of int32"); + } + else if (dlib::is_same_type<T,dlib::uint32>::value) + { + if (ch != 'I') + throw dlib::error("Expected numpy.ndarray of uint32"); + } + else if (dlib::is_same_type<T,unsigned char>::value) + { + if (ch != 'B') + throw dlib::error("Expected numpy.ndarray of uint8"); + } + else if (dlib::is_same_type<T,signed char>::value) + { + if (ch != 'b') + throw dlib::error("Expected numpy.ndarray of int8"); + } + else + { + throw dlib::error("validate_numpy_array_type() called with unsupported type."); + } +} + +// ---------------------------------------------------------------------------------------- + +template <int dims> +void get_numpy_ndarray_shape ( + const py::object& obj, + long (&shape)[dims] +) +/*! + ensures + - stores the shape of the array into #shape. + - the dimension of the given numpy array is not greater than #dims. +!*/ +{ + Py_buffer pybuf; + if (PyObject_GetBuffer(obj.ptr(), &pybuf, PyBUF_STRIDES )) + throw dlib::error("Expected numpy.ndarray with shape set."); + + try + { + + if (pybuf.ndim > dims) + throw dlib::error("Expected array with " + dlib::cast_to_string(dims) + " dimensions."); + + for (int i = 0; i < dims; ++i) + { + if (i < pybuf.ndim) + shape[i] = pybuf.shape[i]; + else + shape[i] = 1; + } + } + catch(...) + { + PyBuffer_Release(&pybuf); + throw; + } + PyBuffer_Release(&pybuf); +} + +// ---------------------------------------------------------------------------------------- + +template <typename T, int dims> +void get_numpy_ndarray_parts ( + py::object& obj, + T*& data, + dlib::array<T>& contig_buf, + long (&shape)[dims] +) +/*! + ensures + - extracts the pointer to the data from the given numpy ndarray. Stores the shape + of the array into #shape. + - the dimension of the given numpy array is not greater than #dims. + - #shape[#dims-1] == pixel_traits<T>::num when #dims is greater than 2 +!*/ +{ + Py_buffer pybuf; + if (PyObject_GetBuffer(obj.ptr(), &pybuf, PyBUF_STRIDES | PyBUF_WRITABLE )) + throw dlib::error("Expected writable numpy.ndarray with shape set."); + + try + { + validate_numpy_array_type<T>(obj); + + if (pybuf.ndim > dims) + throw dlib::error("Expected array with " + dlib::cast_to_string(dims) + " dimensions."); + get_numpy_ndarray_shape(obj, shape); + + if (dlib::pixel_traits<T>::num > 1 && dlib::pixel_traits<T>::num != shape[dims-1]) + throw dlib::error("Expected numpy.ndarray with " + dlib::cast_to_string(dlib::pixel_traits<T>::num) + " channels."); + + if (PyBuffer_IsContiguous(&pybuf, 'C')) + data = (T*)pybuf.buf; + else + { + contig_buf.resize(pybuf.len); + if (PyBuffer_ToContiguous(&contig_buf[0], &pybuf, pybuf.len, 'C')) + throw dlib::error("Can't copy numpy.ndarray to a contiguous buffer."); + data = &contig_buf[0]; + } + } + catch(...) + { + PyBuffer_Release(&pybuf); + throw; + } + PyBuffer_Release(&pybuf); +} + +// ---------------------------------------------------------------------------------------- + +template <typename T, int dims> +void get_numpy_ndarray_parts ( + const py::object& obj, + const T*& data, + dlib::array<T>& contig_buf, + long (&shape)[dims] +) +/*! + ensures + - extracts the pointer to the data from the given numpy ndarray. Stores the shape + of the array into #shape. + - the dimension of the given numpy array is not greater than #dims. + - #shape[#dims-1] == pixel_traits<T>::num when #dims is greater than 2 +!*/ +{ + Py_buffer pybuf; + if (PyObject_GetBuffer(obj.ptr(), &pybuf, PyBUF_STRIDES )) + throw dlib::error("Expected numpy.ndarray with shape set."); + + try + { + validate_numpy_array_type<T>(obj); + + if (pybuf.ndim > dims) + throw dlib::error("Expected array with " + dlib::cast_to_string(dims) + " dimensions."); + get_numpy_ndarray_shape(obj, shape); + + if (dlib::pixel_traits<T>::num > 1 && dlib::pixel_traits<T>::num != shape[dims-1]) + throw dlib::error("Expected numpy.ndarray with " + dlib::cast_to_string(dlib::pixel_traits<T>::num) + " channels."); + + if (PyBuffer_IsContiguous(&pybuf, 'C')) + data = (const T*)pybuf.buf; + else + { + contig_buf.resize(pybuf.len); + if (PyBuffer_ToContiguous(&contig_buf[0], &pybuf, pybuf.len, 'C')) + throw dlib::error("Can't copy numpy.ndarray to a contiguous buffer."); + data = &contig_buf[0]; + } + } + catch(...) + { + PyBuffer_Release(&pybuf); + throw; + } + PyBuffer_Release(&pybuf); +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_PYTHON_NuMPY_Hh_ + diff --git a/ml/dlib/dlib/python/numpy_image.h b/ml/dlib/dlib/python/numpy_image.h new file mode 100644 index 000000000..49ea80317 --- /dev/null +++ b/ml/dlib/dlib/python/numpy_image.h @@ -0,0 +1,129 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PYTHON_NuMPY_IMAGE_Hh_ +#define DLIB_PYTHON_NuMPY_IMAGE_Hh_ + +#include "numpy.h" +#include <dlib/pixel.h> +#include <dlib/matrix.h> +#include <dlib/array.h> + + +// ---------------------------------------------------------------------------------------- + +class numpy_gray_image +{ +public: + + numpy_gray_image() : _data(0), _nr(0), _nc(0) {} + numpy_gray_image (py::object& img) + { + long shape[2]; + get_numpy_ndarray_parts(img, _data, _contig_buf, shape); + _nr = shape[0]; + _nc = shape[1]; + } + + friend inline long num_rows(const numpy_gray_image& img) { return img._nr; } + friend inline long num_columns(const numpy_gray_image& img) { return img._nc; } + friend inline void* image_data(numpy_gray_image& img) { return img._data; } + friend inline const void* image_data(const numpy_gray_image& img) { return img._data; } + friend inline long width_step(const numpy_gray_image& img) { return img._nc*sizeof(unsigned char); } + +private: + + unsigned char* _data; + dlib::array<unsigned char> _contig_buf; + long _nr; + long _nc; +}; + +namespace dlib +{ + template <> + struct image_traits<numpy_gray_image > + { + typedef unsigned char pixel_type; + }; +} + +// ---------------------------------------------------------------------------------------- + +inline bool is_gray_python_image (py::object& img) +{ + try + { + long shape[2]; + get_numpy_ndarray_shape(img, shape); + return true; + } + catch (dlib::error&) + { + return false; + } +} + +// ---------------------------------------------------------------------------------------- + +class numpy_rgb_image +{ +public: + + numpy_rgb_image() : _data(0), _nr(0), _nc(0) {} + numpy_rgb_image (py::object& img) + { + long shape[3]; + get_numpy_ndarray_parts(img, _data, _contig_buf, shape); + _nr = shape[0]; + _nc = shape[1]; + if (shape[2] != 3) + throw dlib::error("Error, python object is not a three band image and therefore can't be a RGB image."); + } + + friend inline long num_rows(const numpy_rgb_image& img) { return img._nr; } + friend inline long num_columns(const numpy_rgb_image& img) { return img._nc; } + friend inline void* image_data(numpy_rgb_image& img) { return img._data; } + friend inline const void* image_data(const numpy_rgb_image& img) { return img._data; } + friend inline long width_step(const numpy_rgb_image& img) { return img._nc*sizeof(dlib::rgb_pixel); } + + +private: + + dlib::rgb_pixel* _data; + dlib::array<dlib::rgb_pixel> _contig_buf; + long _nr; + long _nc; +}; + +namespace dlib +{ + template <> + struct image_traits<numpy_rgb_image > + { + typedef rgb_pixel pixel_type; + }; +} + +// ---------------------------------------------------------------------------------------- + + +inline bool is_rgb_python_image (py::object& img) +{ + try + { + long shape[3]; + get_numpy_ndarray_shape(img, shape); + if (shape[2] == 3) + return true; + return false; + } + catch (dlib::error&) + { + return false; + } +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_PYTHON_NuMPY_IMAGE_Hh_ + diff --git a/ml/dlib/dlib/python/pyassert.h b/ml/dlib/dlib/python/pyassert.h new file mode 100644 index 000000000..80939f501 --- /dev/null +++ b/ml/dlib/dlib/python/pyassert.h @@ -0,0 +1,17 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PYaSSERT_Hh_ +#define DLIB_PYaSSERT_Hh_ + +#include <pybind11/pybind11.h> + +#define pyassert(_exp,_message) \ + {if ( !(_exp) ) \ + { \ + namespace py = pybind11; \ + PyErr_SetString( PyExc_ValueError, _message ); \ + throw py::error_already_set(); \ + }} + +#endif // DLIB_PYaSSERT_Hh_ + diff --git a/ml/dlib/dlib/python/pybind_utils.h b/ml/dlib/dlib/python/pybind_utils.h new file mode 100644 index 000000000..7f94cf32d --- /dev/null +++ b/ml/dlib/dlib/python/pybind_utils.h @@ -0,0 +1,82 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PYBIND_UtILS_Hh_ +#define DLIB_PYBIND_UtILS_Hh_ + +#include <pybind11/pybind11.h> +#include <vector> +#include <string> +#include <dlib/serialize.h> + +namespace py = pybind11; + +template <typename T> +std::vector<T> python_list_to_vector ( + const py::list& obj +) +/*! + ensures + - converts a python object into a std::vector<T> and returns it. +!*/ +{ + std::vector<T> vect(len(obj)); + for (unsigned long i = 0; i < vect.size(); ++i) + { + vect[i] = obj[i].cast<T>(); + } + return vect; +} + +template <typename T> +py::list vector_to_python_list ( + const std::vector<T>& vect +) +/*! + ensures + - converts a std::vector<T> into a python list object. +!*/ +{ + py::list obj; + for (unsigned long i = 0; i < vect.size(); ++i) + obj.append(vect[i]); + return obj; +} + +template <typename T> +void extend_vector_with_python_list ( + std::vector<T> &v, + const py::list &l +) +/*! + ensures + - appends items from a python list to the end of std::vector<T>. +!*/ +{ + for (const auto &item : l) + v.push_back(item.cast<T>()); +} + +// ---------------------------------------------------------------------------------------- + +template <typename T> +std::shared_ptr<T> load_object_from_file ( + const std::string& filename +) +/*! + ensures + - deserializes an object of type T from the given file and returns it. +!*/ +{ + std::ifstream fin(filename.c_str(), std::ios::binary); + if (!fin) + throw dlib::error("Unable to open " + filename); + auto obj = std::make_shared<T>(); + deserialize(*obj, fin); + return obj; +} + +// ---------------------------------------------------------------------------------------- + + +#endif // DLIB_PYBIND_UtILS_Hh_ + diff --git a/ml/dlib/dlib/python/serialize_pickle.h b/ml/dlib/dlib/python/serialize_pickle.h new file mode 100644 index 000000000..2dc44c322 --- /dev/null +++ b/ml/dlib/dlib/python/serialize_pickle.h @@ -0,0 +1,66 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SERIALIZE_PiCKLE_Hh_ +#define DLIB_SERIALIZE_PiCKLE_Hh_ + +#include <dlib/serialize.h> +#include <pybind11/pybind11.h> +#include <sstream> +#include <dlib/vectorstream.h> + +template<typename T> +py::tuple getstate(const T& item) +{ + using namespace dlib; + std::vector<char> buf; + buf.reserve(5000); + vectorstream sout(buf); + serialize(item, sout); + return py::make_tuple(py::handle( + PyBytes_FromStringAndSize(buf.size()?&buf[0]:0, buf.size()))); +} + +template<typename T> +T setstate(py::tuple state) +{ + using namespace dlib; + if (len(state) != 1) + { + PyErr_SetObject(PyExc_ValueError, + py::str("expected 1-item tuple in call to __setstate__; got {}").format(state).ptr() + ); + throw py::error_already_set(); + } + + // We used to serialize by converting to a str but the boost.python routines for + // doing this don't work in Python 3. You end up getting an error about invalid + // UTF-8 encodings. So instead we access the python C interface directly and use + // bytes objects. However, we keep the deserialization code that worked with str + // for backwards compatibility with previously pickled files. + T item; + py::object obj = state[0]; + if (py::isinstance<py::str>(obj)) + { + py::str data = state[0].cast<py::str>(); + std::string temp = data; + std::istringstream sin(temp); + deserialize(item, sin); + } + else if(PyBytes_Check(py::object(state[0]).ptr())) + { + py::object obj = state[0]; + char* data = PyBytes_AsString(obj.ptr()); + unsigned long num = PyBytes_Size(obj.ptr()); + std::istringstream sin(std::string(data, num)); + deserialize(item, sin); + } + else + { + throw error("Unable to unpickle, error in input file."); + } + + return item; +} + +#endif // DLIB_SERIALIZE_PiCKLE_Hh_ + |