diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/test_io.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/test_io.cc | 278 |
1 files changed, 278 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/test_io.cc b/security/nss/gtests/ssl_gtest/test_io.cc new file mode 100644 index 0000000000..e4651a2352 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/test_io.cc @@ -0,0 +1,278 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "test_io.h" + +#include <algorithm> +#include <cassert> +#include <iostream> +#include <memory> + +#include "prerror.h" +#include "prlog.h" +#include "prthread.h" + +extern bool g_ssl_gtest_verbose; + +namespace nss_test { + +#define LOG(a) std::cerr << name_ << ": " << a << std::endl +#define LOGV(a) \ + do { \ + if (g_ssl_gtest_verbose) LOG(a); \ + } while (false) + +PRDescIdentity DummyPrSocket::LayerId() { + static PRDescIdentity id = PR_GetUniqueIdentity("dummysocket"); + return id; +} + +ScopedPRFileDesc DummyPrSocket::CreateFD() { + return DummyIOLayerMethods::CreateFD(DummyPrSocket::LayerId(), this); +} + +void DummyPrSocket::Reset() { + auto p = peer_.lock(); + peer_.reset(); + if (p) { + p->peer_.reset(); + p->Reset(); + } + while (!input_.empty()) { + input_.pop(); + } + filter_ = nullptr; + write_error_ = 0; +} + +void DummyPrSocket::PacketReceived(const DataBuffer &packet) { + input_.push(Packet(packet)); +} + +int32_t DummyPrSocket::Read(PRFileDesc *f, void *data, int32_t len) { + PR_ASSERT(variant_ == ssl_variant_stream); + if (variant_ != ssl_variant_stream) { + PR_SetError(PR_INVALID_METHOD_ERROR, 0); + return -1; + } + + auto dst = peer_.lock(); + if (!dst) { + PR_SetError(PR_NOT_CONNECTED_ERROR, 0); + return -1; + } + + if (input_.empty()) { + LOGV("Read --> wouldblock " << len); + PR_SetError(PR_WOULD_BLOCK_ERROR, 0); + return -1; + } + + auto &front = input_.front(); + size_t to_read = + std::min(static_cast<size_t>(len), front.len() - front.offset()); + memcpy(data, static_cast<const void *>(front.data() + front.offset()), + to_read); + front.Advance(to_read); + + if (!front.remaining()) { + input_.pop(); + } + + return static_cast<int32_t>(to_read); +} + +int32_t DummyPrSocket::Recv(PRFileDesc *f, void *buf, int32_t buflen, + int32_t flags, PRIntervalTime to) { + PR_ASSERT(flags == 0); + if (flags != 0) { + PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0); + return -1; + } + + if (variant() != ssl_variant_datagram) { + return Read(f, buf, buflen); + } + + auto dst = peer_.lock(); + if (!dst) { + PR_SetError(PR_NOT_CONNECTED_ERROR, 0); + return -1; + } + + if (input_.empty()) { + PR_SetError(PR_WOULD_BLOCK_ERROR, 0); + return -1; + } + + auto &front = input_.front(); + if (static_cast<size_t>(buflen) < front.len()) { + PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0); + return -1; + } + + size_t count = front.len(); + memcpy(buf, front.data(), count); + + input_.pop(); + return static_cast<int32_t>(count); +} + +int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) { + if (write_error_) { + PR_SetError(write_error_, 0); + return -1; + } + + auto dst = peer_.lock(); + if (!dst) { + PR_SetError(PR_NOT_CONNECTED_ERROR, 0); + return -1; + } + + DataBuffer packet(static_cast<const uint8_t *>(buf), + static_cast<size_t>(length)); + DataBuffer filtered; + PacketFilter::Action action = PacketFilter::KEEP; + if (filter_) { + LOGV("Original packet: " << packet); + action = filter_->Process(packet, &filtered); + } + switch (action) { + case PacketFilter::CHANGE: + LOG("Filtered packet: " << filtered); + dst->PacketReceived(filtered); + break; + case PacketFilter::DROP: + LOG("Drop packet"); + break; + case PacketFilter::KEEP: + dst->PacketReceived(packet); + break; + } + // libssl can't handle it if this reports something other than the length + // of what was passed in (or less, but we're not doing partial writes). + return static_cast<int32_t>(packet.len()); +} + +Poller *Poller::instance; + +Poller *Poller::Instance() { + if (!instance) instance = new Poller(); + + return instance; +} + +void Poller::Shutdown() { + delete instance; + instance = nullptr; +} + +void Poller::Wait(Event event, std::shared_ptr<DummyPrSocket> &adapter, + PollTarget *target, PollCallback cb) { + assert(event < TIMER_EVENT); + if (event >= TIMER_EVENT) return; + + std::unique_ptr<Waiter> waiter; + auto it = waiters_.find(adapter); + if (it == waiters_.end()) { + waiter.reset(new Waiter(adapter)); + } else { + waiter = std::move(it->second); + } + + waiter->targets_[event] = target; + waiter->callbacks_[event] = cb; + waiters_[adapter] = std::move(waiter); +} + +void Poller::Cancel(Event event, std::shared_ptr<DummyPrSocket> &adapter) { + auto it = waiters_.find(adapter); + if (it == waiters_.end()) { + return; + } + + auto &waiter = it->second; + waiter->targets_[event] = nullptr; + waiter->callbacks_[event] = nullptr; + + // Clean up if there are no callbacks. + for (size_t i = 0; i < TIMER_EVENT; ++i) { + if (waiter->callbacks_[i]) return; + } + + waiters_.erase(adapter); +} + +void Poller::SetTimer(uint32_t timer_ms, PollTarget *target, PollCallback cb, + std::shared_ptr<Timer> *timer) { + auto t = std::make_shared<Timer>(PR_Now() + timer_ms * 1000, target, cb); + timers_.push(t); + if (timer) *timer = t; +} + +bool Poller::Poll() { + if (g_ssl_gtest_verbose) { + std::cerr << "Poll() waiters = " << waiters_.size() + << " timers = " << timers_.size() << std::endl; + } + PRIntervalTime timeout = PR_INTERVAL_NO_TIMEOUT; + PRTime now = PR_Now(); + bool fired = false; + + // Figure out the timer for the select. + if (!timers_.empty()) { + auto first_timer = timers_.top(); + if (now >= first_timer->deadline_) { + // Timer expired. + timeout = PR_INTERVAL_NO_WAIT; + } else { + timeout = + PR_MillisecondsToInterval((first_timer->deadline_ - now) / 1000); + } + } + + for (auto it = waiters_.begin(); it != waiters_.end(); ++it) { + auto &waiter = it->second; + + if (waiter->callbacks_[READABLE_EVENT]) { + if (waiter->io_->readable()) { + PollCallback callback = waiter->callbacks_[READABLE_EVENT]; + PollTarget *target = waiter->targets_[READABLE_EVENT]; + waiter->callbacks_[READABLE_EVENT] = nullptr; + waiter->targets_[READABLE_EVENT] = nullptr; + callback(target, READABLE_EVENT); + fired = true; + } + } + } + + if (fired) timeout = PR_INTERVAL_NO_WAIT; + + // Can't wait forever and also have nothing readable now. + if (timeout == PR_INTERVAL_NO_TIMEOUT) return false; + + // Sleep. + if (timeout != PR_INTERVAL_NO_WAIT) { + PR_Sleep(timeout); + } + + // Now process anything that timed out. + now = PR_Now(); + while (!timers_.empty()) { + if (now < timers_.top()->deadline_) break; + + auto timer = timers_.top(); + timers_.pop(); + if (timer->callback_) { + timer->callback_(timer->target_, TIMER_EVENT); + } + } + + return true; +} + +} // namespace nss_test |