summaryrefslogtreecommitdiffstats
path: root/lib/base/netstring.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/base/netstring.cpp')
-rw-r--r--lib/base/netstring.cpp334
1 files changed, 334 insertions, 0 deletions
diff --git a/lib/base/netstring.cpp b/lib/base/netstring.cpp
new file mode 100644
index 0000000..60f08c2
--- /dev/null
+++ b/lib/base/netstring.cpp
@@ -0,0 +1,334 @@
+/* Icinga 2 | (c) 2012 Icinga GmbH | GPLv2+ */
+
+#include "base/netstring.hpp"
+#include "base/debug.hpp"
+#include "base/tlsstream.hpp"
+#include <cstdint>
+#include <memory>
+#include <sstream>
+#include <utility>
+#include <boost/asio/buffer.hpp>
+#include <boost/asio/read.hpp>
+#include <boost/asio/spawn.hpp>
+#include <boost/asio/write.hpp>
+
+using namespace icinga;
+
+/**
+ * Reads data from a stream in netstring format.
+ *
+ * @param stream The stream to read from.
+ * @param[out] str The String that has been read from the IOQueue.
+ * @returns true if a complete String was read from the IOQueue, false otherwise.
+ * @exception invalid_argument The input stream is invalid.
+ * @see https://github.com/PeterScott/netstring-c/blob/master/netstring.c
+ */
+StreamReadStatus NetString::ReadStringFromStream(const Stream::Ptr& stream, String *str, StreamReadContext& context,
+ bool may_wait, ssize_t maxMessageLength)
+{
+ if (context.Eof)
+ return StatusEof;
+
+ if (context.MustRead) {
+ if (!context.FillFromStream(stream, may_wait)) {
+ context.Eof = true;
+ return StatusEof;
+ }
+
+ context.MustRead = false;
+ }
+
+ size_t header_length = 0;
+
+ for (size_t i = 0; i < context.Size; i++) {
+ if (context.Buffer[i] == ':') {
+ header_length = i;
+
+ /* make sure there's a header */
+ if (header_length == 0)
+ BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (no length specifier)"));
+
+ break;
+ } else if (i > 16)
+ BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing :)"));
+ }
+
+ if (header_length == 0) {
+ context.MustRead = true;
+ return StatusNeedData;
+ }
+
+ /* no leading zeros allowed */
+ if (context.Buffer[0] == '0' && isdigit(context.Buffer[1]))
+ BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (leading zero)"));
+
+ size_t len, i;
+
+ len = 0;
+ for (i = 0; i < header_length && isdigit(context.Buffer[i]); i++) {
+ /* length specifier must have at most 9 characters */
+ if (i >= 9)
+ BOOST_THROW_EXCEPTION(std::invalid_argument("Length specifier must not exceed 9 characters"));
+
+ len = len * 10 + (context.Buffer[i] - '0');
+ }
+
+ /* read the whole message */
+ size_t data_length = len + 1;
+
+ if (maxMessageLength >= 0 && data_length > (size_t)maxMessageLength) {
+ std::stringstream errorMessage;
+ errorMessage << "Max data length exceeded: " << (maxMessageLength / 1024) << " KB";
+
+ BOOST_THROW_EXCEPTION(std::invalid_argument(errorMessage.str()));
+ }
+
+ char *data = context.Buffer + header_length + 1;
+
+ if (context.Size < header_length + 1 + data_length) {
+ context.MustRead = true;
+ return StatusNeedData;
+ }
+
+ if (data[len] != ',')
+ BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing ,)"));
+
+ *str = String(&data[0], &data[len]);
+
+ context.DropData(header_length + 1 + len + 1);
+
+ return StatusNewItem;
+}
+
+/**
+ * Writes data into a stream using the netstring format and returns bytes written.
+ *
+ * @param stream The stream.
+ * @param str The String that is to be written.
+ *
+ * @return The amount of bytes written.
+ */
+size_t NetString::WriteStringToStream(const Stream::Ptr& stream, const String& str)
+{
+ std::ostringstream msgbuf;
+ WriteStringToStream(msgbuf, str);
+
+ String msg = msgbuf.str();
+ stream->Write(msg.CStr(), msg.GetLength());
+ return msg.GetLength();
+}
+
+/**
+ * Reads data from a stream in netstring format.
+ *
+ * @param stream The stream to read from.
+ * @returns The String that has been read from the IOQueue.
+ * @exception invalid_argument The input stream is invalid.
+ * @see https://github.com/PeterScott/netstring-c/blob/master/netstring.c
+ */
+String NetString::ReadStringFromStream(const Shared<AsioTlsStream>::Ptr& stream,
+ ssize_t maxMessageLength)
+{
+ namespace asio = boost::asio;
+
+ size_t len = 0;
+ bool leadingZero = false;
+
+ for (uint_fast8_t readBytes = 0;; ++readBytes) {
+ char byte = 0;
+
+ {
+ asio::mutable_buffer byteBuf (&byte, 1);
+ asio::read(*stream, byteBuf);
+ }
+
+ if (isdigit(byte)) {
+ if (readBytes == 9) {
+ BOOST_THROW_EXCEPTION(std::invalid_argument("Length specifier must not exceed 9 characters"));
+ }
+
+ if (leadingZero) {
+ BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (leading zero)"));
+ }
+
+ len = len * 10u + size_t(byte - '0');
+
+ if (!readBytes && byte == '0') {
+ leadingZero = true;
+ }
+ } else if (byte == ':') {
+ if (!readBytes) {
+ BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (no length specifier)"));
+ }
+
+ break;
+ } else {
+ BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing :)"));
+ }
+ }
+
+ if (maxMessageLength >= 0 && len > maxMessageLength) {
+ std::stringstream errorMessage;
+ errorMessage << "Max data length exceeded: " << (maxMessageLength / 1024) << " KB";
+
+ BOOST_THROW_EXCEPTION(std::invalid_argument(errorMessage.str()));
+ }
+
+ String payload;
+
+ if (len) {
+ payload.Append(len, 0);
+
+ asio::mutable_buffer payloadBuf (&*payload.Begin(), payload.GetLength());
+ asio::read(*stream, payloadBuf);
+ }
+
+ char trailer = 0;
+
+ {
+ asio::mutable_buffer trailerBuf (&trailer, 1);
+ asio::read(*stream, trailerBuf);
+ }
+
+ if (trailer != ',') {
+ BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing ,)"));
+ }
+
+ return payload;
+}
+
+/**
+ * Reads data from a stream in netstring format.
+ *
+ * @param stream The stream to read from.
+ * @returns The String that has been read from the IOQueue.
+ * @exception invalid_argument The input stream is invalid.
+ * @see https://github.com/PeterScott/netstring-c/blob/master/netstring.c
+ */
+String NetString::ReadStringFromStream(const Shared<AsioTlsStream>::Ptr& stream,
+ boost::asio::yield_context yc, ssize_t maxMessageLength)
+{
+ namespace asio = boost::asio;
+
+ size_t len = 0;
+ bool leadingZero = false;
+
+ for (uint_fast8_t readBytes = 0;; ++readBytes) {
+ char byte = 0;
+
+ {
+ asio::mutable_buffer byteBuf (&byte, 1);
+ asio::async_read(*stream, byteBuf, yc);
+ }
+
+ if (isdigit(byte)) {
+ if (readBytes == 9) {
+ BOOST_THROW_EXCEPTION(std::invalid_argument("Length specifier must not exceed 9 characters"));
+ }
+
+ if (leadingZero) {
+ BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (leading zero)"));
+ }
+
+ len = len * 10u + size_t(byte - '0');
+
+ if (!readBytes && byte == '0') {
+ leadingZero = true;
+ }
+ } else if (byte == ':') {
+ if (!readBytes) {
+ BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (no length specifier)"));
+ }
+
+ break;
+ } else {
+ BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing :)"));
+ }
+ }
+
+ if (maxMessageLength >= 0 && len > maxMessageLength) {
+ std::stringstream errorMessage;
+ errorMessage << "Max data length exceeded: " << (maxMessageLength / 1024) << " KB";
+
+ BOOST_THROW_EXCEPTION(std::invalid_argument(errorMessage.str()));
+ }
+
+ String payload;
+
+ if (len) {
+ payload.Append(len, 0);
+
+ asio::mutable_buffer payloadBuf (&*payload.Begin(), payload.GetLength());
+ asio::async_read(*stream, payloadBuf, yc);
+ }
+
+ char trailer = 0;
+
+ {
+ asio::mutable_buffer trailerBuf (&trailer, 1);
+ asio::async_read(*stream, trailerBuf, yc);
+ }
+
+ if (trailer != ',') {
+ BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing ,)"));
+ }
+
+ return payload;
+}
+
+/**
+ * Writes data into a stream using the netstring format and returns bytes written.
+ *
+ * @param stream The stream.
+ * @param str The String that is to be written.
+ *
+ * @return The amount of bytes written.
+ */
+size_t NetString::WriteStringToStream(const Shared<AsioTlsStream>::Ptr& stream, const String& str)
+{
+ namespace asio = boost::asio;
+
+ std::ostringstream msgbuf;
+ WriteStringToStream(msgbuf, str);
+
+ String msg = msgbuf.str();
+ asio::const_buffer msgBuf (msg.CStr(), msg.GetLength());
+
+ asio::write(*stream, msgBuf);
+
+ return msg.GetLength();
+}
+
+/**
+ * Writes data into a stream using the netstring format and returns bytes written.
+ *
+ * @param stream The stream.
+ * @param str The String that is to be written.
+ *
+ * @return The amount of bytes written.
+ */
+size_t NetString::WriteStringToStream(const Shared<AsioTlsStream>::Ptr& stream, const String& str, boost::asio::yield_context yc)
+{
+ namespace asio = boost::asio;
+
+ std::ostringstream msgbuf;
+ WriteStringToStream(msgbuf, str);
+
+ String msg = msgbuf.str();
+ asio::const_buffer msgBuf (msg.CStr(), msg.GetLength());
+
+ asio::async_write(*stream, msgBuf, yc);
+
+ return msg.GetLength();
+}
+
+/**
+ * Writes data into a stream using the netstring format.
+ *
+ * @param stream The stream.
+ * @param str The String that is to be written.
+ */
+void NetString::WriteStringToStream(std::ostream& stream, const String& str)
+{
+ stream << str.GetLength() << ":" << str << ",";
+}