summaryrefslogtreecommitdiffstats
path: root/comm/third_party/botan/src/lib/tls/tls_handshake_io.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'comm/third_party/botan/src/lib/tls/tls_handshake_io.cpp')
-rw-r--r--comm/third_party/botan/src/lib/tls/tls_handshake_io.cpp480
1 files changed, 480 insertions, 0 deletions
diff --git a/comm/third_party/botan/src/lib/tls/tls_handshake_io.cpp b/comm/third_party/botan/src/lib/tls/tls_handshake_io.cpp
new file mode 100644
index 0000000000..7f9e2c86c5
--- /dev/null
+++ b/comm/third_party/botan/src/lib/tls/tls_handshake_io.cpp
@@ -0,0 +1,480 @@
+/*
+* TLS Handshake IO
+* (C) 2012,2014,2015 Jack Lloyd
+*
+* Botan is released under the Simplified BSD License (see license.txt)
+*/
+
+#include <botan/internal/tls_handshake_io.h>
+#include <botan/internal/tls_record.h>
+#include <botan/internal/tls_seq_numbers.h>
+#include <botan/tls_messages.h>
+#include <botan/exceptn.h>
+#include <botan/loadstor.h>
+#include <chrono>
+
+namespace Botan {
+
+namespace TLS {
+
+namespace {
+
+inline size_t load_be24(const uint8_t q[3])
+ {
+ return make_uint32(0,
+ q[0],
+ q[1],
+ q[2]);
+ }
+
+void store_be24(uint8_t out[3], size_t val)
+ {
+ out[0] = get_byte(1, static_cast<uint32_t>(val));
+ out[1] = get_byte(2, static_cast<uint32_t>(val));
+ out[2] = get_byte(3, static_cast<uint32_t>(val));
+ }
+
+uint64_t steady_clock_ms()
+ {
+ return std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::steady_clock::now().time_since_epoch()).count();
+ }
+
+}
+
+Protocol_Version Stream_Handshake_IO::initial_record_version() const
+ {
+ return Protocol_Version::TLS_V10;
+ }
+
+void Stream_Handshake_IO::add_record(const uint8_t record[],
+ size_t record_len,
+ Record_Type record_type, uint64_t)
+ {
+ if(record_type == HANDSHAKE)
+ {
+ m_queue.insert(m_queue.end(), record, record + record_len);
+ }
+ else if(record_type == CHANGE_CIPHER_SPEC)
+ {
+ if(record_len != 1 || record[0] != 1)
+ throw Decoding_Error("Invalid ChangeCipherSpec");
+
+ // Pretend it's a regular handshake message of zero length
+ const uint8_t ccs_hs[] = { HANDSHAKE_CCS, 0, 0, 0 };
+ m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs));
+ }
+ else
+ throw Decoding_Error("Unknown message type " + std::to_string(record_type) + " in handshake processing");
+ }
+
+std::pair<Handshake_Type, std::vector<uint8_t>>
+Stream_Handshake_IO::get_next_record(bool)
+ {
+ if(m_queue.size() >= 4)
+ {
+ const size_t length = 4 + make_uint32(0, m_queue[1], m_queue[2], m_queue[3]);
+
+ if(m_queue.size() >= length)
+ {
+ Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]);
+
+ if(type == HANDSHAKE_NONE)
+ throw Decoding_Error("Invalid handshake message type");
+
+ std::vector<uint8_t> contents(m_queue.begin() + 4,
+ m_queue.begin() + length);
+
+ m_queue.erase(m_queue.begin(), m_queue.begin() + length);
+
+ return std::make_pair(type, contents);
+ }
+ }
+
+ return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
+ }
+
+std::vector<uint8_t>
+Stream_Handshake_IO::format(const std::vector<uint8_t>& msg,
+ Handshake_Type type) const
+ {
+ std::vector<uint8_t> send_buf(4 + msg.size());
+
+ const size_t buf_size = msg.size();
+
+ send_buf[0] = static_cast<uint8_t>(type);
+
+ store_be24(&send_buf[1], buf_size);
+
+ if (msg.size() > 0)
+ {
+ copy_mem(&send_buf[4], msg.data(), msg.size());
+ }
+
+ return send_buf;
+ }
+
+std::vector<uint8_t> Stream_Handshake_IO::send_under_epoch(const Handshake_Message& /*msg*/, uint16_t /*epoch*/)
+ {
+ throw Invalid_State("Not possible to send under arbitrary epoch with stream based TLS");
+ }
+
+std::vector<uint8_t> Stream_Handshake_IO::send(const Handshake_Message& msg)
+ {
+ const std::vector<uint8_t> msg_bits = msg.serialize();
+
+ if(msg.type() == HANDSHAKE_CCS)
+ {
+ m_send_hs(CHANGE_CIPHER_SPEC, msg_bits);
+ return std::vector<uint8_t>(); // not included in handshake hashes
+ }
+
+ const std::vector<uint8_t> buf = format(msg_bits, msg.type());
+ m_send_hs(HANDSHAKE, buf);
+ return buf;
+ }
+
+Protocol_Version Datagram_Handshake_IO::initial_record_version() const
+ {
+ return Protocol_Version::DTLS_V10;
+ }
+
+void Datagram_Handshake_IO::retransmit_last_flight()
+ {
+ const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2);
+ retransmit_flight(flight_idx);
+ }
+
+void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx)
+ {
+ const std::vector<uint16_t>& flight = m_flights.at(flight_idx);
+
+ BOTAN_ASSERT(flight.size() > 0, "Nonempty flight to retransmit");
+
+ uint16_t epoch = m_flight_data[flight[0]].epoch;
+
+ for(auto msg_seq : flight)
+ {
+ auto& msg = m_flight_data[msg_seq];
+
+ if(msg.epoch != epoch)
+ {
+ // Epoch gap: insert the CCS
+ std::vector<uint8_t> ccs(1, 1);
+ m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs);
+ }
+
+ send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
+ epoch = msg.epoch;
+ }
+ }
+
+bool Datagram_Handshake_IO::timeout_check()
+ {
+ if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty()))
+ {
+ /*
+ If we haven't written anything yet obviously no timeout.
+ Also no timeout possible if we are mid-flight,
+ */
+ return false;
+ }
+
+ const uint64_t ms_since_write = steady_clock_ms() - m_last_write;
+
+ if(ms_since_write < m_next_timeout)
+ return false;
+
+ retransmit_last_flight();
+
+ m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout);
+ return true;
+ }
+
+void Datagram_Handshake_IO::add_record(const uint8_t record[],
+ size_t record_len,
+ Record_Type record_type,
+ uint64_t record_sequence)
+ {
+ const uint16_t epoch = static_cast<uint16_t>(record_sequence >> 48);
+
+ if(record_type == CHANGE_CIPHER_SPEC)
+ {
+ if(record_len != 1 || record[0] != 1)
+ throw Decoding_Error("Invalid ChangeCipherSpec");
+
+ // TODO: check this is otherwise empty
+ m_ccs_epochs.insert(epoch);
+ return;
+ }
+
+ const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
+
+ while(record_len)
+ {
+ if(record_len < DTLS_HANDSHAKE_HEADER_LEN)
+ return; // completely bogus? at least degenerate/weird
+
+ const uint8_t msg_type = record[0];
+ const size_t msg_len = load_be24(&record[1]);
+ const uint16_t message_seq = load_be<uint16_t>(&record[4], 0);
+ const size_t fragment_offset = load_be24(&record[6]);
+ const size_t fragment_length = load_be24(&record[9]);
+
+ const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
+
+ if(record_len < total_size)
+ throw Decoding_Error("Bad lengths in DTLS header");
+
+ if(message_seq >= m_in_message_seq)
+ {
+ m_messages[message_seq].add_fragment(&record[DTLS_HANDSHAKE_HEADER_LEN],
+ fragment_length,
+ fragment_offset,
+ epoch,
+ msg_type,
+ msg_len);
+ }
+ else
+ {
+ // TODO: detect retransmitted flight
+ }
+
+ record += total_size;
+ record_len -= total_size;
+ }
+ }
+
+std::pair<Handshake_Type, std::vector<uint8_t>>
+Datagram_Handshake_IO::get_next_record(bool expecting_ccs)
+ {
+ // Expecting a message means the last flight is concluded
+ if(!m_flights.rbegin()->empty())
+ m_flights.push_back(std::vector<uint16_t>());
+
+ if(expecting_ccs)
+ {
+ if(!m_messages.empty())
+ {
+ const uint16_t current_epoch = m_messages.begin()->second.epoch();
+
+ if(m_ccs_epochs.count(current_epoch))
+ return std::make_pair(HANDSHAKE_CCS, std::vector<uint8_t>());
+ }
+ return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
+ }
+
+ auto i = m_messages.find(m_in_message_seq);
+
+ if(i == m_messages.end() || !i->second.complete())
+ {
+ return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
+ }
+
+ m_in_message_seq += 1;
+
+ return i->second.message();
+ }
+
+void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(
+ const uint8_t fragment[],
+ size_t fragment_length,
+ size_t fragment_offset,
+ uint16_t epoch,
+ uint8_t msg_type,
+ size_t msg_length)
+ {
+ if(complete())
+ return; // already have entire message, ignore this
+
+ if(m_msg_type == HANDSHAKE_NONE)
+ {
+ m_epoch = epoch;
+ m_msg_type = msg_type;
+ m_msg_length = msg_length;
+ }
+
+ if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch)
+ throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header");
+
+ if(fragment_offset > m_msg_length)
+ throw Decoding_Error("Fragment offset past end of message");
+
+ if(fragment_offset + fragment_length > m_msg_length)
+ throw Decoding_Error("Fragment overlaps past end of message");
+
+ if(fragment_offset == 0 && fragment_length == m_msg_length)
+ {
+ m_fragments.clear();
+ m_message.assign(fragment, fragment+fragment_length);
+ }
+ else
+ {
+ /*
+ * FIXME. This is a pretty lame way to do defragmentation, huge
+ * overhead with a tree node per byte.
+ *
+ * Also should confirm that all overlaps have no changes,
+ * otherwise we expose ourselves to the classic fingerprinting
+ * and IDS evasion attacks on IP fragmentation.
+ */
+ for(size_t i = 0; i != fragment_length; ++i)
+ m_fragments[fragment_offset+i] = fragment[i];
+
+ if(m_fragments.size() == m_msg_length)
+ {
+ m_message.resize(m_msg_length);
+ for(size_t i = 0; i != m_msg_length; ++i)
+ m_message[i] = m_fragments[i];
+ m_fragments.clear();
+ }
+ }
+ }
+
+bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const
+ {
+ return (m_msg_type != HANDSHAKE_NONE && m_message.size() == m_msg_length);
+ }
+
+std::pair<Handshake_Type, std::vector<uint8_t>>
+Datagram_Handshake_IO::Handshake_Reassembly::message() const
+ {
+ if(!complete())
+ throw Internal_Error("Datagram_Handshake_IO - message not complete");
+
+ return std::make_pair(static_cast<Handshake_Type>(m_msg_type), m_message);
+ }
+
+std::vector<uint8_t>
+Datagram_Handshake_IO::format_fragment(const uint8_t fragment[],
+ size_t frag_len,
+ uint16_t frag_offset,
+ uint16_t msg_len,
+ Handshake_Type type,
+ uint16_t msg_sequence) const
+ {
+ std::vector<uint8_t> send_buf(12 + frag_len);
+
+ send_buf[0] = static_cast<uint8_t>(type);
+
+ store_be24(&send_buf[1], msg_len);
+
+ store_be(msg_sequence, &send_buf[4]);
+
+ store_be24(&send_buf[6], frag_offset);
+ store_be24(&send_buf[9], frag_len);
+
+ if (frag_len > 0)
+ {
+ copy_mem(&send_buf[12], fragment, frag_len);
+ }
+
+ return send_buf;
+ }
+
+std::vector<uint8_t>
+Datagram_Handshake_IO::format_w_seq(const std::vector<uint8_t>& msg,
+ Handshake_Type type,
+ uint16_t msg_sequence) const
+ {
+ return format_fragment(msg.data(), msg.size(), 0, static_cast<uint16_t>(msg.size()), type, msg_sequence);
+ }
+
+std::vector<uint8_t>
+Datagram_Handshake_IO::format(const std::vector<uint8_t>& msg,
+ Handshake_Type type) const
+ {
+ return format_w_seq(msg, type, m_in_message_seq - 1);
+ }
+
+std::vector<uint8_t> Datagram_Handshake_IO::send(const Handshake_Message& msg)
+ {
+ return this->send_under_epoch(msg, m_seqs.current_write_epoch());
+ }
+
+std::vector<uint8_t>
+Datagram_Handshake_IO::send_under_epoch(const Handshake_Message& msg, uint16_t epoch)
+ {
+ const std::vector<uint8_t> msg_bits = msg.serialize();
+ const Handshake_Type msg_type = msg.type();
+
+ if(msg_type == HANDSHAKE_CCS)
+ {
+ m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits);
+ return std::vector<uint8_t>(); // not included in handshake hashes
+ }
+ else if(msg_type == HELLO_VERIFY_REQUEST)
+ {
+ // This message is not included in the handshake hashes
+ send_message(m_out_message_seq, epoch, msg_type, msg_bits);
+ m_out_message_seq += 1;
+ return std::vector<uint8_t>();
+ }
+
+ // Note: not saving CCS, instead we know it was there due to change in epoch
+ m_flights.rbegin()->push_back(m_out_message_seq);
+ m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits);
+
+ m_out_message_seq += 1;
+ m_last_write = steady_clock_ms();
+ m_next_timeout = m_initial_timeout;
+
+ return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
+ }
+
+std::vector<uint8_t> Datagram_Handshake_IO::send_message(uint16_t msg_seq,
+ uint16_t epoch,
+ Handshake_Type msg_type,
+ const std::vector<uint8_t>& msg_bits)
+ {
+ const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
+
+ const std::vector<uint8_t> no_fragment =
+ format_w_seq(msg_bits, msg_type, msg_seq);
+
+ if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu)
+ {
+ m_send_hs(epoch, HANDSHAKE, no_fragment);
+ }
+ else
+ {
+ size_t frag_offset = 0;
+
+ /**
+ * Largest possible overhead is for SHA-384 CBC ciphers, with 16 byte IV,
+ * 16+ for padding and 48 bytes for MAC. 128 is probably a strict
+ * over-estimate here. When CBC ciphers are removed this can be reduced
+ * since AEAD modes have no padding, at most 16 byte mac, and smaller
+ * per-record nonce.
+ */
+ const size_t ciphersuite_overhead = (epoch > 0) ? 128 : 0;
+ const size_t header_overhead = DTLS_HEADER_SIZE + DTLS_HANDSHAKE_HEADER_LEN;
+
+ if(m_mtu <= (header_overhead + ciphersuite_overhead))
+ throw Invalid_Argument("DTLS MTU is too small to send headers");
+
+ const size_t max_rec_size = m_mtu - (header_overhead + ciphersuite_overhead);
+
+ while(frag_offset != msg_bits.size())
+ {
+ const size_t frag_len = std::min<size_t>(msg_bits.size() - frag_offset, max_rec_size);
+
+ const std::vector<uint8_t> frag =
+ format_fragment(&msg_bits[frag_offset],
+ frag_len,
+ static_cast<uint16_t>(frag_offset),
+ static_cast<uint16_t>(msg_bits.size()),
+ msg_type,
+ msg_seq);
+
+ m_send_hs(epoch, HANDSHAKE, frag);
+
+ frag_offset += frag_len;
+ }
+ }
+
+ return no_fragment;
+ }
+
+}
+}