summaryrefslogtreecommitdiffstats
path: root/src/lib/asiolink/tests/tls_socket_unittest.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/asiolink/tests/tls_socket_unittest.cc')
-rw-r--r--src/lib/asiolink/tests/tls_socket_unittest.cc560
1 files changed, 560 insertions, 0 deletions
diff --git a/src/lib/asiolink/tests/tls_socket_unittest.cc b/src/lib/asiolink/tests/tls_socket_unittest.cc
new file mode 100644
index 0000000..c2572c0
--- /dev/null
+++ b/src/lib/asiolink/tests/tls_socket_unittest.cc
@@ -0,0 +1,560 @@
+// Copyright (C) 2011-2021 Internet Systems Consortium, Inc. ("ISC")
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+/// @brief Test of TLSSocket
+///
+/// Tests the functionality of a TLSSocket by working through an open-send-
+/// receive-close sequence and checking that the asynchronous notifications
+/// work.
+
+#include <config.h>
+#include <asiolink/asio_wrapper.h>
+#include <asiolink/io_service.h>
+#include <asiolink/tcp_endpoint.h>
+#include <asiolink/tls_socket.h>
+#include <asiolink/testutils/test_tls.h>
+#include <util/buffer.h>
+#include <util/io_utilities.h>
+
+#include <boost/shared_ptr.hpp>
+#include <gtest/gtest.h>
+
+#include <algorithm>
+#include <arpa/inet.h>
+#include <cstddef>
+#include <cstdlib>
+#include <errno.h>
+#include <netinet/in.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <string>
+#include <vector>
+
+using namespace boost::asio;
+using namespace boost::asio::ip;
+using namespace isc::util;
+using namespace isc::asiolink;
+using namespace std;
+
+namespace {
+
+const char SERVER_ADDRESS[] = "127.0.0.1";
+const unsigned short SERVER_PORT = 5303;
+
+/// @todo Shouldn't we send something that is real message?
+const char OUTBOUND_DATA[] = "Data sent from client to server";
+const char INBOUND_DATA[] = "Returned data from server to client";
+}
+
+/// An instance of this object is passed to the asynchronous I/O functions
+/// and the operator() method is called when when an asynchronous I/O completes.
+/// The arguments to the completion callback are stored for later retrieval.
+class TLSCallback {
+public:
+ /// @brief Operations the server is doing
+ enum Operation {
+ ACCEPT = 0, ///< accept() was issued
+ OPEN = 1, ///< Client connected to server
+ HANDSHAKE = 2, ///< TLS handshake completed
+ READ = 3, ///< Asynchronous read completed
+ WRITE = 4, ///< Asynchronous write completed
+ NONE = 5 ///< "Not set" state
+ };
+
+ /// @brief Minimum size of buffers
+ enum {
+ MIN_SIZE = (64 * 1024 + 2) ///< 64kB + two bytes for a count
+ };
+
+ struct PrivateData {
+ PrivateData() :
+ error_code_(), length_(0), cumulative_(0), expected_(0), offset_(0),
+ name_(""), queued_(NONE), called_(NONE), data_(MIN_SIZE, 0)
+ {}
+
+ boost::system::error_code error_code_; ///< Completion error code
+ size_t length_; ///< Bytes transferred in this I/O
+ size_t cumulative_; ///< Cumulative bytes transferred
+ size_t expected_; ///< Expected amount of data
+ size_t offset_; ///< Where to put data in buffer
+ std::string name_; ///< Which of the objects this is
+ Operation queued_; ///< Queued operation
+ Operation called_; ///< Which callback called
+ std::vector<uint8_t> data_; ///< Receive buffer
+ };
+
+ /// @brief Constructor
+ ///
+ /// Constructs the object. It also creates the data member pointed to by
+ /// a shared pointer. When used as a callback object, this is copied as it
+ /// is passed into the asynchronous function. This means that there are two
+ /// objects and inspecting the one we passed in does not tell us anything.
+ ///
+ /// Therefore we use a boost::shared_ptr. When the object is copied, the
+ /// shared pointer is copied, which leaves both objects pointing to the same
+ /// data.
+ ///
+ /// @param which Which of the two callback objects this is
+ TLSCallback(std::string which) : ptr_(new PrivateData())
+ {
+ ptr_->name_ = which;
+ }
+
+ /// @brief Destructor
+ ///
+ /// No code needed, destroying the shared pointer destroys the private data.
+ virtual ~TLSCallback()
+ {}
+
+ /// @brief Client Callback Function
+ ///
+ /// Called when an asynchronous operation is completed by the client, this
+ /// stores the origin of the operation in the client_called_ data member.
+ ///
+ /// @param ec I/O completion error code passed to callback function.
+ /// @param length Number of bytes transferred
+ void operator()(boost::system::error_code ec = boost::system::error_code(),
+ size_t length = 0)
+ {
+ setCode(ec.value());
+ ptr_->called_ = ptr_->queued_;
+ ptr_->length_ = length;
+ }
+
+ /// @brief Get I/O completion error code
+ int getCode() {
+ return (ptr_->error_code_.value());
+ }
+
+ /// @brief Set I/O completion code
+ ///
+ /// @param code New value of completion code
+ void setCode(int code) {
+ ptr_->error_code_ = boost::system::error_code(code, boost::system::error_code().category());
+ }
+
+ /// @brief Get number of bytes transferred in I/O
+ size_t& length() {
+ return (ptr_->length_);
+ }
+
+ /// @brief Get cumulative number of bytes transferred in I/O
+ size_t& cumulative() {
+ return (ptr_->cumulative_);
+ }
+
+ /// @brief Get expected amount of data
+ size_t& expected() {
+ return (ptr_->expected_);
+ }
+
+ /// @brief Get offset into data
+ size_t& offset() {
+ return (ptr_->offset_);
+ }
+
+ /// @brief Get data member
+ uint8_t* data() {
+ return (&ptr_->data_[0]);
+ }
+
+ /// @brief Get flag to say what was queued
+ Operation& queued() {
+ return (ptr_->queued_);
+ }
+
+ /// @brief Get flag to say when callback was called
+ Operation& called() {
+ return (ptr_->called_);
+ }
+
+ /// @brief Return instance of callback name
+ std::string& name() {
+ return (ptr_->name_);
+ }
+
+private:
+ boost::shared_ptr<PrivateData> ptr_; ///< Pointer to private data
+};
+
+
+// Read Server Data
+//
+// Called in the part of the test that has the client send a message to the
+// server, this loops until all the data has been read (synchronously) by the
+// server.
+//
+// "All the data read" means that the server has received a message that is
+// preceded by a two-byte count field and that the total amount of data received
+// from the remote end is equal to the value in the count field plus two bytes
+// for the count field itself.
+//
+// @param stream Stream on which the server is reading data
+// @param server_cb Structure in which server data is held.
+void
+serverRead(TlsStreamImpl& stream, TLSCallback& server_cb) {
+
+ // As we may need to read multiple times, keep a count of the cumulative
+ // amount of data read and do successive reads into the appropriate part
+ // of the buffer.
+ //
+ // Note that there are no checks for buffer overflow - this is a test
+ // program and we have sized the buffer to be large enough for the test.
+ server_cb.cumulative() = 0;
+
+ bool complete = false;
+ while (!complete) {
+
+ // Read block of data and update cumulative amount of data received.
+ server_cb.length() = stream.read_some(
+ boost::asio::buffer(server_cb.data() + server_cb.cumulative(),
+ TLSCallback::MIN_SIZE - server_cb.cumulative()));
+ server_cb.cumulative() += server_cb.length();
+
+ // If we have read at least two bytes, we can work out how much we
+ // should be reading.
+ if (server_cb.cumulative() >= 2) {
+ server_cb.expected() = readUint16(server_cb.data(), server_cb.length());
+ if ((server_cb.expected() + 2) == server_cb.cumulative()) {
+
+ // Amount of data read from stream equals the size of the
+ // message (as indicated in the first two bytes of the message)
+ // plus the size of the count field. Therefore we have received
+ // all the data.
+ complete = true;
+ }
+ }
+ }
+}
+
+// Receive complete method should return true only if the count in the first
+// two bytes is equal to the size of the rest if the buffer.
+
+TEST(TLSSocket, processReceivedData) {
+ // Amount of "real" data in the buffer
+ const uint16_t PACKET_SIZE = 16382;
+
+ // Used to instantiate socket
+ IOService service;
+ TlsContextPtr context(new TlsContext(CLIENT));
+ // Socket under test
+ TLSSocket<TLSCallback> test(service, context);
+ // Buffer to check
+ uint8_t inbuff[PACKET_SIZE + 2];
+ // Where data is put
+ OutputBufferPtr outbuff(new OutputBuffer(16));
+ // Expected amount of data
+ size_t expected;
+ // Where to put next data
+ size_t offset;
+ // Cumulative data received
+ size_t cumulative;
+
+ // Set some dummy values in the buffer to check
+ for (size_t i = 0; i < sizeof(inbuff); ++i) {
+ inbuff[i] = i % 256;
+ }
+
+ // Check that the method will handle various receive sizes.
+ writeUint16(PACKET_SIZE, inbuff, sizeof(inbuff));
+
+ cumulative = 0;
+ offset = 0;
+ expected = 0;
+ outbuff->clear();
+ bool complete = test.processReceivedData(inbuff, 1, cumulative, offset,
+ expected, outbuff);
+ EXPECT_FALSE(complete);
+ EXPECT_EQ(1, cumulative);
+ EXPECT_EQ(1, offset);
+ EXPECT_EQ(0, expected);
+ EXPECT_EQ(0, outbuff->getLength());
+
+ // Now pretend that we've received one more byte.
+ complete = test.processReceivedData(inbuff, 1, cumulative, offset, expected,
+ outbuff);
+ EXPECT_FALSE(complete);
+ EXPECT_EQ(2, cumulative);
+ EXPECT_EQ(0, offset);
+ EXPECT_EQ(PACKET_SIZE, expected);
+ EXPECT_EQ(0, outbuff->getLength());
+
+ // Add another two bytes. However, this time note that we have to offset
+ // in the input buffer because it is expected that the next chunk of data
+ // from the connection will be read into the start of the buffer.
+ complete = test.processReceivedData(inbuff + cumulative, 2, cumulative,
+ offset, expected, outbuff);
+ EXPECT_FALSE(complete);
+ EXPECT_EQ(4, cumulative);
+ EXPECT_EQ(0, offset);
+ EXPECT_EQ(PACKET_SIZE, expected);
+ EXPECT_EQ(2, outbuff->getLength());
+
+ const uint8_t* dataptr = static_cast<const uint8_t*>(outbuff->getData());
+ EXPECT_TRUE(equal(inbuff + 2, inbuff + cumulative, dataptr));
+
+ // And add the remaining data. Remember that "inbuff" is "PACKET_SIZE + 2"
+ // long.
+ complete = test.processReceivedData(inbuff + cumulative,
+ PACKET_SIZE + 2 - cumulative,
+ cumulative, offset, expected, outbuff);
+ EXPECT_TRUE(complete);
+ EXPECT_EQ(PACKET_SIZE + 2, cumulative);
+ EXPECT_EQ(0, offset);
+ EXPECT_EQ(PACKET_SIZE, expected);
+ EXPECT_EQ(PACKET_SIZE, outbuff->getLength());
+ dataptr = static_cast<const uint8_t*>(outbuff->getData());
+ EXPECT_TRUE(equal(inbuff + 2, inbuff + cumulative, dataptr));
+}
+
+/// @todo Need to add a test to check the cancel() method
+
+// Tests the operation of a TLSSocket by opening it, sending an asynchronous
+// message to a server, receiving an asynchronous message from the server and
+// closing.
+TEST(TLSSocket, sequenceTest) {
+
+ // Common objects.
+ // Service object for async control
+ IOService service;
+
+ // The client - the TLSSocket being tested
+ TlsContextPtr client_ctx;
+ test::configClient(client_ctx);
+ // Socket under test
+ TLSSocket<TLSCallback> client(service, client_ctx);
+ // Async I/O callback function
+ TLSCallback client_cb("Client");
+ // Where client receives message from
+ TCPEndpoint client_remote_endpoint;
+ // Received data is put here
+ OutputBufferPtr client_buffer(new OutputBuffer(128));
+ // The server - with which the client communicates.
+ // Address of target server
+ IOAddress server_address(SERVER_ADDRESS);
+ // Server callback
+ TLSCallback server_cb("Server");
+ // Endpoint describing server
+ TCPEndpoint server_endpoint(server_address, SERVER_PORT);
+ // Address where server received message from
+ TCPEndpoint server_remote_endpoint;
+ TlsContextPtr server_ctx;
+ test::configServer(server_ctx);
+ // Stream used for server.
+ TlsStreamImpl server(service.get_io_service(), server_ctx->getContext());
+
+ // Step 1. Create the connection between the client and the server. Set
+ // up the server to accept incoming connections and have the client open
+ // a channel to it.
+
+ // Set up server - open socket and queue an accept.
+ server_cb.queued() = TLSCallback::ACCEPT;
+ server_cb.called() = TLSCallback::NONE;
+ server_cb.setCode(42); // Some error
+ tcp::acceptor acceptor(service.get_io_service(),
+ tcp::endpoint(tcp::v4(), SERVER_PORT));
+ acceptor.set_option(tcp::acceptor::reuse_address(true));
+ acceptor.async_accept(server.lowest_layer(), server_cb);
+
+ // Set up client - connect to the server.
+ client_cb.queued() = TLSCallback::OPEN;
+ client_cb.called() = TLSCallback::NONE;
+ client_cb.setCode(43); // Some error
+ EXPECT_FALSE(client.isOpenSynchronous());
+ client.open(&server_endpoint, client_cb);
+
+ // Run the open and the accept callback and check that they ran.
+ while ((server_cb.called() == TLSCallback::NONE) ||
+ (client_cb.called() == TLSCallback::NONE)) {
+ service.run_one();
+ }
+ EXPECT_EQ(TLSCallback::ACCEPT, server_cb.called());
+ EXPECT_EQ(0, server_cb.getCode());
+
+ EXPECT_EQ(TLSCallback::OPEN, client_cb.called());
+
+ // On some operating system the async_connect may return EINPROGRESS.
+ // This doesn't necessarily indicate an error. In most cases trying
+ // to asynchronously write and read from the socket would work just
+ // fine.
+ if ((client_cb.getCode()) != 0 && (client_cb.getCode() != EINPROGRESS)) {
+ ADD_FAILURE() << "expected error code of 0 or " << EINPROGRESS
+ << " as a result of async_connect, got " << client_cb.getCode();
+ }
+
+ // Perform handshake.
+ client_cb.queued() = TLSCallback::HANDSHAKE;
+ client_cb.called() = TLSCallback::NONE;
+ client_cb.setCode(43); // Some error
+ client.handshake(client_cb);
+
+ server_cb.queued() = TLSCallback::HANDSHAKE;
+ server_cb.called() = TLSCallback::NONE;
+ server_cb.setCode(42); // Some error
+ server.async_handshake(roleToImpl(TlsRole::SERVER), server_cb);
+
+ while ((server_cb.called() == TLSCallback::NONE) ||
+ (client_cb.called() == TLSCallback::NONE)) {
+ service.run_one();
+ }
+ EXPECT_EQ(TLSCallback::HANDSHAKE, client_cb.called());
+ EXPECT_EQ(0, client_cb.getCode());
+
+ EXPECT_EQ(TLSCallback::HANDSHAKE, server_cb.called());
+ EXPECT_EQ(0, server_cb.getCode());
+
+ // Step 2. Get the client to write to the server asynchronously. The
+ // server will loop reading the data synchronously.
+
+ // Write asynchronously to the server.
+ client_cb.called() = TLSCallback::NONE;
+ client_cb.queued() = TLSCallback::WRITE;
+ client_cb.setCode(143); // Arbitrary number
+ client_cb.length() = 0;
+ client.asyncSend(OUTBOUND_DATA, sizeof(OUTBOUND_DATA), &server_endpoint, client_cb);
+
+ // Wait for the client callback to complete. (Must do this first on
+ // Solaris: if we do the synchronous read first, the test hangs.)
+ while (client_cb.called() == TLSCallback::NONE) {
+ service.run_one();
+ }
+
+ // Synchronously read the data from the server.;
+ serverRead(server, server_cb);
+
+ // Check the client state
+ EXPECT_EQ(TLSCallback::WRITE, client_cb.called());
+ EXPECT_EQ(0, client_cb.getCode());
+ EXPECT_EQ(sizeof(OUTBOUND_DATA) + 2, client_cb.length());
+
+ // ... and check what the server received.
+ EXPECT_EQ(sizeof(OUTBOUND_DATA) + 2, server_cb.cumulative());
+ EXPECT_TRUE(equal(OUTBOUND_DATA,
+ (OUTBOUND_DATA + (sizeof(OUTBOUND_DATA) - 1)),
+ (server_cb.data() + 2)));
+
+ // Step 3. Get the server to write all the data asynchronously and have the
+ // client loop (asynchronously) reading the data. Note that we copy the
+ // data into the server's internal buffer in order to precede it with a two-
+ // byte count field.
+
+ // Have the server write asynchronously to the client.
+ server_cb.called() = TLSCallback::NONE;
+ server_cb.queued() = TLSCallback::WRITE;
+ server_cb.length() = 0;
+ server_cb.cumulative() = 0;
+
+ writeUint16(sizeof(INBOUND_DATA), server_cb.data(), TLSCallback::MIN_SIZE);
+ copy(INBOUND_DATA, (INBOUND_DATA + sizeof(INBOUND_DATA) - 1),
+ (server_cb.data() + 2));
+ boost::asio::async_write(server,
+ boost::asio::buffer(server_cb.data(),
+ (sizeof(INBOUND_DATA) + 2)),
+ server_cb);
+
+ // Have the client read asynchronously.
+ client_cb.called() = TLSCallback::NONE;
+ client_cb.queued() = TLSCallback::READ;
+ client_cb.length() = 0;
+ client_cb.cumulative() = 0;
+ client_cb.expected() = 0;
+ client_cb.offset() = 0;
+
+ client.asyncReceive(client_cb.data(), TLSCallback::MIN_SIZE,
+ client_cb.offset(), &client_remote_endpoint,
+ client_cb);
+
+ // Run the callbacks. Several options are possible depending on how ASIO
+ // is implemented and whether the message gets fragmented:
+ //
+ // 1) The send handler may complete immediately, regardless of whether the
+ // data has been read by the client. (This is the most likely.)
+ // 2) The send handler may only run after all the data has been read by
+ // the client. (This could happen if the client's TCP buffers were too
+ // small so the data was not transferred to the "remote" system until the
+ // remote buffer has been emptied one or more times.)
+ // 3) The client handler may be run a number of times to handle the message
+ // fragments and the server handler may run between calls of the client
+ // handler.
+ //
+ // So loop, running one handler at a time until we are certain that all the
+ // handlers have run.
+
+ bool server_complete = false;
+ bool client_complete = false;
+ while (!server_complete || !client_complete) {
+ service.run_one();
+
+ // Has the server run?
+ if (!server_complete) {
+ if (server_cb.called() != TLSCallback::NONE) {
+
+ // Yes. Check that the send completed successfully and that
+ // all the data that was expected to have been sent was in fact
+ // sent.
+ EXPECT_EQ(TLSCallback::WRITE, server_cb.called());
+ EXPECT_EQ(0, server_cb.getCode());
+ EXPECT_EQ((sizeof(INBOUND_DATA) + 2), server_cb.length());
+ server_complete = true;
+ }
+ }
+
+ // Has the client run?
+ if (!client_complete) {
+
+ if (client_cb.called() == TLSCallback::NONE) {
+ // No. Run the service another time.
+ continue;
+ }
+
+ // Client callback must have run. Check that it ran OK.
+ EXPECT_EQ(TLSCallback::READ, client_cb.called());
+ EXPECT_EQ(0, client_cb.getCode());
+
+ // Check if we need to queue another read, copying the data into
+ // the output buffer as we do so.
+ client_complete = client.processReceivedData(client_cb.data(),
+ client_cb.length(),
+ client_cb.cumulative(),
+ client_cb.offset(),
+ client_cb.expected(),
+ client_buffer);
+
+ // If the data is not complete, queue another read.
+ if (!client_complete) {
+ client_cb.called() = TLSCallback::NONE;
+ client_cb.queued() = TLSCallback::READ;
+ client_cb.length() = 0;
+ client.asyncReceive(client_cb.data(), TLSCallback::MIN_SIZE ,
+ client_cb.offset(), &client_remote_endpoint,
+ client_cb);
+ }
+ }
+ }
+
+ // Both the send and the receive have completed. Check that the received
+ // is what was sent.
+
+ // Check the client state
+ EXPECT_EQ(TLSCallback::READ, client_cb.called());
+ EXPECT_EQ(0, client_cb.getCode());
+ EXPECT_EQ(sizeof(INBOUND_DATA) + 2, client_cb.cumulative());
+ EXPECT_EQ(sizeof(INBOUND_DATA), client_buffer->getLength());
+
+ // ... and check what the server sent.
+ EXPECT_EQ(TLSCallback::WRITE, server_cb.called());
+ EXPECT_EQ(0, server_cb.getCode());
+ EXPECT_EQ(sizeof(INBOUND_DATA) + 2, server_cb.length());
+
+ // ... and that what was sent is what was received.
+ const uint8_t* received = static_cast<const uint8_t*>(client_buffer->getData());
+ EXPECT_TRUE(equal(INBOUND_DATA, (INBOUND_DATA + (sizeof(INBOUND_DATA) - 1)),
+ received));
+
+ // Close client and server.
+ EXPECT_NO_THROW(client.close());
+ EXPECT_NO_THROW(server.lowest_layer().close());
+}