/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ /* vim: set ts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ #ifndef test_io_h_ #define test_io_h_ #include #include #include #include #include #include #include "databuffer.h" #include "dummy_io.h" #include "prio.h" #include "nss_scoped_ptrs.h" #include "sslt.h" namespace nss_test { class DataBuffer; class DummyPrSocket; // Fwd decl. // Allow us to inspect a packet before it is written. class PacketFilter { public: enum Action { KEEP, // keep the original packet unmodified CHANGE, // change the packet to a different value DROP // drop the packet }; explicit PacketFilter(bool on = true) : enabled_(on) {} virtual ~PacketFilter() {} bool enabled() const { return enabled_; } virtual Action Process(const DataBuffer& input, DataBuffer* output) { if (!enabled_) { return KEEP; } return Filter(input, output); } void Enable() { enabled_ = true; } void Disable() { enabled_ = false; } // The packet filter takes input and has the option of mutating it. // // A filter that modifies the data places the modified data in *output and // returns CHANGE. A filter that does not modify data returns LEAVE, in which // case the value in *output is ignored. A Filter can return DROP, in which // case the packet is dropped (and *output is ignored). virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0; private: bool enabled_; }; class DummyPrSocket : public DummyIOLayerMethods { public: DummyPrSocket(const std::string& name, SSLProtocolVariant var) : name_(name), variant_(var), peer_(), input_(), filter_(nullptr), write_error_(0) {} virtual ~DummyPrSocket() {} static PRDescIdentity LayerId(); // Create a file descriptor that will reference this object. The fd must not // live longer than this adapter; call PR_Close() before. ScopedPRFileDesc CreateFD(); std::weak_ptr& peer() { return peer_; } void SetPeer(const std::shared_ptr& p) { peer_ = p; } void SetPacketFilter(const std::shared_ptr& filter) { filter_ = filter; } // Drops peer, packet filter and any outstanding packets. void Reset(); void PacketReceived(const DataBuffer& data); int32_t Read(PRFileDesc* f, void* data, int32_t len) override; int32_t Recv(PRFileDesc* f, void* buf, int32_t buflen, int32_t flags, PRIntervalTime to) override; int32_t Write(PRFileDesc* f, const void* buf, int32_t length) override; void SetWriteError(PRErrorCode code) { write_error_ = code; } SSLProtocolVariant variant() const { return variant_; } bool readable() const { return !input_.empty(); } private: class Packet : public DataBuffer { public: Packet(const DataBuffer& buf) : DataBuffer(buf), offset_(0) {} void Advance(size_t delta) { PR_ASSERT(offset_ + delta <= len()); offset_ = std::min(len(), offset_ + delta); } size_t offset() const { return offset_; } size_t remaining() const { return len() - offset_; } private: size_t offset_; }; const std::string name_; SSLProtocolVariant variant_; std::weak_ptr peer_; std::queue input_; std::shared_ptr filter_; PRErrorCode write_error_; }; // Marker interface. class PollTarget {}; enum Event { READABLE_EVENT, TIMER_EVENT /* Must be last */ }; typedef void (*PollCallback)(PollTarget*, Event); class Poller { public: static Poller* Instance(); // Get a singleton. static void Shutdown(); // Shut it down. class Timer { public: Timer(PRTime deadline, PollTarget* target, PollCallback callback) : deadline_(deadline), target_(target), callback_(callback) {} void Cancel() { callback_ = nullptr; } PRTime deadline_; PollTarget* target_; PollCallback callback_; }; void Wait(Event event, std::shared_ptr& adapter, PollTarget* target, PollCallback cb); void Cancel(Event event, std::shared_ptr& adapter); void SetTimer(uint32_t timer_ms, PollTarget* target, PollCallback cb, std::shared_ptr* handle); bool Poll(); private: Poller() : waiters_(), timers_() {} ~Poller() {} class Waiter { public: Waiter(std::shared_ptr io) : io_(io) { memset(&targets_[0], 0, sizeof(targets_)); memset(&callbacks_[0], 0, sizeof(callbacks_)); } void WaitFor(Event event, PollCallback callback); std::shared_ptr io_; PollTarget* targets_[TIMER_EVENT]; PollCallback callbacks_[TIMER_EVENT]; }; class TimerComparator { public: bool operator()(const std::shared_ptr lhs, const std::shared_ptr rhs) { return lhs->deadline_ > rhs->deadline_; } }; static Poller* instance; std::map, std::unique_ptr> waiters_; std::priority_queue, std::vector>, TimerComparator> timers_; }; } // namespace nss_test #endif