diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-04 18:00:34 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-04 18:00:34 +0000 |
commit | 3f619478f796eddbba6e39502fe941b285dd97b1 (patch) | |
tree | e2c7b5777f728320e5b5542b6213fd3591ba51e2 /wsrep-lib/dbsim/db_tls.cpp | |
parent | Initial commit. (diff) | |
download | mariadb-upstream.tar.xz mariadb-upstream.zip |
Adding upstream version 1:10.11.6.upstream/1%10.11.6upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'wsrep-lib/dbsim/db_tls.cpp')
-rw-r--r-- | wsrep-lib/dbsim/db_tls.cpp | 451 |
1 files changed, 451 insertions, 0 deletions
diff --git a/wsrep-lib/dbsim/db_tls.cpp b/wsrep-lib/dbsim/db_tls.cpp new file mode 100644 index 00000000..c242668c --- /dev/null +++ b/wsrep-lib/dbsim/db_tls.cpp @@ -0,0 +1,451 @@ +/* + * Copyright (C) 2020 Codership Oy <info@codership.com> + * + * This file is part of wsrep-lib. + * + * Wsrep-lib is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * Wsrep-lib is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with wsrep-lib. If not, see <https://www.gnu.org/licenses/>. + */ + +/** @file db_tls.cpp + * + * This file demonstrates the use of TLS service. It does not implement + * real encryption, but may manipulate stream bytes for testing purposes. + */ + +#include "db_tls.hpp" + +#include "wsrep/logger.hpp" + +#include <unistd.h> // read() +#include <fcntl.h> +#include <sys/types.h> +#include <sys/socket.h> // send() +#include <cassert> +#include <cerrno> +#include <cstring> + +#include <mutex> +#include <string> + +namespace +{ + class db_stream : public wsrep::tls_stream + { + public: + db_stream(int fd, int mode) + : fd_(fd) + , state_(s_initialized) + , last_error_() + , mode_(mode) + , stats_() + , is_blocking_() + { + int val(fcntl(fd_, F_GETFL, 0)); + is_blocking_ = not (val & O_NONBLOCK); + } + struct stats + { + size_t bytes_read{0}; + size_t bytes_written{0}; + }; + + /* + * in idle --| + * |-> ch -| ^ | -> want_read --| + * |-> sh -| ---- |--> | -> want_write --| + * |----------------------| + */ + enum state + { + s_initialized, + s_client_handshake, + s_server_handshake, + s_idle, + s_want_read, + s_want_write + }; + + int get_error_number() const { return last_error_; } + const void* get_error_category() const + { + return reinterpret_cast<const void*>(1); + } + + static char* get_error_message(int value, const void*) + { + return ::strerror(value); + } + + enum wsrep::tls_service::status client_handshake(); + + enum wsrep::tls_service::status server_handshake(); + + wsrep::tls_service::op_result read(void*, size_t); + + wsrep::tls_service::op_result write(const void*, size_t); + + enum state state() const { return state_; } + + int fd() const { return fd_; } + void inc_reads(size_t val) { stats_.bytes_read += val; } + void inc_writes(size_t val) { stats_.bytes_written += val; } + const stats& get_stats() const { return stats_; } + private: + enum wsrep::tls_service::status handle_handshake_read(const char* expect); + size_t determine_read_count(size_t max_count) + { + if (is_blocking_ || mode_ < 2) return max_count; + else if (::rand() % 100 == 0) return std::min(size_t(42), max_count); + else return max_count; + } + size_t determine_write_count(size_t count) + { + if (is_blocking_ || mode_ < 2) return count; + else if (::rand() % 100 == 0) return std::min(size_t(43), count); + else return count; + } + + ssize_t do_read(void* buf, size_t max_count) + { + if (is_blocking_ || mode_ < 3 ) + return ::read(fd_, buf, max_count); + else if (::rand() % 1000 == 0) return EINTR; + else return ::read(fd_, buf, max_count); + } + + ssize_t do_write(const void* buf, size_t count) + { + if (is_blocking_ || mode_ < 3) + return ::send(fd_, buf, count, MSG_NOSIGNAL); + else if (::rand() % 1000 == 0) return EINTR; + else return ::send(fd_, buf, count, MSG_NOSIGNAL); + } + + wsrep::tls_service::op_result map_success(ssize_t result) + { + if (is_blocking_ || mode_ < 2) + { + return wsrep::tls_service::op_result{ + wsrep::tls_service::success, size_t(result)}; + } + else if (::rand() % 1000 == 0) + { + wsrep::log_info() << "Success want extra read"; + state_ = s_want_read; + return wsrep::tls_service::op_result{ + wsrep::tls_service::want_read, size_t(result)}; + } + else if (::rand() % 1000 == 0) + { + wsrep::log_info() << "Success want extra write"; + state_ = s_want_write; + return wsrep::tls_service::op_result{ + wsrep::tls_service::want_write, size_t(result)}; + } + else + { + return wsrep::tls_service::op_result{ + wsrep::tls_service::success, size_t(result)}; + } + } + + wsrep::tls_service::op_result map_result(ssize_t result) + { + if (result > 0) + { + return map_success(result); + } + else if (result == 0) + { + return wsrep::tls_service::op_result{ + wsrep::tls_service::eof, 0}; + } + else if (errno == EAGAIN || errno == EWOULDBLOCK) + { + return wsrep::tls_service::op_result{ + wsrep::tls_service::want_read, 0}; + } + else + { + last_error_ = errno; + return wsrep::tls_service::op_result{ + wsrep::tls_service::error, 0}; + } + } + + void clear_error() { last_error_ = 0; } + + int fd_; + enum state state_; + int last_error_; + // Operation mode: + // 1 - simulate handshake exchange + // 2 - simulate errors and short reads + int mode_; + stats stats_; + bool is_blocking_; + }; + + enum wsrep::tls_service::status db_stream::client_handshake() + { + clear_error(); + enum wsrep::tls_service::status ret; + assert(state_ == s_initialized || + state_ == s_client_handshake || + state_ == s_want_write); + if (state_ == s_initialized) + { + (void)::send(fd_, "clie", 4, MSG_NOSIGNAL); + ret = wsrep::tls_service::want_read; + state_ = s_client_handshake; + wsrep::log_info() << this << " client handshake sent"; + stats_.bytes_written += 4; + if (not is_blocking_) return ret; + } + + if (state_ == s_client_handshake) + { + if ((ret = handle_handshake_read("serv")) == + wsrep::tls_service::success) + { + state_ = s_want_write; + ret = wsrep::tls_service::want_write; + } + if (not is_blocking_) return ret; + } + + if (state_ == s_want_write) + { + state_ = s_idle; + ret = wsrep::tls_service::success; + if (not is_blocking_) return ret; + } + + if (not is_blocking_) + { + last_error_ = EPROTO; + ret = wsrep::tls_service::error; + } + return ret; + } + + + + enum wsrep::tls_service::status db_stream::server_handshake() + { + enum wsrep::tls_service::status ret; + assert(state_ == s_initialized || + state_ == s_server_handshake || + state_ == s_want_write); + + if (state_ == s_initialized) + { + ::send(fd_, "serv", 4, MSG_NOSIGNAL); + ret = wsrep::tls_service::want_read; + state_ = s_server_handshake; + stats_.bytes_written += 4; + if (not is_blocking_) return ret; + } + + if (state_ == s_server_handshake) + { + if ((ret = handle_handshake_read("clie")) == + wsrep::tls_service::success) + { + state_ = s_want_write; + ret = wsrep::tls_service::want_write; + } + if (not is_blocking_) return ret; + } + + if (state_ == s_want_write) + { + state_ = s_idle; + ret = wsrep::tls_service::success; + if (not is_blocking_) return ret; + } + + if (not is_blocking_) + { + last_error_ = EPROTO; + ret = wsrep::tls_service::error; + } + return ret; + } + + enum wsrep::tls_service::status db_stream::handle_handshake_read( + const char* expect) + { + assert(::strlen(expect) >= 4); + char buf[4] = { }; + ssize_t read_result(::read(fd_, buf, sizeof(buf))); + if (read_result > 0) stats_.bytes_read += size_t(read_result); + enum wsrep::tls_service::status ret; + if (read_result == -1 && + (errno == EWOULDBLOCK || errno == EAGAIN)) + { + ret = wsrep::tls_service::want_read; + } + else if (read_result == 0) + { + ret = wsrep::tls_service::eof; + } + else if (read_result != 4 || ::memcmp(buf, expect, 4)) + { + last_error_ = EPROTO; + ret = wsrep::tls_service::error; + } + else + { + wsrep::log_info() << "Handshake success: " << std::string(buf, 4); + ret = wsrep::tls_service::success; + } + return ret; + } + + wsrep::tls_service::op_result db_stream::read(void* buf, size_t max_count) + { + clear_error(); + if (state_ == s_want_read) + { + state_ = s_idle; + if (max_count == 0) + return wsrep::tls_service::op_result{ + wsrep::tls_service::success, 0}; + } + max_count = determine_read_count(max_count); + ssize_t read_result(do_read(buf, max_count)); + if (read_result > 0) + { + inc_reads(size_t(read_result)); + } + return map_result(read_result); + } + + wsrep::tls_service::op_result db_stream::write( + const void* buf, size_t count) + { + clear_error(); + if (state_ == s_want_write) + { + state_ = s_idle; + if (count == 0) + return wsrep::tls_service::op_result{ + wsrep::tls_service::success, 0}; + } + count = determine_write_count(count); + ssize_t write_result(do_write(buf, count)); + if (write_result > 0) + { + inc_writes(size_t(write_result)); + } + return map_result(write_result); + } +} + + +static db_stream::stats global_stats; +std::mutex global_stats_lock; +static int global_mode; + +static void merge_to_global_stats(const db_stream::stats& stats) +{ + std::lock_guard<std::mutex> lock(global_stats_lock); + global_stats.bytes_read += stats.bytes_read; + global_stats.bytes_written += stats.bytes_written; +} + +wsrep::tls_stream* db::tls::create_tls_stream(int fd) WSREP_NOEXCEPT +{ + auto ret(new db_stream(fd, global_mode)); + wsrep::log_debug() << "New DB stream: " << ret; + return ret; +} + +void db::tls::destroy(wsrep::tls_stream* stream) WSREP_NOEXCEPT +{ + auto dbs(static_cast<db_stream*>(stream)); + merge_to_global_stats(dbs->get_stats()); + wsrep::log_debug() << "Stream destroy: " << dbs->get_stats().bytes_read + << " " << dbs->get_stats().bytes_written; + wsrep::log_debug() << "Stream destroy" << dbs; + delete dbs; +} + +int db::tls::get_error_number(const wsrep::tls_stream* stream) + const WSREP_NOEXCEPT +{ + return static_cast<const db_stream*>(stream)->get_error_number(); +} + +const void* db::tls::get_error_category(const wsrep::tls_stream* stream) + const WSREP_NOEXCEPT +{ + return static_cast<const db_stream*>(stream)->get_error_category(); +} + +const char* db::tls::get_error_message(const wsrep::tls_stream*, + int value, const void* category) + const WSREP_NOEXCEPT +{ + return db_stream::get_error_message(value, category); +} + +enum wsrep::tls_service::status +db::tls::client_handshake(wsrep::tls_stream* stream) WSREP_NOEXCEPT +{ + return static_cast<db_stream*>(stream)->client_handshake(); +} + +enum wsrep::tls_service::status +db::tls::server_handshake(wsrep::tls_stream* stream) WSREP_NOEXCEPT +{ + return static_cast<db_stream*>(stream)->server_handshake(); +} + +wsrep::tls_service::op_result db::tls::read( + wsrep::tls_stream* stream, + void* buf, size_t max_count) WSREP_NOEXCEPT +{ + return static_cast<db_stream*>(stream)->read(buf, max_count); +} + +wsrep::tls_service::op_result db::tls::write( + wsrep::tls_stream* stream, + const void* buf, size_t count) WSREP_NOEXCEPT +{ + return static_cast<db_stream*>(stream)->write(buf, count); +} + +wsrep::tls_service::status +db::tls::shutdown(wsrep::tls_stream*) WSREP_NOEXCEPT +{ + // @todo error simulation + return wsrep::tls_service::success; +} + + +void db::tls::init(int mode) +{ + global_mode = mode; +} + +std::string db::tls::stats() +{ + std::ostringstream oss; + oss << "Transport stats:\n" + << " bytes_read: " << global_stats.bytes_read << "\n" + << " bytes_written: " << global_stats.bytes_written << "\n"; + return oss.str(); +} |