diff options
Diffstat (limited to '')
-rw-r--r-- | src/shrpx_quic.cc | 87 |
1 files changed, 47 insertions, 40 deletions
diff --git a/src/shrpx_quic.cc b/src/shrpx_quic.cc index c52eee4..1ffd217 100644 --- a/src/shrpx_quic.cc +++ b/src/shrpx_quic.cc @@ -58,8 +58,10 @@ ngtcp2_tstamp quic_timestamp() { int quic_send_packet(const UpstreamAddr *faddr, const sockaddr *remote_sa, size_t remote_salen, const sockaddr *local_sa, size_t local_salen, const ngtcp2_pkt_info &pi, - const uint8_t *data, size_t datalen, size_t gso_size) { - iovec msg_iov = {const_cast<uint8_t *>(data), datalen}; + std::span<const uint8_t> data, size_t gso_size) { + assert(gso_size); + + iovec msg_iov = {const_cast<uint8_t *>(data.data()), data.size()}; msghdr msg{}; msg.msg_name = const_cast<sockaddr *>(remote_sa); msg.msg_namelen = remote_salen; @@ -113,7 +115,7 @@ int quic_send_packet(const UpstreamAddr *faddr, const sockaddr *remote_sa, } #ifdef UDP_SEGMENT - if (gso_size && datalen > gso_size) { + if (data.size() > gso_size) { controllen += CMSG_SPACE(sizeof(uint16_t)); cm = CMSG_NXTHDR(&msg, cm); cm->cmsg_level = SOL_UDP; @@ -170,6 +172,8 @@ int quic_send_packet(const UpstreamAddr *faddr, const sockaddr *remote_sa, << " bytes"; } + assert(static_cast<size_t>(nwrite) == data.size()); + return 0; } @@ -267,87 +271,90 @@ int generate_quic_stateless_reset_token(uint8_t *token, const ngtcp2_cid &cid, return 0; } -int generate_retry_token(uint8_t *token, size_t &tokenlen, uint32_t version, - const sockaddr *sa, socklen_t salen, - const ngtcp2_cid &retry_scid, const ngtcp2_cid &odcid, - const uint8_t *secret, size_t secretlen) { +std::optional<std::span<const uint8_t>> +generate_retry_token(std::span<uint8_t> token, uint32_t version, + const sockaddr *sa, socklen_t salen, + const ngtcp2_cid &retry_scid, const ngtcp2_cid &odcid, + std::span<const uint8_t> secret) { auto t = std::chrono::duration_cast<std::chrono::nanoseconds>( std::chrono::system_clock::now().time_since_epoch()) .count(); - auto stokenlen = ngtcp2_crypto_generate_retry_token( - token, secret, secretlen, version, sa, salen, &retry_scid, &odcid, t); - if (stokenlen < 0) { - return -1; + auto tokenlen = ngtcp2_crypto_generate_retry_token( + token.data(), secret.data(), secret.size(), version, sa, salen, + &retry_scid, &odcid, t); + if (tokenlen < 0) { + return {}; } - tokenlen = stokenlen; - - return 0; + return {{std::begin(token), static_cast<size_t>(tokenlen)}}; } -int verify_retry_token(ngtcp2_cid &odcid, const uint8_t *token, size_t tokenlen, +int verify_retry_token(ngtcp2_cid &odcid, std::span<const uint8_t> token, uint32_t version, const ngtcp2_cid &dcid, const sockaddr *sa, socklen_t salen, - const uint8_t *secret, size_t secretlen) { - + std::span<const uint8_t> secret) { auto t = std::chrono::duration_cast<std::chrono::nanoseconds>( std::chrono::system_clock::now().time_since_epoch()) .count(); - if (ngtcp2_crypto_verify_retry_token(&odcid, token, tokenlen, secret, - secretlen, version, sa, salen, &dcid, - 10 * NGTCP2_SECONDS, t) != 0) { + if (ngtcp2_crypto_verify_retry_token( + &odcid, token.data(), token.size(), secret.data(), secret.size(), + version, sa, salen, &dcid, 10 * NGTCP2_SECONDS, t) != 0) { return -1; } return 0; } -int generate_token(uint8_t *token, size_t &tokenlen, const sockaddr *sa, - size_t salen, const uint8_t *secret, size_t secretlen) { +std::optional<std::span<const uint8_t>> +generate_token(std::span<uint8_t> token, const sockaddr *sa, size_t salen, + std::span<const uint8_t> secret, uint8_t km_id) { auto t = std::chrono::duration_cast<std::chrono::nanoseconds>( std::chrono::system_clock::now().time_since_epoch()) .count(); - auto stokenlen = ngtcp2_crypto_generate_regular_token( - token, secret, secretlen, sa, salen, t); - if (stokenlen < 0) { - return -1; + auto tokenlen = ngtcp2_crypto_generate_regular_token( + token.data(), secret.data(), secret.size(), sa, salen, t); + if (tokenlen < 0) { + return {}; } - tokenlen = stokenlen; + token[tokenlen++] = km_id; - return 0; + return {{std::begin(token), static_cast<size_t>(tokenlen)}}; } -int verify_token(const uint8_t *token, size_t tokenlen, const sockaddr *sa, - socklen_t salen, const uint8_t *secret, size_t secretlen) { +int verify_token(std::span<const uint8_t> token, const sockaddr *sa, + socklen_t salen, std::span<const uint8_t> secret) { + if (token.empty()) { + return -1; + } + auto t = std::chrono::duration_cast<std::chrono::nanoseconds>( std::chrono::system_clock::now().time_since_epoch()) .count(); - if (ngtcp2_crypto_verify_regular_token(token, tokenlen, secret, secretlen, sa, - salen, 3600 * NGTCP2_SECONDS, - t) != 0) { + if (ngtcp2_crypto_verify_regular_token( + token.data(), token.size() - 1, secret.data(), secret.size(), sa, + salen, 3600 * NGTCP2_SECONDS, t) != 0) { return -1; } return 0; } -int generate_quic_connection_id_encryption_key(uint8_t *key, size_t keylen, - const uint8_t *secret, - size_t secretlen, - const uint8_t *salt, - size_t saltlen) { +int generate_quic_connection_id_encryption_key(std::span<uint8_t> key, + std::span<const uint8_t> secret, + std::span<const uint8_t> salt) { constexpr uint8_t info[] = "connection id encryption key"; ngtcp2_crypto_md sha256; ngtcp2_crypto_md_init( &sha256, reinterpret_cast<void *>(const_cast<EVP_MD *>(EVP_sha256()))); - if (ngtcp2_crypto_hkdf(key, keylen, &sha256, secret, secretlen, salt, saltlen, - info, str_size(info)) != 0) { + if (ngtcp2_crypto_hkdf(key.data(), key.size(), &sha256, secret.data(), + secret.size(), salt.data(), salt.size(), info, + str_size(info)) != 0) { return -1; } |