summaryrefslogtreecommitdiffstats
path: root/src/seastar/tests/unit/loopback_socket.hh
blob: 888b06928af3895fb1dd4b4f64a92317baca3459 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
/*
 * This file is open source software, licensed to you under the terms
 * of the Apache License, Version 2.0 (the "License").  See the NOTICE file
 * distributed with this work for additional information regarding copyright
 * ownership.  You may not use this file except in compliance with the License.
 *
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
/*
 * Copyright (C) 2016 ScyllaDB
 */

#pragma once

#include <system_error>
#include <seastar/core/iostream.hh>
#include <seastar/core/circular_buffer.hh>
#include <seastar/core/shared_ptr.hh>
#include <seastar/core/queue.hh>
#include <seastar/core/future-util.hh>
#include <seastar/core/do_with.hh>
#include <seastar/net/stack.hh>
#include <seastar/core/reactor.hh>
#include <seastar/core/sharded.hh>

namespace seastar {

struct loopback_error_injector {
    virtual ~loopback_error_injector() {};
    virtual bool server_rcv_error() { return false; }
    virtual bool server_snd_error() { return false; }
    virtual bool client_rcv_error() { return false; }
    virtual bool client_snd_error() { return false; }
};

class loopback_buffer {
public:
    enum class type : uint8_t {
        CLIENT_TX,
        SERVER_TX
    };
private:
    bool _aborted = false;
    queue<temporary_buffer<char>> _q{1};
    loopback_error_injector* _error_injector;
    type _type;
public:
    loopback_buffer(loopback_error_injector* error_injection, type t) : _error_injector(error_injection), _type(t) {}
    future<> push(temporary_buffer<char>&& b) {
        if (_aborted) {
            return make_exception_future<>(std::system_error(EPIPE, std::system_category()));
        }
        bool error = false;
        if (_error_injector) {
            error = _type == type::CLIENT_TX ? _error_injector->client_snd_error() : _error_injector->server_snd_error();
        }
        if (error) {
            shutdown();
            return make_exception_future<>(std::runtime_error("test injected error on send"));
        }
        return _q.push_eventually(std::move(b));
    }
    future<temporary_buffer<char>> pop() {
        if (_aborted) {
            return make_exception_future<temporary_buffer<char>>(std::system_error(EPIPE, std::system_category()));
        }
        bool error = false;
        if (_error_injector) {
            error = _type == type::CLIENT_TX ? _error_injector->client_rcv_error() : _error_injector->server_rcv_error();
        }
        if (error) {
            shutdown();
            return make_exception_future<temporary_buffer<char>>(std::runtime_error("test injected error on receive"));
        }
        return _q.pop_eventually();
    }
    void shutdown() {
        _aborted = true;
        _q.abort(std::make_exception_ptr(std::system_error(EPIPE, std::system_category())));
    }
};

class loopback_data_sink_impl : public data_sink_impl {
    foreign_ptr<lw_shared_ptr<loopback_buffer>>& _buffer;
public:
    explicit loopback_data_sink_impl(foreign_ptr<lw_shared_ptr<loopback_buffer>>& buffer)
            : _buffer(buffer) {
    }
    future<> put(net::packet data) override {
        return do_with(data.release(), [this] (std::vector<temporary_buffer<char>>& bufs) {
            return do_for_each(bufs, [this] (temporary_buffer<char>& buf) {
                return smp::submit_to(_buffer.get_owner_shard(), [this, b = buf.get(), s = buf.size()] {
                    return _buffer->push(temporary_buffer<char>(b, s));
                });
            });
        });
    }
    future<> close() override {
        return smp::submit_to(_buffer.get_owner_shard(), [this] {
            return _buffer->push({});
        });
    }
};

class loopback_data_source_impl : public data_source_impl {
    bool _eof = false;
    lw_shared_ptr<loopback_buffer> _buffer;
public:
    explicit loopback_data_source_impl(lw_shared_ptr<loopback_buffer> buffer)
            : _buffer(std::move(buffer)) {
    }
    future<temporary_buffer<char>> get() override {
        return _buffer->pop().then_wrapped([this] (future<temporary_buffer<char>>&& b) {
            _eof = b.failed();
            if (!_eof) {
                // future::get0() is destructive, so we have to play these games
                // FIXME: make future::get0() non-destructive
                auto&& tmp = b.get0();
                _eof = tmp.empty();
                b = make_ready_future<temporary_buffer<char>>(std::move(tmp));
            }
            return std::move(b);
        });
    }
    future<> close() override {
        if (!_eof) {
            _buffer->shutdown();
        }
        return make_ready_future<>();
    }
};


class loopback_connected_socket_impl : public net::connected_socket_impl {
    foreign_ptr<lw_shared_ptr<loopback_buffer>> _tx;
    lw_shared_ptr<loopback_buffer> _rx;
public:
    loopback_connected_socket_impl(foreign_ptr<lw_shared_ptr<loopback_buffer>> tx, lw_shared_ptr<loopback_buffer> rx)
            : _tx(std::move(tx)), _rx(std::move(rx)) {
    }
    data_source source() override {
        return data_source(std::make_unique<loopback_data_source_impl>(_rx));
    }
    data_sink sink() override {
        return data_sink(std::make_unique<loopback_data_sink_impl>(_tx));
    }
    void shutdown_input() override {
        _rx->shutdown();
    }
    void shutdown_output() override {
        smp::submit_to(_tx.get_owner_shard(), [this] {
            // FIXME: who holds to _tx?
            _tx->shutdown();
        });
    }
    void set_nodelay(bool nodelay) override {
    }
    bool get_nodelay() const override {
        return true;
    }
    void set_keepalive(bool keepalive) override {}
    bool get_keepalive() const override {
        return false;
    }
    void set_keepalive_parameters(const net::keepalive_params&) override {}
    net::keepalive_params get_keepalive_parameters() const override {
        return net::tcp_keepalive_params {std::chrono::seconds(0), std::chrono::seconds(0), 0};
    }
};

class loopback_server_socket_impl : public net::server_socket_impl {
    lw_shared_ptr<queue<connected_socket>> _pending;
public:
    explicit loopback_server_socket_impl(lw_shared_ptr<queue<connected_socket>> q)
            : _pending(std::move(q)) {
    }
    future<connected_socket, socket_address> accept() override {
        return _pending->pop_eventually().then([] (connected_socket&& cs) {
            return make_ready_future<connected_socket, socket_address>(std::move(cs), socket_address());
        });
    }
    void abort_accept() override {
        _pending->abort(std::make_exception_ptr(std::system_error(ECONNABORTED, std::system_category())));
    }
};


class loopback_connection_factory {
    unsigned _shard = 0;
    std::vector<lw_shared_ptr<queue<connected_socket>>> _pending;
public:
    loopback_connection_factory() {
        _pending.resize(smp::count);
    }
    server_socket get_server_socket() {
       if (!_pending[engine().cpu_id()]) {
           _pending[engine().cpu_id()] = make_lw_shared<queue<connected_socket>>(10);
       }
       return server_socket(std::make_unique<loopback_server_socket_impl>(_pending[engine().cpu_id()]));
    }
    future<> make_new_server_connection(foreign_ptr<lw_shared_ptr<loopback_buffer>> b1, lw_shared_ptr<loopback_buffer> b2) {
        if (!_pending[engine().cpu_id()]) {
            _pending[engine().cpu_id()] = make_lw_shared<queue<connected_socket>>(10);
        }
        return _pending[engine().cpu_id()]->push_eventually(connected_socket(std::make_unique<loopback_connected_socket_impl>(std::move(b1), b2)));
    }
    connected_socket make_new_client_connection(lw_shared_ptr<loopback_buffer> b1, foreign_ptr<lw_shared_ptr<loopback_buffer>> b2) {
        return connected_socket(std::make_unique<loopback_connected_socket_impl>(std::move(b2), b1));
    }
    unsigned next_shard() {
        return _shard++ % smp::count;
    }
    void destroy_shard(unsigned shard) {
        _pending[shard] = nullptr;
    }
};

class loopback_socket_impl : public net::socket_impl {
    loopback_connection_factory& _factory;
    loopback_error_injector* _error_injector;
    lw_shared_ptr<loopback_buffer> _b1;
    foreign_ptr<lw_shared_ptr<loopback_buffer>> _b2;
public:
    loopback_socket_impl(loopback_connection_factory& factory, loopback_error_injector* error_injector = nullptr)
            : _factory(factory), _error_injector(error_injector)
    { }
    future<connected_socket> connect(socket_address sa, socket_address local, seastar::transport proto = seastar::transport::TCP) {
        auto shard = _factory.next_shard();
        _b1 = make_lw_shared<loopback_buffer>(_error_injector, loopback_buffer::type::SERVER_TX);
        return smp::submit_to(shard, [this, b1 = make_foreign(_b1)] () mutable {
            auto b2 = make_lw_shared<loopback_buffer>(_error_injector, loopback_buffer::type::CLIENT_TX);
            _b2 = make_foreign(b2);
            return _factory.make_new_server_connection(std::move(b1), b2).then([b2] {
                return make_foreign(b2);
            });
        }).then([this, shard] (foreign_ptr<lw_shared_ptr<loopback_buffer>> b2) {
            return _factory.make_new_client_connection(_b1, std::move(b2));
        });
    }

    void shutdown() {
        _b1->shutdown();
        smp::submit_to(_b2.get_owner_shard(), [this, b2 = std::move(_b2)] {
            b2->shutdown();
        });
    }
};

}