/* * Copyright (C) 2020 Codership Oy * * This file is part of wsrep-lib. * * Wsrep-lib is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 2 of the License, or * (at your option) any later version. * * Wsrep-lib is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with wsrep-lib. If not, see . */ /** @file db_tls.cpp * * This file demonstrates the use of TLS service. It does not implement * real encryption, but may manipulate stream bytes for testing purposes. */ #include "db_tls.hpp" #include "wsrep/logger.hpp" #include // read() #include #include #include // send() #include #include #include #include #include namespace { class db_stream : public wsrep::tls_stream { public: db_stream(int fd, int mode) : fd_(fd) , state_(s_initialized) , last_error_() , mode_(mode) , stats_() , is_blocking_() { int val(fcntl(fd_, F_GETFL, 0)); is_blocking_ = not (val & O_NONBLOCK); } struct stats { size_t bytes_read{0}; size_t bytes_written{0}; }; /* * in idle --| * |-> ch -| ^ | -> want_read --| * |-> sh -| ---- |--> | -> want_write --| * |----------------------| */ enum state { s_initialized, s_client_handshake, s_server_handshake, s_idle, s_want_read, s_want_write }; int get_error_number() const { return last_error_; } const void* get_error_category() const { return reinterpret_cast(1); } static char* get_error_message(int value, const void*) { return ::strerror(value); } enum wsrep::tls_service::status client_handshake(); enum wsrep::tls_service::status server_handshake(); wsrep::tls_service::op_result read(void*, size_t); wsrep::tls_service::op_result write(const void*, size_t); enum state state() const { return state_; } int fd() const { return fd_; } void inc_reads(size_t val) { stats_.bytes_read += val; } void inc_writes(size_t val) { stats_.bytes_written += val; } const stats& get_stats() const { return stats_; } private: enum wsrep::tls_service::status handle_handshake_read(const char* expect); size_t determine_read_count(size_t max_count) { if (is_blocking_ || mode_ < 2) return max_count; else if (::rand() % 100 == 0) return std::min(size_t(42), max_count); else return max_count; } size_t determine_write_count(size_t count) { if (is_blocking_ || mode_ < 2) return count; else if (::rand() % 100 == 0) return std::min(size_t(43), count); else return count; } ssize_t do_read(void* buf, size_t max_count) { if (is_blocking_ || mode_ < 3 ) return ::read(fd_, buf, max_count); else if (::rand() % 1000 == 0) return EINTR; else return ::read(fd_, buf, max_count); } ssize_t do_write(const void* buf, size_t count) { if (is_blocking_ || mode_ < 3) return ::send(fd_, buf, count, MSG_NOSIGNAL); else if (::rand() % 1000 == 0) return EINTR; else return ::send(fd_, buf, count, MSG_NOSIGNAL); } wsrep::tls_service::op_result map_success(ssize_t result) { if (is_blocking_ || mode_ < 2) { return wsrep::tls_service::op_result{ wsrep::tls_service::success, size_t(result)}; } else if (::rand() % 1000 == 0) { wsrep::log_info() << "Success want extra read"; state_ = s_want_read; return wsrep::tls_service::op_result{ wsrep::tls_service::want_read, size_t(result)}; } else if (::rand() % 1000 == 0) { wsrep::log_info() << "Success want extra write"; state_ = s_want_write; return wsrep::tls_service::op_result{ wsrep::tls_service::want_write, size_t(result)}; } else { return wsrep::tls_service::op_result{ wsrep::tls_service::success, size_t(result)}; } } wsrep::tls_service::op_result map_result(ssize_t result) { if (result > 0) { return map_success(result); } else if (result == 0) { return wsrep::tls_service::op_result{ wsrep::tls_service::eof, 0}; } else if (errno == EAGAIN || errno == EWOULDBLOCK) { return wsrep::tls_service::op_result{ wsrep::tls_service::want_read, 0}; } else { last_error_ = errno; return wsrep::tls_service::op_result{ wsrep::tls_service::error, 0}; } } void clear_error() { last_error_ = 0; } int fd_; enum state state_; int last_error_; // Operation mode: // 1 - simulate handshake exchange // 2 - simulate errors and short reads int mode_; stats stats_; bool is_blocking_; }; enum wsrep::tls_service::status db_stream::client_handshake() { clear_error(); enum wsrep::tls_service::status ret; assert(state_ == s_initialized || state_ == s_client_handshake || state_ == s_want_write); if (state_ == s_initialized) { (void)::send(fd_, "clie", 4, MSG_NOSIGNAL); ret = wsrep::tls_service::want_read; state_ = s_client_handshake; wsrep::log_info() << this << " client handshake sent"; stats_.bytes_written += 4; if (not is_blocking_) return ret; } if (state_ == s_client_handshake) { if ((ret = handle_handshake_read("serv")) == wsrep::tls_service::success) { state_ = s_want_write; ret = wsrep::tls_service::want_write; } if (not is_blocking_) return ret; } if (state_ == s_want_write) { state_ = s_idle; ret = wsrep::tls_service::success; if (not is_blocking_) return ret; } if (not is_blocking_) { last_error_ = EPROTO; ret = wsrep::tls_service::error; } return ret; } enum wsrep::tls_service::status db_stream::server_handshake() { enum wsrep::tls_service::status ret; assert(state_ == s_initialized || state_ == s_server_handshake || state_ == s_want_write); if (state_ == s_initialized) { ::send(fd_, "serv", 4, MSG_NOSIGNAL); ret = wsrep::tls_service::want_read; state_ = s_server_handshake; stats_.bytes_written += 4; if (not is_blocking_) return ret; } if (state_ == s_server_handshake) { if ((ret = handle_handshake_read("clie")) == wsrep::tls_service::success) { state_ = s_want_write; ret = wsrep::tls_service::want_write; } if (not is_blocking_) return ret; } if (state_ == s_want_write) { state_ = s_idle; ret = wsrep::tls_service::success; if (not is_blocking_) return ret; } if (not is_blocking_) { last_error_ = EPROTO; ret = wsrep::tls_service::error; } return ret; } enum wsrep::tls_service::status db_stream::handle_handshake_read( const char* expect) { assert(::strlen(expect) >= 4); char buf[4] = { }; ssize_t read_result(::read(fd_, buf, sizeof(buf))); if (read_result > 0) stats_.bytes_read += size_t(read_result); enum wsrep::tls_service::status ret; if (read_result == -1 && (errno == EWOULDBLOCK || errno == EAGAIN)) { ret = wsrep::tls_service::want_read; } else if (read_result == 0) { ret = wsrep::tls_service::eof; } else if (read_result != 4 || ::memcmp(buf, expect, 4)) { last_error_ = EPROTO; ret = wsrep::tls_service::error; } else { wsrep::log_info() << "Handshake success: " << std::string(buf, 4); ret = wsrep::tls_service::success; } return ret; } wsrep::tls_service::op_result db_stream::read(void* buf, size_t max_count) { clear_error(); if (state_ == s_want_read) { state_ = s_idle; if (max_count == 0) return wsrep::tls_service::op_result{ wsrep::tls_service::success, 0}; } max_count = determine_read_count(max_count); ssize_t read_result(do_read(buf, max_count)); if (read_result > 0) { inc_reads(size_t(read_result)); } return map_result(read_result); } wsrep::tls_service::op_result db_stream::write( const void* buf, size_t count) { clear_error(); if (state_ == s_want_write) { state_ = s_idle; if (count == 0) return wsrep::tls_service::op_result{ wsrep::tls_service::success, 0}; } count = determine_write_count(count); ssize_t write_result(do_write(buf, count)); if (write_result > 0) { inc_writes(size_t(write_result)); } return map_result(write_result); } } static db_stream::stats global_stats; std::mutex global_stats_lock; static int global_mode; static void merge_to_global_stats(const db_stream::stats& stats) { std::lock_guard lock(global_stats_lock); global_stats.bytes_read += stats.bytes_read; global_stats.bytes_written += stats.bytes_written; } wsrep::tls_stream* db::tls::create_tls_stream(int fd) WSREP_NOEXCEPT { auto ret(new db_stream(fd, global_mode)); wsrep::log_debug() << "New DB stream: " << ret; return ret; } void db::tls::destroy(wsrep::tls_stream* stream) WSREP_NOEXCEPT { auto dbs(static_cast(stream)); merge_to_global_stats(dbs->get_stats()); wsrep::log_debug() << "Stream destroy: " << dbs->get_stats().bytes_read << " " << dbs->get_stats().bytes_written; wsrep::log_debug() << "Stream destroy" << dbs; delete dbs; } int db::tls::get_error_number(const wsrep::tls_stream* stream) const WSREP_NOEXCEPT { return static_cast(stream)->get_error_number(); } const void* db::tls::get_error_category(const wsrep::tls_stream* stream) const WSREP_NOEXCEPT { return static_cast(stream)->get_error_category(); } const char* db::tls::get_error_message(const wsrep::tls_stream*, int value, const void* category) const WSREP_NOEXCEPT { return db_stream::get_error_message(value, category); } enum wsrep::tls_service::status db::tls::client_handshake(wsrep::tls_stream* stream) WSREP_NOEXCEPT { return static_cast(stream)->client_handshake(); } enum wsrep::tls_service::status db::tls::server_handshake(wsrep::tls_stream* stream) WSREP_NOEXCEPT { return static_cast(stream)->server_handshake(); } wsrep::tls_service::op_result db::tls::read( wsrep::tls_stream* stream, void* buf, size_t max_count) WSREP_NOEXCEPT { return static_cast(stream)->read(buf, max_count); } wsrep::tls_service::op_result db::tls::write( wsrep::tls_stream* stream, const void* buf, size_t count) WSREP_NOEXCEPT { return static_cast(stream)->write(buf, count); } wsrep::tls_service::status db::tls::shutdown(wsrep::tls_stream*) WSREP_NOEXCEPT { // @todo error simulation return wsrep::tls_service::success; } void db::tls::init(int mode) { global_mode = mode; } std::string db::tls::stats() { std::ostringstream oss; oss << "Transport stats:\n" << " bytes_read: " << global_stats.bytes_read << "\n" << " bytes_written: " << global_stats.bytes_written << "\n"; return oss.str(); }