diff options
Diffstat (limited to 'src/lib/tcp/tests/tcp_test_client.h')
-rw-r--r-- | src/lib/tcp/tests/tcp_test_client.h | 437 |
1 files changed, 437 insertions, 0 deletions
diff --git a/src/lib/tcp/tests/tcp_test_client.h b/src/lib/tcp/tests/tcp_test_client.h new file mode 100644 index 0000000..2e1b30b --- /dev/null +++ b/src/lib/tcp/tests/tcp_test_client.h @@ -0,0 +1,437 @@ +// Copyright (C) 2022-2023 Internet Systems Consortium, Inc. ("ISC") +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef TCP_TEST_CLIENT_H +#define TCP_TEST_CLIENT_H + +#include <cc/data.h> +#include <asiolink/tcp_socket.h> +#include <asiolink/tls_socket.h> +#include <asiolink/testutils/test_tls.h> +#include <tcp/tcp_connection.h> +#include <tcp/tcp_stream_msg.h> +#include <boost/asio/read.hpp> +#include <boost/asio/buffer.hpp> +#include <boost/asio/ip/tcp.hpp> +#include <gtest/gtest.h> + +/// @brief Entity which can connect to the TCP server endpoint with or +/// or without TLS. +class TcpTestClient : public boost::noncopyable { + +private: + /// @brief Type of the function implementing a callback invoked by the + /// @c SocketCallback functor. + typedef std::function<void(boost::system::error_code ec, size_t length)> + SocketCallbackFunction; + + /// @brief Functor associated with the socket object. + /// + /// This functor calls a callback function specified in the constructor. + class SocketCallback { + public: + + /// @brief Constructor. + /// + /// @param socket_callback Callback to be invoked by the functor upon + /// an event associated with the socket. + SocketCallback(SocketCallbackFunction socket_callback) + : callback_(socket_callback) { + } + + /// @brief Operator called when event associated with a socket occurs. + /// + /// This operator returns immediately when received error code is + /// @c boost::system::error_code is equal to + /// @c boost::asio::error::operation_aborted, i.e. the callback is not + /// invoked. + /// + /// @param ec Error code. + /// @param length Data length. + void operator()(boost::system::error_code ec, size_t length = 0) { + if (ec.value() == boost::asio::error::operation_aborted) { + return; + } + + callback_(ec, length); + } + + private: + /// @brief Supplied callback. + SocketCallbackFunction callback_; + }; + +public: + + /// @brief Constructor. + /// + /// This constructor creates new socket instance. It doesn't connect. Call + /// start() to connect to the server. + /// + /// @param io_service IO service to be stopped on error or completion. + /// @param done_callback Function cient should invoke when it has finished + /// all its requests or failed. + /// @param tls_context + /// @param server_address string containing the IP address of the server. + /// @param port port number of the server. + explicit TcpTestClient(isc::asiolink::IOService& io_service, + std::function<void()> done_callback, + isc::asiolink::TlsContextPtr tls_context = + isc::asiolink::TlsContextPtr(), + const std::string& server_address = "127.0.0.1", + uint16_t port = 18123) + : io_service_(io_service.get_io_service()), + tls_context_(tls_context), + tcp_socket_(), tls_socket_(), + done_callback_(done_callback), + server_address_(server_address), server_port_(port), + buf_(), response_(), + receive_done_(false), expected_eof_(false), handshake_failed_(false) { + if (!tls_context_) { + tcp_socket_.reset(new isc::asiolink::TCPSocket<SocketCallback>(io_service)); + } else { + tls_socket_.reset(new isc::asiolink::TLSSocket<SocketCallback>(io_service, + tls_context)); + } + } + + bool useTls() { + return (!!tls_context_); + } + + /// @brief Destructor. + /// + /// Closes the underlying socket if it is open. + virtual ~TcpTestClient() { + close(); + } + + /// @brief Connect to the listener and initiate request processing. + /// + /// Upon successful connection, carry out the TLS handshake. If the handshake + /// completes successful start sending requests. + void start() { + isc::asiolink::TCPEndpoint endpoint(boost::asio::ip::address::from_string(server_address_), server_port_); + SocketCallback socket_cb( + [this](boost::system::error_code ec, size_t /*length */) { + receive_done_ = false; + expected_eof_ = false; + handshake_failed_ = false; + if (ec) { + // One would expect that open wouldn't return + // EINPROGRESS error code, but simply wait for the connection + // to get established before the handler is invoked. It turns out, + // however, that on some OSes the connect handler may receive this + // error code which doesn't necessarily indicate a problem. + // Making an attempt to write and read from this socket will + // typically succeed. So, we ignore this error. + if (ec.value() != boost::asio::error::in_progress) { + ADD_FAILURE() << "error occurred while connecting: " + << ec.message(); + done_callback_(); + } + } + + if (useTls()) { + SocketCallback socket_cb( + [this](boost::system::error_code ec, size_t /*length */) { + if (ec) { + handshake_failed_ = true; + done_callback_(); + } else { + sendNextRequest(); + } + }); + + tls_socket_->handshake(socket_cb); + } else { + sendNextRequest(); + } + }); + + if (useTls()) { + tls_socket_->open(&endpoint, socket_cb); + } else { + tcp_socket_->open(&endpoint, socket_cb); + } + } + + /// @brief Send request specified in textual format. + /// + /// @param request request in the textual format. + void startRequest(const std::string& request) { + requests_to_send_.push_back(request); + start(); + } + + /// @brief Send request specified in textual format. + /// + /// @param request request in the textual format. + void startRequests(const std::list<std::string>& requests) { + requests_to_send_ = requests; + start(); + } + + /// @brief Sends the next request from the list of requests to send. + void sendNextRequest() { + // If there are any requests left to send, send them. + if (!requests_to_send_.empty()) { + std::string request = requests_to_send_.front(); + requests_to_send_.pop_front(); + if (request.empty()) { + waitForEof(); + } else { + sendRequest(request); + } + } + } + + /// @brief Send a stream request. + /// + /// @param request request data to send textual format. + /// @param send_length number of bytes to send. If not zero, can be used + /// to truncate the amount of data sent. + void sendRequest(const std::string& request, const size_t send_length = 0) { + // Prepend the length of the request. + uint16_t size = static_cast<uint16_t>(request.size()); + isc::tcp::WireData wire_request; + if (!request.empty()) { + wire_request.push_back(static_cast<uint8_t>((size & 0xff00U) >> 8)); + wire_request.push_back(static_cast<uint8_t>(size & 0x00ffU)); + wire_request.insert(wire_request.end(), request.begin(), request.end()); + } + + sendPartialRequest(wire_request, send_length); + } + + /// @brief Wait for a server to close the connection. + void waitForEof() { + stream_response_.reset(new isc::tcp::TcpStreamRequest()); + receivePartialResponse(true); + } + + /// @brief Send part of the request. + /// + /// @param request part of the request to be sent. + /// @param send_length number of bytes to send. If not zero, can be used + /// to truncate the amount of data sent. + void sendPartialRequest(isc::tcp::WireData& wire_request, size_t send_length = 0) { + if (!send_length) { + send_length = wire_request.size(); + } else { + ASSERT_LE(send_length, wire_request.size()) + << "broken test, send_length exceeds wire size"; + } + + SocketCallback socket_cb( + [this, wire_request](boost::system::error_code ec, size_t bytes_transferred) mutable { + if (ec) { + if (ec.value() == boost::asio::error::operation_aborted) { + return; + + } else if ((ec.value() == boost::asio::error::try_again) || + (ec.value() == boost::asio::error::would_block)) { + // If we should try again make sure there is no garbage in the + // bytes_transferred. + bytes_transferred = 0; + } else { + ADD_FAILURE() << "error occurred while connecting: " + << ec.message(); + done_callback_(); + return; + } + } + + // Remove the part of the request which has been sent. + if (bytes_transferred > 0 && (wire_request.size() <= bytes_transferred)) { + wire_request.erase(wire_request.begin(), + (wire_request.begin() + bytes_transferred)); + } + + // Continue sending request data if there are still some data to be + // sent. + if (!wire_request.empty()) { + sendPartialRequest(wire_request); + } else { + // Request has been sent. Start receiving response. + receivePartialResponse(); + } + }); + + if (useTls()) { + tls_socket_->asyncSend(static_cast<const void *>(wire_request.data()), + send_length, socket_cb); + } else { + tcp_socket_->asyncSend(static_cast<const void *>(wire_request.data()), + send_length, socket_cb); + } + } + + /// @brief Receive response from the server. + void receivePartialResponse(bool expect_eof = false) { + SocketCallback socket_cb( + [this, expect_eof](const boost::system::error_code& ec, + std::size_t bytes_transferred) { + if (!stream_response_) { + stream_response_.reset(new isc::tcp::TcpStreamRequest()); + } + + if (ec) { + // IO service stopped so simply return. + if (ec.value() == boost::asio::error::operation_aborted) { + return; + } else if ((ec.value() == boost::asio::error::try_again) || + (ec.value() == boost::asio::error::would_block)) { + // If we should try again, make sure that there is no garbage + // in the bytes_transferred. + bytes_transferred = 0; + } else if (expect_eof) { + expected_eof_ = true; + done_callback_(); + return; + } else { + // Error occurred, bail... + ADD_FAILURE() << "client: " << this + << " error occurred while receiving TCP" + << " response from the server: " << ec.message(); + done_callback_(); + return; + } + } + + // Post received data to the current response. + if (bytes_transferred > 0) { + stream_response_->postBuffer(buf_.data(), bytes_transferred); + } + + if (stream_response_->needData()) { + // Response is incomplete, keep reading. + receivePartialResponse(); + } else { + // Response is complete, process it. + responseReceived(); + } + }); + + isc::asiolink::TCPEndpoint from; + if (useTls()) { + tls_socket_->asyncReceive(static_cast<void*>(buf_.data()), buf_.size(), 0, + &from, socket_cb); + } else { + tcp_socket_->asyncReceive(static_cast<void*>(buf_.data()), buf_.size(), 0, + &from, socket_cb); + } + } + + /// @brief Process a completed response received from the server. + virtual void responseReceived() { + /// Unpack wire data into a string. + ASSERT_NO_THROW(stream_response_->unpack()); + std::string response = stream_response_->getRequestString(); + responses_received_.push_back(response); + + // Quit if server tells us "good bye". + if (response.find("good bye", 0) != std::string::npos) { + receive_done_ = true; + done_callback_(); + return; + } + + // Clear out for the next one. + stream_response_.reset(); + sendNextRequest(); + } + + /// @brief Close connection. + void close() { + if (useTls()) { + tls_socket_->close(); + } else { + tcp_socket_->close(); + } + } + + /// @brief Returns true if the receive completed without error. + /// + /// @return True if the receive completed successfully, false + /// otherwise. + bool receiveDone() { + return (receive_done_); + } + + /// @brief Returns true if the receive ended with expected EOF + /// + /// @return True if the receive ended with EOF, false otherwise + bool expectedEof() { + return (expected_eof_); + } + + /// @brief Returns the list of received responses. + /// + /// @return list of string responses. + const std::list<std::string>& getResponses() { + return (responses_received_); + } + + bool handshakeFailed() { + return(handshake_failed_); + } + +private: + + /// @brief Holds reference to the IO service. + boost::asio::io_service& io_service_; + + /// @brief TLS context. + isc::asiolink::TlsContextPtr tls_context_; + + /// @brief TCP socket used by this connection. + std::unique_ptr<isc::asiolink::TCPSocket<SocketCallback> > tcp_socket_; + + /// @brief TLS socket used by this connection. + std::unique_ptr<isc::asiolink::TLSSocket<SocketCallback> > tls_socket_; + + /// @brief Callback to invoke when the client has finished its work or + /// failed. + std::function<void()> done_callback_; + + /// @brief IP address of the server. + std::string server_address_; + + /// @brief IP port of the server. + uint16_t server_port_; + + /// @brief Buffer into which response is written. + std::array<char, 8192> buf_; + + /// @brief Response in the textual format. + std::string response_; + + /// @brief Set to true when the receive has completed successfully. + bool receive_done_; + + /// @brief Set to true when the receive ended in EOF as expected. In other + /// words, the server closed the connection while we were reading as we + /// expected it to do. + bool expected_eof_; + + /// @brief Set to true if the TLS handshake failed. + bool handshake_failed_; + + /// @brief Pointer to the server response currently being received. + isc::tcp::TcpStreamRequestPtr stream_response_; + + /// @brief List of string requests to send. + std::list<std::string> requests_to_send_; + + /// @brief List of string responses received. + std::list<std::string> responses_received_; +}; + +/// @brief Pointer to the TcpTestClient. +typedef boost::shared_ptr<TcpTestClient> TcpTestClientPtr; + +#endif |