187 lines
5.3 KiB
C++
187 lines
5.3 KiB
C++
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
|
|
/* vim: set ts=2 et sw=2 tw=80: */
|
|
/* This Source Code Form is subject to the terms of the Mozilla Public
|
|
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
|
* You can obtain one at http://mozilla.org/MPL/2.0/. */
|
|
|
|
#ifndef test_io_h_
|
|
#define test_io_h_
|
|
|
|
#include <string.h>
|
|
#include <map>
|
|
#include <memory>
|
|
#include <ostream>
|
|
#include <queue>
|
|
#include <string>
|
|
|
|
#include "databuffer.h"
|
|
#include "dummy_io.h"
|
|
#include "prio.h"
|
|
#include "nss_scoped_ptrs.h"
|
|
#include "sslt.h"
|
|
|
|
namespace nss_test {
|
|
|
|
class DataBuffer;
|
|
class DummyPrSocket; // Fwd decl.
|
|
|
|
// Allow us to inspect a packet before it is written.
|
|
class PacketFilter {
|
|
public:
|
|
enum Action {
|
|
KEEP, // keep the original packet unmodified
|
|
CHANGE, // change the packet to a different value
|
|
DROP // drop the packet
|
|
};
|
|
explicit PacketFilter(bool on = true) : enabled_(on) {}
|
|
virtual ~PacketFilter() {}
|
|
|
|
bool enabled() const { return enabled_; }
|
|
|
|
virtual Action Process(const DataBuffer& input, DataBuffer* output) {
|
|
if (!enabled_) {
|
|
return KEEP;
|
|
}
|
|
return Filter(input, output);
|
|
}
|
|
void Enable() { enabled_ = true; }
|
|
void Disable() { enabled_ = false; }
|
|
|
|
// The packet filter takes input and has the option of mutating it.
|
|
//
|
|
// A filter that modifies the data places the modified data in *output and
|
|
// returns CHANGE. A filter that does not modify data returns LEAVE, in which
|
|
// case the value in *output is ignored. A Filter can return DROP, in which
|
|
// case the packet is dropped (and *output is ignored).
|
|
virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0;
|
|
|
|
private:
|
|
bool enabled_;
|
|
};
|
|
|
|
class DummyPrSocket : public DummyIOLayerMethods {
|
|
public:
|
|
DummyPrSocket(const std::string& name, SSLProtocolVariant var)
|
|
: name_(name),
|
|
variant_(var),
|
|
peer_(),
|
|
input_(),
|
|
filter_(nullptr),
|
|
write_error_(0) {}
|
|
virtual ~DummyPrSocket() {}
|
|
|
|
static PRDescIdentity LayerId();
|
|
|
|
// Create a file descriptor that will reference this object. The fd must not
|
|
// live longer than this adapter; call PR_Close() before.
|
|
ScopedPRFileDesc CreateFD();
|
|
|
|
std::weak_ptr<DummyPrSocket>& peer() { return peer_; }
|
|
void SetPeer(const std::shared_ptr<DummyPrSocket>& p) { peer_ = p; }
|
|
void SetPacketFilter(const std::shared_ptr<PacketFilter>& filter) {
|
|
filter_ = filter;
|
|
}
|
|
// Drops peer, packet filter and any outstanding packets.
|
|
void Reset();
|
|
|
|
void PacketReceived(const DataBuffer& data);
|
|
int32_t Read(PRFileDesc* f, void* data, int32_t len) override;
|
|
int32_t Recv(PRFileDesc* f, void* buf, int32_t buflen, int32_t flags,
|
|
PRIntervalTime to) override;
|
|
int32_t Write(PRFileDesc* f, const void* buf, int32_t length) override;
|
|
void SetWriteError(PRErrorCode code) { write_error_ = code; }
|
|
|
|
SSLProtocolVariant variant() const { return variant_; }
|
|
bool readable() const { return !input_.empty(); }
|
|
|
|
private:
|
|
class Packet : public DataBuffer {
|
|
public:
|
|
Packet(const DataBuffer& buf) : DataBuffer(buf), offset_(0) {}
|
|
|
|
void Advance(size_t delta) {
|
|
PR_ASSERT(offset_ + delta <= len());
|
|
offset_ = std::min(len(), offset_ + delta);
|
|
}
|
|
|
|
size_t offset() const { return offset_; }
|
|
size_t remaining() const { return len() - offset_; }
|
|
|
|
private:
|
|
size_t offset_;
|
|
};
|
|
|
|
const std::string name_;
|
|
SSLProtocolVariant variant_;
|
|
std::weak_ptr<DummyPrSocket> peer_;
|
|
std::queue<Packet> input_;
|
|
std::shared_ptr<PacketFilter> filter_;
|
|
PRErrorCode write_error_;
|
|
};
|
|
|
|
// Marker interface.
|
|
class PollTarget {};
|
|
|
|
enum Event { READABLE_EVENT, TIMER_EVENT /* Must be last */ };
|
|
|
|
typedef void (*PollCallback)(PollTarget*, Event);
|
|
|
|
class Poller {
|
|
public:
|
|
static Poller* Instance(); // Get a singleton.
|
|
static void Shutdown(); // Shut it down.
|
|
|
|
class Timer {
|
|
public:
|
|
Timer(PRTime deadline, PollTarget* target, PollCallback callback)
|
|
: deadline_(deadline), target_(target), callback_(callback) {}
|
|
void Cancel() { callback_ = nullptr; }
|
|
|
|
PRTime deadline_;
|
|
PollTarget* target_;
|
|
PollCallback callback_;
|
|
};
|
|
|
|
void Wait(Event event, std::shared_ptr<DummyPrSocket>& adapter,
|
|
PollTarget* target, PollCallback cb);
|
|
void Cancel(Event event, std::shared_ptr<DummyPrSocket>& adapter);
|
|
void SetTimer(uint32_t timer_ms, PollTarget* target, PollCallback cb,
|
|
std::shared_ptr<Timer>* handle);
|
|
bool Poll();
|
|
|
|
private:
|
|
Poller() : waiters_(), timers_() {}
|
|
~Poller() {}
|
|
|
|
class Waiter {
|
|
public:
|
|
Waiter(std::shared_ptr<DummyPrSocket> io) : io_(io) {
|
|
memset(&targets_[0], 0, sizeof(targets_));
|
|
memset(&callbacks_[0], 0, sizeof(callbacks_));
|
|
}
|
|
|
|
void WaitFor(Event event, PollCallback callback);
|
|
|
|
std::shared_ptr<DummyPrSocket> io_;
|
|
PollTarget* targets_[TIMER_EVENT];
|
|
PollCallback callbacks_[TIMER_EVENT];
|
|
};
|
|
|
|
class TimerComparator {
|
|
public:
|
|
bool operator()(const std::shared_ptr<Timer> lhs,
|
|
const std::shared_ptr<Timer> rhs) {
|
|
return lhs->deadline_ > rhs->deadline_;
|
|
}
|
|
};
|
|
|
|
static Poller* instance;
|
|
std::map<std::shared_ptr<DummyPrSocket>, std::unique_ptr<Waiter>> waiters_;
|
|
std::priority_queue<std::shared_ptr<Timer>,
|
|
std::vector<std::shared_ptr<Timer>>, TimerComparator>
|
|
timers_;
|
|
};
|
|
|
|
} // namespace nss_test
|
|
|
|
#endif
|