diff options
Diffstat (limited to 'netwerk/base/TCPFastOpenLayer.cpp')
-rw-r--r-- | netwerk/base/TCPFastOpenLayer.cpp | 527 |
1 files changed, 527 insertions, 0 deletions
diff --git a/netwerk/base/TCPFastOpenLayer.cpp b/netwerk/base/TCPFastOpenLayer.cpp new file mode 100644 index 0000000000..ea5e5f11d6 --- /dev/null +++ b/netwerk/base/TCPFastOpenLayer.cpp @@ -0,0 +1,527 @@ +/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim:set ts=2 sw=2 sts=2 et cindent: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "TCPFastOpenLayer.h" +#include "nsSocketTransportService2.h" +#include "prmem.h" +#include "prio.h" + +namespace mozilla { +namespace net { + +static PRDescIdentity sTCPFastOpenLayerIdentity; +static PRIOMethods sTCPFastOpenLayerMethods; +static PRIOMethods* sTCPFastOpenLayerMethodsPtr = nullptr; + +#define TFO_MAX_PACKET_SIZE_IPV4 1460 +#define TFO_MAX_PACKET_SIZE_IPV6 1440 +#define TFO_TLS_RECORD_HEADER_SIZE 22 + +/** + * For the TCP Fast Open it is necessary to send all data that can fit into the + * first packet on a single sendto function call. Consecutive calls will not + * have an effect. Therefore TCPFastOpenLayer will collect some data before + * calling sendto. Necko and nss will call PR_Write multiple times with small + * amount of data. + * TCPFastOpenLayer has 4 states: + * WAITING_FOR_CONNECT: + * This is before connect is call. A call of recv, send or getpeername will + * return PR_NOT_CONNECTED_ERROR. After connect is call the state transfers + * into COLLECT_DATA_FOR_FIRST_PACKET. + * + * COLLECT_DATA_FOR_FIRST_PACKET: + * In this state all data received by send function calls will be stored in + * a buffer. If transaction do not have any more data ready to be sent or + * the buffer is full, TCPFastOpenFinish is call. TCPFastOpenFinish sends + * the collected data using sendto function and the state transfers to + * WAITING_FOR_CONNECTCONTINUE. If an error occurs during sendto, the error + * is reported by the TCPFastOpenFinish return values. nsSocketTransfer is + * the only caller of TCPFastOpenFinish; it knows how to interpreter these + * errors. + * WAITING_FOR_CONNECTCONTINUE: + * connectcontinue transfers from this state to CONNECTED. Any other + * function (e.g. send, recv) returns PR_WOULD_BLOCK_ERROR. + * CONNECTED: + * The size of mFirstPacketBuf is 1440/1460 (RFC7413 recomends that packet + * does exceeds these sizes). SendTo does not have to consume all buffered + * data and some data can be still in mFirstPacketBuf. Before sending any + * new data we need to send the remaining buffered data. + **/ + +class TCPFastOpenSecret { + public: + TCPFastOpenSecret() + : mState(WAITING_FOR_CONNECT), mFirstPacketBufLen(0), mCondition(0) { + this->mAddr.raw.family = 0; + this->mAddr.inet.family = 0; + this->mAddr.inet.port = 0; + this->mAddr.inet.ip = 0; + this->mAddr.ipv6.family = 0; + this->mAddr.ipv6.port = 0; + this->mAddr.ipv6.flowinfo = 0; + this->mAddr.ipv6.scope_id = 0; + this->mAddr.local.family = 0; + } + + enum { + CONNECTED, + WAITING_FOR_CONNECTCONTINUE, + COLLECT_DATA_FOR_FIRST_PACKET, + WAITING_FOR_CONNECT, + SOCKET_ERROR_STATE + } mState; + PRNetAddr mAddr; + char mFirstPacketBuf[1460]; + uint16_t mFirstPacketBufLen; + PRErrorCode mCondition; +}; + +static PRStatus TCPFastOpenConnect(PRFileDesc* fd, const PRNetAddr* addr, + PRIntervalTime timeout) { + MOZ_RELEASE_ASSERT(fd->identity == sTCPFastOpenLayerIdentity); + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + TCPFastOpenSecret* secret = reinterpret_cast<TCPFastOpenSecret*>(fd->secret); + + SOCKET_LOG(("TCPFastOpenConnect state=%d.\n", secret->mState)); + + if (secret->mState != TCPFastOpenSecret::WAITING_FOR_CONNECT) { + PR_SetError(PR_IS_CONNECTED_ERROR, 0); + return PR_FAILURE; + } + + // Remember the address. It will be used for sendto or connect later. + memcpy(&secret->mAddr, addr, sizeof(secret->mAddr)); + secret->mState = TCPFastOpenSecret::COLLECT_DATA_FOR_FIRST_PACKET; + + return PR_SUCCESS; +} + +static PRInt32 TCPFastOpenSend(PRFileDesc* fd, const void* buf, PRInt32 amount, + PRIntn flags, PRIntervalTime timeout) { + MOZ_RELEASE_ASSERT(fd->identity == sTCPFastOpenLayerIdentity); + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + TCPFastOpenSecret* secret = reinterpret_cast<TCPFastOpenSecret*>(fd->secret); + + SOCKET_LOG(("TCPFastOpenSend state=%d.\n", secret->mState)); + + switch (secret->mState) { + case TCPFastOpenSecret::CONNECTED: + // Before sending new data we need to drain the data collected during tfo. + if (secret->mFirstPacketBufLen) { + SOCKET_LOG( + ("TCPFastOpenSend - %d bytes to drain from " + "mFirstPacketBufLen.\n", + secret->mFirstPacketBufLen)); + PRInt32 rv = (fd->lower->methods->send)( + fd->lower, secret->mFirstPacketBuf, secret->mFirstPacketBufLen, + 0, // flags + PR_INTERVAL_NO_WAIT); + if (rv <= 0) { + return rv; + } + secret->mFirstPacketBufLen -= rv; + if (secret->mFirstPacketBufLen) { + memmove(secret->mFirstPacketBuf, secret->mFirstPacketBuf + rv, + secret->mFirstPacketBufLen); + + PR_SetError(PR_WOULD_BLOCK_ERROR, 0); + return PR_FAILURE; + } // if we drained the buffer we can fall through this checks and call + // send for the new data + } + SOCKET_LOG(("TCPFastOpenSend sending new data.\n")); + return (fd->lower->methods->send)(fd->lower, buf, amount, flags, timeout); + case TCPFastOpenSecret::WAITING_FOR_CONNECTCONTINUE: + PR_SetError(PR_WOULD_BLOCK_ERROR, 0); + return -1; + case TCPFastOpenSecret::COLLECT_DATA_FOR_FIRST_PACKET: { + int32_t toSend = (secret->mAddr.raw.family == PR_AF_INET) + ? TFO_MAX_PACKET_SIZE_IPV4 + : TFO_MAX_PACKET_SIZE_IPV6; + MOZ_ASSERT(secret->mFirstPacketBufLen <= toSend); + toSend -= secret->mFirstPacketBufLen; + + SOCKET_LOG( + ("TCPFastOpenSend: amount of data in the buffer=%d; the amount" + " of additional data that can be stored=%d.\n", + secret->mFirstPacketBufLen, toSend)); + + if (!toSend) { + PR_SetError(PR_WOULD_BLOCK_ERROR, 0); + return -1; + } + + toSend = (toSend > amount) ? amount : toSend; + memcpy(secret->mFirstPacketBuf + secret->mFirstPacketBufLen, buf, toSend); + secret->mFirstPacketBufLen += toSend; + return toSend; + } + case TCPFastOpenSecret::WAITING_FOR_CONNECT: + PR_SetError(PR_NOT_CONNECTED_ERROR, 0); + return -1; + case TCPFastOpenSecret::SOCKET_ERROR_STATE: + PR_SetError(secret->mCondition, 0); + return -1; + } + PR_SetError(PR_WOULD_BLOCK_ERROR, 0); + return PR_FAILURE; +} + +static PRInt32 TCPFastOpenWrite(PRFileDesc* fd, const void* buf, + PRInt32 amount) { + return TCPFastOpenSend(fd, buf, amount, 0, PR_INTERVAL_NO_WAIT); +} + +static PRInt32 TCPFastOpenRecv(PRFileDesc* fd, void* buf, PRInt32 amount, + PRIntn flags, PRIntervalTime timeout) { + MOZ_RELEASE_ASSERT(fd->identity == sTCPFastOpenLayerIdentity); + + TCPFastOpenSecret* secret = reinterpret_cast<TCPFastOpenSecret*>(fd->secret); + + PRInt32 rv = -1; + switch (secret->mState) { + case TCPFastOpenSecret::CONNECTED: + + if (secret->mFirstPacketBufLen) { + // TLS will not call write before receiving data from a server, + // therefore we need to force sending buffered data even during recv + // call. Otherwise It can come to a deadlock (clients waits for + // response, but the request is sitting in mFirstPacketBufLen). + SOCKET_LOG( + ("TCPFastOpenRevc - %d bytes to drain from mFirstPacketBuf.\n", + secret->mFirstPacketBufLen)); + PRInt32 rv = (fd->lower->methods->send)( + fd->lower, secret->mFirstPacketBuf, secret->mFirstPacketBufLen, + 0, // flags + PR_INTERVAL_NO_WAIT); + if (rv <= 0) { + return rv; + } + secret->mFirstPacketBufLen -= rv; + if (secret->mFirstPacketBufLen) { + memmove(secret->mFirstPacketBuf, secret->mFirstPacketBuf + rv, + secret->mFirstPacketBufLen); + } + } + rv = (fd->lower->methods->recv)(fd->lower, buf, amount, flags, timeout); + break; + case TCPFastOpenSecret::WAITING_FOR_CONNECTCONTINUE: + case TCPFastOpenSecret::COLLECT_DATA_FOR_FIRST_PACKET: + PR_SetError(PR_WOULD_BLOCK_ERROR, 0); + break; + case TCPFastOpenSecret::WAITING_FOR_CONNECT: + PR_SetError(PR_NOT_CONNECTED_ERROR, 0); + break; + case TCPFastOpenSecret::SOCKET_ERROR_STATE: + PR_SetError(secret->mCondition, 0); + } + return rv; +} + +static PRInt32 TCPFastOpenRead(PRFileDesc* fd, void* buf, PRInt32 amount) { + return TCPFastOpenRecv(fd, buf, amount, 0, PR_INTERVAL_NO_WAIT); +} + +static PRStatus TCPFastOpenConnectContinue(PRFileDesc* fd, PRInt16 out_flags) { + MOZ_RELEASE_ASSERT(fd->identity == sTCPFastOpenLayerIdentity); + + TCPFastOpenSecret* secret = reinterpret_cast<TCPFastOpenSecret*>(fd->secret); + + PRStatus rv = PR_FAILURE; + switch (secret->mState) { + case TCPFastOpenSecret::CONNECTED: + rv = PR_SUCCESS; + break; + case TCPFastOpenSecret::WAITING_FOR_CONNECT: + case TCPFastOpenSecret::COLLECT_DATA_FOR_FIRST_PACKET: + PR_SetError(PR_NOT_CONNECTED_ERROR, 0); + rv = PR_FAILURE; + break; + case TCPFastOpenSecret::WAITING_FOR_CONNECTCONTINUE: + rv = (fd->lower->methods->connectcontinue)(fd->lower, out_flags); + + SOCKET_LOG(("TCPFastOpenConnectContinue result=%d.\n", rv)); + secret->mState = TCPFastOpenSecret::CONNECTED; + break; + case TCPFastOpenSecret::SOCKET_ERROR_STATE: + PR_SetError(secret->mCondition, 0); + rv = PR_FAILURE; + } + return rv; +} + +static PRStatus TCPFastOpenClose(PRFileDesc* fd) { + if (!fd) { + return PR_FAILURE; + } + + PRFileDesc* layer = PR_PopIOLayer(fd, PR_TOP_IO_LAYER); + + MOZ_RELEASE_ASSERT(layer && layer->identity == sTCPFastOpenLayerIdentity, + "TCP Fast Open Layer not on top of stack"); + + TCPFastOpenSecret* secret = + reinterpret_cast<TCPFastOpenSecret*>(layer->secret); + layer->secret = nullptr; + layer->dtor(layer); + delete secret; + return fd->methods->close(fd); +} + +static PRStatus TCPFastOpenGetpeername(PRFileDesc* fd, PRNetAddr* addr) { + MOZ_RELEASE_ASSERT(fd); + MOZ_RELEASE_ASSERT(addr); + + MOZ_RELEASE_ASSERT(fd->identity == sTCPFastOpenLayerIdentity); + + TCPFastOpenSecret* secret = reinterpret_cast<TCPFastOpenSecret*>(fd->secret); + if (secret->mState == TCPFastOpenSecret::WAITING_FOR_CONNECT) { + PR_SetError(PR_NOT_CONNECTED_ERROR, 0); + return PR_FAILURE; + } + + memcpy(addr, &secret->mAddr, sizeof(secret->mAddr)); + return PR_SUCCESS; +} + +static PRInt16 TCPFastOpenPoll(PRFileDesc* fd, PRInt16 how_flags, + PRInt16* p_out_flags) { + MOZ_RELEASE_ASSERT(fd); + MOZ_RELEASE_ASSERT(fd->identity == sTCPFastOpenLayerIdentity); + + TCPFastOpenSecret* secret = reinterpret_cast<TCPFastOpenSecret*>(fd->secret); + if (secret->mFirstPacketBufLen) { + how_flags |= PR_POLL_WRITE; + } + + return fd->lower->methods->poll(fd->lower, how_flags, p_out_flags); +} + +nsresult AttachTCPFastOpenIOLayer(PRFileDesc* fd) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + if (!sTCPFastOpenLayerMethodsPtr) { + sTCPFastOpenLayerIdentity = PR_GetUniqueIdentity("TCPFastOpen Layer"); + sTCPFastOpenLayerMethods = *PR_GetDefaultIOMethods(); + sTCPFastOpenLayerMethods.connect = TCPFastOpenConnect; + sTCPFastOpenLayerMethods.send = TCPFastOpenSend; + sTCPFastOpenLayerMethods.write = TCPFastOpenWrite; + sTCPFastOpenLayerMethods.recv = TCPFastOpenRecv; + sTCPFastOpenLayerMethods.read = TCPFastOpenRead; + sTCPFastOpenLayerMethods.connectcontinue = TCPFastOpenConnectContinue; + sTCPFastOpenLayerMethods.close = TCPFastOpenClose; + sTCPFastOpenLayerMethods.getpeername = TCPFastOpenGetpeername; + sTCPFastOpenLayerMethods.poll = TCPFastOpenPoll; + sTCPFastOpenLayerMethodsPtr = &sTCPFastOpenLayerMethods; + } + + PRFileDesc* layer = PR_CreateIOLayerStub(sTCPFastOpenLayerIdentity, + sTCPFastOpenLayerMethodsPtr); + + if (!layer) { + return NS_ERROR_FAILURE; + } + + TCPFastOpenSecret* secret = new TCPFastOpenSecret(); + + layer->secret = reinterpret_cast<PRFilePrivate*>(secret); + + PRStatus status = PR_PushIOLayer(fd, PR_NSPR_IO_LAYER, layer); + + if (status == PR_FAILURE) { + delete secret; + PR_Free(layer); // PR_CreateIOLayerStub() uses PR_Malloc(). + return NS_ERROR_FAILURE; + } + return NS_OK; +} + +void TCPFastOpenFinish(PRFileDesc* fd, PRErrorCode& err, + bool& fastOpenNotSupported, uint8_t& tfoStatus) { + PRFileDesc* tfoFd = PR_GetIdentitiesLayer(fd, sTCPFastOpenLayerIdentity); + MOZ_RELEASE_ASSERT(tfoFd); + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + TCPFastOpenSecret* secret = + reinterpret_cast<TCPFastOpenSecret*>(tfoFd->secret); + + MOZ_ASSERT(secret->mState == + TCPFastOpenSecret::COLLECT_DATA_FOR_FIRST_PACKET); + + fastOpenNotSupported = false; + tfoStatus = TFO_NOT_TRIED; + PRErrorCode result = 0; + + // If we do not have data to send with syn packet or nspr version does not + // have sendto implemented we will call normal connect. + // If sendto is not implemented it points to _PR_InvalidInt, therefore we + // check if sendto != _PR_InvalidInt. _PR_InvalidInt is exposed so we use + // reserved_fn_0 which also points to _PR_InvalidInt. + if (!secret->mFirstPacketBufLen || + (tfoFd->lower->methods->sendto == + (PRSendtoFN)tfoFd->lower->methods->reserved_fn_0)) { + // Because of the way our nsHttpTransaction dispatch work, it can happened + // that data has not been written into the socket. + // In this case we can just call connect. + PRInt32 rv = (tfoFd->lower->methods->connect)(tfoFd->lower, &secret->mAddr, + PR_INTERVAL_NO_WAIT); + if (rv == PR_SUCCESS) { + result = PR_IS_CONNECTED_ERROR; + } else { + result = PR_GetError(); + } + if (tfoFd->lower->methods->sendto == + (PRSendtoFN)tfoFd->lower->methods->reserved_fn_0) { + // sendto is not implemented, it is equal to _PR_InvalidInt! + // We will disable Fast Open. + SOCKET_LOG(("TCPFastOpenFinish - sendto not implemented.\n")); + fastOpenNotSupported = true; + tfoStatus = TFO_DISABLED; + } + } else { + // We have some data ready in the buffer we will send it with the syn + // packet. + PRInt32 rv = (tfoFd->lower->methods->sendto)( + tfoFd->lower, secret->mFirstPacketBuf, secret->mFirstPacketBufLen, + 0, // flags + &secret->mAddr, PR_INTERVAL_NO_WAIT); + + SOCKET_LOG(("TCPFastOpenFinish - sendto result=%d.\n", rv)); + if (rv > 0) { + result = PR_IN_PROGRESS_ERROR; + secret->mFirstPacketBufLen -= rv; + if (secret->mFirstPacketBufLen) { + memmove(secret->mFirstPacketBuf, secret->mFirstPacketBuf + rv, + secret->mFirstPacketBufLen); + } + tfoStatus = TFO_DATA_SENT; + } else { + result = PR_GetError(); + SOCKET_LOG(("TCPFastOpenFinish - sendto error=%d.\n", result)); + + if (result == + PR_NOT_TCP_SOCKET_ERROR) { // SendTo will return + // PR_NOT_TCP_SOCKET_ERROR if TCP Fast + // Open is turned off on Linux. + // We can call connect again. + fastOpenNotSupported = true; + rv = (tfoFd->lower->methods->connect)(tfoFd->lower, &secret->mAddr, + PR_INTERVAL_NO_WAIT); + + if (rv == PR_SUCCESS) { + result = PR_IS_CONNECTED_ERROR; + } else { + result = PR_GetError(); + } + tfoStatus = TFO_DISABLED; + } else { + tfoStatus = TFO_TRIED; + } + } + } + + if (result == PR_IN_PROGRESS_ERROR) { + secret->mState = TCPFastOpenSecret::WAITING_FOR_CONNECTCONTINUE; + } else { + // If the error is not PR_IN_PROGRESS_ERROR, change the state to CONNECT so + // that recv/send can perform recv/send on the next lower layer and pick up + // the real error. This is really important! + // The result can also be PR_IS_CONNECTED_ERROR, that should change the + // state to CONNECT anyway. + secret->mState = TCPFastOpenSecret::CONNECTED; + } + err = result; +} + +/* This function returns the size of the remaining free space in the + * first_packet_buffer. This will be used by transactions with a tls layer. For + * other transactions it is not necessary. The tls transactions make a tls + * record before writing to this layer and if the record is too big the part + * that does not have place in the mFirstPacketBuf will be saved on the tls + * layer. During TFO we cannot send more than TFO_MAX_PACKET_SIZE_IPV4/6 bytes, + * so if we have a big tls record, this record is encrypted with 0RTT key, + * tls-early-data can be rejected and than we still need to send the rest of the + * record. + */ +int32_t TCPFastOpenGetBufferSizeLeft(PRFileDesc* fd) { + PRFileDesc* tfoFd = PR_GetIdentitiesLayer(fd, sTCPFastOpenLayerIdentity); + MOZ_RELEASE_ASSERT(tfoFd); + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + TCPFastOpenSecret* secret = + reinterpret_cast<TCPFastOpenSecret*>(tfoFd->secret); + + if (secret->mState != TCPFastOpenSecret::COLLECT_DATA_FOR_FIRST_PACKET) { + return 0; + } + + int32_t sizeLeft = (secret->mAddr.raw.family == PR_AF_INET) + ? TFO_MAX_PACKET_SIZE_IPV4 + : TFO_MAX_PACKET_SIZE_IPV6; + MOZ_ASSERT(secret->mFirstPacketBufLen <= sizeLeft); + sizeLeft -= secret->mFirstPacketBufLen; + + SOCKET_LOG(("TCPFastOpenGetBufferSizeLeft=%d.\n", sizeLeft)); + + return (sizeLeft > TFO_TLS_RECORD_HEADER_SIZE) + ? sizeLeft - TFO_TLS_RECORD_HEADER_SIZE + : 0; +} + +bool TCPFastOpenGetCurrentBufferSize(PRFileDesc* fd) { + PRFileDesc* tfoFd = PR_GetIdentitiesLayer(fd, sTCPFastOpenLayerIdentity); + MOZ_RELEASE_ASSERT(tfoFd); + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + TCPFastOpenSecret* secret = + reinterpret_cast<TCPFastOpenSecret*>(tfoFd->secret); + + return secret->mFirstPacketBufLen; +} + +bool TCPFastOpenFlushBuffer(PRFileDesc* fd) { + PRFileDesc* tfoFd = PR_GetIdentitiesLayer(fd, sTCPFastOpenLayerIdentity); + MOZ_RELEASE_ASSERT(tfoFd); + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + TCPFastOpenSecret* secret = + reinterpret_cast<TCPFastOpenSecret*>(tfoFd->secret); + MOZ_ASSERT(secret->mState == TCPFastOpenSecret::CONNECTED); + + if (secret->mFirstPacketBufLen) { + SOCKET_LOG( + ("TCPFastOpenFlushBuffer - %d bytes to drain from " + "mFirstPacketBufLen.\n", + secret->mFirstPacketBufLen)); + PRInt32 rv = (tfoFd->lower->methods->send)( + tfoFd->lower, secret->mFirstPacketBuf, secret->mFirstPacketBufLen, + 0, // flags + PR_INTERVAL_NO_WAIT); + if (rv <= 0) { + PRErrorCode err = PR_GetError(); + if (err == PR_WOULD_BLOCK_ERROR) { + // We still need to send this data. + return true; + } + // There is an error, let nsSocketTransport pick it up properly. + secret->mCondition = err; + secret->mState = TCPFastOpenSecret::SOCKET_ERROR_STATE; + return false; + } + + secret->mFirstPacketBufLen -= rv; + if (secret->mFirstPacketBufLen) { + memmove(secret->mFirstPacketBuf, secret->mFirstPacketBuf + rv, + secret->mFirstPacketBufLen); + } + } + return secret->mFirstPacketBufLen; +} + +} // namespace net +} // namespace mozilla |