diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/test_io.h')
-rw-r--r-- | security/nss/gtests/ssl_gtest/test_io.h | 187 |
1 files changed, 187 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/test_io.h b/security/nss/gtests/ssl_gtest/test_io.h new file mode 100644 index 0000000000..e262fb123e --- /dev/null +++ b/security/nss/gtests/ssl_gtest/test_io.h @@ -0,0 +1,187 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef test_io_h_ +#define test_io_h_ + +#include <string.h> +#include <map> +#include <memory> +#include <ostream> +#include <queue> +#include <string> + +#include "databuffer.h" +#include "dummy_io.h" +#include "prio.h" +#include "nss_scoped_ptrs.h" +#include "sslt.h" + +namespace nss_test { + +class DataBuffer; +class DummyPrSocket; // Fwd decl. + +// Allow us to inspect a packet before it is written. +class PacketFilter { + public: + enum Action { + KEEP, // keep the original packet unmodified + CHANGE, // change the packet to a different value + DROP // drop the packet + }; + explicit PacketFilter(bool on = true) : enabled_(on) {} + virtual ~PacketFilter() {} + + bool enabled() const { return enabled_; } + + virtual Action Process(const DataBuffer& input, DataBuffer* output) { + if (!enabled_) { + return KEEP; + } + return Filter(input, output); + } + void Enable() { enabled_ = true; } + void Disable() { enabled_ = false; } + + // The packet filter takes input and has the option of mutating it. + // + // A filter that modifies the data places the modified data in *output and + // returns CHANGE. A filter that does not modify data returns LEAVE, in which + // case the value in *output is ignored. A Filter can return DROP, in which + // case the packet is dropped (and *output is ignored). + virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0; + + private: + bool enabled_; +}; + +class DummyPrSocket : public DummyIOLayerMethods { + public: + DummyPrSocket(const std::string& name, SSLProtocolVariant var) + : name_(name), + variant_(var), + peer_(), + input_(), + filter_(nullptr), + write_error_(0) {} + virtual ~DummyPrSocket() {} + + static PRDescIdentity LayerId(); + + // Create a file descriptor that will reference this object. The fd must not + // live longer than this adapter; call PR_Close() before. + ScopedPRFileDesc CreateFD(); + + std::weak_ptr<DummyPrSocket>& peer() { return peer_; } + void SetPeer(const std::shared_ptr<DummyPrSocket>& p) { peer_ = p; } + void SetPacketFilter(const std::shared_ptr<PacketFilter>& filter) { + filter_ = filter; + } + // Drops peer, packet filter and any outstanding packets. + void Reset(); + + void PacketReceived(const DataBuffer& data); + int32_t Read(PRFileDesc* f, void* data, int32_t len) override; + int32_t Recv(PRFileDesc* f, void* buf, int32_t buflen, int32_t flags, + PRIntervalTime to) override; + int32_t Write(PRFileDesc* f, const void* buf, int32_t length) override; + void SetWriteError(PRErrorCode code) { write_error_ = code; } + + SSLProtocolVariant variant() const { return variant_; } + bool readable() const { return !input_.empty(); } + + private: + class Packet : public DataBuffer { + public: + Packet(const DataBuffer& buf) : DataBuffer(buf), offset_(0) {} + + void Advance(size_t delta) { + PR_ASSERT(offset_ + delta <= len()); + offset_ = std::min(len(), offset_ + delta); + } + + size_t offset() const { return offset_; } + size_t remaining() const { return len() - offset_; } + + private: + size_t offset_; + }; + + const std::string name_; + SSLProtocolVariant variant_; + std::weak_ptr<DummyPrSocket> peer_; + std::queue<Packet> input_; + std::shared_ptr<PacketFilter> filter_; + PRErrorCode write_error_; +}; + +// Marker interface. +class PollTarget {}; + +enum Event { READABLE_EVENT, TIMER_EVENT /* Must be last */ }; + +typedef void (*PollCallback)(PollTarget*, Event); + +class Poller { + public: + static Poller* Instance(); // Get a singleton. + static void Shutdown(); // Shut it down. + + class Timer { + public: + Timer(PRTime deadline, PollTarget* target, PollCallback callback) + : deadline_(deadline), target_(target), callback_(callback) {} + void Cancel() { callback_ = nullptr; } + + PRTime deadline_; + PollTarget* target_; + PollCallback callback_; + }; + + void Wait(Event event, std::shared_ptr<DummyPrSocket>& adapter, + PollTarget* target, PollCallback cb); + void Cancel(Event event, std::shared_ptr<DummyPrSocket>& adapter); + void SetTimer(uint32_t timer_ms, PollTarget* target, PollCallback cb, + std::shared_ptr<Timer>* handle); + bool Poll(); + + private: + Poller() : waiters_(), timers_() {} + ~Poller() {} + + class Waiter { + public: + Waiter(std::shared_ptr<DummyPrSocket> io) : io_(io) { + memset(&targets_[0], 0, sizeof(targets_)); + memset(&callbacks_[0], 0, sizeof(callbacks_)); + } + + void WaitFor(Event event, PollCallback callback); + + std::shared_ptr<DummyPrSocket> io_; + PollTarget* targets_[TIMER_EVENT]; + PollCallback callbacks_[TIMER_EVENT]; + }; + + class TimerComparator { + public: + bool operator()(const std::shared_ptr<Timer> lhs, + const std::shared_ptr<Timer> rhs) { + return lhs->deadline_ > rhs->deadline_; + } + }; + + static Poller* instance; + std::map<std::shared_ptr<DummyPrSocket>, std::unique_ptr<Waiter>> waiters_; + std::priority_queue<std::shared_ptr<Timer>, + std::vector<std::shared_ptr<Timer>>, TimerComparator> + timers_; +}; + +} // namespace nss_test + +#endif |