summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/python/serialize_pickle.h
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/python/serialize_pickle.h')
-rw-r--r--ml/dlib/dlib/python/serialize_pickle.h66
1 files changed, 66 insertions, 0 deletions
diff --git a/ml/dlib/dlib/python/serialize_pickle.h b/ml/dlib/dlib/python/serialize_pickle.h
new file mode 100644
index 000000000..2dc44c322
--- /dev/null
+++ b/ml/dlib/dlib/python/serialize_pickle.h
@@ -0,0 +1,66 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SERIALIZE_PiCKLE_Hh_
+#define DLIB_SERIALIZE_PiCKLE_Hh_
+
+#include <dlib/serialize.h>
+#include <pybind11/pybind11.h>
+#include <sstream>
+#include <dlib/vectorstream.h>
+
+template<typename T>
+py::tuple getstate(const T& item)
+{
+ using namespace dlib;
+ std::vector<char> buf;
+ buf.reserve(5000);
+ vectorstream sout(buf);
+ serialize(item, sout);
+ return py::make_tuple(py::handle(
+ PyBytes_FromStringAndSize(buf.size()?&buf[0]:0, buf.size())));
+}
+
+template<typename T>
+T setstate(py::tuple state)
+{
+ using namespace dlib;
+ if (len(state) != 1)
+ {
+ PyErr_SetObject(PyExc_ValueError,
+ py::str("expected 1-item tuple in call to __setstate__; got {}").format(state).ptr()
+ );
+ throw py::error_already_set();
+ }
+
+ // We used to serialize by converting to a str but the boost.python routines for
+ // doing this don't work in Python 3. You end up getting an error about invalid
+ // UTF-8 encodings. So instead we access the python C interface directly and use
+ // bytes objects. However, we keep the deserialization code that worked with str
+ // for backwards compatibility with previously pickled files.
+ T item;
+ py::object obj = state[0];
+ if (py::isinstance<py::str>(obj))
+ {
+ py::str data = state[0].cast<py::str>();
+ std::string temp = data;
+ std::istringstream sin(temp);
+ deserialize(item, sin);
+ }
+ else if(PyBytes_Check(py::object(state[0]).ptr()))
+ {
+ py::object obj = state[0];
+ char* data = PyBytes_AsString(obj.ptr());
+ unsigned long num = PyBytes_Size(obj.ptr());
+ std::istringstream sin(std::string(data, num));
+ deserialize(item, sin);
+ }
+ else
+ {
+ throw error("Unable to unpickle, error in input file.");
+ }
+
+ return item;
+}
+
+#endif // DLIB_SERIALIZE_PiCKLE_Hh_
+