/* * This file is open source software, licensed to you under the terms * of the Apache License, Version 2.0 (the "License"). See the NOTICE file * distributed with this work for additional information regarding copyright * ownership. You may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /* * Copyright (C) 2016 ScyllaDB */ #pragma once #include #include #include #include #include #include #include #include #include #include namespace seastar { struct loopback_error_injector { virtual ~loopback_error_injector() {}; virtual bool server_rcv_error() { return false; } virtual bool server_snd_error() { return false; } virtual bool client_rcv_error() { return false; } virtual bool client_snd_error() { return false; } }; class loopback_buffer { public: enum class type : uint8_t { CLIENT_TX, SERVER_TX }; private: bool _aborted = false; queue> _q{1}; loopback_error_injector* _error_injector; type _type; public: loopback_buffer(loopback_error_injector* error_injection, type t) : _error_injector(error_injection), _type(t) {} future<> push(temporary_buffer&& b) { if (_aborted) { return make_exception_future<>(std::system_error(EPIPE, std::system_category())); } bool error = false; if (_error_injector) { error = _type == type::CLIENT_TX ? _error_injector->client_snd_error() : _error_injector->server_snd_error(); } if (error) { shutdown(); return make_exception_future<>(std::runtime_error("test injected error on send")); } return _q.push_eventually(std::move(b)); } future> pop() { if (_aborted) { return make_exception_future>(std::system_error(EPIPE, std::system_category())); } bool error = false; if (_error_injector) { error = _type == type::CLIENT_TX ? _error_injector->client_rcv_error() : _error_injector->server_rcv_error(); } if (error) { shutdown(); return make_exception_future>(std::runtime_error("test injected error on receive")); } return _q.pop_eventually(); } void shutdown() { _aborted = true; _q.abort(std::make_exception_ptr(std::system_error(EPIPE, std::system_category()))); } }; class loopback_data_sink_impl : public data_sink_impl { foreign_ptr>& _buffer; public: explicit loopback_data_sink_impl(foreign_ptr>& buffer) : _buffer(buffer) { } future<> put(net::packet data) override { return do_with(data.release(), [this] (std::vector>& bufs) { return do_for_each(bufs, [this] (temporary_buffer& buf) { return smp::submit_to(_buffer.get_owner_shard(), [this, b = buf.get(), s = buf.size()] { return _buffer->push(temporary_buffer(b, s)); }); }); }); } future<> close() override { return smp::submit_to(_buffer.get_owner_shard(), [this] { return _buffer->push({}); }); } }; class loopback_data_source_impl : public data_source_impl { bool _eof = false; lw_shared_ptr _buffer; public: explicit loopback_data_source_impl(lw_shared_ptr buffer) : _buffer(std::move(buffer)) { } future> get() override { return _buffer->pop().then_wrapped([this] (future>&& b) { _eof = b.failed(); if (!_eof) { // future::get0() is destructive, so we have to play these games // FIXME: make future::get0() non-destructive auto&& tmp = b.get0(); _eof = tmp.empty(); b = make_ready_future>(std::move(tmp)); } return std::move(b); }); } future<> close() override { if (!_eof) { _buffer->shutdown(); } return make_ready_future<>(); } }; class loopback_connected_socket_impl : public net::connected_socket_impl { foreign_ptr> _tx; lw_shared_ptr _rx; public: loopback_connected_socket_impl(foreign_ptr> tx, lw_shared_ptr rx) : _tx(std::move(tx)), _rx(std::move(rx)) { } data_source source() override { return data_source(std::make_unique(_rx)); } data_sink sink() override { return data_sink(std::make_unique(_tx)); } void shutdown_input() override { _rx->shutdown(); } void shutdown_output() override { smp::submit_to(_tx.get_owner_shard(), [this] { // FIXME: who holds to _tx? _tx->shutdown(); }); } void set_nodelay(bool nodelay) override { } bool get_nodelay() const override { return true; } void set_keepalive(bool keepalive) override {} bool get_keepalive() const override { return false; } void set_keepalive_parameters(const net::keepalive_params&) override {} net::keepalive_params get_keepalive_parameters() const override { return net::tcp_keepalive_params {std::chrono::seconds(0), std::chrono::seconds(0), 0}; } }; class loopback_server_socket_impl : public net::server_socket_impl { lw_shared_ptr> _pending; public: explicit loopback_server_socket_impl(lw_shared_ptr> q) : _pending(std::move(q)) { } future accept() override { return _pending->pop_eventually().then([] (connected_socket&& cs) { return make_ready_future(std::move(cs), socket_address()); }); } void abort_accept() override { _pending->abort(std::make_exception_ptr(std::system_error(ECONNABORTED, std::system_category()))); } }; class loopback_connection_factory { unsigned _shard = 0; std::vector>> _pending; public: loopback_connection_factory() { _pending.resize(smp::count); } server_socket get_server_socket() { if (!_pending[engine().cpu_id()]) { _pending[engine().cpu_id()] = make_lw_shared>(10); } return server_socket(std::make_unique(_pending[engine().cpu_id()])); } future<> make_new_server_connection(foreign_ptr> b1, lw_shared_ptr b2) { if (!_pending[engine().cpu_id()]) { _pending[engine().cpu_id()] = make_lw_shared>(10); } return _pending[engine().cpu_id()]->push_eventually(connected_socket(std::make_unique(std::move(b1), b2))); } connected_socket make_new_client_connection(lw_shared_ptr b1, foreign_ptr> b2) { return connected_socket(std::make_unique(std::move(b2), b1)); } unsigned next_shard() { return _shard++ % smp::count; } void destroy_shard(unsigned shard) { _pending[shard] = nullptr; } }; class loopback_socket_impl : public net::socket_impl { loopback_connection_factory& _factory; loopback_error_injector* _error_injector; lw_shared_ptr _b1; foreign_ptr> _b2; public: loopback_socket_impl(loopback_connection_factory& factory, loopback_error_injector* error_injector = nullptr) : _factory(factory), _error_injector(error_injector) { } future connect(socket_address sa, socket_address local, seastar::transport proto = seastar::transport::TCP) { auto shard = _factory.next_shard(); _b1 = make_lw_shared(_error_injector, loopback_buffer::type::SERVER_TX); return smp::submit_to(shard, [this, b1 = make_foreign(_b1)] () mutable { auto b2 = make_lw_shared(_error_injector, loopback_buffer::type::CLIENT_TX); _b2 = make_foreign(b2); return _factory.make_new_server_connection(std::move(b1), b2).then([b2] { return make_foreign(b2); }); }).then([this, shard] (foreign_ptr> b2) { return _factory.make_new_client_connection(_b1, std::move(b2)); }); } void shutdown() { _b1->shutdown(); smp::submit_to(_b2.get_owner_shard(), [this, b2 = std::move(_b2)] { b2->shutdown(); }); } }; }