1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
|
// Copyright (C) 2022-2023 Internet Systems Consortium, Inc. ("ISC")
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef TCP_TEST_CLIENT_H
#define TCP_TEST_CLIENT_H
#include <cc/data.h>
#include <asiolink/tcp_socket.h>
#include <asiolink/tls_socket.h>
#include <asiolink/testutils/test_tls.h>
#include <tcp/tcp_connection.h>
#include <tcp/tcp_stream_msg.h>
#include <boost/asio/read.hpp>
#include <boost/asio/buffer.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <gtest/gtest.h>
/// @brief Entity which can connect to the TCP server endpoint with or
/// or without TLS.
class TcpTestClient : public boost::noncopyable {
private:
/// @brief Type of the function implementing a callback invoked by the
/// @c SocketCallback functor.
typedef std::function<void(boost::system::error_code ec, size_t length)>
SocketCallbackFunction;
/// @brief Functor associated with the socket object.
///
/// This functor calls a callback function specified in the constructor.
class SocketCallback {
public:
/// @brief Constructor.
///
/// @param socket_callback Callback to be invoked by the functor upon
/// an event associated with the socket.
SocketCallback(SocketCallbackFunction socket_callback)
: callback_(socket_callback) {
}
/// @brief Operator called when event associated with a socket occurs.
///
/// This operator returns immediately when received error code is
/// @c boost::system::error_code is equal to
/// @c boost::asio::error::operation_aborted, i.e. the callback is not
/// invoked.
///
/// @param ec Error code.
/// @param length Data length.
void operator()(boost::system::error_code ec, size_t length = 0) {
if (ec.value() == boost::asio::error::operation_aborted) {
return;
}
callback_(ec, length);
}
private:
/// @brief Supplied callback.
SocketCallbackFunction callback_;
};
public:
/// @brief Constructor.
///
/// This constructor creates new socket instance. It doesn't connect. Call
/// start() to connect to the server.
///
/// @param io_service IO service to be stopped on error or completion.
/// @param done_callback Function cient should invoke when it has finished
/// all its requests or failed.
/// @param tls_context
/// @param server_address string containing the IP address of the server.
/// @param port port number of the server.
explicit TcpTestClient(isc::asiolink::IOService& io_service,
std::function<void()> done_callback,
isc::asiolink::TlsContextPtr tls_context =
isc::asiolink::TlsContextPtr(),
const std::string& server_address = "127.0.0.1",
uint16_t port = 18123)
: io_service_(io_service.get_io_service()),
tls_context_(tls_context),
tcp_socket_(), tls_socket_(),
done_callback_(done_callback),
server_address_(server_address), server_port_(port),
buf_(), response_(),
receive_done_(false), expected_eof_(false), handshake_failed_(false) {
if (!tls_context_) {
tcp_socket_.reset(new isc::asiolink::TCPSocket<SocketCallback>(io_service));
} else {
tls_socket_.reset(new isc::asiolink::TLSSocket<SocketCallback>(io_service,
tls_context));
}
}
bool useTls() {
return (!!tls_context_);
}
/// @brief Destructor.
///
/// Closes the underlying socket if it is open.
virtual ~TcpTestClient() {
close();
}
/// @brief Connect to the listener and initiate request processing.
///
/// Upon successful connection, carry out the TLS handshake. If the handshake
/// completes successful start sending requests.
void start() {
isc::asiolink::TCPEndpoint endpoint(boost::asio::ip::address::from_string(server_address_), server_port_);
SocketCallback socket_cb(
[this](boost::system::error_code ec, size_t /*length */) {
receive_done_ = false;
expected_eof_ = false;
handshake_failed_ = false;
if (ec) {
// One would expect that open wouldn't return
// EINPROGRESS error code, but simply wait for the connection
// to get established before the handler is invoked. It turns out,
// however, that on some OSes the connect handler may receive this
// error code which doesn't necessarily indicate a problem.
// Making an attempt to write and read from this socket will
// typically succeed. So, we ignore this error.
if (ec.value() != boost::asio::error::in_progress) {
ADD_FAILURE() << "error occurred while connecting: "
<< ec.message();
done_callback_();
}
}
if (useTls()) {
SocketCallback socket_cb(
[this](boost::system::error_code ec, size_t /*length */) {
if (ec) {
handshake_failed_ = true;
done_callback_();
} else {
sendNextRequest();
}
});
tls_socket_->handshake(socket_cb);
} else {
sendNextRequest();
}
});
if (useTls()) {
tls_socket_->open(&endpoint, socket_cb);
} else {
tcp_socket_->open(&endpoint, socket_cb);
}
}
/// @brief Send request specified in textual format.
///
/// @param request request in the textual format.
void startRequest(const std::string& request) {
requests_to_send_.push_back(request);
start();
}
/// @brief Send request specified in textual format.
///
/// @param request request in the textual format.
void startRequests(const std::list<std::string>& requests) {
requests_to_send_ = requests;
start();
}
/// @brief Sends the next request from the list of requests to send.
void sendNextRequest() {
// If there are any requests left to send, send them.
if (!requests_to_send_.empty()) {
std::string request = requests_to_send_.front();
requests_to_send_.pop_front();
if (request.empty()) {
waitForEof();
} else {
sendRequest(request);
}
}
}
/// @brief Send a stream request.
///
/// @param request request data to send textual format.
/// @param send_length number of bytes to send. If not zero, can be used
/// to truncate the amount of data sent.
void sendRequest(const std::string& request, const size_t send_length = 0) {
// Prepend the length of the request.
uint16_t size = static_cast<uint16_t>(request.size());
isc::tcp::WireData wire_request;
if (!request.empty()) {
wire_request.push_back(static_cast<uint8_t>((size & 0xff00U) >> 8));
wire_request.push_back(static_cast<uint8_t>(size & 0x00ffU));
wire_request.insert(wire_request.end(), request.begin(), request.end());
}
sendPartialRequest(wire_request, send_length);
}
/// @brief Wait for a server to close the connection.
void waitForEof() {
stream_response_.reset(new isc::tcp::TcpStreamRequest());
receivePartialResponse(true);
}
/// @brief Send part of the request.
///
/// @param request part of the request to be sent.
/// @param send_length number of bytes to send. If not zero, can be used
/// to truncate the amount of data sent.
void sendPartialRequest(isc::tcp::WireData& wire_request, size_t send_length = 0) {
if (!send_length) {
send_length = wire_request.size();
} else {
ASSERT_LE(send_length, wire_request.size())
<< "broken test, send_length exceeds wire size";
}
SocketCallback socket_cb(
[this, wire_request](boost::system::error_code ec, size_t bytes_transferred) mutable {
if (ec) {
if (ec.value() == boost::asio::error::operation_aborted) {
return;
} else if ((ec.value() == boost::asio::error::try_again) ||
(ec.value() == boost::asio::error::would_block)) {
// If we should try again make sure there is no garbage in the
// bytes_transferred.
bytes_transferred = 0;
} else {
ADD_FAILURE() << "error occurred while connecting: "
<< ec.message();
done_callback_();
return;
}
}
// Remove the part of the request which has been sent.
if (bytes_transferred > 0 && (wire_request.size() <= bytes_transferred)) {
wire_request.erase(wire_request.begin(),
(wire_request.begin() + bytes_transferred));
}
// Continue sending request data if there are still some data to be
// sent.
if (!wire_request.empty()) {
sendPartialRequest(wire_request);
} else {
// Request has been sent. Start receiving response.
receivePartialResponse();
}
});
if (useTls()) {
tls_socket_->asyncSend(static_cast<const void *>(wire_request.data()),
send_length, socket_cb);
} else {
tcp_socket_->asyncSend(static_cast<const void *>(wire_request.data()),
send_length, socket_cb);
}
}
/// @brief Receive response from the server.
void receivePartialResponse(bool expect_eof = false) {
SocketCallback socket_cb(
[this, expect_eof](const boost::system::error_code& ec,
std::size_t bytes_transferred) {
if (!stream_response_) {
stream_response_.reset(new isc::tcp::TcpStreamRequest());
}
if (ec) {
// IO service stopped so simply return.
if (ec.value() == boost::asio::error::operation_aborted) {
return;
} else if ((ec.value() == boost::asio::error::try_again) ||
(ec.value() == boost::asio::error::would_block)) {
// If we should try again, make sure that there is no garbage
// in the bytes_transferred.
bytes_transferred = 0;
} else if (expect_eof) {
expected_eof_ = true;
done_callback_();
return;
} else {
// Error occurred, bail...
ADD_FAILURE() << "client: " << this
<< " error occurred while receiving TCP"
<< " response from the server: " << ec.message();
done_callback_();
return;
}
}
// Post received data to the current response.
if (bytes_transferred > 0) {
stream_response_->postBuffer(buf_.data(), bytes_transferred);
}
if (stream_response_->needData()) {
// Response is incomplete, keep reading.
receivePartialResponse();
} else {
// Response is complete, process it.
responseReceived();
}
});
isc::asiolink::TCPEndpoint from;
if (useTls()) {
tls_socket_->asyncReceive(static_cast<void*>(buf_.data()), buf_.size(), 0,
&from, socket_cb);
} else {
tcp_socket_->asyncReceive(static_cast<void*>(buf_.data()), buf_.size(), 0,
&from, socket_cb);
}
}
/// @brief Process a completed response received from the server.
virtual void responseReceived() {
/// Unpack wire data into a string.
ASSERT_NO_THROW(stream_response_->unpack());
std::string response = stream_response_->getRequestString();
responses_received_.push_back(response);
// Quit if server tells us "good bye".
if (response.find("good bye", 0) != std::string::npos) {
receive_done_ = true;
done_callback_();
return;
}
// Clear out for the next one.
stream_response_.reset();
sendNextRequest();
}
/// @brief Close connection.
void close() {
if (useTls()) {
tls_socket_->close();
} else {
tcp_socket_->close();
}
}
/// @brief Returns true if the receive completed without error.
///
/// @return True if the receive completed successfully, false
/// otherwise.
bool receiveDone() {
return (receive_done_);
}
/// @brief Returns true if the receive ended with expected EOF
///
/// @return True if the receive ended with EOF, false otherwise
bool expectedEof() {
return (expected_eof_);
}
/// @brief Returns the list of received responses.
///
/// @return list of string responses.
const std::list<std::string>& getResponses() {
return (responses_received_);
}
bool handshakeFailed() {
return(handshake_failed_);
}
private:
/// @brief Holds reference to the IO service.
boost::asio::io_service& io_service_;
/// @brief TLS context.
isc::asiolink::TlsContextPtr tls_context_;
/// @brief TCP socket used by this connection.
std::unique_ptr<isc::asiolink::TCPSocket<SocketCallback> > tcp_socket_;
/// @brief TLS socket used by this connection.
std::unique_ptr<isc::asiolink::TLSSocket<SocketCallback> > tls_socket_;
/// @brief Callback to invoke when the client has finished its work or
/// failed.
std::function<void()> done_callback_;
/// @brief IP address of the server.
std::string server_address_;
/// @brief IP port of the server.
uint16_t server_port_;
/// @brief Buffer into which response is written.
std::array<char, 8192> buf_;
/// @brief Response in the textual format.
std::string response_;
/// @brief Set to true when the receive has completed successfully.
bool receive_done_;
/// @brief Set to true when the receive ended in EOF as expected. In other
/// words, the server closed the connection while we were reading as we
/// expected it to do.
bool expected_eof_;
/// @brief Set to true if the TLS handshake failed.
bool handshake_failed_;
/// @brief Pointer to the server response currently being received.
isc::tcp::TcpStreamRequestPtr stream_response_;
/// @brief List of string requests to send.
std::list<std::string> requests_to_send_;
/// @brief List of string responses received.
std::list<std::string> responses_received_;
};
/// @brief Pointer to the TcpTestClient.
typedef boost::shared_ptr<TcpTestClient> TcpTestClientPtr;
#endif
|