diff options
Diffstat (limited to 'third_party/libwebrtc/p2p/base/stun_dictionary.cc')
-rw-r--r-- | third_party/libwebrtc/p2p/base/stun_dictionary.cc | 357 |
1 files changed, 357 insertions, 0 deletions
diff --git a/third_party/libwebrtc/p2p/base/stun_dictionary.cc b/third_party/libwebrtc/p2p/base/stun_dictionary.cc new file mode 100644 index 0000000000..bf6a1e49c2 --- /dev/null +++ b/third_party/libwebrtc/p2p/base/stun_dictionary.cc @@ -0,0 +1,357 @@ +/* + * Copyright 2020 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "p2p/base/stun_dictionary.h" + +#include <algorithm> +#include <deque> +#include <utility> + +#include "rtc_base/logging.h" + +namespace cricket { + +const StunAddressAttribute* StunDictionaryView::GetAddress(int key) const { + const StunAttribute* attr = GetOrNull(key, STUN_VALUE_ADDRESS); + if (attr == nullptr) { + return nullptr; + } + return reinterpret_cast<const StunAddressAttribute*>(attr); +} + +const StunUInt32Attribute* StunDictionaryView::GetUInt32(int key) const { + const StunAttribute* attr = GetOrNull(key, STUN_VALUE_UINT32); + if (attr == nullptr) { + return nullptr; + } + return reinterpret_cast<const StunUInt32Attribute*>(attr); +} + +const StunUInt64Attribute* StunDictionaryView::GetUInt64(int key) const { + const StunAttribute* attr = GetOrNull(key, STUN_VALUE_UINT64); + if (attr == nullptr) { + return nullptr; + } + return reinterpret_cast<const StunUInt64Attribute*>(attr); +} + +const StunByteStringAttribute* StunDictionaryView::GetByteString( + int key) const { + const StunAttribute* attr = GetOrNull(key, STUN_VALUE_BYTE_STRING); + if (attr == nullptr) { + return nullptr; + } + return reinterpret_cast<const StunByteStringAttribute*>(attr); +} + +const StunUInt16ListAttribute* StunDictionaryView::GetUInt16List( + int key) const { + const StunAttribute* attr = GetOrNull(key, STUN_VALUE_UINT16_LIST); + if (attr == nullptr) { + return nullptr; + } + return reinterpret_cast<const StunUInt16ListAttribute*>(attr); +} + +const StunAttribute* StunDictionaryView::GetOrNull( + int key, + absl::optional<StunAttributeValueType> type) const { + const auto it = attrs_.find(key); + if (it == attrs_.end()) { + return nullptr; + } + + if (type && it->second->value_type() != *type) { + RTC_LOG(LS_WARNING) << "Get key: " << key << " with type: " << *type + << " found different type: " + << it->second->value_type(); + return nullptr; + } + return (*it).second.get(); +} + +webrtc::RTCErrorOr< + std::pair<uint64_t, std::deque<std::unique_ptr<StunAttribute>>>> +StunDictionaryView::ParseDelta(const StunByteStringAttribute& delta) { + rtc::ByteBufferReader buf(delta.bytes(), delta.length()); + uint16_t magic; + if (!buf.ReadUInt16(&magic)) { + return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, + "Failed to read magic number"); + } + if (magic != kDeltaMagic) { + return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, + "Invalid magic number"); + } + + uint16_t delta_version; + if (!buf.ReadUInt16(&delta_version)) { + return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, + "Failed to read version"); + } + + if (delta_version != kDeltaVersion) { + return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, + "Unsupported delta version"); + } + + // Now read all the attributes + std::deque<std::unique_ptr<StunAttribute>> attrs; + while (buf.Length()) { + uint16_t key, length, value_type; + if (!buf.ReadUInt16(&key)) { + return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, + "Failed to read attribute key"); + } + if (!buf.ReadUInt16(&length)) { + return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, + "Failed to read attribute length"); + } + if (!buf.ReadUInt16(&value_type)) { + return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, + "Failed to read value type"); + } + + StunAttributeValueType value_type_enum = + static_cast<StunAttributeValueType>(value_type); + std::unique_ptr<StunAttribute> attr( + StunAttribute::Create(value_type_enum, key, length, nullptr)); + if (!attr) { + return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, + "Failed to create attribute"); + } + if (attr->length() != length) { + return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, + "Inconsistent attribute length"); + } + if (!attr->Read(&buf)) { + return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, + "Failed to read attribute content"); + } + attrs.push_back(std::move(attr)); + } + + // The first attribute should be the version... + if (attrs.empty()) { + return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, + "Empty delta!"); + } + + if (attrs[0]->type() != kVersionKey || + attrs[0]->value_type() != STUN_VALUE_UINT64) { + return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, + "Missing version!"); + } + + uint64_t version_in_delta = + reinterpret_cast<const StunUInt64Attribute*>(attrs[0].get())->value(); + attrs.pop_front(); + + return std::make_pair(std::max(version_in_delta, version_in_delta), + std::move(attrs)); +} + +// Apply a delta return an StunUInt64Attribute to ack the update. +webrtc::RTCErrorOr< + std::pair<std::unique_ptr<StunUInt64Attribute>, std::vector<uint16_t>>> +StunDictionaryView::ApplyDelta(const StunByteStringAttribute& delta) { + auto parsed_delta = ParseDelta(delta); + if (!parsed_delta.ok()) { + return webrtc::RTCError(parsed_delta.error()); + } + + uint64_t version_in_delta = parsed_delta.value().first; + + // Check that update does not overflow max_bytes_stored_. + int new_bytes_stored = bytes_stored_; + for (auto& attr : parsed_delta.value().second) { + auto old_version = version_per_key_.find(attr->type()); + if (old_version == version_per_key_.end() || + version_in_delta > old_version->second) { + size_t new_length = attr->length(); + size_t old_length = GetLength(attr->type()); + if (old_version == version_per_key_.end()) { + new_length += sizeof(int64_t); + } + + new_bytes_stored = new_bytes_stored + new_length - old_length; + if (new_bytes_stored <= 0) { + RTC_LOG(LS_WARNING) + << "attr: " << attr->type() << " old_length: " << old_length + << " new_length: " << new_length + << " bytes_stored_: " << bytes_stored_ + << " new_bytes_stored: " << new_bytes_stored; + return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER); + } + if (new_bytes_stored > max_bytes_stored_) { + RTC_LOG(LS_INFO) << "attr: " << attr->type() + << " old_length: " << old_length + << " new_length: " << new_length + << " bytes_stored_: " << bytes_stored_ + << " new_bytes_stored: " << new_bytes_stored; + } + } + } + if (new_bytes_stored > max_bytes_stored_) { + RTC_LOG(LS_INFO) << " bytes_stored_: " << bytes_stored_ + << " new_bytes_stored: " << new_bytes_stored; + return webrtc::RTCError(webrtc::RTCErrorType::RESOURCE_EXHAUSTED); + } + + // Apply the update. + std::vector<uint16_t> keys; + for (auto& attr : parsed_delta.value().second) { + if (version_in_delta > version_per_key_[attr->type()]) { + version_per_key_[attr->type()] = version_in_delta; + keys.push_back(attr->type()); + if (attr->value_type() == STUN_VALUE_BYTE_STRING && attr->length() == 0) { + attrs_.erase(attr->type()); + } else { + attrs_[attr->type()] = std::move(attr); + } + } + } + bytes_stored_ = new_bytes_stored; + + return std::make_pair(std::make_unique<StunUInt64Attribute>( + STUN_ATTR_GOOG_DELTA_ACK, version_in_delta), + std::move(keys)); +} + +size_t StunDictionaryView::GetLength(int key) const { + auto attr = GetOrNull(key); + if (attr != nullptr) { + return attr->length(); + } + return 0; +} + +void StunDictionaryWriter::Disable() { + disabled_ = true; +} + +void StunDictionaryWriter::Delete(int key) { + if (disabled_) { + return; + } + + if (dictionary_) { + if (dictionary_->attrs_.find(key) == dictionary_->attrs_.end()) { + return; + } + } + + // remove any pending updates. + pending_.erase( + std::remove_if(pending_.begin(), pending_.end(), + [key](const auto& p) { return p.second->type() == key; }), + pending_.end()); + + // Create tombstone. + auto tombstone = std::make_unique<StunByteStringAttribute>(key); + + // add a pending entry. + pending_.push_back(std::make_pair(++version_, tombstone.get())); + + // store the tombstone. + tombstones_[key] = std::move(tombstone); + + if (dictionary_) { + // remove value + dictionary_->attrs_.erase(key); + } +} + +void StunDictionaryWriter::Set(std::unique_ptr<StunAttribute> attr) { + if (disabled_) { + return; + } + int key = attr->type(); + // remove any pending updates. + pending_.erase( + std::remove_if(pending_.begin(), pending_.end(), + [key](const auto& p) { return p.second->type() == key; }), + pending_.end()); + + // remove any existing key. + tombstones_.erase(key); + + // create pending entry. + pending_.push_back(std::make_pair(++version_, attr.get())); + + if (dictionary_) { + // store attribute. + dictionary_->attrs_[key] = std::move(attr); + } +} + +// Create an StunByteStringAttribute containing the pending (e.g not ack:ed) +// modifications. +std::unique_ptr<StunByteStringAttribute> StunDictionaryWriter::CreateDelta() { + if (disabled_) { + return nullptr; + } + if (pending_.empty()) { + return nullptr; + } + + rtc::ByteBufferWriter buf; + buf.WriteUInt16(StunDictionaryView::kDeltaMagic); // 0,1 + buf.WriteUInt16(StunDictionaryView::kDeltaVersion); // 2,3 + + // max version in Delta. + buf.WriteUInt16(StunDictionaryView::kVersionKey); // 4,5 + buf.WriteUInt16(8); // 6,7 + buf.WriteUInt16(STUN_VALUE_UINT64); // 8,9 + buf.WriteUInt64(pending_.back().first); // 10-17 + // attributes + for (const auto& attr : pending_) { + buf.WriteUInt16(attr.second->type()); + buf.WriteUInt16(static_cast<uint16_t>(attr.second->length())); + buf.WriteUInt16(attr.second->value_type()); + if (!attr.second->Write(&buf)) { + RTC_LOG(LS_ERROR) << "Failed to write key: " << attr.second->type(); + return nullptr; + } + } + return std::make_unique<StunByteStringAttribute>(STUN_ATTR_GOOG_DELTA, + buf.Data(), buf.Length()); +} + +// Apply a delta ack, i.e prune list of pending changes. +void StunDictionaryWriter::ApplyDeltaAck(const StunUInt64Attribute& ack) { + uint64_t acked_version = ack.value(); + auto entries_to_remove = std::remove_if( + pending_.begin(), pending_.end(), + [acked_version](const auto& p) { return p.first <= acked_version; }); + + // remove tombstones. + for (auto it = entries_to_remove; it != pending_.end(); ++it) { + tombstones_.erase((*it).second->type()); + } + pending_.erase(entries_to_remove, pending_.end()); +} + +// Check if a key has a pending change (i.e a change +// that has not been acked). +bool StunDictionaryWriter::Pending(int key) const { + for (const auto& attr : pending_) { + if (attr.second->type() == key) { + return true; + } + } + return false; +} + +int StunDictionaryWriter::Pending() const { + return pending_.size(); +} + +} // namespace cricket |