diff options
Diffstat (limited to 'security/nss/lib/ssl/dtlscon.c')
-rw-r--r-- | security/nss/lib/ssl/dtlscon.c | 1474 |
1 files changed, 1474 insertions, 0 deletions
diff --git a/security/nss/lib/ssl/dtlscon.c b/security/nss/lib/ssl/dtlscon.c new file mode 100644 index 0000000000..a4a7c998c4 --- /dev/null +++ b/security/nss/lib/ssl/dtlscon.c @@ -0,0 +1,1474 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 4 -*- */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +/* + * DTLS Protocol + */ + +#include "ssl.h" +#include "sslimpl.h" +#include "sslproto.h" +#include "dtls13con.h" + +#ifndef PR_ARRAY_SIZE +#define PR_ARRAY_SIZE(a) (sizeof(a) / sizeof((a)[0])) +#endif + +static SECStatus dtls_StartRetransmitTimer(sslSocket *ss); +static void dtls_RetransmitTimerExpiredCb(sslSocket *ss); +static SECStatus dtls_SendSavedWriteData(sslSocket *ss); +static void dtls_FinishedTimerCb(sslSocket *ss); +static void dtls_CancelAllTimers(sslSocket *ss); + +/* -28 adjusts for the IP/UDP header */ +static const PRUint16 COMMON_MTU_VALUES[] = { + 1500 - 28, /* Ethernet MTU */ + 1280 - 28, /* IPv6 minimum MTU */ + 576 - 28, /* Common assumption */ + 256 - 28 /* We're in serious trouble now */ +}; + +#define DTLS_COOKIE_BYTES 32 +/* Maximum DTLS expansion = header + IV + max CBC padding + + * maximum MAC. */ +#define DTLS_MAX_EXPANSION (DTLS_RECORD_HEADER_LENGTH + 16 + 16 + 32) + +/* List copied from ssl3con.c:cipherSuites */ +static const ssl3CipherSuite nonDTLSSuites[] = { + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + TLS_ECDHE_RSA_WITH_RC4_128_SHA, + TLS_DHE_DSS_WITH_RC4_128_SHA, + TLS_ECDH_RSA_WITH_RC4_128_SHA, + TLS_ECDH_ECDSA_WITH_RC4_128_SHA, + TLS_RSA_WITH_RC4_128_MD5, + TLS_RSA_WITH_RC4_128_SHA, + 0 /* End of list marker */ +}; + +/* Map back and forth between TLS and DTLS versions in wire format. + * Mapping table is: + * + * TLS DTLS + * 1.1 (0302) 1.0 (feff) + * 1.2 (0303) 1.2 (fefd) + * 1.3 (0304) 1.3 (0304) + */ +SSL3ProtocolVersion +dtls_TLSVersionToDTLSVersion(SSL3ProtocolVersion tlsv) +{ + if (tlsv == SSL_LIBRARY_VERSION_TLS_1_1) { + return SSL_LIBRARY_VERSION_DTLS_1_0_WIRE; + } + if (tlsv == SSL_LIBRARY_VERSION_TLS_1_2) { + return SSL_LIBRARY_VERSION_DTLS_1_2_WIRE; + } + if (tlsv == SSL_LIBRARY_VERSION_TLS_1_3) { + return SSL_LIBRARY_VERSION_DTLS_1_3_WIRE; + } + + /* Anything else is an error, so return + * the invalid version 0xffff. */ + return 0xffff; +} + +/* Map known DTLS versions to known TLS versions. + * - Invalid versions (< 1.0) return a version of 0 + * - Versions > known return a version one higher than we know of + * to accomodate a theoretically newer version */ +SSL3ProtocolVersion +dtls_DTLSVersionToTLSVersion(SSL3ProtocolVersion dtlsv) +{ + if (MSB(dtlsv) == 0xff) { + return 0; + } + + if (dtlsv == SSL_LIBRARY_VERSION_DTLS_1_0_WIRE) { + return SSL_LIBRARY_VERSION_TLS_1_1; + } + /* Handle the skipped version of DTLS 1.1 by returning + * an error. */ + if (dtlsv == ((~0x0101) & 0xffff)) { + return 0; + } + if (dtlsv == SSL_LIBRARY_VERSION_DTLS_1_2_WIRE) { + return SSL_LIBRARY_VERSION_TLS_1_2; + } + if (dtlsv == SSL_LIBRARY_VERSION_DTLS_1_3_WIRE) { + return SSL_LIBRARY_VERSION_TLS_1_3; + } + + /* Return a fictional higher version than we know of */ + return SSL_LIBRARY_VERSION_MAX_SUPPORTED + 1; +} + +/* On this socket, Disable non-DTLS cipher suites in the argument's list */ +SECStatus +ssl3_DisableNonDTLSSuites(sslSocket *ss) +{ + const ssl3CipherSuite *suite; + + for (suite = nonDTLSSuites; *suite; ++suite) { + PORT_CheckSuccess(ssl3_CipherPrefSet(ss, *suite, PR_FALSE)); + } + return SECSuccess; +} + +/* Allocate a DTLSQueuedMessage. + * + * Called from dtls_QueueMessage() + */ +static DTLSQueuedMessage * +dtls_AllocQueuedMessage(ssl3CipherSpec *cwSpec, SSLContentType ct, + const unsigned char *data, PRUint32 len) +{ + DTLSQueuedMessage *msg; + + msg = PORT_ZNew(DTLSQueuedMessage); + if (!msg) + return NULL; + + msg->data = PORT_Alloc(len); + if (!msg->data) { + PORT_Free(msg); + return NULL; + } + PORT_Memcpy(msg->data, data, len); + + msg->len = len; + msg->cwSpec = cwSpec; + msg->type = ct; + /* Safe if we are < 1.3, since the refct is + * already very high. */ + ssl_CipherSpecAddRef(cwSpec); + + return msg; +} + +/* + * Free a handshake message + * + * Called from dtls_FreeHandshakeMessages() + */ +void +dtls_FreeHandshakeMessage(DTLSQueuedMessage *msg) +{ + if (!msg) + return; + + /* Safe if we are < 1.3, since the refct is + * already very high. */ + ssl_CipherSpecRelease(msg->cwSpec); + PORT_ZFree(msg->data, msg->len); + PORT_Free(msg); +} + +/* + * Free a list of handshake messages + * + * Called from: + * dtls_HandleHandshake() + * ssl3_DestroySSL3Info() + */ +void +dtls_FreeHandshakeMessages(PRCList *list) +{ + PRCList *cur_p; + + while (!PR_CLIST_IS_EMPTY(list)) { + cur_p = PR_LIST_TAIL(list); + PR_REMOVE_LINK(cur_p); + dtls_FreeHandshakeMessage((DTLSQueuedMessage *)cur_p); + } +} + +/* Called by dtls_HandleHandshake() and dtls_MaybeRetransmitHandshake() if a + * handshake message retransmission is detected. */ +static SECStatus +dtls_RetransmitDetected(sslSocket *ss) +{ + dtlsTimer *timer = ss->ssl3.hs.rtTimer; + SECStatus rv = SECSuccess; + + PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss)); + PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss)); + + if (timer->cb == dtls_RetransmitTimerExpiredCb) { + /* Check to see if we retransmitted recently. If so, + * suppress the triggered retransmit. This avoids + * retransmit wars after packet loss. + * This is not in RFC 5346 but it should be. + */ + if ((PR_IntervalNow() - timer->started) > + (timer->timeout / 4)) { + SSL_TRC(30, + ("%d: SSL3[%d]: Shortcutting retransmit timer", + SSL_GETPID(), ss->fd)); + + /* Cancel the timer and call the CB, + * which re-arms the timer */ + dtls_CancelTimer(ss, ss->ssl3.hs.rtTimer); + dtls_RetransmitTimerExpiredCb(ss); + } else { + SSL_TRC(30, + ("%d: SSL3[%d]: Ignoring retransmission: " + "last retransmission %dms ago, suppressed for %dms", + SSL_GETPID(), ss->fd, + PR_IntervalNow() - timer->started, + timer->timeout / 4)); + } + + } else if (timer->cb == dtls_FinishedTimerCb) { + SSL_TRC(30, ("%d: SSL3[%d]: Retransmit detected in holddown", + SSL_GETPID(), ss->fd)); + /* Retransmit the messages and re-arm the timer + * Note that we are not backing off the timer here. + * The spec isn't clear and my reasoning is that this + * may be a re-ordered packet rather than slowness, + * so let's be aggressive. */ + dtls_CancelTimer(ss, ss->ssl3.hs.rtTimer); + rv = dtls_TransmitMessageFlight(ss); + if (rv == SECSuccess) { + rv = dtls_StartHolddownTimer(ss); + } + + } else { + PORT_Assert(timer->cb == NULL); + /* ... and ignore it. */ + } + return rv; +} + +static SECStatus +dtls_HandleHandshakeMessage(sslSocket *ss, PRUint8 *data, PRBool last) +{ + ss->ssl3.hs.recvdHighWater = -1; + + return ssl3_HandleHandshakeMessage(ss, data, ss->ssl3.hs.msg_len, + last); +} + +/* Called only from ssl3_HandleRecord, for each (deciphered) DTLS record. + * origBuf is the decrypted ssl record content and is expected to contain + * complete handshake records + * Caller must hold the handshake and RecvBuf locks. + * + * Note that this code uses msg_len for two purposes: + * + * (1) To pass the length to ssl3_HandleHandshakeMessage() + * (2) To carry the length of a message currently being reassembled + * + * However, unlike ssl3_HandleHandshake(), it is not used to carry + * the state of reassembly (i.e., whether one is in progress). That + * is carried in recvdHighWater and recvdFragments. + */ +#define OFFSET_BYTE(o) (o / 8) +#define OFFSET_MASK(o) (1 << (o % 8)) + +SECStatus +dtls_HandleHandshake(sslSocket *ss, DTLSEpoch epoch, sslSequenceNumber seqNum, + sslBuffer *origBuf) +{ + sslBuffer buf = *origBuf; + SECStatus rv = SECSuccess; + PRBool discarded = PR_FALSE; + + ss->ssl3.hs.endOfFlight = PR_FALSE; + + PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss)); + PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss)); + + while (buf.len > 0) { + PRUint8 type; + PRUint32 message_length; + PRUint16 message_seq; + PRUint32 fragment_offset; + PRUint32 fragment_length; + PRUint32 offset; + + if (buf.len < 12) { + PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE); + rv = SECFailure; + goto loser; + } + + /* Parse the header */ + type = buf.buf[0]; + message_length = (buf.buf[1] << 16) | (buf.buf[2] << 8) | buf.buf[3]; + message_seq = (buf.buf[4] << 8) | buf.buf[5]; + fragment_offset = (buf.buf[6] << 16) | (buf.buf[7] << 8) | buf.buf[8]; + fragment_length = (buf.buf[9] << 16) | (buf.buf[10] << 8) | buf.buf[11]; + +#define MAX_HANDSHAKE_MSG_LEN 0x1ffff /* 128k - 1 */ + if (message_length > MAX_HANDSHAKE_MSG_LEN) { + (void)ssl3_DecodeError(ss); + PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE); + rv = SECFailure; + goto loser; + } +#undef MAX_HANDSHAKE_MSG_LEN + + buf.buf += 12; + buf.len -= 12; + + /* This fragment must be complete */ + if (buf.len < fragment_length) { + PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE); + rv = SECFailure; + goto loser; + } + + /* Sanity check the packet contents */ + if ((fragment_length + fragment_offset) > message_length) { + PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE); + rv = SECFailure; + goto loser; + } + + /* If we're a server and we receive what appears to be a retried + * ClientHello, and we are expecting a ClientHello, move the receive + * sequence number forward. This allows for a retried ClientHello if we + * send a stateless HelloRetryRequest. */ + if (message_seq > ss->ssl3.hs.recvMessageSeq && + message_seq == 1 && + fragment_offset == 0 && + ss->ssl3.hs.ws == wait_client_hello && + (SSLHandshakeType)type == ssl_hs_client_hello) { + SSL_TRC(5, ("%d: DTLS[%d]: Received apparent 2nd ClientHello", + SSL_GETPID(), ss->fd)); + ss->ssl3.hs.recvMessageSeq = 1; + ss->ssl3.hs.helloRetry = PR_TRUE; + } + + /* There are three ways we could not be ready for this packet. + * + * 1. It's a partial next message. + * 2. It's a partial or complete message beyond the next + * 3. It's a message we've already seen + * + * If it's the complete next message we accept it right away. + * This is the common case for short messages + */ + if ((message_seq == ss->ssl3.hs.recvMessageSeq) && + (fragment_offset == 0) && + (fragment_length == message_length)) { + /* Complete next message. Process immediately */ + ss->ssl3.hs.msg_type = (SSLHandshakeType)type; + ss->ssl3.hs.msg_len = message_length; + + rv = dtls_HandleHandshakeMessage(ss, buf.buf, + buf.len == fragment_length); + if (rv != SECSuccess) { + goto loser; + } + } else { + if (message_seq < ss->ssl3.hs.recvMessageSeq) { + /* Case 3: we do an immediate retransmit if we're + * in a waiting state. */ + rv = dtls_RetransmitDetected(ss); + goto loser; + } else if (message_seq > ss->ssl3.hs.recvMessageSeq) { + /* Case 2 + * + * Ignore this message. This means we don't handle out of + * order complete messages that well, but we're still + * compliant and this probably does not happen often + * + * XXX OK for now. Maybe do something smarter at some point? + */ + SSL_TRC(10, ("%d: SSL3[%d]: dtls_HandleHandshake, discarding handshake message", + SSL_GETPID(), ss->fd)); + discarded = PR_TRUE; + } else { + PRInt32 end = fragment_offset + fragment_length; + + /* Case 1 + * + * Buffer the fragment for reassembly + */ + /* Make room for the message */ + if (ss->ssl3.hs.recvdHighWater == -1) { + PRUint32 map_length = OFFSET_BYTE(message_length) + 1; + + rv = sslBuffer_Grow(&ss->ssl3.hs.msg_body, message_length); + if (rv != SECSuccess) + goto loser; + /* Make room for the fragment map */ + rv = sslBuffer_Grow(&ss->ssl3.hs.recvdFragments, + map_length); + if (rv != SECSuccess) + goto loser; + + /* Reset the reassembly map */ + ss->ssl3.hs.recvdHighWater = 0; + PORT_Memset(ss->ssl3.hs.recvdFragments.buf, 0, + ss->ssl3.hs.recvdFragments.space); + ss->ssl3.hs.msg_type = (SSLHandshakeType)type; + ss->ssl3.hs.msg_len = message_length; + } + + /* If we have a message length mismatch, abandon the reassembly + * in progress and hope that the next retransmit will give us + * something sane + */ + if (message_length != ss->ssl3.hs.msg_len) { + ss->ssl3.hs.recvdHighWater = -1; + PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE); + rv = SECFailure; + goto loser; + } + + /* Now copy this fragment into the buffer. */ + if (end > ss->ssl3.hs.recvdHighWater) { + PORT_Memcpy(ss->ssl3.hs.msg_body.buf + fragment_offset, + buf.buf, fragment_length); + } + + /* This logic is a bit tricky. We have two values for + * reassembly state: + * + * - recvdHighWater contains the highest contiguous number of + * bytes received + * - recvdFragments contains a bitmask of packets received + * above recvdHighWater + * + * This avoids having to fill in the bitmask in the common + * case of adjacent fragments received in sequence + */ + if (fragment_offset <= (unsigned int)ss->ssl3.hs.recvdHighWater) { + /* Either this is the adjacent fragment or an overlapping + * fragment */ + if (end > ss->ssl3.hs.recvdHighWater) { + ss->ssl3.hs.recvdHighWater = end; + } + } else { + for (offset = fragment_offset; offset < end; offset++) { + ss->ssl3.hs.recvdFragments.buf[OFFSET_BYTE(offset)] |= + OFFSET_MASK(offset); + } + } + + /* Now figure out the new high water mark if appropriate */ + for (offset = ss->ssl3.hs.recvdHighWater; + offset < ss->ssl3.hs.msg_len; offset++) { + /* Note that this loop is not efficient, since it counts + * bit by bit. If we have a lot of out-of-order packets, + * we should optimize this */ + if (ss->ssl3.hs.recvdFragments.buf[OFFSET_BYTE(offset)] & + OFFSET_MASK(offset)) { + ss->ssl3.hs.recvdHighWater++; + } else { + break; + } + } + + /* If we have all the bytes, then we are good to go */ + if (ss->ssl3.hs.recvdHighWater == ss->ssl3.hs.msg_len) { + rv = dtls_HandleHandshakeMessage(ss, ss->ssl3.hs.msg_body.buf, + buf.len == fragment_length); + + if (rv != SECSuccess) { + goto loser; + } + } + } + } + + buf.buf += fragment_length; + buf.len -= fragment_length; + } + + // This should never happen, but belt and suspenders. + if (rv != SECSuccess) { + PORT_Assert(0); + goto loser; + } + + /* If we processed all the fragments in this message, then mark it as remembered. + * TODO(ekr@rtfm.com): Store out of order messages for DTLS 1.3 so ACKs work + * better. Bug 1392620.*/ + if (!discarded && tls13_MaybeTls13(ss)) { + rv = dtls13_RememberFragment(ss, &ss->ssl3.hs.dtlsRcvdHandshake, + 0, 0, 0, epoch, seqNum); + } + if (rv != SECSuccess) { + goto loser; + } + + rv = dtls13_SetupAcks(ss); + +loser: + origBuf->len = 0; /* So ssl3_GatherAppDataRecord will keep looping. */ + return rv; +} + +/* Enqueue a message (either handshake or CCS) + * + * Called from: + * dtls_StageHandshakeMessage() + * ssl3_SendChangeCipherSpecs() + */ +SECStatus +dtls_QueueMessage(sslSocket *ss, SSLContentType ct, + const PRUint8 *pIn, PRInt32 nIn) +{ + SECStatus rv = SECSuccess; + DTLSQueuedMessage *msg = NULL; + ssl3CipherSpec *spec; + + PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss)); + PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss)); + + spec = ss->ssl3.cwSpec; + msg = dtls_AllocQueuedMessage(spec, ct, pIn, nIn); + + if (!msg) { + PORT_SetError(SEC_ERROR_NO_MEMORY); + rv = SECFailure; + } else { + PR_APPEND_LINK(&msg->link, &ss->ssl3.hs.lastMessageFlight); + } + + return rv; +} + +/* Add DTLS handshake message to the pending queue + * Empty the sendBuf buffer. + * Always set sendBuf.len to 0, even when returning SECFailure. + * + * Called from: + * ssl3_AppendHandshakeHeader() + * dtls_FlushHandshake() + */ +SECStatus +dtls_StageHandshakeMessage(sslSocket *ss) +{ + SECStatus rv = SECSuccess; + + PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss)); + PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss)); + + /* This function is sometimes called when no data is actually to + * be staged, so just return SECSuccess. */ + if (!ss->sec.ci.sendBuf.buf || !ss->sec.ci.sendBuf.len) + return rv; + + rv = dtls_QueueMessage(ss, ssl_ct_handshake, + ss->sec.ci.sendBuf.buf, ss->sec.ci.sendBuf.len); + + /* Whether we succeeded or failed, toss the old handshake data. */ + ss->sec.ci.sendBuf.len = 0; + return rv; +} + +/* Enqueue the handshake message in sendBuf (if any) and then + * transmit the resulting flight of handshake messages. + * + * Called from: + * ssl3_FlushHandshake() + */ +SECStatus +dtls_FlushHandshakeMessages(sslSocket *ss, PRInt32 flags) +{ + SECStatus rv = SECSuccess; + + PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss)); + PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss)); + + rv = dtls_StageHandshakeMessage(ss); + if (rv != SECSuccess) + return rv; + + if (!(flags & ssl_SEND_FLAG_FORCE_INTO_BUFFER)) { + rv = dtls_TransmitMessageFlight(ss); + if (rv != SECSuccess) { + return rv; + } + + if (!(flags & ssl_SEND_FLAG_NO_RETRANSMIT)) { + rv = dtls_StartRetransmitTimer(ss); + } else { + PORT_Assert(ss->version < SSL_LIBRARY_VERSION_TLS_1_3); + } + } + + return rv; +} + +/* The callback for when the retransmit timer expires + * + * Called from: + * dtls_CheckTimer() + * dtls_HandleHandshake() + */ +static void +dtls_RetransmitTimerExpiredCb(sslSocket *ss) +{ + SECStatus rv; + dtlsTimer *timer = ss->ssl3.hs.rtTimer; + ss->ssl3.hs.rtRetries++; + + if (!(ss->ssl3.hs.rtRetries % 3)) { + /* If one of the messages was potentially greater than > MTU, + * then downgrade. Do this every time we have retransmitted a + * message twice, per RFC 6347 Sec. 4.1.1 */ + dtls_SetMTU(ss, ss->ssl3.hs.maxMessageSent - 1); + } + + rv = dtls_TransmitMessageFlight(ss); + if (rv == SECSuccess) { + /* Re-arm the timer */ + timer->timeout *= 2; + if (timer->timeout > DTLS_RETRANSMIT_MAX_MS) { + timer->timeout = DTLS_RETRANSMIT_MAX_MS; + } + + timer->started = PR_IntervalNow(); + timer->cb = dtls_RetransmitTimerExpiredCb; + + SSL_TRC(30, + ("%d: SSL3[%d]: Retransmit #%d, next in %d", + SSL_GETPID(), ss->fd, + ss->ssl3.hs.rtRetries, timer->timeout)); + } + /* else: OK for now. In future maybe signal the stack that we couldn't + * transmit. For now, let the read handle any real network errors */ +} + +#define DTLS_HS_HDR_LEN 12 +#define DTLS_MIN_FRAGMENT (DTLS_HS_HDR_LEN + 1 + DTLS_MAX_EXPANSION) + +/* Encrypt and encode a handshake message fragment. Flush the data out to the + * network if there is insufficient space for any fragment. */ +static SECStatus +dtls_SendFragment(sslSocket *ss, DTLSQueuedMessage *msg, PRUint8 *data, + unsigned int len) +{ + PRInt32 sent; + SECStatus rv; + + PRINT_BUF(40, (ss, "dtls_SendFragment", data, len)); + sent = ssl3_SendRecord(ss, msg->cwSpec, msg->type, data, len, + ssl_SEND_FLAG_FORCE_INTO_BUFFER); + if (sent != len) { + if (sent != -1) { + PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); + } + return SECFailure; + } + + /* If another fragment won't fit, flush. */ + if (ss->ssl3.mtu < ss->pendingBuf.len + DTLS_MIN_FRAGMENT) { + SSL_TRC(20, ("%d: DTLS[%d]: dtls_SendFragment: flush", + SSL_GETPID(), ss->fd)); + rv = dtls_SendSavedWriteData(ss); + if (rv != SECSuccess) { + return SECFailure; + } + } + return SECSuccess; +} + +/* Fragment a handshake message into multiple records and send them. */ +static SECStatus +dtls_FragmentHandshake(sslSocket *ss, DTLSQueuedMessage *msg) +{ + PRBool fragmentWritten = PR_FALSE; + PRUint16 msgSeq; + PRUint8 *fragment; + PRUint32 fragmentOffset = 0; + PRUint32 fragmentLen; + const PRUint8 *content = msg->data + DTLS_HS_HDR_LEN; + PRUint32 contentLen = msg->len - DTLS_HS_HDR_LEN; + SECStatus rv; + + /* The headers consume 12 bytes so the smallest possible message (i.e., an + * empty one) is 12 bytes. */ + PORT_Assert(msg->len >= DTLS_HS_HDR_LEN); + + /* DTLS only supports fragmenting handshaking messages. */ + PORT_Assert(msg->type == ssl_ct_handshake); + + msgSeq = (msg->data[4] << 8) | msg->data[5]; + + /* do {} while() so that empty messages are sent at least once. */ + do { + PRUint8 buf[DTLS_MAX_MTU]; /* >= than largest plausible MTU */ + PRBool hasUnackedRange; + PRUint32 end; + + hasUnackedRange = dtls_NextUnackedRange(ss, msgSeq, + fragmentOffset, contentLen, + &fragmentOffset, &end); + if (!hasUnackedRange) { + SSL_TRC(20, ("%d: SSL3[%d]: FragmentHandshake %d: all acknowledged", + SSL_GETPID(), ss->fd, msgSeq)); + break; + } + + SSL_TRC(20, ("%d: SSL3[%d]: FragmentHandshake %d: unacked=%u-%u", + SSL_GETPID(), ss->fd, msgSeq, fragmentOffset, end)); + + /* Cut down to the data we have available. */ + PORT_Assert(fragmentOffset <= contentLen); + PORT_Assert(fragmentOffset <= end); + PORT_Assert(end <= contentLen); + fragmentLen = PR_MIN(end, contentLen) - fragmentOffset; + + /* Limit further by the record size limit. Account for the header. */ + fragmentLen = PR_MIN(fragmentLen, + msg->cwSpec->recordSizeLimit - DTLS_HS_HDR_LEN); + + /* Reduce to the space remaining in the MTU. */ + fragmentLen = PR_MIN(fragmentLen, + ss->ssl3.mtu - /* MTU estimate. */ + ss->pendingBuf.len - /* Less any unsent records. */ + DTLS_MAX_EXPANSION - /* Allow for expansion. */ + DTLS_HS_HDR_LEN); /* And the handshake header. */ + PORT_Assert(fragmentLen > 0 || fragmentOffset == 0); + + /* Make totally sure that we will fit in the buffer. This should be + * impossible; DTLS_MAX_MTU should always be more than ss->ssl3.mtu. */ + if (fragmentLen >= (DTLS_MAX_MTU - DTLS_HS_HDR_LEN)) { + PORT_Assert(0); + PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); + return SECFailure; + } + + if (fragmentLen == contentLen) { + fragment = msg->data; + } else { + sslBuffer tmp = SSL_BUFFER_FIXED(buf, sizeof(buf)); + + /* Construct an appropriate-sized fragment */ + /* Type, length, sequence */ + rv = sslBuffer_Append(&tmp, msg->data, 6); + if (rv != SECSuccess) { + return SECFailure; + } + /* Offset. */ + rv = sslBuffer_AppendNumber(&tmp, fragmentOffset, 3); + if (rv != SECSuccess) { + return SECFailure; + } + /* Length. */ + rv = sslBuffer_AppendNumber(&tmp, fragmentLen, 3); + if (rv != SECSuccess) { + return SECFailure; + } + /* Data. */ + rv = sslBuffer_Append(&tmp, content + fragmentOffset, fragmentLen); + if (rv != SECSuccess) { + return SECFailure; + } + + fragment = SSL_BUFFER_BASE(&tmp); + } + + /* Record that we are sending first, because encrypting + * increments the sequence number. */ + rv = dtls13_RememberFragment(ss, &ss->ssl3.hs.dtlsSentHandshake, + msgSeq, fragmentOffset, fragmentLen, + msg->cwSpec->epoch, + msg->cwSpec->nextSeqNum); + if (rv != SECSuccess) { + return SECFailure; + } + + rv = dtls_SendFragment(ss, msg, fragment, + fragmentLen + DTLS_HS_HDR_LEN); + if (rv != SECSuccess) { + return SECFailure; + } + + fragmentWritten = PR_TRUE; + fragmentOffset += fragmentLen; + } while (fragmentOffset < contentLen); + + if (!fragmentWritten) { + /* Nothing was written if we got here, so the whole message must have + * been acknowledged. Discard it. */ + SSL_TRC(10, ("%d: SSL3[%d]: FragmentHandshake %d: removed", + SSL_GETPID(), ss->fd, msgSeq)); + PR_REMOVE_LINK(&msg->link); + dtls_FreeHandshakeMessage(msg); + } + + return SECSuccess; +} + +/* Transmit a flight of handshake messages, stuffing them + * into as few records as seems reasonable. + * + * TODO: Space separate UDP packets out a little. + * + * Called from: + * dtls_FlushHandshake() + * dtls_RetransmitTimerExpiredCb() + */ +SECStatus +dtls_TransmitMessageFlight(sslSocket *ss) +{ + SECStatus rv = SECSuccess; + PRCList *msg_p; + + SSL_TRC(10, ("%d: SSL3[%d]: dtls_TransmitMessageFlight", + SSL_GETPID(), ss->fd)); + + ssl_GetXmitBufLock(ss); + ssl_GetSpecReadLock(ss); + + /* DTLS does not buffer its handshake messages in ss->pendingBuf, but rather + * in the lastMessageFlight structure. This is just a sanity check that some + * programming error hasn't inadvertantly stuffed something in + * ss->pendingBuf. This function uses ss->pendingBuf temporarily and it + * needs to be empty to start. + */ + PORT_Assert(!ss->pendingBuf.len); + + for (msg_p = PR_LIST_HEAD(&ss->ssl3.hs.lastMessageFlight); + msg_p != &ss->ssl3.hs.lastMessageFlight;) { + DTLSQueuedMessage *msg = (DTLSQueuedMessage *)msg_p; + + /* Move the pointer forward so that the functions below are free to + * remove messages from the list. */ + msg_p = PR_NEXT_LINK(msg_p); + + /* Note: This function fragments messages so that each record is close + * to full. This produces fewer records, but it means that messages can + * be quite fragmented. Adding an extra flush here would push new + * messages into new records and reduce fragmentation. */ + + if (msg->type == ssl_ct_handshake) { + rv = dtls_FragmentHandshake(ss, msg); + } else { + PORT_Assert(!tls13_MaybeTls13(ss)); + rv = dtls_SendFragment(ss, msg, msg->data, msg->len); + } + if (rv != SECSuccess) { + break; + } + } + + /* Finally, flush any data that wasn't flushed already. */ + if (rv == SECSuccess) { + rv = dtls_SendSavedWriteData(ss); + } + + /* Give up the locks */ + ssl_ReleaseSpecReadLock(ss); + ssl_ReleaseXmitBufLock(ss); + + return rv; +} + +/* Flush the data in the pendingBuf and update the max message sent + * so we can adjust the MTU estimate if we need to. + * Wrapper for ssl_SendSavedWriteData. + * + * Called from dtls_TransmitMessageFlight() + */ +static SECStatus +dtls_SendSavedWriteData(sslSocket *ss) +{ + PRInt32 sent; + + sent = ssl_SendSavedWriteData(ss); + if (sent < 0) + return SECFailure; + + /* We should always have complete writes b/c datagram sockets + * don't really block */ + if (ss->pendingBuf.len > 0) { + ssl_MapLowLevelError(SSL_ERROR_SOCKET_WRITE_FAILURE); + return SECFailure; + } + + /* Update the largest message sent so we can adjust the MTU + * estimate if necessary */ + if (sent > ss->ssl3.hs.maxMessageSent) + ss->ssl3.hs.maxMessageSent = sent; + + return SECSuccess; +} + +void +dtls_InitTimers(sslSocket *ss) +{ + unsigned int i; + dtlsTimer **timers[PR_ARRAY_SIZE(ss->ssl3.hs.timers)] = { + &ss->ssl3.hs.rtTimer, + &ss->ssl3.hs.ackTimer, + &ss->ssl3.hs.hdTimer + }; + static const char *timerLabels[] = { + "retransmit", "ack", "holddown" + }; + + PORT_Assert(PR_ARRAY_SIZE(timers) == PR_ARRAY_SIZE(timerLabels)); + for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) { + *timers[i] = &ss->ssl3.hs.timers[i]; + ss->ssl3.hs.timers[i].label = timerLabels[i]; + } +} + +SECStatus +dtls_StartTimer(sslSocket *ss, dtlsTimer *timer, PRUint32 time, DTLSTimerCb cb) +{ + PORT_Assert(timer->cb == NULL); + + SSL_TRC(10, ("%d: SSL3[%d]: %s dtls_StartTimer %s timeout=%d", + SSL_GETPID(), ss->fd, SSL_ROLE(ss), timer->label, time)); + + timer->started = PR_IntervalNow(); + timer->timeout = time; + timer->cb = cb; + return SECSuccess; +} + +SECStatus +dtls_RestartTimer(sslSocket *ss, dtlsTimer *timer) +{ + timer->started = PR_IntervalNow(); + return SECSuccess; +} + +PRBool +dtls_TimerActive(sslSocket *ss, dtlsTimer *timer) +{ + return timer->cb != NULL; +} +/* Start a timer for retransmission. */ +static SECStatus +dtls_StartRetransmitTimer(sslSocket *ss) +{ + ss->ssl3.hs.rtRetries = 0; + return dtls_StartTimer(ss, ss->ssl3.hs.rtTimer, + DTLS_RETRANSMIT_INITIAL_MS, + dtls_RetransmitTimerExpiredCb); +} + +/* Start a timer for holding an old cipher spec. */ +SECStatus +dtls_StartHolddownTimer(sslSocket *ss) +{ + ss->ssl3.hs.rtRetries = 0; + return dtls_StartTimer(ss, ss->ssl3.hs.rtTimer, + DTLS_RETRANSMIT_FINISHED_MS, + dtls_FinishedTimerCb); +} + +/* Cancel a pending timer + * + * Called from: + * dtls_HandleHandshake() + * dtls_CheckTimer() + */ +void +dtls_CancelTimer(sslSocket *ss, dtlsTimer *timer) +{ + SSL_TRC(30, ("%d: SSL3[%d]: %s dtls_CancelTimer %s", + SSL_GETPID(), ss->fd, SSL_ROLE(ss), + timer->label)); + + PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss)); + + timer->cb = NULL; +} + +static void +dtls_CancelAllTimers(sslSocket *ss) +{ + unsigned int i; + + for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) { + dtls_CancelTimer(ss, &ss->ssl3.hs.timers[i]); + } +} + +/* Check the pending timer and fire the callback if it expired + * + * Called from ssl3_GatherCompleteHandshake() + */ +void +dtls_CheckTimer(sslSocket *ss) +{ + unsigned int i; + SSL_TRC(30, ("%d: SSL3[%d]: dtls_CheckTimer (%s)", + SSL_GETPID(), ss->fd, ss->sec.isServer ? "server" : "client")); + + ssl_GetSSL3HandshakeLock(ss); + + for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) { + dtlsTimer *timer = &ss->ssl3.hs.timers[i]; + if (!timer->cb) { + continue; + } + + if ((PR_IntervalNow() - timer->started) >= + PR_MillisecondsToInterval(timer->timeout)) { + /* Timer has expired */ + DTLSTimerCb cb = timer->cb; + + SSL_TRC(10, ("%d: SSL3[%d]: %s firing timer %s", + SSL_GETPID(), ss->fd, SSL_ROLE(ss), + timer->label)); + + /* Cancel the timer so that we can call the CB safely */ + dtls_CancelTimer(ss, timer); + + /* Now call the CB */ + cb(ss); + } + } + ssl_ReleaseSSL3HandshakeLock(ss); +} + +/* The callback to fire when the holddown timer for the Finished + * message expires and we can delete it + * + * Called from dtls_CheckTimer() + */ +static void +dtls_FinishedTimerCb(sslSocket *ss) +{ + dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight); +} + +/* Cancel the Finished hold-down timer and destroy the + * pending cipher spec. Note that this means that + * successive rehandshakes will fail if the Finished is + * lost. + * + * XXX OK for now. Figure out how to handle the combination + * of Finished lost and rehandshake + */ +void +dtls_RehandshakeCleanup(sslSocket *ss) +{ + /* Skip this if we are handling a second ClientHello. */ + if (ss->ssl3.hs.helloRetry) { + return; + } + PORT_Assert((ss->version < SSL_LIBRARY_VERSION_TLS_1_3)); + dtls_CancelAllTimers(ss); + dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight); + ss->ssl3.hs.sendMessageSeq = 0; + ss->ssl3.hs.recvMessageSeq = 0; +} + +/* Set the MTU to the next step less than or equal to the + * advertised value. Also used to downgrade the MTU by + * doing dtls_SetMTU(ss, biggest packet set). + * + * Passing 0 means set this to the largest MTU known + * (effectively resetting the PMTU backoff value). + * + * Called by: + * ssl3_InitState() + * dtls_RetransmitTimerExpiredCb() + */ +void +dtls_SetMTU(sslSocket *ss, PRUint16 advertised) +{ + int i; + + if (advertised == 0) { + ss->ssl3.mtu = COMMON_MTU_VALUES[0]; + SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu)); + return; + } + + for (i = 0; i < PR_ARRAY_SIZE(COMMON_MTU_VALUES); i++) { + if (COMMON_MTU_VALUES[i] <= advertised) { + ss->ssl3.mtu = COMMON_MTU_VALUES[i]; + SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu)); + return; + } + } + + /* Fallback */ + ss->ssl3.mtu = COMMON_MTU_VALUES[PR_ARRAY_SIZE(COMMON_MTU_VALUES) - 1]; + SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu)); +} + +/* Called from ssl3_HandleHandshakeMessage() when it has deciphered a + * DTLS hello_verify_request + * Caller must hold Handshake and RecvBuf locks. + */ +SECStatus +dtls_HandleHelloVerifyRequest(sslSocket *ss, PRUint8 *b, PRUint32 length) +{ + int errCode = SSL_ERROR_RX_MALFORMED_HELLO_VERIFY_REQUEST; + SECStatus rv; + SSL3ProtocolVersion temp; + SSL3AlertDescription desc = illegal_parameter; + + SSL_TRC(3, ("%d: SSL3[%d]: handle hello_verify_request handshake", + SSL_GETPID(), ss->fd)); + PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss)); + PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss)); + + if (ss->ssl3.hs.ws != wait_server_hello) { + errCode = SSL_ERROR_RX_UNEXPECTED_HELLO_VERIFY_REQUEST; + desc = unexpected_message; + goto alert_loser; + } + + dtls_ReceivedFirstMessageInFlight(ss); + + /* The version. + * + * RFC 4347 required that you verify that the server versions + * match (Section 4.2.1) in the HelloVerifyRequest and the + * ServerHello. + * + * RFC 6347 suggests (SHOULD) that servers always use 1.0 in + * HelloVerifyRequest and allows the versions not to match, + * especially when 1.2 is being negotiated. + * + * Therefore we do not do anything to enforce a match, just + * read and check that this value is sane. + */ + rv = ssl_ClientReadVersion(ss, &b, &length, &temp); + if (rv != SECSuccess) { + goto loser; /* alert has been sent */ + } + + /* Read the cookie. + * IMPORTANT: The value of ss->ssl3.hs.cookie is only valid while the + * HelloVerifyRequest message remains valid. */ + rv = ssl3_ConsumeHandshakeVariable(ss, &ss->ssl3.hs.cookie, 1, &b, &length); + if (rv != SECSuccess) { + goto loser; /* alert has been sent */ + } + if (ss->ssl3.hs.cookie.len > DTLS_COOKIE_BYTES) { + desc = decode_error; + goto alert_loser; /* malformed. */ + } + + ssl_GetXmitBufLock(ss); /*******************************/ + + /* Now re-send the client hello */ + rv = ssl3_SendClientHello(ss, client_hello_retransmit); + + ssl_ReleaseXmitBufLock(ss); /*******************************/ + + if (rv == SECSuccess) + return rv; + +alert_loser: + (void)SSL3_SendAlert(ss, alert_fatal, desc); + +loser: + ssl_MapLowLevelError(errCode); + return SECFailure; +} + +/* Initialize the DTLS anti-replay window + * + * Called from: + * ssl3_SetupPendingCipherSpec() + * ssl3_InitCipherSpec() + */ +void +dtls_InitRecvdRecords(DTLSRecvdRecords *records) +{ + PORT_Memset(records->data, 0, sizeof(records->data)); + records->left = 0; + records->right = DTLS_RECVD_RECORDS_WINDOW - 1; +} + +/* + * Has this DTLS record been received? Return values are: + * -1 -- out of range to the left + * 0 -- not received yet + * 1 -- replay + * + * Called from: ssl3_HandleRecord() + */ +int +dtls_RecordGetRecvd(const DTLSRecvdRecords *records, sslSequenceNumber seq) +{ + PRUint64 offset; + + /* Out of range to the left */ + if (seq < records->left) { + return -1; + } + + /* Out of range to the right; since we advance the window on + * receipt, that means that this packet has not been received + * yet */ + if (seq > records->right) + return 0; + + offset = seq % DTLS_RECVD_RECORDS_WINDOW; + + return !!(records->data[offset / 8] & (1 << (offset % 8))); +} + +/* Update the DTLS anti-replay window + * + * Called from ssl3_HandleRecord() + */ +void +dtls_RecordSetRecvd(DTLSRecvdRecords *records, sslSequenceNumber seq) +{ + PRUint64 offset; + + if (seq < records->left) + return; + + if (seq > records->right) { + sslSequenceNumber new_left; + sslSequenceNumber new_right; + sslSequenceNumber right; + + /* Slide to the right; this is the tricky part + * + * 1. new_top is set to have room for seq, on the + * next byte boundary by setting the right 8 + * bits of seq + * 2. new_left is set to compensate. + * 3. Zero all bits between top and new_top. Since + * this is a ring, this zeroes everything as-yet + * unseen. Because we always operate on byte + * boundaries, we can zero one byte at a time + */ + new_right = seq | 0x07; + new_left = (new_right - DTLS_RECVD_RECORDS_WINDOW) + 1; + + if (new_right > records->right + DTLS_RECVD_RECORDS_WINDOW) { + PORT_Memset(records->data, 0, sizeof(records->data)); + } else { + for (right = records->right + 8; right <= new_right; right += 8) { + offset = right % DTLS_RECVD_RECORDS_WINDOW; + records->data[offset / 8] = 0; + } + } + + records->right = new_right; + records->left = new_left; + } + + offset = seq % DTLS_RECVD_RECORDS_WINDOW; + + records->data[offset / 8] |= (1 << (offset % 8)); +} + +SECStatus +DTLS_GetHandshakeTimeout(PRFileDesc *socket, PRIntervalTime *timeout) +{ + sslSocket *ss = NULL; + PRBool found = PR_FALSE; + PRIntervalTime now = PR_IntervalNow(); + PRIntervalTime to; + unsigned int i; + + *timeout = PR_INTERVAL_NO_TIMEOUT; + + ss = ssl_FindSocket(socket); + + if (!ss) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + + if (!IS_DTLS(ss)) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + + for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) { + PRIntervalTime elapsed; + PRIntervalTime desired; + dtlsTimer *timer = &ss->ssl3.hs.timers[i]; + + if (!timer->cb) { + continue; + } + found = PR_TRUE; + + elapsed = now - timer->started; + desired = PR_MillisecondsToInterval(timer->timeout); + if (elapsed > desired) { + /* Timer expired */ + *timeout = PR_INTERVAL_NO_WAIT; + return SECSuccess; + } else { + to = desired - elapsed; + } + + if (*timeout > to) { + *timeout = to; + } + } + + if (!found) { + PORT_SetError(SSL_ERROR_NO_TIMERS_FOUND); + return SECFailure; + } + + return SECSuccess; +} + +PRBool +dtls_IsLongHeader(SSL3ProtocolVersion version, PRUint8 firstOctet) +{ +#ifndef UNSAFE_FUZZER_MODE + return version < SSL_LIBRARY_VERSION_TLS_1_3 || + firstOctet == ssl_ct_handshake || + firstOctet == ssl_ct_ack || + firstOctet == ssl_ct_alert; +#else + return PR_TRUE; +#endif +} + +PRBool +dtls_IsDtls13Ciphertext(SSL3ProtocolVersion version, PRUint8 firstOctet) +{ + // Allow no version in case we haven't negotiated one yet. + return (version == 0 || version >= SSL_LIBRARY_VERSION_TLS_1_3) && + (firstOctet & 0xe0) == 0x20; +} + +DTLSEpoch +dtls_ReadEpoch(const ssl3CipherSpec *crSpec, const PRUint8 *hdr) +{ + DTLSEpoch epoch; + DTLSEpoch maxEpoch; + DTLSEpoch partial; + + if (dtls_IsLongHeader(crSpec->version, hdr[0])) { + return ((DTLSEpoch)hdr[3] << 8) | hdr[4]; + } + + /* A lot of how we recover the epoch here will depend on how we plan to + * manage KeyUpdate. In the case that we decide to install a new read spec + * as a KeyUpdate is handled, crSpec will always be the highest epoch we can + * possibly receive. That makes this easier to manage. + */ + if (dtls_IsDtls13Ciphertext(crSpec->version, hdr[0])) { + /* TODO(ekr@rtfm.com: do something with the two-bit epoch. */ + /* Use crSpec->epoch, or crSpec->epoch - 1 if the last bit differs. */ + return crSpec->epoch - ((hdr[0] ^ crSpec->epoch) & 0x3); + } + + /* dtls_GatherData should ensure that this works. */ + PORT_Assert(hdr[0] == ssl_ct_application_data); + + /* This uses the same method as is used to recover the sequence number in + * dtls_ReadSequenceNumber, except that the maximum value is set to the + * current epoch. */ + partial = hdr[1] >> 6; + maxEpoch = PR_MAX(crSpec->epoch, 3); + epoch = (maxEpoch & 0xfffc) | partial; + if (partial > (maxEpoch & 0x03)) { + epoch -= 4; + } + return epoch; +} + +static sslSequenceNumber +dtls_ReadSequenceNumber(const ssl3CipherSpec *spec, const PRUint8 *hdr) +{ + sslSequenceNumber cap; + sslSequenceNumber partial; + sslSequenceNumber seqNum; + sslSequenceNumber mask; + + if (dtls_IsLongHeader(spec->version, hdr[0])) { + static const unsigned int seqNumOffset = 5; /* type, version, epoch */ + static const unsigned int seqNumLength = 6; + sslReader r = SSL_READER(hdr + seqNumOffset, seqNumLength); + (void)sslRead_ReadNumber(&r, seqNumLength, &seqNum); + return seqNum; + } + + /* Only the least significant bits of the sequence number is available here. + * This recovers the value based on the next expected sequence number. + * + * This works by determining the maximum possible sequence number, which is + * half the range of possible values above the expected next value (the + * expected next value is in |spec->seqNum|). Then, the last part of the + * sequence number is replaced. If that causes the value to exceed the + * maximum, subtract an entire range. + */ + if (hdr[0] & 0x08) { + cap = spec->nextSeqNum + (1ULL << 15); + partial = (((sslSequenceNumber)hdr[1]) << 8) | + (sslSequenceNumber)hdr[2]; + mask = (1ULL << 16) - 1; + } else { + cap = spec->nextSeqNum + (1ULL << 7); + partial = (sslSequenceNumber)hdr[1]; + mask = (1ULL << 8) - 1; + } + seqNum = (cap & ~mask) | partial; + /* The second check prevents the value from underflowing if we get a large + * gap at the start of a connection, where this subtraction would cause the + * sequence number to wrap to near UINT64_MAX. */ + if ((partial > (cap & mask)) && (seqNum > mask)) { + seqNum -= mask + 1; + } + return seqNum; +} + +/* + * DTLS relevance checks: + * Note that this code currently ignores all out-of-epoch packets, + * which means we lose some in the case of rehandshake + + * loss/reordering. Since DTLS is explicitly unreliable, this + * seems like a good tradeoff for implementation effort and is + * consistent with the guidance of RFC 6347 Sections 4.1 and 4.2.4.1. + * + * If the packet is not relevant, this function returns PR_FALSE. If the packet + * is relevant, this function returns PR_TRUE and sets |*seqNumOut| to the + * packet sequence number (removing the epoch). + */ +PRBool +dtls_IsRelevant(sslSocket *ss, const ssl3CipherSpec *spec, + const SSL3Ciphertext *cText, + sslSequenceNumber *seqNumOut) +{ + sslSequenceNumber seqNum = dtls_ReadSequenceNumber(spec, cText->hdr); + if (dtls_RecordGetRecvd(&spec->recvdRecords, seqNum) != 0) { + SSL_TRC(10, ("%d: SSL3[%d]: dtls_IsRelevant, rejecting " + "potentially replayed packet", + SSL_GETPID(), ss->fd)); + return PR_FALSE; + } + + *seqNumOut = seqNum; + return PR_TRUE; +} + +void +dtls_ReceivedFirstMessageInFlight(sslSocket *ss) +{ + if (!IS_DTLS(ss)) + return; + + /* At this point we are advancing our state machine, so we can free our last + * flight of messages. */ + if (ss->ssl3.hs.ws != idle_handshake || + ss->version >= SSL_LIBRARY_VERSION_TLS_1_3) { + /* We need to keep our last flight around in DTLS 1.2 and below, + * so we can retransmit it in response to other people's + * retransmits. */ + dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight); + + /* Reset the timer to the initial value if the retry counter + * is 0, per RFC 6347, Sec. 4.2.4.1 */ + dtls_CancelTimer(ss, ss->ssl3.hs.rtTimer); + if (ss->ssl3.hs.rtRetries == 0) { + ss->ssl3.hs.rtTimer->timeout = DTLS_RETRANSMIT_INITIAL_MS; + } + } + + /* Empty the ACK queue (TLS 1.3 only). */ + ssl_ClearPRCList(&ss->ssl3.hs.dtlsRcvdHandshake, NULL); +} |