diff options
Diffstat (limited to '')
37 files changed, 9724 insertions, 0 deletions
diff --git a/netwerk/protocol/websocket/BaseWebSocketChannel.cpp b/netwerk/protocol/websocket/BaseWebSocketChannel.cpp new file mode 100644 index 0000000000..6e9589afb2 --- /dev/null +++ b/netwerk/protocol/websocket/BaseWebSocketChannel.cpp @@ -0,0 +1,378 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et tw=80 : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "WebSocketLog.h" +#include "BaseWebSocketChannel.h" +#include "MainThreadUtils.h" +#include "nsILoadGroup.h" +#include "nsINode.h" +#include "nsIInterfaceRequestor.h" +#include "nsProxyRelease.h" +#include "nsStandardURL.h" +#include "LoadInfo.h" +#include "mozilla/dom/ContentChild.h" +#include "nsITransportProvider.h" + +using mozilla::dom::ContentChild; + +namespace mozilla { +namespace net { + +LazyLogModule webSocketLog("nsWebSocket"); +static uint64_t gNextWebSocketID = 0; + +// We use only 53 bits for the WebSocket serial ID so that it can be converted +// to and from a JS value without loss of precision. The upper bits of the +// WebSocket serial ID hold the process ID. The lower bits identify the +// WebSocket. +static const uint64_t kWebSocketIDTotalBits = 53; +static const uint64_t kWebSocketIDProcessBits = 22; +static const uint64_t kWebSocketIDWebSocketBits = + kWebSocketIDTotalBits - kWebSocketIDProcessBits; + +BaseWebSocketChannel::BaseWebSocketChannel() + : mWasOpened(0), + mClientSetPingInterval(0), + mClientSetPingTimeout(0), + mEncrypted(false), + mPingForced(false), + mIsServerSide(false), + mPingInterval(0), + mPingResponseTimeout(10000), + mHttpChannelId(0) { + // Generation of a unique serial ID. + uint64_t processID = 0; + if (XRE_IsContentProcess()) { + ContentChild* cc = ContentChild::GetSingleton(); + processID = cc->GetID(); + } + + uint64_t processBits = + processID & ((uint64_t(1) << kWebSocketIDProcessBits) - 1); + + // Make sure no actual webSocket ends up with mWebSocketID == 0 but less then + // what the kWebSocketIDProcessBits allows. + if (++gNextWebSocketID >= (uint64_t(1) << kWebSocketIDWebSocketBits)) { + gNextWebSocketID = 1; + } + + uint64_t webSocketBits = + gNextWebSocketID & ((uint64_t(1) << kWebSocketIDWebSocketBits) - 1); + mSerial = (processBits << kWebSocketIDWebSocketBits) | webSocketBits; +} + +BaseWebSocketChannel::~BaseWebSocketChannel() { + NS_ReleaseOnMainThread("BaseWebSocketChannel::mLoadGroup", + mLoadGroup.forget()); + NS_ReleaseOnMainThread("BaseWebSocketChannel::mLoadInfo", mLoadInfo.forget()); + nsCOMPtr<nsISerialEventTarget> target; + { + auto lock = mTargetThread.Lock(); + target.swap(*lock); + } + NS_ReleaseOnMainThread("BaseWebSocketChannel::mTargetThread", + target.forget()); +} + +//----------------------------------------------------------------------------- +// BaseWebSocketChannel::nsIWebSocketChannel +//----------------------------------------------------------------------------- + +NS_IMETHODIMP +BaseWebSocketChannel::GetOriginalURI(nsIURI** aOriginalURI) { + LOG(("BaseWebSocketChannel::GetOriginalURI() %p\n", this)); + + if (!mOriginalURI) return NS_ERROR_NOT_INITIALIZED; + *aOriginalURI = do_AddRef(mOriginalURI).take(); + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::GetURI(nsIURI** aURI) { + LOG(("BaseWebSocketChannel::GetURI() %p\n", this)); + + if (!mOriginalURI) return NS_ERROR_NOT_INITIALIZED; + if (mURI) { + *aURI = do_AddRef(mURI).take(); + } else { + *aURI = do_AddRef(mOriginalURI).take(); + } + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::GetNotificationCallbacks( + nsIInterfaceRequestor** aNotificationCallbacks) { + LOG(("BaseWebSocketChannel::GetNotificationCallbacks() %p\n", this)); + *aNotificationCallbacks = do_AddRef(mCallbacks).take(); + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::SetNotificationCallbacks( + nsIInterfaceRequestor* aNotificationCallbacks) { + LOG(("BaseWebSocketChannel::SetNotificationCallbacks() %p\n", this)); + mCallbacks = aNotificationCallbacks; + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::GetLoadGroup(nsILoadGroup** aLoadGroup) { + LOG(("BaseWebSocketChannel::GetLoadGroup() %p\n", this)); + *aLoadGroup = do_AddRef(mLoadGroup).take(); + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::SetLoadGroup(nsILoadGroup* aLoadGroup) { + LOG(("BaseWebSocketChannel::SetLoadGroup() %p\n", this)); + mLoadGroup = aLoadGroup; + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::SetLoadInfo(nsILoadInfo* aLoadInfo) { + MOZ_RELEASE_ASSERT(aLoadInfo, "loadinfo can't be null"); + mLoadInfo = aLoadInfo; + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::GetLoadInfo(nsILoadInfo** aLoadInfo) { + *aLoadInfo = do_AddRef(mLoadInfo).take(); + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::GetExtensions(nsACString& aExtensions) { + LOG(("BaseWebSocketChannel::GetExtensions() %p\n", this)); + aExtensions = mNegotiatedExtensions; + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::GetProtocol(nsACString& aProtocol) { + LOG(("BaseWebSocketChannel::GetProtocol() %p\n", this)); + aProtocol = mProtocol; + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::SetProtocol(const nsACString& aProtocol) { + LOG(("BaseWebSocketChannel::SetProtocol() %p\n", this)); + mProtocol = aProtocol; /* the sub protocol */ + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::GetPingInterval(uint32_t* aSeconds) { + // stored in ms but should only have second resolution + MOZ_ASSERT(!(mPingInterval % 1000)); + + *aSeconds = mPingInterval / 1000; + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::SetPingInterval(uint32_t aSeconds) { + MOZ_ASSERT(NS_IsMainThread()); + + if (mWasOpened) { + return NS_ERROR_IN_PROGRESS; + } + + mPingInterval = aSeconds * 1000; + mClientSetPingInterval = 1; + + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::GetPingTimeout(uint32_t* aSeconds) { + // stored in ms but should only have second resolution + MOZ_ASSERT(!(mPingResponseTimeout % 1000)); + + *aSeconds = mPingResponseTimeout / 1000; + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::SetPingTimeout(uint32_t aSeconds) { + MOZ_ASSERT(NS_IsMainThread()); + + if (mWasOpened) { + return NS_ERROR_IN_PROGRESS; + } + + mPingResponseTimeout = aSeconds * 1000; + mClientSetPingTimeout = 1; + + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::InitLoadInfoNative( + nsINode* aLoadingNode, nsIPrincipal* aLoadingPrincipal, + nsIPrincipal* aTriggeringPrincipal, + nsICookieJarSettings* aCookieJarSettings, uint32_t aSecurityFlags, + nsContentPolicyType aContentPolicyType, uint32_t aSandboxFlags) { + mLoadInfo = new LoadInfo( + aLoadingPrincipal, aTriggeringPrincipal, aLoadingNode, aSecurityFlags, + aContentPolicyType, Maybe<mozilla::dom::ClientInfo>(), + Maybe<mozilla::dom::ServiceWorkerDescriptor>(), aSandboxFlags); + if (aCookieJarSettings) { + mLoadInfo->SetCookieJarSettings(aCookieJarSettings); + } + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::InitLoadInfo(nsINode* aLoadingNode, + nsIPrincipal* aLoadingPrincipal, + nsIPrincipal* aTriggeringPrincipal, + uint32_t aSecurityFlags, + nsContentPolicyType aContentPolicyType) { + return InitLoadInfoNative(aLoadingNode, aLoadingPrincipal, + aTriggeringPrincipal, nullptr, aSecurityFlags, + aContentPolicyType, 0); +} + +NS_IMETHODIMP +BaseWebSocketChannel::GetSerial(uint32_t* aSerial) { + if (!aSerial) { + return NS_ERROR_FAILURE; + } + + *aSerial = mSerial; + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::SetSerial(uint32_t aSerial) { + mSerial = aSerial; + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::SetServerParameters( + nsITransportProvider* aProvider, const nsACString& aNegotiatedExtensions) { + MOZ_ASSERT(aProvider); + mServerTransportProvider = aProvider; + mNegotiatedExtensions = aNegotiatedExtensions; + mIsServerSide = true; + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::GetHttpChannelId(uint64_t* aHttpChannelId) { + *aHttpChannelId = mHttpChannelId; + return NS_OK; +} + +//----------------------------------------------------------------------------- +// BaseWebSocketChannel::nsIProtocolHandler +//----------------------------------------------------------------------------- + +NS_IMETHODIMP +BaseWebSocketChannel::GetScheme(nsACString& aScheme) { + LOG(("BaseWebSocketChannel::GetScheme() %p\n", this)); + + if (mEncrypted) { + aScheme.AssignLiteral("wss"); + } else { + aScheme.AssignLiteral("ws"); + } + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::NewChannel(nsIURI* aURI, nsILoadInfo* aLoadInfo, + nsIChannel** outChannel) { + LOG(("BaseWebSocketChannel::NewChannel() %p\n", this)); + return NS_ERROR_NOT_IMPLEMENTED; +} + +NS_IMETHODIMP +BaseWebSocketChannel::AllowPort(int32_t port, const char* scheme, + bool* _retval) { + LOG(("BaseWebSocketChannel::AllowPort() %p\n", this)); + + // do not override any blacklisted ports + *_retval = false; + return NS_OK; +} + +//----------------------------------------------------------------------------- +// BaseWebSocketChannel::nsIThreadRetargetableRequest +//----------------------------------------------------------------------------- + +NS_IMETHODIMP +BaseWebSocketChannel::RetargetDeliveryTo(nsISerialEventTarget* aTargetThread) { + MOZ_ASSERT(NS_IsMainThread()); + MOZ_ASSERT(aTargetThread); + MOZ_ASSERT(!mWasOpened, "Should not be called after AsyncOpen!"); + MOZ_ASSERT(aTargetThread); + + auto lock = mTargetThread.Lock(); + MOZ_ASSERT(!lock.ref(), + "Delivery target should be set once, before AsyncOpen"); + lock.ref() = aTargetThread; + return NS_OK; +} + +NS_IMETHODIMP +BaseWebSocketChannel::GetDeliveryTarget(nsISerialEventTarget** aTargetThread) { + MOZ_ASSERT(NS_IsMainThread()); + + nsCOMPtr<nsISerialEventTarget> target = GetTargetThread(); + if (!target) { + target = GetCurrentSerialEventTarget(); + } + target.forget(aTargetThread); + return NS_OK; +} + +already_AddRefed<nsISerialEventTarget> BaseWebSocketChannel::GetTargetThread() { + nsCOMPtr<nsISerialEventTarget> target; + auto lock = mTargetThread.Lock(); + target = *lock; + return target.forget(); +} + +bool BaseWebSocketChannel::IsOnTargetThread() { + nsCOMPtr<nsISerialEventTarget> target = GetTargetThread(); + if (!target) { + MOZ_ASSERT(false); + return false; + } + + bool isOnTargetThread = false; + nsresult rv = target->IsOnCurrentThread(&isOnTargetThread); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + return NS_FAILED(rv) ? false : isOnTargetThread; +} + +BaseWebSocketChannel::ListenerAndContextContainer::ListenerAndContextContainer( + nsIWebSocketListener* aListener, nsISupports* aContext) + : mListener(aListener), mContext(aContext) { + MOZ_ASSERT(NS_IsMainThread()); + MOZ_ASSERT(mListener); +} + +BaseWebSocketChannel::ListenerAndContextContainer:: + ~ListenerAndContextContainer() { + MOZ_ASSERT(mListener); + + NS_ReleaseOnMainThread( + "BaseWebSocketChannel::ListenerAndContextContainer::mListener", + mListener.forget()); + NS_ReleaseOnMainThread( + "BaseWebSocketChannel::ListenerAndContextContainer::mContext", + mContext.forget()); +} + +} // namespace net +} // namespace mozilla diff --git a/netwerk/protocol/websocket/BaseWebSocketChannel.h b/netwerk/protocol/websocket/BaseWebSocketChannel.h new file mode 100644 index 0000000000..ee05e5bf4c --- /dev/null +++ b/netwerk/protocol/websocket/BaseWebSocketChannel.h @@ -0,0 +1,137 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et tw=80 : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_net_BaseWebSocketChannel_h +#define mozilla_net_BaseWebSocketChannel_h + +#include "mozilla/DataMutex.h" +#include "nsIWebSocketChannel.h" +#include "nsIWebSocketListener.h" +#include "nsIProtocolHandler.h" +#include "nsIThread.h" +#include "nsIThreadRetargetableRequest.h" +#include "nsCOMPtr.h" +#include "nsString.h" + +namespace mozilla { +namespace net { + +const static int32_t kDefaultWSPort = 80; +const static int32_t kDefaultWSSPort = 443; + +class BaseWebSocketChannel : public nsIWebSocketChannel, + public nsIProtocolHandler, + public nsIThreadRetargetableRequest { + public: + BaseWebSocketChannel(); + + NS_DECL_NSIPROTOCOLHANDLER + NS_DECL_NSITHREADRETARGETABLEREQUEST + + NS_IMETHOD QueryInterface(const nsIID& uuid, void** result) override = 0; + NS_IMETHOD_(MozExternalRefCountType) AddRef(void) override = 0; + NS_IMETHOD_(MozExternalRefCountType) Release(void) override = 0; + + // Partial implementation of nsIWebSocketChannel + // + NS_IMETHOD GetOriginalURI(nsIURI** aOriginalURI) override; + NS_IMETHOD GetURI(nsIURI** aURI) override; + NS_IMETHOD GetNotificationCallbacks( + nsIInterfaceRequestor** aNotificationCallbacks) override; + NS_IMETHOD SetNotificationCallbacks( + nsIInterfaceRequestor* aNotificationCallbacks) override; + NS_IMETHOD GetLoadGroup(nsILoadGroup** aLoadGroup) override; + NS_IMETHOD SetLoadGroup(nsILoadGroup* aLoadGroup) override; + NS_IMETHOD SetLoadInfo(nsILoadInfo* aLoadInfo) override; + NS_IMETHOD GetLoadInfo(nsILoadInfo** aLoadInfo) override; + NS_IMETHOD GetExtensions(nsACString& aExtensions) override; + NS_IMETHOD GetProtocol(nsACString& aProtocol) override; + NS_IMETHOD SetProtocol(const nsACString& aProtocol) override; + NS_IMETHOD GetPingInterval(uint32_t* aSeconds) override; + NS_IMETHOD SetPingInterval(uint32_t aSeconds) override; + NS_IMETHOD GetPingTimeout(uint32_t* aSeconds) override; + NS_IMETHOD SetPingTimeout(uint32_t aSeconds) override; + NS_IMETHOD InitLoadInfoNative(nsINode* aLoadingNode, + nsIPrincipal* aLoadingPrincipal, + nsIPrincipal* aTriggeringPrincipal, + nsICookieJarSettings* aCookieJarSettings, + uint32_t aSecurityFlags, + nsContentPolicyType aContentPolicyType, + uint32_t aSandboxFlags) override; + NS_IMETHOD InitLoadInfo(nsINode* aLoadingNode, + nsIPrincipal* aLoadingPrincipal, + nsIPrincipal* aTriggeringPrincipal, + uint32_t aSecurityFlags, + nsContentPolicyType aContentPolicyType) override; + NS_IMETHOD GetSerial(uint32_t* aSerial) override; + NS_IMETHOD SetSerial(uint32_t aSerial) override; + NS_IMETHOD SetServerParameters( + nsITransportProvider* aProvider, + const nsACString& aNegotiatedExtensions) override; + NS_IMETHOD GetHttpChannelId(uint64_t* aHttpChannelId) override; + + // Off main thread URI access. + virtual void GetEffectiveURL(nsAString& aEffectiveURL) const = 0; + virtual bool IsEncrypted() const = 0; + + already_AddRefed<nsISerialEventTarget> GetTargetThread(); + bool IsOnTargetThread(); + + class ListenerAndContextContainer final { + public: + NS_INLINE_DECL_THREADSAFE_REFCOUNTING(ListenerAndContextContainer) + + ListenerAndContextContainer(nsIWebSocketListener* aListener, + nsISupports* aContext); + + nsCOMPtr<nsIWebSocketListener> mListener; + nsCOMPtr<nsISupports> mContext; + + private: + ~ListenerAndContextContainer(); + }; + + protected: + virtual ~BaseWebSocketChannel(); + nsCOMPtr<nsIURI> mOriginalURI; + nsCOMPtr<nsIURI> mURI; + RefPtr<ListenerAndContextContainer> mListenerMT; + nsCOMPtr<nsIInterfaceRequestor> mCallbacks; + nsCOMPtr<nsILoadGroup> mLoadGroup; + nsCOMPtr<nsILoadInfo> mLoadInfo; + nsCOMPtr<nsITransportProvider> mServerTransportProvider; + + // Used to ensure atomicity of mTargetThread. + // Set before AsyncOpen via RetargetDeliveryTo or in AsyncOpen, never changed + // after AsyncOpen + DataMutex<nsCOMPtr<nsISerialEventTarget>> mTargetThread{ + "BaseWebSocketChannel::EventTargetMutex"}; + + nsCString mProtocol; + nsCString mOrigin; + + nsCString mNegotiatedExtensions; + + uint32_t mWasOpened : 1; + uint32_t mClientSetPingInterval : 1; + uint32_t mClientSetPingTimeout : 1; + + Atomic<bool> mEncrypted; + bool mPingForced; + bool mIsServerSide; + + Atomic<uint32_t> mPingInterval; /* milliseconds */ + uint32_t mPingResponseTimeout; /* milliseconds */ + + uint32_t mSerial; + + uint64_t mHttpChannelId; +}; + +} // namespace net +} // namespace mozilla + +#endif // mozilla_net_BaseWebSocketChannel_h diff --git a/netwerk/protocol/websocket/IPCTransportProvider.cpp b/netwerk/protocol/websocket/IPCTransportProvider.cpp new file mode 100644 index 0000000000..ff902ae835 --- /dev/null +++ b/netwerk/protocol/websocket/IPCTransportProvider.cpp @@ -0,0 +1,89 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et ft=cpp : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "mozilla/net/IPCTransportProvider.h" + +#include "IPCTransportProvider.h" +#include "nsISocketTransport.h" +#include "nsIAsyncInputStream.h" +#include "nsIAsyncOutputStream.h" + +namespace mozilla { +namespace net { + +NS_IMPL_ISUPPORTS(TransportProviderParent, nsITransportProvider, + nsIHttpUpgradeListener) + +NS_IMETHODIMP +TransportProviderParent::SetListener(nsIHttpUpgradeListener* aListener) { + MOZ_ASSERT(aListener); + mListener = aListener; + + MaybeNotify(); + + return NS_OK; +} + +NS_IMETHODIMP +TransportProviderParent::GetIPCChild( + mozilla::net::PTransportProviderChild** aChild) { + MOZ_CRASH("Don't call this in parent process"); + *aChild = nullptr; + return NS_OK; +} + +NS_IMETHODIMP +TransportProviderParent::OnTransportAvailable( + nsISocketTransport* aTransport, nsIAsyncInputStream* aSocketIn, + nsIAsyncOutputStream* aSocketOut) { + MOZ_ASSERT(aTransport && aSocketOut && aSocketOut); + mTransport = aTransport; + mSocketIn = aSocketIn; + mSocketOut = aSocketOut; + + MaybeNotify(); + + return NS_OK; +} + +NS_IMETHODIMP +TransportProviderParent::OnUpgradeFailed(nsresult aErrorCode) { return NS_OK; } + +NS_IMETHODIMP +TransportProviderParent::OnWebSocketConnectionAvailable( + WebSocketConnectionBase* aConnection) { + return NS_ERROR_NOT_IMPLEMENTED; +} + +void TransportProviderParent::MaybeNotify() { + if (!mListener || !mTransport) { + return; + } + + DebugOnly<nsresult> rv = + mListener->OnTransportAvailable(mTransport, mSocketIn, mSocketOut); + MOZ_ASSERT(NS_SUCCEEDED(rv)); +} + +NS_IMPL_ISUPPORTS(TransportProviderChild, nsITransportProvider) + +TransportProviderChild::~TransportProviderChild() { Send__delete__(this); } + +NS_IMETHODIMP +TransportProviderChild::SetListener(nsIHttpUpgradeListener* aListener) { + MOZ_CRASH("Don't call this in child process"); + return NS_OK; +} + +NS_IMETHODIMP +TransportProviderChild::GetIPCChild( + mozilla::net::PTransportProviderChild** aChild) { + *aChild = this; + return NS_OK; +} + +} // namespace net +} // namespace mozilla diff --git a/netwerk/protocol/websocket/IPCTransportProvider.h b/netwerk/protocol/websocket/IPCTransportProvider.h new file mode 100644 index 0000000000..de54be127a --- /dev/null +++ b/netwerk/protocol/websocket/IPCTransportProvider.h @@ -0,0 +1,89 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et ft=cpp : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_net_IPCTransportProvider_h +#define mozilla_net_IPCTransportProvider_h + +#include "nsISupportsImpl.h" +#include "mozilla/net/PTransportProviderParent.h" +#include "mozilla/net/PTransportProviderChild.h" +#include "nsIHttpChannelInternal.h" +#include "nsITransportProvider.h" + +/* + * No, the ownership model for TransportProvider is that the child object is + * refcounted "normally". I.e. ipdl code doesn't hold a strong reference to + * TransportProviderChild. + * + * When TransportProviderChild goes away, it sends a __delete__ message to the + * parent. + * + * On the parent side, ipdl holds a strong reference to TransportProviderParent. + * When the actor is deallocatde it releases the reference to the + * TransportProviderParent. + * + * So effectively the child holds a strong reference to the parent, and are + * otherwise normally refcounted and have their lifetime determined by that + * refcount. + * + * The only other caveat is that the creation happens from the parent. + * So to create a TransportProvider, a constructor is sent from the parent to + * the child. At this time the child gets its first addref. + * + * A reference to the TransportProvider is then sent as part of some other + * message from the parent to the child. + * + * The receiver of that message can then grab the TransportProviderChild and + * without addreffing it, effectively using the refcount that the + * TransportProviderChild got on creation. + */ + +class nsISocketTransport; +class nsIAsyncInputStream; +class nsIAsyncOutputStream; + +namespace mozilla { +namespace net { + +class TransportProviderParent final : public PTransportProviderParent, + public nsITransportProvider, + public nsIHttpUpgradeListener { + public: + TransportProviderParent() = default; + + NS_DECL_ISUPPORTS + NS_DECL_NSITRANSPORTPROVIDER + NS_DECL_NSIHTTPUPGRADELISTENER + + void ActorDestroy(ActorDestroyReason aWhy) override {} + + private: + ~TransportProviderParent() = default; + + void MaybeNotify(); + + nsCOMPtr<nsIHttpUpgradeListener> mListener; + nsCOMPtr<nsISocketTransport> mTransport; + nsCOMPtr<nsIAsyncInputStream> mSocketIn; + nsCOMPtr<nsIAsyncOutputStream> mSocketOut; +}; + +class TransportProviderChild final : public PTransportProviderChild, + public nsITransportProvider { + public: + TransportProviderChild() = default; + + NS_DECL_ISUPPORTS + NS_DECL_NSITRANSPORTPROVIDER + + private: + ~TransportProviderChild(); +}; + +} // namespace net +} // namespace mozilla + +#endif diff --git a/netwerk/protocol/websocket/PTransportProvider.ipdl b/netwerk/protocol/websocket/PTransportProvider.ipdl new file mode 100644 index 0000000000..42ff3589a5 --- /dev/null +++ b/netwerk/protocol/websocket/PTransportProvider.ipdl @@ -0,0 +1,30 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et ft=cpp : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +include protocol PNecko; + +include "mozilla/net/IPCTransportProvider.h"; + +namespace mozilla { +namespace net { + +/* + * The only thing this protocol manages is used for is passing a + * PTransportProvider object from parent to child and then back to the parent + * again. Hence there's no need for any messages on the protocol itself. + */ + +[ManualDealloc, ChildImpl="TransportProviderChild", ParentImpl="TransportProviderParent"] +async protocol PTransportProvider +{ + manager PNecko; + +parent: + async __delete__(); +}; + +} // namespace net +} // namespace mozilla diff --git a/netwerk/protocol/websocket/PWebSocket.ipdl b/netwerk/protocol/websocket/PWebSocket.ipdl new file mode 100644 index 0000000000..bb263049f0 --- /dev/null +++ b/netwerk/protocol/websocket/PWebSocket.ipdl @@ -0,0 +1,67 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et tw=80 ft=cpp : */ + +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +include protocol PNecko; +include protocol PBrowser; +include protocol PTransportProvider; +include IPCStream; +include NeckoChannelParams; + +include "mozilla/net/WebSocketChannelParent.h"; +include "mozilla/net/WebSocketChannelChild.h"; + +using class IPC::SerializedLoadContext from "SerializedLoadContext.h"; +[RefCounted] using class nsIURI from "mozilla/ipc/URIUtils.h"; + +namespace mozilla { +namespace net { + +[ManualDealloc, ChildImpl="WebSocketChannelChild", ParentImpl="WebSocketChannelParent"] +async protocol PWebSocket +{ + manager PNecko; + +parent: + // Forwarded methods corresponding to methods on nsIWebSocketChannel + async AsyncOpen(nullable nsIURI aURI, + nsCString aOrigin, + OriginAttributes aOriginAttributes, + uint64_t aInnerWindowID, + nsCString aProtocol, + bool aSecure, + // ping values only meaningful if client set them + uint32_t aPingInterval, + bool aClientSetPingInterval, + uint32_t aPingTimeout, + bool aClientSetPingTimeout, + LoadInfoArgs aLoadInfoArgs, + PTransportProvider? aProvider, + nsCString aNegotiatedExtensions); + async Close(uint16_t code, nsCString reason); + async SendMsg(nsCString aMsg); + async SendBinaryMsg(nsCString aMsg); + async SendBinaryStream(IPCStream aStream, uint32_t aLength); + + async DeleteSelf(); + +child: + // Forwarded notifications corresponding to the nsIWebSocketListener interface + async OnStart(nsCString aProtocol, nsCString aExtensions, + nsString aEffectiveURL, bool aEncrypted, + uint64_t aHttpChannelId); + async OnStop(nsresult aStatusCode); + async OnMessageAvailable(nsCString aMsg, bool aMoreData); + async OnBinaryMessageAvailable(nsCString aMsg, bool aMoreData); + async OnAcknowledge(uint32_t aSize); + async OnServerClose(uint16_t code, nsCString aReason); + + async __delete__(); + +}; + +} //namespace net +} //namespace mozilla diff --git a/netwerk/protocol/websocket/PWebSocketConnection.ipdl b/netwerk/protocol/websocket/PWebSocketConnection.ipdl new file mode 100644 index 0000000000..b378320a2b --- /dev/null +++ b/netwerk/protocol/websocket/PWebSocketConnection.ipdl @@ -0,0 +1,33 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et tw=80 ft=cpp : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +include "mozilla/ipc/TransportSecurityInfoUtils.h"; + +[RefCounted] using class nsITransportSecurityInfo from "nsITransportSecurityInfo.h"; + +namespace mozilla { +namespace net { + +[ChildProc=Socket] +protocol PWebSocketConnection +{ +parent: + async OnTransportAvailable(nullable nsITransportSecurityInfo aSecurityInfo); + async OnError(nsresult aStatus); + async OnTCPClosed(); + async OnDataReceived(uint8_t[] aData); + async OnUpgradeFailed(nsresult aReason); + +child: + async WriteOutputData(uint8_t[] aData); + async StartReading(); + async DrainSocketData(); + + async __delete__(); +}; + +} //namespace net +} //namespace mozilla diff --git a/netwerk/protocol/websocket/PWebSocketEventListener.ipdl b/netwerk/protocol/websocket/PWebSocketEventListener.ipdl new file mode 100644 index 0000000000..d48224b337 --- /dev/null +++ b/netwerk/protocol/websocket/PWebSocketEventListener.ipdl @@ -0,0 +1,53 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et tw=80 ft=cpp : */ + +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +include protocol PNecko; + +using class mozilla::net::WebSocketFrameData from "mozilla/net/WebSocketFrame.h"; + +namespace mozilla { +namespace net { + +[ManualDealloc] +async protocol PWebSocketEventListener +{ + manager PNecko; + +child: + async WebSocketCreated(uint32_t awebSocketSerialID, + nsString aURI, + nsCString aProtocols); + + async WebSocketOpened(uint32_t awebSocketSerialID, + nsString aEffectiveURI, + nsCString aProtocols, + nsCString aExtensions, + uint64_t aHttpChannelId); + + async WebSocketMessageAvailable(uint32_t awebSocketSerialID, + nsCString aData, + uint16_t aMessageType); + + async WebSocketClosed(uint32_t awebSocketSerialID, + bool aWasClean, + uint16_t aCode, + nsString aReason); + + async FrameReceived(uint32_t aWebSocketSerialID, + WebSocketFrameData aFrameData); + + async FrameSent(uint32_t aWebSocketSerialID, + WebSocketFrameData aFrameData); + + async __delete__(); + +parent: + async Close(); +}; + +} //namespace net +} //namespace mozilla diff --git a/netwerk/protocol/websocket/WebSocketChannel.cpp b/netwerk/protocol/websocket/WebSocketChannel.cpp new file mode 100644 index 0000000000..a1ebf85d8e --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketChannel.cpp @@ -0,0 +1,4336 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et tw=80 : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <algorithm> + +#include "WebSocketChannel.h" + +#include "WebSocketConnectionBase.h" +#include "WebSocketFrame.h" +#include "WebSocketLog.h" +#include "mozilla/Atomics.h" +#include "mozilla/Attributes.h" +#include "mozilla/Base64.h" +#include "mozilla/EndianUtils.h" +#include "mozilla/MathAlgorithms.h" +#include "mozilla/ScopeExit.h" +#include "mozilla/StaticMutex.h" +#include "mozilla/StaticPrefs_privacy.h" +#include "mozilla/Telemetry.h" +#include "mozilla/TimeStamp.h" +#include "mozilla/Utf8.h" +#include "mozilla/net/WebSocketEventService.h" +#include "nsAlgorithm.h" +#include "nsCRT.h" +#include "nsCharSeparatedTokenizer.h" +#include "nsComponentManagerUtils.h" +#include "nsError.h" +#include "nsIAsyncVerifyRedirectCallback.h" +#include "nsICancelable.h" +#include "nsIChannel.h" +#include "nsIClassOfService.h" +#include "nsICryptoHash.h" +#include "nsIDNSRecord.h" +#include "nsIDNSService.h" +#include "nsIDashboardEventNotifier.h" +#include "nsIEventTarget.h" +#include "nsIHttpChannel.h" +#include "nsIIOService.h" +#include "nsINSSErrorsService.h" +#include "nsINetworkLinkService.h" +#include "nsINode.h" +#include "nsIObserverService.h" +#include "nsIPrefBranch.h" +#include "nsIProtocolHandler.h" +#include "nsIProtocolProxyService.h" +#include "nsIProxiedChannel.h" +#include "nsIProxyInfo.h" +#include "nsIRandomGenerator.h" +#include "nsIRunnable.h" +#include "nsISocketTransport.h" +#include "nsITLSSocketControl.h" +#include "nsITransportProvider.h" +#include "nsITransportSecurityInfo.h" +#include "nsIURI.h" +#include "nsIURIMutator.h" +#include "nsNetCID.h" +#include "nsNetUtil.h" +#include "nsProxyRelease.h" +#include "nsServiceManagerUtils.h" +#include "nsSocketTransportService2.h" +#include "nsStringStream.h" +#include "nsThreadUtils.h" +#include "plbase64.h" +#include "prmem.h" +#include "prnetdb.h" +#include "zlib.h" + +// rather than slurp up all of nsIWebSocket.idl, which lives outside necko, just +// dupe one constant we need from it +#define CLOSE_GOING_AWAY 1001 + +using namespace mozilla; +using namespace mozilla::net; + +namespace mozilla::net { + +NS_IMPL_ISUPPORTS(WebSocketChannel, nsIWebSocketChannel, nsIHttpUpgradeListener, + nsIRequestObserver, nsIStreamListener, nsIProtocolHandler, + nsIInputStreamCallback, nsIOutputStreamCallback, + nsITimerCallback, nsIDNSListener, nsIProtocolProxyCallback, + nsIInterfaceRequestor, nsIChannelEventSink, + nsIThreadRetargetableRequest, nsIObserver, nsINamed) + +// We implement RFC 6455, which uses Sec-WebSocket-Version: 13 on the wire. +#define SEC_WEBSOCKET_VERSION "13" + +/* + * About SSL unsigned certificates + * + * wss will not work to a host using an unsigned certificate unless there + * is already an exception (i.e. it cannot popup a dialog asking for + * a security exception). This is similar to how an inlined img will + * fail without a dialog if fails for the same reason. This should not + * be a problem in practice as it is expected the websocket javascript + * is served from the same host as the websocket server (or of course, + * a valid cert could just be provided). + * + */ + +// some helper classes + +//----------------------------------------------------------------------------- +// FailDelayManager +// +// Stores entries (searchable by {host, port}) of connections that have recently +// failed, so we can do delay of reconnects per RFC 6455 Section 7.2.3 +//----------------------------------------------------------------------------- + +// Initial reconnect delay is randomly chosen between 200-400 ms. +// This is a gentler backoff than the 0-5 seconds the spec offhandedly suggests. +const uint32_t kWSReconnectInitialBaseDelay = 200; +const uint32_t kWSReconnectInitialRandomDelay = 200; + +// Base lifetime (in ms) of a FailDelay: kept longer if more failures occur +const uint32_t kWSReconnectBaseLifeTime = 60 * 1000; +// Maximum reconnect delay (in ms) +const uint32_t kWSReconnectMaxDelay = 60 * 1000; + +// hold record of failed connections, and calculates needed delay for reconnects +// to same host/path/port. +class FailDelay { + public: + FailDelay(nsCString address, nsCString path, int32_t port) + : mAddress(std::move(address)), mPath(std::move(path)), mPort(port) { + mLastFailure = TimeStamp::Now(); + mNextDelay = kWSReconnectInitialBaseDelay + + (rand() % kWSReconnectInitialRandomDelay); + } + + // Called to update settings when connection fails again. + void FailedAgain() { + mLastFailure = TimeStamp::Now(); + // We use a truncated exponential backoff as suggested by RFC 6455, + // but multiply by 1.5 instead of 2 to be more gradual. + mNextDelay = static_cast<uint32_t>( + std::min<double>(kWSReconnectMaxDelay, mNextDelay * 1.5)); + LOG( + ("WebSocket: FailedAgain: host=%s, path=%s, port=%d: incremented delay " + "to " + "%" PRIu32, + mAddress.get(), mPath.get(), mPort, mNextDelay)); + } + + // returns 0 if there is no need to delay (i.e. delay interval is over) + uint32_t RemainingDelay(TimeStamp rightNow) { + TimeDuration dur = rightNow - mLastFailure; + uint32_t sinceFail = (uint32_t)dur.ToMilliseconds(); + if (sinceFail > mNextDelay) return 0; + + return mNextDelay - sinceFail; + } + + bool IsExpired(TimeStamp rightNow) { + return (mLastFailure + TimeDuration::FromMilliseconds( + kWSReconnectBaseLifeTime + mNextDelay)) <= + rightNow; + } + + nsCString mAddress; // IP address (or hostname if using proxy) + nsCString mPath; + int32_t mPort; + + private: + TimeStamp mLastFailure; // Time of last failed attempt + // mLastFailure + mNextDelay is the soonest we'll allow a reconnect + uint32_t mNextDelay; // milliseconds +}; + +class FailDelayManager { + public: + FailDelayManager() { + MOZ_COUNT_CTOR(FailDelayManager); + + mDelaysDisabled = false; + + nsCOMPtr<nsIPrefBranch> prefService = + do_GetService(NS_PREFSERVICE_CONTRACTID); + if (!prefService) { + return; + } + bool boolpref = true; + nsresult rv; + rv = prefService->GetBoolPref("network.websocket.delay-failed-reconnects", + &boolpref); + if (NS_SUCCEEDED(rv) && !boolpref) { + mDelaysDisabled = true; + } + } + + ~FailDelayManager() { MOZ_COUNT_DTOR(FailDelayManager); } + + void Add(nsCString& address, nsCString& path, int32_t port) { + if (mDelaysDisabled) return; + + UniquePtr<FailDelay> record(new FailDelay(address, path, port)); + mEntries.AppendElement(std::move(record)); + } + + // Element returned may not be valid after next main thread event: don't keep + // pointer to it around + FailDelay* Lookup(nsCString& address, nsCString& path, int32_t port, + uint32_t* outIndex = nullptr) { + if (mDelaysDisabled) return nullptr; + + FailDelay* result = nullptr; + TimeStamp rightNow = TimeStamp::Now(); + + // We also remove expired entries during search: iterate from end to make + // indexing simpler + for (int32_t i = mEntries.Length() - 1; i >= 0; --i) { + FailDelay* fail = mEntries[i].get(); + if (fail->mAddress.Equals(address) && fail->mPath.Equals(path) && + fail->mPort == port) { + if (outIndex) *outIndex = i; + result = fail; + // break here: removing more entries would mess up *outIndex. + // Any remaining expired entries will be deleted next time Lookup + // finds nothing, which is the most common case anyway. + break; + } + if (fail->IsExpired(rightNow)) { + mEntries.RemoveElementAt(i); + } + } + return result; + } + + // returns true if channel connects immediately, or false if it's delayed + void DelayOrBegin(WebSocketChannel* ws) { + if (!mDelaysDisabled) { + uint32_t failIndex = 0; + FailDelay* fail = Lookup(ws->mAddress, ws->mPath, ws->mPort, &failIndex); + + if (fail) { + TimeStamp rightNow = TimeStamp::Now(); + + uint32_t remainingDelay = fail->RemainingDelay(rightNow); + if (remainingDelay) { + // reconnecting within delay interval: delay by remaining time + nsresult rv; + MutexAutoLock lock(ws->mMutex); + rv = NS_NewTimerWithCallback(getter_AddRefs(ws->mReconnectDelayTimer), + ws, remainingDelay, + nsITimer::TYPE_ONE_SHOT); + if (NS_SUCCEEDED(rv)) { + LOG( + ("WebSocket: delaying websocket [this=%p] by %lu ms, changing" + " state to CONNECTING_DELAYED", + ws, (unsigned long)remainingDelay)); + ws->mConnecting = CONNECTING_DELAYED; + return; + } + // if timer fails (which is very unlikely), drop down to BeginOpen + // call + } else if (fail->IsExpired(rightNow)) { + mEntries.RemoveElementAt(failIndex); + } + } + } + + // Delays disabled, or no previous failure, or we're reconnecting after + // scheduled delay interval has passed: connect. + ws->BeginOpen(true); + } + + // Remove() also deletes all expired entries as it iterates: better for + // battery life than using a periodic timer. + void Remove(nsCString& address, nsCString& path, int32_t port) { + TimeStamp rightNow = TimeStamp::Now(); + + // iterate from end, to make deletion indexing easier + for (int32_t i = mEntries.Length() - 1; i >= 0; --i) { + FailDelay* entry = mEntries[i].get(); + if ((entry->mAddress.Equals(address) && entry->mPath.Equals(path) && + entry->mPort == port) || + entry->IsExpired(rightNow)) { + mEntries.RemoveElementAt(i); + } + } + } + + private: + nsTArray<UniquePtr<FailDelay>> mEntries; + bool mDelaysDisabled; +}; + +//----------------------------------------------------------------------------- +// nsWSAdmissionManager +// +// 1) Ensures that only one websocket at a time is CONNECTING to a given IP +// address (or hostname, if using proxy), per RFC 6455 Section 4.1. +// 2) Delays reconnects to IP/host after connection failure, per Section 7.2.3 +//----------------------------------------------------------------------------- + +class nsWSAdmissionManager { + public: + static void Init() { + StaticMutexAutoLock lock(sLock); + if (!sManager) { + sManager = new nsWSAdmissionManager(); + } + } + + static void Shutdown() { + StaticMutexAutoLock lock(sLock); + delete sManager; + sManager = nullptr; + } + + // Determine if we will open connection immediately (returns true), or + // delay/queue the connection (returns false) + static void ConditionallyConnect(WebSocketChannel* ws) { + LOG(("Websocket: ConditionallyConnect: [this=%p]", ws)); + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + MOZ_ASSERT(ws->mConnecting == NOT_CONNECTING, "opening state"); + + StaticMutexAutoLock lock(sLock); + if (!sManager) { + return; + } + + // If there is already another WS channel connecting to this IP address, + // defer BeginOpen and mark as waiting in queue. + bool hostFound = (sManager->IndexOf(ws->mAddress, ws->mOriginSuffix) >= 0); + + uint32_t failIndex = 0; + FailDelay* fail = sManager->mFailures.Lookup(ws->mAddress, ws->mPath, + ws->mPort, &failIndex); + bool existingFail = fail != nullptr; + + // Always add ourselves to queue, even if we'll connect immediately + UniquePtr<nsOpenConn> newdata( + new nsOpenConn(ws->mAddress, ws->mOriginSuffix, existingFail, ws)); + + // If a connection has not previously failed then prioritize it over + // connections that have + if (existingFail) { + sManager->mQueue.AppendElement(std::move(newdata)); + } else { + uint32_t insertionIndex = sManager->IndexOfFirstFailure(); + MOZ_ASSERT(insertionIndex <= sManager->mQueue.Length(), + "Insertion index outside bounds"); + sManager->mQueue.InsertElementAt(insertionIndex, std::move(newdata)); + } + + if (hostFound) { + LOG( + ("Websocket: some other channel is connecting, changing state to " + "CONNECTING_QUEUED")); + ws->mConnecting = CONNECTING_QUEUED; + } else { + sManager->mFailures.DelayOrBegin(ws); + } + } + + static void OnConnected(WebSocketChannel* aChannel) { + LOG(("Websocket: OnConnected: [this=%p]", aChannel)); + + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + MOZ_ASSERT(aChannel->mConnecting == CONNECTING_IN_PROGRESS, + "Channel completed connect, but not connecting?"); + + StaticMutexAutoLock lock(sLock); + if (!sManager) { + return; + } + + LOG(("Websocket: changing state to NOT_CONNECTING")); + aChannel->mConnecting = NOT_CONNECTING; + + // Remove from queue + sManager->RemoveFromQueue(aChannel); + + // Connection succeeded, so stop keeping track of any previous failures + sManager->mFailures.Remove(aChannel->mAddress, aChannel->mPath, + aChannel->mPort); + + // Check for queued connections to same host. + // Note: still need to check for failures, since next websocket with same + // host may have different port + sManager->ConnectNext(aChannel->mAddress, aChannel->mOriginSuffix); + } + + // Called every time a websocket channel ends its session (including going + // away w/o ever successfully creating a connection) + static void OnStopSession(WebSocketChannel* aChannel, nsresult aReason) { + LOG(("Websocket: OnStopSession: [this=%p, reason=0x%08" PRIx32 "]", + aChannel, static_cast<uint32_t>(aReason))); + + StaticMutexAutoLock lock(sLock); + if (!sManager) { + return; + } + + if (NS_FAILED(aReason)) { + // Have we seen this failure before? + FailDelay* knownFailure = sManager->mFailures.Lookup( + aChannel->mAddress, aChannel->mPath, aChannel->mPort); + if (knownFailure) { + if (aReason == NS_ERROR_NOT_CONNECTED) { + // Don't count close() before connection as a network error + LOG( + ("Websocket close() before connection to %s, %s, %d completed" + " [this=%p]", + aChannel->mAddress.get(), aChannel->mPath.get(), + (int)aChannel->mPort, aChannel)); + } else { + // repeated failure to connect: increase delay for next connection + knownFailure->FailedAgain(); + } + } else { + // new connection failure: record it. + LOG(("WebSocket: connection to %s, %s, %d failed: [this=%p]", + aChannel->mAddress.get(), aChannel->mPath.get(), + (int)aChannel->mPort, aChannel)); + sManager->mFailures.Add(aChannel->mAddress, aChannel->mPath, + aChannel->mPort); + } + } + + if (NS_IsMainThread()) { + ContinueOnStopSession(aChannel, aReason); + } else { + NS_DispatchToMainThread(NS_NewRunnableFunction( + "nsWSAdmissionManager::ContinueOnStopSession", + [channel = RefPtr{aChannel}, reason = aReason]() { + StaticMutexAutoLock lock(sLock); + if (!sManager) { + return; + } + + nsWSAdmissionManager::ContinueOnStopSession(channel, reason); + })); + } + } + + static void ContinueOnStopSession(WebSocketChannel* aChannel, + nsresult aReason) { + sLock.AssertCurrentThreadOwns(); + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + + if (!aChannel->mConnecting) { + return; + } + + // Only way a connecting channel may get here w/o failing is if it + // was closed with GOING_AWAY (1001) because of navigation, tab + // close, etc. +#ifdef DEBUG + { + MutexAutoLock lock(aChannel->mMutex); + MOZ_ASSERT( + NS_FAILED(aReason) || aChannel->mScriptCloseCode == CLOSE_GOING_AWAY, + "websocket closed while connecting w/o failing?"); + } +#endif + Unused << aReason; + + sManager->RemoveFromQueue(aChannel); + + bool wasNotQueued = (aChannel->mConnecting != CONNECTING_QUEUED); + LOG(("Websocket: changing state to NOT_CONNECTING")); + aChannel->mConnecting = NOT_CONNECTING; + if (wasNotQueued) { + sManager->ConnectNext(aChannel->mAddress, aChannel->mOriginSuffix); + } + } + + static void IncrementSessionCount() { + StaticMutexAutoLock lock(sLock); + if (!sManager) { + return; + } + sManager->mSessionCount++; + } + + static void DecrementSessionCount() { + StaticMutexAutoLock lock(sLock); + if (!sManager) { + return; + } + sManager->mSessionCount--; + } + + static void GetSessionCount(int32_t& aSessionCount) { + StaticMutexAutoLock lock(sLock); + if (!sManager) { + return; + } + aSessionCount = sManager->mSessionCount; + } + + private: + nsWSAdmissionManager() : mSessionCount(0) { + MOZ_COUNT_CTOR(nsWSAdmissionManager); + } + + ~nsWSAdmissionManager() { MOZ_COUNT_DTOR(nsWSAdmissionManager); } + + class nsOpenConn { + public: + nsOpenConn(nsCString& addr, nsCString& originSuffix, bool failed, + WebSocketChannel* channel) + : mAddress(addr), + mOriginSuffix(originSuffix), + mFailed(failed), + mChannel(channel) { + MOZ_COUNT_CTOR(nsOpenConn); + } + MOZ_COUNTED_DTOR(nsOpenConn) + + nsCString mAddress; + nsCString mOriginSuffix; + bool mFailed = false; + RefPtr<WebSocketChannel> mChannel; + }; + + void ConnectNext(nsCString& hostName, nsCString& originSuffix) { + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + + int32_t index = IndexOf(hostName, originSuffix); + if (index >= 0) { + WebSocketChannel* chan = mQueue[index]->mChannel; + + MOZ_ASSERT(chan->mConnecting == CONNECTING_QUEUED, + "transaction not queued but in queue"); + LOG(("WebSocket: ConnectNext: found channel [this=%p] in queue", chan)); + + mFailures.DelayOrBegin(chan); + } + } + + void RemoveFromQueue(WebSocketChannel* aChannel) { + LOG(("Websocket: RemoveFromQueue: [this=%p]", aChannel)); + int32_t index = IndexOf(aChannel); + MOZ_ASSERT(index >= 0, "connection to remove not in queue"); + if (index >= 0) { + mQueue.RemoveElementAt(index); + } + } + + int32_t IndexOf(nsCString& aAddress, nsCString& aOriginSuffix) { + for (uint32_t i = 0; i < mQueue.Length(); i++) { + bool isPartitioned = StaticPrefs::privacy_partition_network_state() || + StaticPrefs::privacy_firstparty_isolate(); + if (aAddress == (mQueue[i])->mAddress && + (!isPartitioned || aOriginSuffix == (mQueue[i])->mOriginSuffix)) { + return i; + } + } + return -1; + } + + int32_t IndexOf(WebSocketChannel* aChannel) { + for (uint32_t i = 0; i < mQueue.Length(); i++) { + if (aChannel == (mQueue[i])->mChannel) return i; + } + return -1; + } + + // Returns the index of the first entry that failed, or else the last entry if + // none found + uint32_t IndexOfFirstFailure() { + for (uint32_t i = 0; i < mQueue.Length(); i++) { + if (mQueue[i]->mFailed) return i; + } + return mQueue.Length(); + } + + // SessionCount might be decremented from the main or the socket + // thread, so manage it with atomic counters + Atomic<int32_t> mSessionCount; + + // Queue for websockets that have not completed connecting yet. + // The first nsOpenConn with a given address will be either be + // CONNECTING_IN_PROGRESS or CONNECTING_DELAYED. Later ones with the same + // hostname must be CONNECTING_QUEUED. + // + // We could hash hostnames instead of using a single big vector here, but the + // dataset is expected to be small. + nsTArray<UniquePtr<nsOpenConn>> mQueue; + + FailDelayManager mFailures; + + static nsWSAdmissionManager* sManager MOZ_GUARDED_BY(sLock); + static StaticMutex sLock; +}; + +nsWSAdmissionManager* nsWSAdmissionManager::sManager; +StaticMutex nsWSAdmissionManager::sLock; + +//----------------------------------------------------------------------------- +// CallOnMessageAvailable +//----------------------------------------------------------------------------- + +class CallOnMessageAvailable final : public Runnable { + public: + CallOnMessageAvailable(WebSocketChannel* aChannel, nsACString& aData, + int32_t aLen) + : Runnable("net::CallOnMessageAvailable"), + mChannel(aChannel), + mListenerMT(aChannel->mListenerMT), + mData(aData), + mLen(aLen) {} + + NS_IMETHOD Run() override { + MOZ_ASSERT(mChannel->IsOnTargetThread()); + + if (mListenerMT) { + nsresult rv; + if (mLen < 0) { + rv = mListenerMT->mListener->OnMessageAvailable(mListenerMT->mContext, + mData); + } else { + rv = mListenerMT->mListener->OnBinaryMessageAvailable( + mListenerMT->mContext, mData); + } + if (NS_FAILED(rv)) { + LOG( + ("OnMessageAvailable or OnBinaryMessageAvailable " + "failed with 0x%08" PRIx32, + static_cast<uint32_t>(rv))); + } + } + + return NS_OK; + } + + private: + ~CallOnMessageAvailable() = default; + + RefPtr<WebSocketChannel> mChannel; + RefPtr<BaseWebSocketChannel::ListenerAndContextContainer> mListenerMT; + nsCString mData; + int32_t mLen; +}; + +//----------------------------------------------------------------------------- +// CallOnStop +//----------------------------------------------------------------------------- + +class CallOnStop final : public Runnable { + public: + CallOnStop(WebSocketChannel* aChannel, nsresult aReason) + : Runnable("net::CallOnStop"), + mChannel(aChannel), + mListenerMT(mChannel->mListenerMT), + mReason(aReason) {} + + NS_IMETHOD Run() override { + MOZ_ASSERT(mChannel->IsOnTargetThread()); + + if (mListenerMT) { + nsresult rv = + mListenerMT->mListener->OnStop(mListenerMT->mContext, mReason); + if (NS_FAILED(rv)) { + LOG( + ("WebSocketChannel::CallOnStop " + "OnStop failed (%08" PRIx32 ")\n", + static_cast<uint32_t>(rv))); + } + mChannel->mListenerMT = nullptr; + } + + return NS_OK; + } + + private: + ~CallOnStop() = default; + + RefPtr<WebSocketChannel> mChannel; + RefPtr<BaseWebSocketChannel::ListenerAndContextContainer> mListenerMT; + nsresult mReason; +}; + +//----------------------------------------------------------------------------- +// CallOnServerClose +//----------------------------------------------------------------------------- + +class CallOnServerClose final : public Runnable { + public: + CallOnServerClose(WebSocketChannel* aChannel, uint16_t aCode, + nsACString& aReason) + : Runnable("net::CallOnServerClose"), + mChannel(aChannel), + mListenerMT(mChannel->mListenerMT), + mCode(aCode), + mReason(aReason) {} + + NS_IMETHOD Run() override { + MOZ_ASSERT(mChannel->IsOnTargetThread()); + + if (mListenerMT) { + nsresult rv = mListenerMT->mListener->OnServerClose(mListenerMT->mContext, + mCode, mReason); + if (NS_FAILED(rv)) { + LOG( + ("WebSocketChannel::CallOnServerClose " + "OnServerClose failed (%08" PRIx32 ")\n", + static_cast<uint32_t>(rv))); + } + } + return NS_OK; + } + + private: + ~CallOnServerClose() = default; + + RefPtr<WebSocketChannel> mChannel; + RefPtr<BaseWebSocketChannel::ListenerAndContextContainer> mListenerMT; + uint16_t mCode; + nsCString mReason; +}; + +//----------------------------------------------------------------------------- +// CallAcknowledge +//----------------------------------------------------------------------------- + +class CallAcknowledge final : public Runnable { + public: + CallAcknowledge(WebSocketChannel* aChannel, uint32_t aSize) + : Runnable("net::CallAcknowledge"), + mChannel(aChannel), + mListenerMT(mChannel->mListenerMT), + mSize(aSize) {} + + NS_IMETHOD Run() override { + MOZ_ASSERT(mChannel->IsOnTargetThread()); + + LOG(("WebSocketChannel::CallAcknowledge: Size %u\n", mSize)); + if (mListenerMT) { + nsresult rv = + mListenerMT->mListener->OnAcknowledge(mListenerMT->mContext, mSize); + if (NS_FAILED(rv)) { + LOG(("WebSocketChannel::CallAcknowledge: Acknowledge failed (%08" PRIx32 + ")\n", + static_cast<uint32_t>(rv))); + } + } + return NS_OK; + } + + private: + ~CallAcknowledge() = default; + + RefPtr<WebSocketChannel> mChannel; + RefPtr<BaseWebSocketChannel::ListenerAndContextContainer> mListenerMT; + uint32_t mSize; +}; + +//----------------------------------------------------------------------------- +// CallOnTransportAvailable +//----------------------------------------------------------------------------- + +class CallOnTransportAvailable final : public Runnable { + public: + CallOnTransportAvailable(WebSocketChannel* aChannel, + nsISocketTransport* aTransport, + nsIAsyncInputStream* aSocketIn, + nsIAsyncOutputStream* aSocketOut) + : Runnable("net::CallOnTransportAvailble"), + mChannel(aChannel), + mTransport(aTransport), + mSocketIn(aSocketIn), + mSocketOut(aSocketOut) {} + + NS_IMETHOD Run() override { + LOG(("WebSocketChannel::CallOnTransportAvailable %p\n", this)); + return mChannel->OnTransportAvailable(mTransport, mSocketIn, mSocketOut); + } + + private: + ~CallOnTransportAvailable() = default; + + RefPtr<WebSocketChannel> mChannel; + nsCOMPtr<nsISocketTransport> mTransport; + nsCOMPtr<nsIAsyncInputStream> mSocketIn; + nsCOMPtr<nsIAsyncOutputStream> mSocketOut; +}; + +//----------------------------------------------------------------------------- +// PMCECompression +//----------------------------------------------------------------------------- + +class PMCECompression { + public: + PMCECompression(bool aNoContextTakeover, int32_t aLocalMaxWindowBits, + int32_t aRemoteMaxWindowBits) + : mActive(false), + mNoContextTakeover(aNoContextTakeover), + mResetDeflater(false), + mMessageDeflated(false) { + this->mDeflater.next_in = nullptr; + this->mDeflater.avail_in = 0; + this->mDeflater.total_in = 0; + this->mDeflater.next_out = nullptr; + this->mDeflater.avail_out = 0; + this->mDeflater.total_out = 0; + this->mDeflater.msg = nullptr; + this->mDeflater.state = nullptr; + this->mDeflater.data_type = 0; + this->mDeflater.adler = 0; + this->mDeflater.reserved = 0; + this->mInflater.next_in = nullptr; + this->mInflater.avail_in = 0; + this->mInflater.total_in = 0; + this->mInflater.next_out = nullptr; + this->mInflater.avail_out = 0; + this->mInflater.total_out = 0; + this->mInflater.msg = nullptr; + this->mInflater.state = nullptr; + this->mInflater.data_type = 0; + this->mInflater.adler = 0; + this->mInflater.reserved = 0; + MOZ_COUNT_CTOR(PMCECompression); + + mDeflater.zalloc = mInflater.zalloc = Z_NULL; + mDeflater.zfree = mInflater.zfree = Z_NULL; + mDeflater.opaque = mInflater.opaque = Z_NULL; + + if (deflateInit2(&mDeflater, Z_DEFAULT_COMPRESSION, Z_DEFLATED, + -aLocalMaxWindowBits, 8, Z_DEFAULT_STRATEGY) == Z_OK) { + if (inflateInit2(&mInflater, -aRemoteMaxWindowBits) == Z_OK) { + mActive = true; + } else { + deflateEnd(&mDeflater); + } + } + } + + ~PMCECompression() { + MOZ_COUNT_DTOR(PMCECompression); + + if (mActive) { + inflateEnd(&mInflater); + deflateEnd(&mDeflater); + } + } + + bool Active() { return mActive; } + + void SetMessageDeflated() { + MOZ_ASSERT(!mMessageDeflated); + mMessageDeflated = true; + } + bool IsMessageDeflated() { return mMessageDeflated; } + + bool UsingContextTakeover() { return !mNoContextTakeover; } + + nsresult Deflate(uint8_t* data, uint32_t dataLen, nsACString& _retval) { + if (mResetDeflater || mNoContextTakeover) { + if (deflateReset(&mDeflater) != Z_OK) { + return NS_ERROR_UNEXPECTED; + } + mResetDeflater = false; + } + + mDeflater.avail_out = kBufferLen; + mDeflater.next_out = mBuffer; + mDeflater.avail_in = dataLen; + mDeflater.next_in = data; + + while (true) { + int zerr = deflate(&mDeflater, Z_SYNC_FLUSH); + + if (zerr != Z_OK) { + mResetDeflater = true; + return NS_ERROR_UNEXPECTED; + } + + uint32_t deflated = kBufferLen - mDeflater.avail_out; + if (deflated > 0) { + _retval.Append(reinterpret_cast<char*>(mBuffer), deflated); + } + + mDeflater.avail_out = kBufferLen; + mDeflater.next_out = mBuffer; + + if (mDeflater.avail_in > 0) { + continue; // There is still some data to deflate + } + + if (deflated == kBufferLen) { + continue; // There was not enough space in the buffer + } + + break; + } + + if (_retval.Length() < 4) { + MOZ_ASSERT(false, "Expected trailing not found in deflated data!"); + mResetDeflater = true; + return NS_ERROR_UNEXPECTED; + } + + _retval.Truncate(_retval.Length() - 4); + + return NS_OK; + } + + nsresult Inflate(uint8_t* data, uint32_t dataLen, nsACString& _retval) { + mMessageDeflated = false; + + Bytef trailingData[] = {0x00, 0x00, 0xFF, 0xFF}; + bool trailingDataUsed = false; + + mInflater.avail_out = kBufferLen; + mInflater.next_out = mBuffer; + mInflater.avail_in = dataLen; + mInflater.next_in = data; + + while (true) { + int zerr = inflate(&mInflater, Z_NO_FLUSH); + + if (zerr == Z_STREAM_END) { + Bytef* saveNextIn = mInflater.next_in; + uint32_t saveAvailIn = mInflater.avail_in; + Bytef* saveNextOut = mInflater.next_out; + uint32_t saveAvailOut = mInflater.avail_out; + + inflateReset(&mInflater); + + mInflater.next_in = saveNextIn; + mInflater.avail_in = saveAvailIn; + mInflater.next_out = saveNextOut; + mInflater.avail_out = saveAvailOut; + } else if (zerr != Z_OK && zerr != Z_BUF_ERROR) { + return NS_ERROR_INVALID_CONTENT_ENCODING; + } + + uint32_t inflated = kBufferLen - mInflater.avail_out; + if (inflated > 0) { + if (!_retval.Append(reinterpret_cast<char*>(mBuffer), inflated, + fallible)) { + return NS_ERROR_OUT_OF_MEMORY; + } + } + + mInflater.avail_out = kBufferLen; + mInflater.next_out = mBuffer; + + if (mInflater.avail_in > 0) { + continue; // There is still some data to inflate + } + + if (inflated == kBufferLen) { + continue; // There was not enough space in the buffer + } + + if (!trailingDataUsed) { + trailingDataUsed = true; + mInflater.avail_in = sizeof(trailingData); + mInflater.next_in = trailingData; + continue; + } + + return NS_OK; + } + } + + private: + bool mActive; + bool mNoContextTakeover; + bool mResetDeflater; + bool mMessageDeflated; + z_stream mDeflater{}; + z_stream mInflater{}; + const static uint32_t kBufferLen = 4096; + uint8_t mBuffer[kBufferLen]{0}; +}; + +//----------------------------------------------------------------------------- +// OutboundMessage +//----------------------------------------------------------------------------- + +enum WsMsgType { + kMsgTypeString = 0, + kMsgTypeBinaryString, + kMsgTypeStream, + kMsgTypePing, + kMsgTypePong, + kMsgTypeFin +}; + +static const char* msgNames[] = {"text", "binaryString", "binaryStream", + "ping", "pong", "close"}; + +class OutboundMessage { + public: + OutboundMessage(WsMsgType type, const nsACString& str) + : mMsg(mozilla::AsVariant(pString(str))), + mMsgType(type), + mDeflated(false) { + MOZ_COUNT_CTOR(OutboundMessage); + } + + OutboundMessage(nsIInputStream* stream, uint32_t length) + : mMsg(mozilla::AsVariant(StreamWithLength(stream, length))), + mMsgType(kMsgTypeStream), + mDeflated(false) { + MOZ_COUNT_CTOR(OutboundMessage); + } + + ~OutboundMessage() { + MOZ_COUNT_DTOR(OutboundMessage); + switch (mMsgType) { + case kMsgTypeString: + case kMsgTypeBinaryString: + case kMsgTypePing: + case kMsgTypePong: + break; + case kMsgTypeStream: + // for now this only gets hit if msg deleted w/o being sent + if (mMsg.as<StreamWithLength>().mStream) { + mMsg.as<StreamWithLength>().mStream->Close(); + } + break; + case kMsgTypeFin: + break; // do-nothing: avoid compiler warning + } + } + + WsMsgType GetMsgType() const { return mMsgType; } + int32_t Length() { + if (mMsg.is<pString>()) { + return mMsg.as<pString>().mValue.Length(); + } + + return mMsg.as<StreamWithLength>().mLength; + } + int32_t OrigLength() { + if (mMsg.is<pString>()) { + pString& ref = mMsg.as<pString>(); + return mDeflated ? ref.mOrigValue.Length() : ref.mValue.Length(); + } + + return mMsg.as<StreamWithLength>().mLength; + } + + uint8_t* BeginWriting() { + MOZ_ASSERT(mMsgType != kMsgTypeStream, + "Stream should have been converted to string by now"); + if (!mMsg.as<pString>().mValue.IsVoid()) { + return (uint8_t*)mMsg.as<pString>().mValue.BeginWriting(); + } + return nullptr; + } + + uint8_t* BeginReading() { + MOZ_ASSERT(mMsgType != kMsgTypeStream, + "Stream should have been converted to string by now"); + if (!mMsg.as<pString>().mValue.IsVoid()) { + return (uint8_t*)mMsg.as<pString>().mValue.BeginReading(); + } + return nullptr; + } + + uint8_t* BeginOrigReading() { + MOZ_ASSERT(mMsgType != kMsgTypeStream, + "Stream should have been converted to string by now"); + if (!mDeflated) return BeginReading(); + if (!mMsg.as<pString>().mOrigValue.IsVoid()) { + return (uint8_t*)mMsg.as<pString>().mOrigValue.BeginReading(); + } + return nullptr; + } + + nsresult ConvertStreamToString() { + MOZ_ASSERT(mMsgType == kMsgTypeStream, "Not a stream!"); + nsAutoCString temp; + { + StreamWithLength& ref = mMsg.as<StreamWithLength>(); + nsresult rv = NS_ReadInputStreamToString(ref.mStream, temp, ref.mLength); + + NS_ENSURE_SUCCESS(rv, rv); + if (temp.Length() != ref.mLength) { + return NS_ERROR_UNEXPECTED; + } + ref.mStream->Close(); + } + + mMsg = mozilla::AsVariant(pString(temp)); + mMsgType = kMsgTypeBinaryString; + + return NS_OK; + } + + bool DeflatePayload(PMCECompression* aCompressor) { + MOZ_ASSERT(mMsgType != kMsgTypeStream, + "Stream should have been converted to string by now"); + MOZ_ASSERT(!mDeflated); + + nsresult rv; + pString& ref = mMsg.as<pString>(); + if (ref.mValue.Length() == 0) { + // Empty message + return false; + } + + nsAutoCString temp; + rv = aCompressor->Deflate(BeginReading(), ref.mValue.Length(), temp); + if (NS_FAILED(rv)) { + LOG( + ("WebSocketChannel::OutboundMessage: Deflating payload failed " + "[rv=0x%08" PRIx32 "]\n", + static_cast<uint32_t>(rv))); + return false; + } + + if (!aCompressor->UsingContextTakeover() && + temp.Length() > ref.mValue.Length()) { + // When "<local>_no_context_takeover" was negotiated, do not send deflated + // payload if it's larger that the original one. OTOH, it makes sense + // to send the larger deflated payload when the sliding window is not + // reset between messages because if we would skip some deflated block + // we would need to empty the sliding window which could affect the + // compression of the subsequent messages. + LOG( + ("WebSocketChannel::OutboundMessage: Not deflating message since the " + "deflated payload is larger than the original one [deflated=%zd, " + "original=%zd]", + temp.Length(), ref.mValue.Length())); + return false; + } + + mDeflated = true; + mMsg.as<pString>().mOrigValue = mMsg.as<pString>().mValue; + mMsg.as<pString>().mValue = temp; + return true; + } + + private: + struct pString { + nsCString mValue; + nsCString mOrigValue; + explicit pString(const nsACString& value) + : mValue(value), mOrigValue(VoidCString()) {} + }; + struct StreamWithLength { + nsCOMPtr<nsIInputStream> mStream; + uint32_t mLength; + explicit StreamWithLength(nsIInputStream* stream, uint32_t Length) + : mStream(stream), mLength(Length) {} + }; + mozilla::Variant<pString, StreamWithLength> mMsg; + WsMsgType mMsgType; + bool mDeflated; +}; + +//----------------------------------------------------------------------------- +// OutboundEnqueuer +//----------------------------------------------------------------------------- + +class OutboundEnqueuer final : public Runnable { + public: + OutboundEnqueuer(WebSocketChannel* aChannel, OutboundMessage* aMsg) + : Runnable("OutboundEnquerer"), mChannel(aChannel), mMessage(aMsg) {} + + NS_IMETHOD Run() override { + mChannel->EnqueueOutgoingMessage(mChannel->mOutgoingMessages, mMessage); + return NS_OK; + } + + private: + ~OutboundEnqueuer() = default; + + RefPtr<WebSocketChannel> mChannel; + OutboundMessage* mMessage; +}; + +//----------------------------------------------------------------------------- +// WebSocketChannel +//----------------------------------------------------------------------------- + +WebSocketChannel::WebSocketChannel() + : mPort(0), + mCloseTimeout(20000), + mOpenTimeout(20000), + mConnecting(NOT_CONNECTING), + mMaxConcurrentConnections(200), + mInnerWindowID(0), + mGotUpgradeOK(0), + mRecvdHttpUpgradeTransport(0), + mAllowPMCE(1), + mPingOutstanding(0), + mReleaseOnTransmit(0), + mDataStarted(false), + mRequestedClose(false), + mClientClosed(false), + mServerClosed(false), + mStopped(false), + mCalledOnStop(false), + mTCPClosed(false), + mOpenedHttpChannel(false), + mIncrementedSessionCount(false), + mDecrementedSessionCount(false), + mMaxMessageSize(INT32_MAX), + mStopOnClose(NS_OK), + mServerCloseCode(CLOSE_ABNORMAL), + mScriptCloseCode(0), + mFragmentOpcode(nsIWebSocketFrame::OPCODE_CONTINUATION), + mFragmentAccumulator(0), + mBuffered(0), + mBufferSize(kIncomingBufferInitialSize), + mCurrentOut(nullptr), + mCurrentOutSent(0), + mHdrOutToSend(0), + mHdrOut(nullptr), + mCompressorMutex("WebSocketChannel::mCompressorMutex"), + mDynamicOutputSize(0), + mDynamicOutput(nullptr), + mPrivateBrowsing(false), + mConnectionLogService(nullptr), + mMutex("WebSocketChannel::mMutex") { + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + + LOG(("WebSocketChannel::WebSocketChannel() %p\n", this)); + + nsWSAdmissionManager::Init(); + + mFramePtr = mBuffer = static_cast<uint8_t*>(moz_xmalloc(mBufferSize)); + + nsresult rv; + mConnectionLogService = + do_GetService("@mozilla.org/network/dashboard;1", &rv); + if (NS_FAILED(rv)) LOG(("Failed to initiate dashboard service.")); + + mService = WebSocketEventService::GetOrCreate(); +} + +WebSocketChannel::~WebSocketChannel() { + LOG(("WebSocketChannel::~WebSocketChannel() %p\n", this)); + + if (mWasOpened) { + MOZ_ASSERT(mCalledOnStop, "WebSocket was opened but OnStop was not called"); + MOZ_ASSERT(mStopped, "WebSocket was opened but never stopped"); + } + MOZ_ASSERT(!mCancelable, "DNS/Proxy Request still alive at destruction"); + MOZ_ASSERT(!mConnecting, "Should not be connecting in destructor"); + + free(mBuffer); + free(mDynamicOutput); + delete mCurrentOut; + + while ((mCurrentOut = mOutgoingPingMessages.PopFront())) { + delete mCurrentOut; + } + while ((mCurrentOut = mOutgoingPongMessages.PopFront())) { + delete mCurrentOut; + } + while ((mCurrentOut = mOutgoingMessages.PopFront())) { + delete mCurrentOut; + } + + mListenerMT = nullptr; + + NS_ReleaseOnMainThread("WebSocketChannel::mService", mService.forget()); +} + +NS_IMETHODIMP +WebSocketChannel::Observe(nsISupports* subject, const char* topic, + const char16_t* data) { + LOG(("WebSocketChannel::Observe [topic=\"%s\"]\n", topic)); + + if (strcmp(topic, NS_NETWORK_LINK_TOPIC) == 0) { + nsCString converted = NS_ConvertUTF16toUTF8(data); + const char* state = converted.get(); + + if (strcmp(state, NS_NETWORK_LINK_DATA_CHANGED) == 0) { + LOG(("WebSocket: received network CHANGED event")); + + if (!mIOThread) { + // there has not been an asyncopen yet on the object and then we need + // no ping. + LOG(("WebSocket: early object, no ping needed")); + } else { + mIOThread->Dispatch( + NewRunnableMethod("net::WebSocketChannel::OnNetworkChanged", this, + &WebSocketChannel::OnNetworkChanged), + NS_DISPATCH_NORMAL); + } + } + } + + return NS_OK; +} + +nsresult WebSocketChannel::OnNetworkChanged() { + if (!mDataStarted) { + LOG(("WebSocket: data not started yet, no ping needed")); + return NS_OK; + } + + MOZ_ASSERT(mIOThread->IsOnCurrentThread(), "not on right thread"); + + LOG(("WebSocketChannel::OnNetworkChanged() - on socket thread %p", this)); + + if (mPingOutstanding) { + // If there's an outstanding ping that's expected to get a pong back + // we let that do its thing. + LOG(("WebSocket: pong already pending")); + return NS_OK; + } + + if (mPingForced) { + // avoid more than one + LOG(("WebSocket: forced ping timer already fired")); + return NS_OK; + } + + LOG(("nsWebSocketChannel:: Generating Ping as network changed\n")); + + if (!mPingTimer) { + // The ping timer is only conditionally running already. If it wasn't + // already created do it here. + mPingTimer = NS_NewTimer(); + if (!mPingTimer) { + LOG(("WebSocket: unable to create ping timer!")); + NS_WARNING("unable to create ping timer!"); + return NS_ERROR_OUT_OF_MEMORY; + } + } + // Trigger the ping timeout asap to fire off a new ping. Wait just + // a little bit to better avoid multi-triggers. + mPingForced = true; + mPingTimer->InitWithCallback(this, 200, nsITimer::TYPE_ONE_SHOT); + + return NS_OK; +} + +void WebSocketChannel::Shutdown() { nsWSAdmissionManager::Shutdown(); } + +void WebSocketChannel::GetEffectiveURL(nsAString& aEffectiveURL) const { + aEffectiveURL = mEffectiveURL; +} + +bool WebSocketChannel::IsEncrypted() const { return mEncrypted; } + +void WebSocketChannel::BeginOpen(bool aCalledFromAdmissionManager) { + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + + LOG(("WebSocketChannel::BeginOpen() %p\n", this)); + + // Important that we set CONNECTING_IN_PROGRESS before any call to + // AbortSession here: ensures that any remaining queued connection(s) are + // scheduled in OnStopSession + LOG(("Websocket: changing state to CONNECTING_IN_PROGRESS")); + mConnecting = CONNECTING_IN_PROGRESS; + + if (aCalledFromAdmissionManager) { + // When called from nsWSAdmissionManager post an event to avoid potential + // re-entering of nsWSAdmissionManager and its lock. + NS_DispatchToMainThread( + NewRunnableMethod("net::WebSocketChannel::BeginOpenInternal", this, + &WebSocketChannel::BeginOpenInternal), + NS_DISPATCH_NORMAL); + } else { + BeginOpenInternal(); + } +} + +// MainThread +void WebSocketChannel::BeginOpenInternal() { + LOG(("WebSocketChannel::BeginOpenInternal() %p\n", this)); + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + + nsresult rv; + + if (mRedirectCallback) { + LOG(("WebSocketChannel::BeginOpenInternal: Resuming Redirect\n")); + rv = mRedirectCallback->OnRedirectVerifyCallback(NS_OK); + mRedirectCallback = nullptr; + return; + } + + nsCOMPtr<nsIChannel> localChannel = do_QueryInterface(mChannel, &rv); + if (NS_FAILED(rv)) { + LOG(("WebSocketChannel::BeginOpenInternal: cannot async open\n")); + AbortSession(NS_ERROR_UNEXPECTED); + return; + } + + rv = localChannel->AsyncOpen(this); + + if (NS_FAILED(rv)) { + LOG(("WebSocketChannel::BeginOpenInternal: cannot async open\n")); + AbortSession(NS_ERROR_WEBSOCKET_CONNECTION_REFUSED); + return; + } + mOpenedHttpChannel = true; + + rv = NS_NewTimerWithCallback(getter_AddRefs(mOpenTimer), this, mOpenTimeout, + nsITimer::TYPE_ONE_SHOT); + if (NS_FAILED(rv)) { + LOG( + ("WebSocketChannel::BeginOpenInternal: cannot initialize open " + "timer\n")); + AbortSession(NS_ERROR_UNEXPECTED); + return; + } +} + +bool WebSocketChannel::IsPersistentFramePtr() { + return (mFramePtr >= mBuffer && mFramePtr < mBuffer + mBufferSize); +} + +// Extends the internal buffer by count and returns the total +// amount of data available for read +// +// Accumulated fragment size is passed in instead of using the member +// variable beacuse when transitioning from the stack to the persistent +// read buffer we want to explicitly include them in the buffer instead +// of as already existing data. +bool WebSocketChannel::UpdateReadBuffer(uint8_t* buffer, uint32_t count, + uint32_t accumulatedFragments, + uint32_t* available) { + LOG(("WebSocketChannel::UpdateReadBuffer() %p [%p %u]\n", this, buffer, + count)); + + if (!mBuffered) mFramePtr = mBuffer; + + MOZ_ASSERT(IsPersistentFramePtr(), "update read buffer bad mFramePtr"); + MOZ_ASSERT(mFramePtr - accumulatedFragments >= mBuffer, + "reserved FramePtr bad"); + + if (mBuffered + count <= mBufferSize) { + // append to existing buffer + LOG(("WebSocketChannel: update read buffer absorbed %u\n", count)); + } else if (mBuffered + count - (mFramePtr - accumulatedFragments - mBuffer) <= + mBufferSize) { + // make room in existing buffer by shifting unused data to start + mBuffered -= (mFramePtr - mBuffer - accumulatedFragments); + LOG(("WebSocketChannel: update read buffer shifted %u\n", mBuffered)); + ::memmove(mBuffer, mFramePtr - accumulatedFragments, mBuffered); + mFramePtr = mBuffer + accumulatedFragments; + } else { + // existing buffer is not sufficient, extend it + mBufferSize += count + 8192 + mBufferSize / 3; + LOG(("WebSocketChannel: update read buffer extended to %u\n", mBufferSize)); + uint8_t* old = mBuffer; + mBuffer = (uint8_t*)realloc(mBuffer, mBufferSize); + if (!mBuffer) { + mBuffer = old; + return false; + } + mFramePtr = mBuffer + (mFramePtr - old); + } + + ::memcpy(mBuffer + mBuffered, buffer, count); + mBuffered += count; + + if (available) *available = mBuffered - (mFramePtr - mBuffer); + + return true; +} + +nsresult WebSocketChannel::ProcessInput(uint8_t* buffer, uint32_t count) { + LOG(("WebSocketChannel::ProcessInput %p [%d %d]\n", this, count, mBuffered)); + MOZ_ASSERT(mIOThread->IsOnCurrentThread(), "not on right thread"); + + nsresult rv; + + // The purpose of ping/pong is to actively probe the peer so that an + // unreachable peer is not mistaken for a period of idleness. This + // implementation accepts any application level read activity as a sign of + // life, it does not necessarily have to be a pong. + ResetPingTimer(); + + uint32_t avail; + + if (!mBuffered) { + // Most of the time we can process right off the stack buffer without + // having to accumulate anything + mFramePtr = buffer; + avail = count; + } else { + if (!UpdateReadBuffer(buffer, count, mFragmentAccumulator, &avail)) { + return NS_ERROR_FILE_TOO_BIG; + } + } + + uint8_t* payload; + uint32_t totalAvail = avail; + + while (avail >= 2) { + int64_t payloadLength64 = mFramePtr[1] & kPayloadLengthBitsMask; + uint8_t finBit = mFramePtr[0] & kFinalFragBit; + uint8_t rsvBits = mFramePtr[0] & kRsvBitsMask; + uint8_t rsvBit1 = mFramePtr[0] & kRsv1Bit; + uint8_t rsvBit2 = mFramePtr[0] & kRsv2Bit; + uint8_t rsvBit3 = mFramePtr[0] & kRsv3Bit; + uint8_t opcode = mFramePtr[0] & kOpcodeBitsMask; + uint8_t maskBit = mFramePtr[1] & kMaskBit; + uint32_t mask = 0; + + uint32_t framingLength = 2; + if (maskBit) framingLength += 4; + + if (payloadLength64 < 126) { + if (avail < framingLength) break; + } else if (payloadLength64 == 126) { + // 16 bit length field + framingLength += 2; + if (avail < framingLength) break; + + payloadLength64 = mFramePtr[2] << 8 | mFramePtr[3]; + + if (payloadLength64 < 126) { + // Section 5.2 says that the minimal number of bytes MUST + // be used to encode the length in all cases + LOG(("WebSocketChannel:: non-minimal-encoded payload length")); + return NS_ERROR_ILLEGAL_VALUE; + } + + } else { + // 64 bit length + framingLength += 8; + if (avail < framingLength) break; + + if (mFramePtr[2] & 0x80) { + // Section 4.2 says that the most significant bit MUST be + // 0. (i.e. this is really a 63 bit value) + LOG(("WebSocketChannel:: high bit of 64 bit length set")); + return NS_ERROR_ILLEGAL_VALUE; + } + + // copy this in case it is unaligned + payloadLength64 = NetworkEndian::readInt64(mFramePtr + 2); + + if (payloadLength64 <= 0xffff) { + // Section 5.2 says that the minimal number of bytes MUST + // be used to encode the length in all cases + LOG(("WebSocketChannel:: non-minimal-encoded payload length")); + return NS_ERROR_ILLEGAL_VALUE; + } + } + + payload = mFramePtr + framingLength; + avail -= framingLength; + + LOG(("WebSocketChannel::ProcessInput: payload %" PRId64 " avail %" PRIu32 + "\n", + payloadLength64, avail)); + + CheckedInt<int64_t> payloadLengthChecked(payloadLength64); + payloadLengthChecked += mFragmentAccumulator; + if (!payloadLengthChecked.isValid() || + payloadLengthChecked.value() > mMaxMessageSize) { + return NS_ERROR_FILE_TOO_BIG; + } + + uint32_t payloadLength = static_cast<uint32_t>(payloadLength64); + + if (avail < payloadLength) break; + + LOG(("WebSocketChannel::ProcessInput: Frame accumulated - opcode %d\n", + opcode)); + + if (!maskBit && mIsServerSide) { + LOG( + ("WebSocketChannel::ProcessInput: unmasked frame received " + "from client\n")); + return NS_ERROR_ILLEGAL_VALUE; + } + + if (maskBit) { + if (!mIsServerSide) { + // The server should not be allowed to send masked frames to clients. + // But we've been allowing it for some time, so this should be + // deprecated with care. + LOG(("WebSocketChannel:: Client RECEIVING masked frame.")); + } + + mask = NetworkEndian::readUint32(payload - 4); + } + + if (mask) { + ApplyMask(mask, payload, payloadLength); + } else if (mIsServerSide) { + LOG( + ("WebSocketChannel::ProcessInput: masked frame with mask 0 received" + "from client\n")); + return NS_ERROR_ILLEGAL_VALUE; + } + + // Control codes are required to have the fin bit set + if (!finBit && (opcode & kControlFrameMask)) { + LOG(("WebSocketChannel:: fragmented control frame code %d\n", opcode)); + return NS_ERROR_ILLEGAL_VALUE; + } + + if (rsvBits) { + // PMCE sets RSV1 bit in the first fragment when the non-control frame + // is deflated + MutexAutoLock lock(mCompressorMutex); + if (mPMCECompressor && rsvBits == kRsv1Bit && mFragmentAccumulator == 0 && + !(opcode & kControlFrameMask)) { + mPMCECompressor->SetMessageDeflated(); + LOG(("WebSocketChannel::ProcessInput: received deflated frame\n")); + } else { + LOG(("WebSocketChannel::ProcessInput: unexpected reserved bits %x\n", + rsvBits)); + return NS_ERROR_ILLEGAL_VALUE; + } + } + + if (!finBit || opcode == nsIWebSocketFrame::OPCODE_CONTINUATION) { + // This is part of a fragment response + + // Only the first frame has a non zero op code: Make sure we don't see a + // first frame while some old fragments are open + if ((mFragmentAccumulator != 0) && + (opcode != nsIWebSocketFrame::OPCODE_CONTINUATION)) { + LOG(("WebSocketChannel:: nested fragments\n")); + return NS_ERROR_ILLEGAL_VALUE; + } + + LOG(("WebSocketChannel:: Accumulating Fragment %" PRIu32 "\n", + payloadLength)); + + if (opcode == nsIWebSocketFrame::OPCODE_CONTINUATION) { + // Make sure this continuation fragment isn't the first fragment + if (mFragmentOpcode == nsIWebSocketFrame::OPCODE_CONTINUATION) { + LOG(("WebSocketHeandler:: continuation code in first fragment\n")); + return NS_ERROR_ILLEGAL_VALUE; + } + + // For frag > 1 move the data body back on top of the headers + // so we have contiguous stream of data + MOZ_ASSERT(mFramePtr + framingLength == payload, + "payload offset from frameptr wrong"); + ::memmove(mFramePtr, payload, avail); + payload = mFramePtr; + if (mBuffered) mBuffered -= framingLength; + } else { + mFragmentOpcode = opcode; + } + + if (finBit) { + LOG(("WebSocketChannel:: Finalizing Fragment\n")); + payload -= mFragmentAccumulator; + payloadLength += mFragmentAccumulator; + avail += mFragmentAccumulator; + mFragmentAccumulator = 0; + opcode = mFragmentOpcode; + // reset to detect if next message illegally starts with continuation + mFragmentOpcode = nsIWebSocketFrame::OPCODE_CONTINUATION; + } else { + opcode = nsIWebSocketFrame::OPCODE_CONTINUATION; + mFragmentAccumulator += payloadLength; + } + } else if (mFragmentAccumulator != 0 && !(opcode & kControlFrameMask)) { + // This frame is not part of a fragment sequence but we + // have an open fragment.. it must be a control code or else + // we have a problem + LOG(("WebSocketChannel:: illegal fragment sequence\n")); + return NS_ERROR_ILLEGAL_VALUE; + } + + if (mServerClosed) { + LOG(("WebSocketChannel:: ignoring read frame code %d after close\n", + opcode)); + // nop + } else if (mStopped) { + LOG(("WebSocketChannel:: ignoring read frame code %d after completion\n", + opcode)); + } else if (opcode == nsIWebSocketFrame::OPCODE_TEXT) { + if (mListenerMT) { + nsCString utf8Data; + { + MutexAutoLock lock(mCompressorMutex); + bool isDeflated = + mPMCECompressor && mPMCECompressor->IsMessageDeflated(); + LOG(("WebSocketChannel:: %stext frame received\n", + isDeflated ? "deflated " : "")); + + if (isDeflated) { + rv = mPMCECompressor->Inflate(payload, payloadLength, utf8Data); + if (NS_FAILED(rv)) { + return rv; + } + LOG( + ("WebSocketChannel:: message successfully inflated " + "[origLength=%d, newLength=%zd]\n", + payloadLength, utf8Data.Length())); + } else { + if (!utf8Data.Assign((const char*)payload, payloadLength, + mozilla::fallible)) { + return NS_ERROR_OUT_OF_MEMORY; + } + } + } + + // Section 8.1 says to fail connection if invalid utf-8 in text message + if (!IsUtf8(utf8Data)) { + LOG(("WebSocketChannel:: text frame invalid utf-8\n")); + return NS_ERROR_CANNOT_CONVERT_DATA; + } + + RefPtr<WebSocketFrame> frame = mService->CreateFrameIfNeeded( + finBit, rsvBit1, rsvBit2, rsvBit3, opcode, maskBit, mask, utf8Data); + + if (frame) { + mService->FrameReceived(mSerial, mInnerWindowID, frame.forget()); + } + + if (nsCOMPtr<nsIEventTarget> target = GetTargetThread()) { + target->Dispatch(new CallOnMessageAvailable(this, utf8Data, -1), + NS_DISPATCH_NORMAL); + } else { + return NS_ERROR_UNEXPECTED; + } + if (mConnectionLogService && !mPrivateBrowsing) { + mConnectionLogService->NewMsgReceived(mHost, mSerial, count); + LOG(("Added new msg received for %s", mHost.get())); + } + } + } else if (opcode & kControlFrameMask) { + // control frames + if (payloadLength > 125) { + LOG(("WebSocketChannel:: bad control frame code %d length %d\n", opcode, + payloadLength)); + return NS_ERROR_ILLEGAL_VALUE; + } + + RefPtr<WebSocketFrame> frame = mService->CreateFrameIfNeeded( + finBit, rsvBit1, rsvBit2, rsvBit3, opcode, maskBit, mask, payload, + payloadLength); + + if (opcode == nsIWebSocketFrame::OPCODE_CLOSE) { + LOG(("WebSocketChannel:: close received\n")); + mServerClosed = true; + + mServerCloseCode = CLOSE_NO_STATUS; + if (payloadLength >= 2) { + mServerCloseCode = NetworkEndian::readUint16(payload); + LOG(("WebSocketChannel:: close recvd code %u\n", mServerCloseCode)); + uint16_t msglen = static_cast<uint16_t>(payloadLength - 2); + if (msglen > 0) { + mServerCloseReason.SetLength(msglen); + memcpy(mServerCloseReason.BeginWriting(), (const char*)payload + 2, + msglen); + + // section 8.1 says to replace received non utf-8 sequences + // (which are non-conformant to send) with u+fffd, + // but secteam feels that silently rewriting messages is + // inappropriate - so we will fail the connection instead. + if (!IsUtf8(mServerCloseReason)) { + LOG(("WebSocketChannel:: close frame invalid utf-8\n")); + return NS_ERROR_CANNOT_CONVERT_DATA; + } + + LOG(("WebSocketChannel:: close msg %s\n", + mServerCloseReason.get())); + } + } + + if (mCloseTimer) { + mCloseTimer->Cancel(); + mCloseTimer = nullptr; + } + + if (frame) { + // We send the frame immediately becuase we want to have it dispatched + // before the CallOnServerClose. + mService->FrameReceived(mSerial, mInnerWindowID, frame.forget()); + frame = nullptr; + } + + if (mListenerMT) { + if (nsCOMPtr<nsIEventTarget> target = GetTargetThread()) { + target->Dispatch(new CallOnServerClose(this, mServerCloseCode, + mServerCloseReason), + NS_DISPATCH_NORMAL); + } else { + return NS_ERROR_UNEXPECTED; + } + } + + if (mClientClosed) ReleaseSession(); + } else if (opcode == nsIWebSocketFrame::OPCODE_PING) { + LOG(("WebSocketChannel:: ping received\n")); + GeneratePong(payload, payloadLength); + } else if (opcode == nsIWebSocketFrame::OPCODE_PONG) { + // opcode OPCODE_PONG: the mere act of receiving the packet is all we + // need to do for the pong to trigger the activity timers + LOG(("WebSocketChannel:: pong received\n")); + } else { + /* unknown control frame opcode */ + LOG(("WebSocketChannel:: unknown control op code %d\n", opcode)); + return NS_ERROR_ILLEGAL_VALUE; + } + + if (mFragmentAccumulator) { + // Remove the control frame from the stream so we have a contiguous + // data buffer of reassembled fragments + LOG(("WebSocketChannel:: Removing Control From Read buffer\n")); + MOZ_ASSERT(mFramePtr + framingLength == payload, + "payload offset from frameptr wrong"); + ::memmove(mFramePtr, payload + payloadLength, avail - payloadLength); + payload = mFramePtr; + avail -= payloadLength; + if (mBuffered) mBuffered -= framingLength + payloadLength; + payloadLength = 0; + } + + if (frame) { + mService->FrameReceived(mSerial, mInnerWindowID, frame.forget()); + } + } else if (opcode == nsIWebSocketFrame::OPCODE_BINARY) { + if (mListenerMT) { + nsCString binaryData; + { + MutexAutoLock lock(mCompressorMutex); + bool isDeflated = + mPMCECompressor && mPMCECompressor->IsMessageDeflated(); + LOG(("WebSocketChannel:: %sbinary frame received\n", + isDeflated ? "deflated " : "")); + + if (isDeflated) { + rv = mPMCECompressor->Inflate(payload, payloadLength, binaryData); + if (NS_FAILED(rv)) { + return rv; + } + LOG( + ("WebSocketChannel:: message successfully inflated " + "[origLength=%d, newLength=%zd]\n", + payloadLength, binaryData.Length())); + } else { + if (!binaryData.Assign((const char*)payload, payloadLength, + mozilla::fallible)) { + return NS_ERROR_OUT_OF_MEMORY; + } + } + } + + RefPtr<WebSocketFrame> frame = + mService->CreateFrameIfNeeded(finBit, rsvBit1, rsvBit2, rsvBit3, + opcode, maskBit, mask, binaryData); + if (frame) { + mService->FrameReceived(mSerial, mInnerWindowID, frame.forget()); + } + + if (nsCOMPtr<nsIEventTarget> target = GetTargetThread()) { + target->Dispatch( + new CallOnMessageAvailable(this, binaryData, binaryData.Length()), + NS_DISPATCH_NORMAL); + } else { + return NS_ERROR_UNEXPECTED; + } + // To add the header to 'Networking Dashboard' log + if (mConnectionLogService && !mPrivateBrowsing) { + mConnectionLogService->NewMsgReceived(mHost, mSerial, count); + LOG(("Added new received msg for %s", mHost.get())); + } + } + } else if (opcode != nsIWebSocketFrame::OPCODE_CONTINUATION) { + /* unknown opcode */ + LOG(("WebSocketChannel:: unknown op code %d\n", opcode)); + return NS_ERROR_ILLEGAL_VALUE; + } + + mFramePtr = payload + payloadLength; + avail -= payloadLength; + totalAvail = avail; + } + + // Adjust the stateful buffer. If we were operating off the stack and + // now have a partial message then transition to the buffer, or if + // we were working off the buffer but no longer have any active state + // then transition to the stack + if (!IsPersistentFramePtr()) { + mBuffered = 0; + + if (mFragmentAccumulator) { + LOG(("WebSocketChannel:: Setup Buffer due to fragment")); + + if (!UpdateReadBuffer(mFramePtr - mFragmentAccumulator, + totalAvail + mFragmentAccumulator, 0, nullptr)) { + return NS_ERROR_FILE_TOO_BIG; + } + + // UpdateReadBuffer will reset the frameptr to the beginning + // of new saved state, so we need to skip past processed framgents + mFramePtr += mFragmentAccumulator; + } else if (totalAvail) { + LOG(("WebSocketChannel:: Setup Buffer due to partial frame")); + if (!UpdateReadBuffer(mFramePtr, totalAvail, 0, nullptr)) { + return NS_ERROR_FILE_TOO_BIG; + } + } + } else if (!mFragmentAccumulator && !totalAvail) { + // If we were working off a saved buffer state and there is no partial + // frame or fragment in process, then revert to stack behavior + LOG(("WebSocketChannel:: Internal buffering not needed anymore")); + mBuffered = 0; + + // release memory if we've been processing a large message + if (mBufferSize > kIncomingBufferStableSize) { + mBufferSize = kIncomingBufferStableSize; + free(mBuffer); + mBuffer = (uint8_t*)moz_xmalloc(mBufferSize); + } + } + return NS_OK; +} + +/* static */ +void WebSocketChannel::ApplyMask(uint32_t mask, uint8_t* data, uint64_t len) { + if (!data || len == 0) return; + + // Optimally we want to apply the mask 32 bits at a time, + // but the buffer might not be alligned. So we first deal with + // 0 to 3 bytes of preamble individually + + while (len && (reinterpret_cast<uintptr_t>(data) & 3)) { + *data ^= mask >> 24; + mask = RotateLeft(mask, 8); + data++; + len--; + } + + // perform mask on full words of data + + uint32_t* iData = (uint32_t*)data; + uint32_t* end = iData + (len / 4); + NetworkEndian::writeUint32(&mask, mask); + for (; iData < end; iData++) *iData ^= mask; + mask = NetworkEndian::readUint32(&mask); + data = (uint8_t*)iData; + len = len % 4; + + // There maybe up to 3 trailing bytes that need to be dealt with + // individually + + while (len) { + *data ^= mask >> 24; + mask = RotateLeft(mask, 8); + data++; + len--; + } +} + +void WebSocketChannel::GeneratePing() { + nsAutoCString buf; + buf.AssignLiteral("PING"); + EnqueueOutgoingMessage(mOutgoingPingMessages, + new OutboundMessage(kMsgTypePing, buf)); +} + +void WebSocketChannel::GeneratePong(uint8_t* payload, uint32_t len) { + nsAutoCString buf; + buf.SetLength(len); + if (buf.Length() < len) { + LOG(("WebSocketChannel::GeneratePong Allocation Failure\n")); + return; + } + + memcpy(buf.BeginWriting(), payload, len); + EnqueueOutgoingMessage(mOutgoingPongMessages, + new OutboundMessage(kMsgTypePong, buf)); +} + +void WebSocketChannel::EnqueueOutgoingMessage(nsDeque<OutboundMessage>& aQueue, + OutboundMessage* aMsg) { + MOZ_ASSERT(mIOThread->IsOnCurrentThread(), "not on right thread"); + + LOG( + ("WebSocketChannel::EnqueueOutgoingMessage %p " + "queueing msg %p [type=%s len=%d]\n", + this, aMsg, msgNames[aMsg->GetMsgType()], aMsg->Length())); + + aQueue.Push(aMsg); + if (mSocketOut) { + OnOutputStreamReady(mSocketOut); + } else { + DoEnqueueOutgoingMessage(); + } +} + +uint16_t WebSocketChannel::ResultToCloseCode(nsresult resultCode) { + if (NS_SUCCEEDED(resultCode)) return CLOSE_NORMAL; + + switch (resultCode) { + case NS_ERROR_FILE_TOO_BIG: + case NS_ERROR_OUT_OF_MEMORY: + return CLOSE_TOO_LARGE; + case NS_ERROR_CANNOT_CONVERT_DATA: + return CLOSE_INVALID_PAYLOAD; + case NS_ERROR_UNEXPECTED: + return CLOSE_INTERNAL_ERROR; + default: + return CLOSE_PROTOCOL_ERROR; + } +} + +void WebSocketChannel::PrimeNewOutgoingMessage() { + LOG(("WebSocketChannel::PrimeNewOutgoingMessage() %p\n", this)); + MOZ_ASSERT(mIOThread->IsOnCurrentThread(), "not on right thread"); + MOZ_ASSERT(!mCurrentOut, "Current message in progress"); + + nsresult rv = NS_OK; + + mCurrentOut = mOutgoingPongMessages.PopFront(); + if (mCurrentOut) { + MOZ_ASSERT(mCurrentOut->GetMsgType() == kMsgTypePong, "Not pong message!"); + } else { + mCurrentOut = mOutgoingPingMessages.PopFront(); + if (mCurrentOut) { + MOZ_ASSERT(mCurrentOut->GetMsgType() == kMsgTypePing, + "Not ping message!"); + } else { + mCurrentOut = mOutgoingMessages.PopFront(); + } + } + + if (!mCurrentOut) return; + + auto cleanupAfterFailure = + MakeScopeExit([&] { DeleteCurrentOutGoingMessage(); }); + + WsMsgType msgType = mCurrentOut->GetMsgType(); + + LOG( + ("WebSocketChannel::PrimeNewOutgoingMessage " + "%p found queued msg %p [type=%s len=%d]\n", + this, mCurrentOut, msgNames[msgType], mCurrentOut->Length())); + + mCurrentOutSent = 0; + mHdrOut = mOutHeader; + + uint8_t maskBit = mIsServerSide ? 0 : kMaskBit; + uint8_t maskSize = mIsServerSide ? 0 : 4; + + uint8_t* payload = nullptr; + + if (msgType == kMsgTypeFin) { + // This is a demand to create a close message + if (mClientClosed) { + DeleteCurrentOutGoingMessage(); + PrimeNewOutgoingMessage(); + cleanupAfterFailure.release(); + return; + } + + mClientClosed = true; + mOutHeader[0] = kFinalFragBit | nsIWebSocketFrame::OPCODE_CLOSE; + mOutHeader[1] = maskBit; + + // payload is offset 2 plus size of the mask + payload = mOutHeader + 2 + maskSize; + + // The close reason code sits in the first 2 bytes of payload + // If the channel user provided a code and reason during Close() + // and there isn't an internal error, use that. + if (NS_SUCCEEDED(mStopOnClose)) { + MutexAutoLock lock(mMutex); + if (mScriptCloseCode) { + NetworkEndian::writeUint16(payload, mScriptCloseCode); + mOutHeader[1] += 2; + mHdrOutToSend = 4 + maskSize; + if (!mScriptCloseReason.IsEmpty()) { + MOZ_ASSERT(mScriptCloseReason.Length() <= 123, + "Close Reason Too Long"); + mOutHeader[1] += mScriptCloseReason.Length(); + mHdrOutToSend += mScriptCloseReason.Length(); + memcpy(payload + 2, mScriptCloseReason.BeginReading(), + mScriptCloseReason.Length()); + } + } else { + // No close code/reason, so payload length = 0. We must still send mask + // even though it's not used. Keep payload offset so we write mask + // below. + mHdrOutToSend = 2 + maskSize; + } + } else { + NetworkEndian::writeUint16(payload, ResultToCloseCode(mStopOnClose)); + mOutHeader[1] += 2; + mHdrOutToSend = 4 + maskSize; + } + + if (mServerClosed) { + /* bidi close complete */ + mReleaseOnTransmit = 1; + } else if (NS_FAILED(mStopOnClose)) { + /* result of abort session - give up */ + StopSession(mStopOnClose); + } else { + /* wait for reciprocal close from server */ + rv = NS_NewTimerWithCallback(getter_AddRefs(mCloseTimer), this, + mCloseTimeout, nsITimer::TYPE_ONE_SHOT); + if (NS_FAILED(rv)) { + StopSession(rv); + } + } + } else { + switch (msgType) { + case kMsgTypePong: + mOutHeader[0] = kFinalFragBit | nsIWebSocketFrame::OPCODE_PONG; + break; + case kMsgTypePing: + mOutHeader[0] = kFinalFragBit | nsIWebSocketFrame::OPCODE_PING; + break; + case kMsgTypeString: + mOutHeader[0] = kFinalFragBit | nsIWebSocketFrame::OPCODE_TEXT; + break; + case kMsgTypeStream: + // HACK ALERT: read in entire stream into string. + // Will block socket transport thread if file is blocking. + // TODO: bug 704447: don't block socket thread! + rv = mCurrentOut->ConvertStreamToString(); + if (NS_FAILED(rv)) { + AbortSession(NS_ERROR_FILE_TOO_BIG); + return; + } + // Now we're a binary string + msgType = kMsgTypeBinaryString; + + // no break: fall down into binary string case + [[fallthrough]]; + + case kMsgTypeBinaryString: + mOutHeader[0] = kFinalFragBit | nsIWebSocketFrame::OPCODE_BINARY; + break; + case kMsgTypeFin: + MOZ_ASSERT(false, "unreachable"); // avoid compiler warning + break; + } + + // deflate the payload if PMCE is negotiated + MutexAutoLock lock(mCompressorMutex); + if (mPMCECompressor && + (msgType == kMsgTypeString || msgType == kMsgTypeBinaryString)) { + if (mCurrentOut->DeflatePayload(mPMCECompressor.get())) { + // The payload was deflated successfully, set RSV1 bit + mOutHeader[0] |= kRsv1Bit; + + LOG( + ("WebSocketChannel::PrimeNewOutgoingMessage %p current msg %p was " + "deflated [origLength=%d, newLength=%d].\n", + this, mCurrentOut, mCurrentOut->OrigLength(), + mCurrentOut->Length())); + } + } + + if (mCurrentOut->Length() < 126) { + mOutHeader[1] = mCurrentOut->Length() | maskBit; + mHdrOutToSend = 2 + maskSize; + } else if (mCurrentOut->Length() <= 0xffff) { + mOutHeader[1] = 126 | maskBit; + NetworkEndian::writeUint16(mOutHeader + sizeof(uint16_t), + mCurrentOut->Length()); + mHdrOutToSend = 4 + maskSize; + } else { + mOutHeader[1] = 127 | maskBit; + NetworkEndian::writeUint64(mOutHeader + 2, mCurrentOut->Length()); + mHdrOutToSend = 10 + maskSize; + } + payload = mOutHeader + mHdrOutToSend; + } + + MOZ_ASSERT(payload, "payload offset not found"); + + uint32_t mask = 0; + if (!mIsServerSide) { + // Perform the sending mask. Never use a zero mask + do { + static_assert(4 == sizeof(mask), "Size of the mask should be equal to 4"); + nsresult rv = mRandomGenerator->GenerateRandomBytesInto(mask); + if (NS_FAILED(rv)) { + LOG( + ("WebSocketChannel::PrimeNewOutgoingMessage(): " + "GenerateRandomBytes failure %" PRIx32 "\n", + static_cast<uint32_t>(rv))); + AbortSession(rv); + return; + } + } while (!mask); + NetworkEndian::writeUint32(payload - sizeof(uint32_t), mask); + } + + LOG(("WebSocketChannel::PrimeNewOutgoingMessage() using mask %08x\n", mask)); + + // We don't mask the framing, but occasionally we stick a little payload + // data in the buffer used for the framing. Close frames are the current + // example. This data needs to be masked, but it is never more than a + // handful of bytes and might rotate the mask, so we can just do it locally. + // For real data frames we ship the bulk of the payload off to ApplyMask() + + RefPtr<WebSocketFrame> frame = mService->CreateFrameIfNeeded( + mOutHeader[0] & WebSocketChannel::kFinalFragBit, + mOutHeader[0] & WebSocketChannel::kRsv1Bit, + mOutHeader[0] & WebSocketChannel::kRsv2Bit, + mOutHeader[0] & WebSocketChannel::kRsv3Bit, + mOutHeader[0] & WebSocketChannel::kOpcodeBitsMask, + mOutHeader[1] & WebSocketChannel::kMaskBit, mask, payload, + mHdrOutToSend - (payload - mOutHeader), mCurrentOut->BeginOrigReading(), + mCurrentOut->OrigLength()); + + if (frame) { + mService->FrameSent(mSerial, mInnerWindowID, frame.forget()); + } + + if (mask) { + while (payload < (mOutHeader + mHdrOutToSend)) { + *payload ^= mask >> 24; + mask = RotateLeft(mask, 8); + payload++; + } + + // Mask the real message payloads + ApplyMask(mask, mCurrentOut->BeginWriting(), mCurrentOut->Length()); + } + + int32_t len = mCurrentOut->Length(); + + // for small frames, copy it all together for a contiguous write + if (len && len <= kCopyBreak) { + memcpy(mOutHeader + mHdrOutToSend, mCurrentOut->BeginWriting(), len); + mHdrOutToSend += len; + mCurrentOutSent = len; + } + + // Transmitting begins - mHdrOutToSend bytes from mOutHeader and + // mCurrentOut->Length() bytes from mCurrentOut. The latter may be + // coaleseced into the former for small messages or as the result of the + // compression process. + + cleanupAfterFailure.release(); +} + +void WebSocketChannel::DeleteCurrentOutGoingMessage() { + delete mCurrentOut; + mCurrentOut = nullptr; + mCurrentOutSent = 0; +} + +void WebSocketChannel::EnsureHdrOut(uint32_t size) { + LOG(("WebSocketChannel::EnsureHdrOut() %p [%d]\n", this, size)); + + if (mDynamicOutputSize < size) { + mDynamicOutputSize = size; + mDynamicOutput = (uint8_t*)moz_xrealloc(mDynamicOutput, mDynamicOutputSize); + } + + mHdrOut = mDynamicOutput; +} + +namespace { + +class RemoveObserverRunnable : public Runnable { + RefPtr<WebSocketChannel> mChannel; + + public: + explicit RemoveObserverRunnable(WebSocketChannel* aChannel) + : Runnable("net::RemoveObserverRunnable"), mChannel(aChannel) {} + + NS_IMETHOD Run() override { + nsCOMPtr<nsIObserverService> observerService = + mozilla::services::GetObserverService(); + if (!observerService) { + NS_WARNING("failed to get observer service"); + return NS_OK; + } + + observerService->RemoveObserver(mChannel, NS_NETWORK_LINK_TOPIC); + return NS_OK; + } +}; + +} // namespace + +void WebSocketChannel::CleanupConnection() { + // normally this should be called on socket thread, but it may be called + // on MainThread + + LOG(("WebSocketChannel::CleanupConnection() %p", this)); + // This needs to run on the IOThread so we don't need to lock a bunch of these + if (!mIOThread->IsOnCurrentThread()) { + mIOThread->Dispatch( + NewRunnableMethod("net::WebSocketChannel::CleanupConnection", this, + &WebSocketChannel::CleanupConnection), + NS_DISPATCH_NORMAL); + return; + } + + if (mLingeringCloseTimer) { + mLingeringCloseTimer->Cancel(); + mLingeringCloseTimer = nullptr; + } + + if (mSocketIn) { + if (mDataStarted) { + mSocketIn->AsyncWait(nullptr, 0, 0, nullptr); + } + mSocketIn = nullptr; + } + + if (mSocketOut) { + mSocketOut->AsyncWait(nullptr, 0, 0, nullptr); + mSocketOut = nullptr; + } + + if (mTransport) { + mTransport->SetSecurityCallbacks(nullptr); + mTransport->SetEventSink(nullptr, nullptr); + mTransport->Close(NS_BASE_STREAM_CLOSED); + mTransport = nullptr; + } + + if (mConnection) { + mConnection->Close(); + mConnection = nullptr; + } + + if (mConnectionLogService && !mPrivateBrowsing) { + mConnectionLogService->RemoveHost(mHost, mSerial); + } + + // The observer has to be removed on the main-thread. + NS_DispatchToMainThread(new RemoveObserverRunnable(this)); + + DecrementSessionCount(); +} + +void WebSocketChannel::StopSession(nsresult reason) { + LOG(("WebSocketChannel::StopSession() %p [%" PRIx32 "]\n", this, + static_cast<uint32_t>(reason))); + + { + MutexAutoLock lock(mMutex); + if (mStopped) { + return; + } + mStopped = true; + } + + DoStopSession(reason); +} + +void WebSocketChannel::DoStopSession(nsresult reason) { + LOG(("WebSocketChannel::DoStopSession() %p [%" PRIx32 "]\n", this, + static_cast<uint32_t>(reason))); + + // normally this should be called on socket thread, but it is ok to call it + // from OnStartRequest before the socket thread machine has gotten underway. + // If mDataStarted is false, this is called on MainThread for Close(). + // Otherwise it should be called on the IO thread + + MOZ_ASSERT(mStopped); + MOZ_ASSERT(mIOThread->IsOnCurrentThread() || mTCPClosed || !mDataStarted); + + if (!mOpenedHttpChannel) { + // The HTTP channel information will never be used in this case + NS_ReleaseOnMainThread("WebSocketChannel::mChannel", mChannel.forget()); + NS_ReleaseOnMainThread("WebSocketChannel::mHttpChannel", + mHttpChannel.forget()); + NS_ReleaseOnMainThread("WebSocketChannel::mLoadGroup", mLoadGroup.forget()); + NS_ReleaseOnMainThread("WebSocketChannel::mCallbacks", mCallbacks.forget()); + } + + if (mCloseTimer) { + mCloseTimer->Cancel(); + mCloseTimer = nullptr; + } + + // mOpenTimer must be null if mDataStarted is true and we're not on MainThread + if (mOpenTimer) { + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + mOpenTimer->Cancel(); + mOpenTimer = nullptr; + } + + { + MutexAutoLock lock(mMutex); + if (mReconnectDelayTimer) { + mReconnectDelayTimer->Cancel(); + NS_ReleaseOnMainThread("WebSocketChannel::mMutex", + mReconnectDelayTimer.forget()); + } + } + + if (mPingTimer) { + mPingTimer->Cancel(); + mPingTimer = nullptr; + } + + if (!mTCPClosed && mDataStarted) { + if (mSocketIn) { + // Drain, within reason, this socket. if we leave any data + // unconsumed (including the tcp fin) a RST will be generated + // The right thing to do here is shutdown(SHUT_WR) and then wait + // a little while to see if any data comes in.. but there is no + // reason to delay things for that when the websocket handshake + // is supposed to guarantee a quiet connection except for that fin. + + char buffer[512]; + uint32_t count = 0; + uint32_t total = 0; + nsresult rv; + do { + total += count; + rv = mSocketIn->Read(buffer, 512, &count); + if (rv != NS_BASE_STREAM_WOULD_BLOCK && (NS_FAILED(rv) || count == 0)) { + mTCPClosed = true; + } + } while (NS_SUCCEEDED(rv) && count > 0 && total < 32000); + } else if (mConnection) { + mConnection->DrainSocketData(); + } + } + + int32_t sessionCount = kLingeringCloseThreshold; + nsWSAdmissionManager::GetSessionCount(sessionCount); + + if (!mTCPClosed && (mTransport || mConnection) && + sessionCount < kLingeringCloseThreshold) { + // 7.1.1 says that the client SHOULD wait for the server to close the TCP + // connection. This is so we can reuse port numbers before 2 MSL expires, + // which is not really as much of a concern for us as the amount of state + // that might be accrued by keeping this channel object around waiting for + // the server. We handle the SHOULD by waiting a short time in the common + // case, but not waiting in the case of high concurrency. + // + // Normally this will be taken care of in AbortSession() after mTCPClosed + // is set when the server close arrives without waiting for the timeout to + // expire. + + LOG(("WebSocketChannel::DoStopSession: Wait for Server TCP close")); + + nsresult rv; + rv = NS_NewTimerWithCallback(getter_AddRefs(mLingeringCloseTimer), this, + kLingeringCloseTimeout, + nsITimer::TYPE_ONE_SHOT); + if (NS_FAILED(rv)) CleanupConnection(); + } else { + CleanupConnection(); + } + + { + MutexAutoLock lock(mMutex); + if (mCancelable) { + mCancelable->Cancel(NS_ERROR_UNEXPECTED); + mCancelable = nullptr; + } + } + + { + MutexAutoLock lock(mCompressorMutex); + mPMCECompressor = nullptr; + } + if (!mCalledOnStop) { + mCalledOnStop = true; + + nsWSAdmissionManager::OnStopSession(this, reason); + + RefPtr<CallOnStop> runnable = new CallOnStop(this, reason); + if (nsCOMPtr<nsIEventTarget> target = GetTargetThread()) { + target->Dispatch(runnable, NS_DISPATCH_NORMAL); + } + } +} + +// Called from MainThread, and called from IOThread in +// PrimeNewOutgoingMessage +void WebSocketChannel::AbortSession(nsresult reason) { + LOG(("WebSocketChannel::AbortSession() %p [reason %" PRIx32 + "] stopped = %d\n", + this, static_cast<uint32_t>(reason), !!mStopped)); + + MOZ_ASSERT(NS_FAILED(reason), "reason must be a failure!"); + + // normally this should be called on socket thread, but it is ok to call it + // from the main thread before StartWebsocketData() has completed + MOZ_ASSERT(mIOThread->IsOnCurrentThread() || !mDataStarted); + + // When we are failing we need to close the TCP connection immediately + // as per 7.1.1 + mTCPClosed = true; + + if (mLingeringCloseTimer) { + MOZ_ASSERT(mStopped, "Lingering without Stop"); + LOG(("WebSocketChannel:: Cleanup connection based on TCP Close")); + CleanupConnection(); + return; + } + + { + MutexAutoLock lock(mMutex); + if (mStopped) { + return; + } + + if ((mTransport || mConnection) && reason != NS_BASE_STREAM_CLOSED && + !mRequestedClose && !mClientClosed && !mServerClosed && mDataStarted) { + mRequestedClose = true; + mStopOnClose = reason; + mIOThread->Dispatch( + new OutboundEnqueuer(this, + new OutboundMessage(kMsgTypeFin, VoidCString())), + nsIEventTarget::DISPATCH_NORMAL); + return; + } + + mStopped = true; + } + + DoStopSession(reason); +} + +// ReleaseSession is called on orderly shutdown +void WebSocketChannel::ReleaseSession() { + LOG(("WebSocketChannel::ReleaseSession() %p stopped = %d\n", this, + !!mStopped)); + MOZ_ASSERT(mIOThread->IsOnCurrentThread(), "not on right thread"); + + StopSession(NS_OK); +} + +void WebSocketChannel::IncrementSessionCount() { + if (!mIncrementedSessionCount) { + nsWSAdmissionManager::IncrementSessionCount(); + mIncrementedSessionCount = true; + } +} + +void WebSocketChannel::DecrementSessionCount() { + // Make sure we decrement session count only once, and only if we incremented + // it. This code is thread-safe: sWebSocketAdmissions->DecrementSessionCount + // is atomic, and mIncrementedSessionCount/mDecrementedSessionCount are set at + // times when they'll never be a race condition for checking/setting them. + if (mIncrementedSessionCount && !mDecrementedSessionCount) { + nsWSAdmissionManager::DecrementSessionCount(); + mDecrementedSessionCount = true; + } +} + +namespace { +enum ExtensionParseMode { eParseServerSide, eParseClientSide }; +} + +static nsresult ParseWebSocketExtension(const nsACString& aExtension, + ExtensionParseMode aMode, + bool& aClientNoContextTakeover, + bool& aServerNoContextTakeover, + int32_t& aClientMaxWindowBits, + int32_t& aServerMaxWindowBits) { + nsCCharSeparatedTokenizer tokens(aExtension, ';'); + + if (!tokens.hasMoreTokens() || + !tokens.nextToken().EqualsLiteral("permessage-deflate")) { + LOG( + ("WebSocketChannel::ParseWebSocketExtension: " + "HTTP Sec-WebSocket-Extensions negotiated unknown value %s\n", + PromiseFlatCString(aExtension).get())); + return NS_ERROR_ILLEGAL_VALUE; + } + + aClientNoContextTakeover = aServerNoContextTakeover = false; + aClientMaxWindowBits = aServerMaxWindowBits = -1; + + while (tokens.hasMoreTokens()) { + auto token = tokens.nextToken(); + + int32_t nameEnd, valueStart; + int32_t delimPos = token.FindChar('='); + if (delimPos == kNotFound) { + nameEnd = token.Length(); + valueStart = token.Length(); + } else { + nameEnd = delimPos; + valueStart = delimPos + 1; + } + + auto paramName = Substring(token, 0, nameEnd); + auto paramValue = Substring(token, valueStart); + + if (paramName.EqualsLiteral("client_no_context_takeover")) { + if (!paramValue.IsEmpty()) { + LOG( + ("WebSocketChannel::ParseWebSocketExtension: parameter " + "client_no_context_takeover must not have value, found %s\n", + PromiseFlatCString(paramValue).get())); + return NS_ERROR_ILLEGAL_VALUE; + } + if (aClientNoContextTakeover) { + LOG( + ("WebSocketChannel::ParseWebSocketExtension: found multiple " + "parameters client_no_context_takeover\n")); + return NS_ERROR_ILLEGAL_VALUE; + } + aClientNoContextTakeover = true; + } else if (paramName.EqualsLiteral("server_no_context_takeover")) { + if (!paramValue.IsEmpty()) { + LOG( + ("WebSocketChannel::ParseWebSocketExtension: parameter " + "server_no_context_takeover must not have value, found %s\n", + PromiseFlatCString(paramValue).get())); + return NS_ERROR_ILLEGAL_VALUE; + } + if (aServerNoContextTakeover) { + LOG( + ("WebSocketChannel::ParseWebSocketExtension: found multiple " + "parameters server_no_context_takeover\n")); + return NS_ERROR_ILLEGAL_VALUE; + } + aServerNoContextTakeover = true; + } else if (paramName.EqualsLiteral("client_max_window_bits")) { + if (aClientMaxWindowBits != -1) { + LOG( + ("WebSocketChannel::ParseWebSocketExtension: found multiple " + "parameters client_max_window_bits\n")); + return NS_ERROR_ILLEGAL_VALUE; + } + + if (aMode == eParseServerSide && paramValue.IsEmpty()) { + // Use -2 to indicate that "client_max_window_bits" has been parsed, + // but had no value. + aClientMaxWindowBits = -2; + } else { + nsresult errcode; + aClientMaxWindowBits = + PromiseFlatCString(paramValue).ToInteger(&errcode); + if (NS_FAILED(errcode) || aClientMaxWindowBits < 8 || + aClientMaxWindowBits > 15) { + LOG( + ("WebSocketChannel::ParseWebSocketExtension: found invalid " + "parameter client_max_window_bits %s\n", + PromiseFlatCString(paramValue).get())); + return NS_ERROR_ILLEGAL_VALUE; + } + } + } else if (paramName.EqualsLiteral("server_max_window_bits")) { + if (aServerMaxWindowBits != -1) { + LOG( + ("WebSocketChannel::ParseWebSocketExtension: found multiple " + "parameters server_max_window_bits\n")); + return NS_ERROR_ILLEGAL_VALUE; + } + + nsresult errcode; + aServerMaxWindowBits = PromiseFlatCString(paramValue).ToInteger(&errcode); + if (NS_FAILED(errcode) || aServerMaxWindowBits < 8 || + aServerMaxWindowBits > 15) { + LOG( + ("WebSocketChannel::ParseWebSocketExtension: found invalid " + "parameter server_max_window_bits %s\n", + PromiseFlatCString(paramValue).get())); + return NS_ERROR_ILLEGAL_VALUE; + } + } else { + LOG( + ("WebSocketChannel::ParseWebSocketExtension: found unknown " + "parameter %s\n", + PromiseFlatCString(paramName).get())); + return NS_ERROR_ILLEGAL_VALUE; + } + } + + if (aClientMaxWindowBits == -2) { + aClientMaxWindowBits = -1; + } + + return NS_OK; +} + +nsresult WebSocketChannel::HandleExtensions() { + LOG(("WebSocketChannel::HandleExtensions() %p\n", this)); + + nsresult rv; + nsAutoCString extensions; + + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + + rv = mHttpChannel->GetResponseHeader("Sec-WebSocket-Extensions"_ns, + extensions); + extensions.CompressWhitespace(); + if (extensions.IsEmpty()) { + return NS_OK; + } + + LOG( + ("WebSocketChannel::HandleExtensions: received " + "Sec-WebSocket-Extensions header: %s\n", + extensions.get())); + + bool clientNoContextTakeover; + bool serverNoContextTakeover; + int32_t clientMaxWindowBits; + int32_t serverMaxWindowBits; + + rv = ParseWebSocketExtension(extensions, eParseClientSide, + clientNoContextTakeover, serverNoContextTakeover, + clientMaxWindowBits, serverMaxWindowBits); + if (NS_FAILED(rv)) { + AbortSession(rv); + return rv; + } + + if (!mAllowPMCE) { + LOG( + ("WebSocketChannel::HandleExtensions: " + "Recvd permessage-deflate which wasn't offered\n")); + AbortSession(NS_ERROR_ILLEGAL_VALUE); + return NS_ERROR_ILLEGAL_VALUE; + } + + if (clientMaxWindowBits == -1) { + clientMaxWindowBits = 15; + } + if (serverMaxWindowBits == -1) { + serverMaxWindowBits = 15; + } + + MutexAutoLock lock(mCompressorMutex); + mPMCECompressor = MakeUnique<PMCECompression>( + clientNoContextTakeover, clientMaxWindowBits, serverMaxWindowBits); + if (mPMCECompressor->Active()) { + LOG( + ("WebSocketChannel::HandleExtensions: PMCE negotiated, %susing " + "context takeover, clientMaxWindowBits=%d, " + "serverMaxWindowBits=%d\n", + clientNoContextTakeover ? "NOT " : "", clientMaxWindowBits, + serverMaxWindowBits)); + + mNegotiatedExtensions = "permessage-deflate"; + } else { + LOG( + ("WebSocketChannel::HandleExtensions: Cannot init PMCE " + "compression object\n")); + mPMCECompressor = nullptr; + AbortSession(NS_ERROR_UNEXPECTED); + return NS_ERROR_UNEXPECTED; + } + + return NS_OK; +} + +void ProcessServerWebSocketExtensions(const nsACString& aExtensions, + nsACString& aNegotiatedExtensions) { + aNegotiatedExtensions.Truncate(); + + nsCOMPtr<nsIPrefBranch> prefService = + do_GetService(NS_PREFSERVICE_CONTRACTID); + if (prefService) { + bool boolpref; + nsresult rv = prefService->GetBoolPref( + "network.websocket.extensions.permessage-deflate", &boolpref); + if (NS_SUCCEEDED(rv) && !boolpref) { + return; + } + } + + for (const auto& ext : + nsCCharSeparatedTokenizer(aExtensions, ',').ToRange()) { + bool clientNoContextTakeover; + bool serverNoContextTakeover; + int32_t clientMaxWindowBits; + int32_t serverMaxWindowBits; + + nsresult rv = ParseWebSocketExtension( + ext, eParseServerSide, clientNoContextTakeover, serverNoContextTakeover, + clientMaxWindowBits, serverMaxWindowBits); + if (NS_FAILED(rv)) { + // Ignore extensions that we can't parse + continue; + } + + aNegotiatedExtensions.AssignLiteral("permessage-deflate"); + if (clientNoContextTakeover) { + aNegotiatedExtensions.AppendLiteral(";client_no_context_takeover"); + } + if (serverNoContextTakeover) { + aNegotiatedExtensions.AppendLiteral(";server_no_context_takeover"); + } + if (clientMaxWindowBits != -1) { + aNegotiatedExtensions.AppendLiteral(";client_max_window_bits="); + aNegotiatedExtensions.AppendInt(clientMaxWindowBits); + } + if (serverMaxWindowBits != -1) { + aNegotiatedExtensions.AppendLiteral(";server_max_window_bits="); + aNegotiatedExtensions.AppendInt(serverMaxWindowBits); + } + + return; + } +} + +nsresult CalculateWebSocketHashedSecret(const nsACString& aKey, + nsACString& aHash) { + nsresult rv; + nsCString key = aKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"_ns; + nsCOMPtr<nsICryptoHash> hasher = + do_CreateInstance(NS_CRYPTO_HASH_CONTRACTID, &rv); + NS_ENSURE_SUCCESS(rv, rv); + rv = hasher->Init(nsICryptoHash::SHA1); + NS_ENSURE_SUCCESS(rv, rv); + rv = hasher->Update((const uint8_t*)key.BeginWriting(), key.Length()); + NS_ENSURE_SUCCESS(rv, rv); + return hasher->Finish(true, aHash); +} + +nsresult WebSocketChannel::SetupRequest() { + LOG(("WebSocketChannel::SetupRequest() %p\n", this)); + + nsresult rv; + + if (mLoadGroup) { + rv = mHttpChannel->SetLoadGroup(mLoadGroup); + NS_ENSURE_SUCCESS(rv, rv); + } + + rv = mHttpChannel->SetLoadFlags( + nsIRequest::LOAD_BACKGROUND | nsIRequest::INHIBIT_CACHING | + nsIRequest::LOAD_BYPASS_CACHE | nsIChannel::LOAD_BYPASS_SERVICE_WORKER); + NS_ENSURE_SUCCESS(rv, rv); + + // we never let websockets be blocked by head CSS/JS loads to avoid + // potential deadlock where server generation of CSS/JS requires + // an XHR signal. + nsCOMPtr<nsIClassOfService> cos(do_QueryInterface(mChannel)); + if (cos) { + cos->AddClassFlags(nsIClassOfService::Unblocked); + } + + // draft-ietf-hybi-thewebsocketprotocol-07 illustrates Upgrade: websocket + // in lower case, so go with that. It is technically case insensitive. + rv = mChannel->HTTPUpgrade("websocket"_ns, this); + NS_ENSURE_SUCCESS(rv, rv); + + rv = mHttpChannel->SetRequestHeader("Sec-WebSocket-Version"_ns, + nsLiteralCString(SEC_WEBSOCKET_VERSION), + false); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + + if (!mOrigin.IsEmpty()) { + rv = mHttpChannel->SetRequestHeader("Origin"_ns, mOrigin, false); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + } + + if (!mProtocol.IsEmpty()) { + rv = mHttpChannel->SetRequestHeader("Sec-WebSocket-Protocol"_ns, mProtocol, + true); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + } + + if (mAllowPMCE) { + rv = mHttpChannel->SetRequestHeader("Sec-WebSocket-Extensions"_ns, + "permessage-deflate"_ns, false); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + } + + uint8_t* secKey; + nsAutoCString secKeyString; + + rv = mRandomGenerator->GenerateRandomBytes(16, &secKey); + NS_ENSURE_SUCCESS(rv, rv); + rv = Base64Encode(reinterpret_cast<const char*>(secKey), 16, secKeyString); + free(secKey); + if (NS_FAILED(rv)) { + return rv; + } + + rv = mHttpChannel->SetRequestHeader("Sec-WebSocket-Key"_ns, secKeyString, + false); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + LOG(("WebSocketChannel::SetupRequest: client key %s\n", secKeyString.get())); + + // prepare the value we expect to see in + // the sec-websocket-accept response header + rv = CalculateWebSocketHashedSecret(secKeyString, mHashedSecret); + NS_ENSURE_SUCCESS(rv, rv); + LOG(("WebSocketChannel::SetupRequest: expected server key %s\n", + mHashedSecret.get())); + + mHttpChannelId = mHttpChannel->ChannelId(); + + return NS_OK; +} + +nsresult WebSocketChannel::DoAdmissionDNS() { + nsresult rv; + + nsCString hostName; + rv = mURI->GetHost(hostName); + NS_ENSURE_SUCCESS(rv, rv); + mAddress = hostName; + nsCString path; + rv = mURI->GetFilePath(path); + NS_ENSURE_SUCCESS(rv, rv); + mPath = path; + rv = mURI->GetPort(&mPort); + NS_ENSURE_SUCCESS(rv, rv); + if (mPort == -1) mPort = (mEncrypted ? kDefaultWSSPort : kDefaultWSPort); + nsCOMPtr<nsIDNSService> dns = do_GetService(NS_DNSSERVICE_CONTRACTID, &rv); + NS_ENSURE_SUCCESS(rv, rv); + nsCOMPtr<nsIEventTarget> main = GetMainThreadSerialEventTarget(); + nsCOMPtr<nsICancelable> cancelable; + rv = dns->AsyncResolveNative(hostName, nsIDNSService::RESOLVE_TYPE_DEFAULT, + nsIDNSService::RESOLVE_DEFAULT_FLAGS, nullptr, + this, main, mLoadInfo->GetOriginAttributes(), + getter_AddRefs(cancelable)); + if (NS_FAILED(rv)) { + return rv; + } + + MutexAutoLock lock(mMutex); + MOZ_ASSERT(!mCancelable); + mCancelable = std::move(cancelable); + return rv; +} + +nsresult WebSocketChannel::ApplyForAdmission() { + LOG(("WebSocketChannel::ApplyForAdmission() %p\n", this)); + + // Websockets has a policy of 1 session at a time being allowed in the + // CONNECTING state per server IP address (not hostname) + + // Check to see if a proxy is being used before making DNS call + nsCOMPtr<nsIProtocolProxyService> pps = + do_GetService(NS_PROTOCOLPROXYSERVICE_CONTRACTID); + + if (!pps) { + // go straight to DNS + // expect the callback in ::OnLookupComplete + LOG(( + "WebSocketChannel::ApplyForAdmission: checking for concurrent open\n")); + return DoAdmissionDNS(); + } + + nsresult rv; + nsCOMPtr<nsICancelable> cancelable; + rv = pps->AsyncResolve( + mHttpChannel, + nsIProtocolProxyService::RESOLVE_PREFER_SOCKS_PROXY | + nsIProtocolProxyService::RESOLVE_PREFER_HTTPS_PROXY | + nsIProtocolProxyService::RESOLVE_ALWAYS_TUNNEL, + this, nullptr, getter_AddRefs(cancelable)); + + MutexAutoLock lock(mMutex); + MOZ_ASSERT(!mCancelable); + mCancelable = std::move(cancelable); + return rv; +} + +// Called after both OnStartRequest and OnTransportAvailable have +// executed. This essentially ends the handshake and starts the websockets +// protocol state machine. +nsresult WebSocketChannel::CallStartWebsocketData() { + LOG(("WebSocketChannel::CallStartWebsocketData() %p", this)); + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + + if (mOpenTimer) { + mOpenTimer->Cancel(); + mOpenTimer = nullptr; + } + + nsCOMPtr<nsIEventTarget> target = GetTargetThread(); + if (target && !target->IsOnCurrentThread()) { + return target->Dispatch( + NewRunnableMethod("net::WebSocketChannel::StartWebsocketData", this, + &WebSocketChannel::StartWebsocketData), + NS_DISPATCH_NORMAL); + } + + return StartWebsocketData(); +} + +nsresult WebSocketChannel::StartWebsocketData() { + { + MutexAutoLock lock(mMutex); + LOG(("WebSocketChannel::StartWebsocketData() %p", this)); + MOZ_ASSERT(!mDataStarted, "StartWebsocketData twice"); + + if (mStopped) { + LOG( + ("WebSocketChannel::StartWebsocketData channel already closed, not " + "starting data")); + return NS_ERROR_NOT_AVAILABLE; + } + } + + RefPtr<WebSocketChannel> self = this; + mIOThread->Dispatch(NS_NewRunnableFunction( + "WebSocketChannel::StartWebsocketData", [self{std::move(self)}] { + LOG(("WebSocketChannel::DoStartWebsocketData() %p", self.get())); + + NS_DispatchToMainThread( + NewRunnableMethod("net::WebSocketChannel::NotifyOnStart", self, + &WebSocketChannel::NotifyOnStart), + NS_DISPATCH_NORMAL); + + nsresult rv = self->mConnection ? self->mConnection->StartReading() + : self->mSocketIn->AsyncWait( + self, 0, 0, self->mIOThread); + if (NS_FAILED(rv)) { + self->AbortSession(rv); + } + + if (self->mPingInterval) { + rv = self->StartPinging(); + if (NS_FAILED(rv)) { + LOG(( + "WebSocketChannel::StartWebsocketData Could not start pinging, " + "rv=0x%08" PRIx32, + static_cast<uint32_t>(rv))); + self->AbortSession(rv); + } + } + })); + + return NS_OK; +} + +void WebSocketChannel::NotifyOnStart() { + LOG(("WebSocketChannel::NotifyOnStart Notifying Listener %p", + mListenerMT ? mListenerMT->mListener.get() : nullptr)); + mDataStarted = true; + if (mListenerMT) { + nsresult rv = mListenerMT->mListener->OnStart(mListenerMT->mContext); + if (NS_FAILED(rv)) { + LOG( + ("WebSocketChannel::NotifyOnStart " + "mListenerMT->mListener->OnStart() failed with error 0x%08" PRIx32, + static_cast<uint32_t>(rv))); + } + } +} + +nsresult WebSocketChannel::StartPinging() { + LOG(("WebSocketChannel::StartPinging() %p", this)); + MOZ_ASSERT(mIOThread->IsOnCurrentThread(), "not on right thread"); + MOZ_ASSERT(mPingInterval); + MOZ_ASSERT(!mPingTimer); + + nsresult rv; + rv = NS_NewTimerWithCallback(getter_AddRefs(mPingTimer), this, mPingInterval, + nsITimer::TYPE_ONE_SHOT); + if (NS_SUCCEEDED(rv)) { + LOG(("WebSocketChannel will generate ping after %d ms of receive silence\n", + (uint32_t)mPingInterval)); + } else { + NS_WARNING("unable to create ping timer. Carrying on."); + } + + return NS_OK; +} + +void WebSocketChannel::ReportConnectionTelemetry(nsresult aStatusCode) { + // 3 bits are used. high bit is for wss, middle bit for failed, + // and low bit for proxy.. + // 0 - 7 : ws-ok-plain, ws-ok-proxy, ws-failed-plain, ws-failed-proxy, + // wss-ok-plain, wss-ok-proxy, wss-failed-plain, wss-failed-proxy + + bool didProxy = false; + + nsCOMPtr<nsIProxyInfo> pi; + nsCOMPtr<nsIProxiedChannel> pc = do_QueryInterface(mChannel); + if (pc) pc->GetProxyInfo(getter_AddRefs(pi)); + if (pi) { + nsAutoCString proxyType; + pi->GetType(proxyType); + if (!proxyType.IsEmpty() && !proxyType.EqualsLiteral("direct")) { + didProxy = true; + } + } + + uint8_t value = + (mEncrypted ? (1 << 2) : 0) | + (!(mGotUpgradeOK && NS_SUCCEEDED(aStatusCode)) ? (1 << 1) : 0) | + (didProxy ? (1 << 0) : 0); + + LOG(("WebSocketChannel::ReportConnectionTelemetry() %p %d", this, value)); + Telemetry::Accumulate(Telemetry::WEBSOCKETS_HANDSHAKE_TYPE, value); +} + +// nsIDNSListener + +NS_IMETHODIMP +WebSocketChannel::OnLookupComplete(nsICancelable* aRequest, + nsIDNSRecord* aRecord, nsresult aStatus) { + LOG(("WebSocketChannel::OnLookupComplete() %p [%p %p %" PRIx32 "]\n", this, + aRequest, aRecord, static_cast<uint32_t>(aStatus))); + + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + + { + MutexAutoLock lock(mMutex); + mCancelable = nullptr; + } + + if (mStopped) { + LOG(("WebSocketChannel::OnLookupComplete: Request Already Stopped\n")); + return NS_OK; + } + + // These failures are not fatal - we just use the hostname as the key + if (NS_FAILED(aStatus)) { + LOG(("WebSocketChannel::OnLookupComplete: No DNS Response\n")); + + // set host in case we got here without calling DoAdmissionDNS() + mURI->GetHost(mAddress); + } else { + nsCOMPtr<nsIDNSAddrRecord> record = do_QueryInterface(aRecord); + MOZ_ASSERT(record); + nsresult rv = record->GetNextAddrAsString(mAddress); + if (NS_FAILED(rv)) { + LOG(("WebSocketChannel::OnLookupComplete: Failed GetNextAddr\n")); + } + } + + LOG(("WebSocket OnLookupComplete: Proceeding to ConditionallyConnect\n")); + nsWSAdmissionManager::ConditionallyConnect(this); + + return NS_OK; +} + +// nsIProtocolProxyCallback +NS_IMETHODIMP +WebSocketChannel::OnProxyAvailable(nsICancelable* aRequest, + nsIChannel* aChannel, nsIProxyInfo* pi, + nsresult status) { + { + MutexAutoLock lock(mMutex); + MOZ_ASSERT(!mCancelable || (aRequest == mCancelable)); + mCancelable = nullptr; + } + + if (mStopped) { + LOG(("WebSocketChannel::OnProxyAvailable: [%p] Request Already Stopped\n", + this)); + return NS_OK; + } + + nsAutoCString type; + if (NS_SUCCEEDED(status) && pi && NS_SUCCEEDED(pi->GetType(type)) && + !type.EqualsLiteral("direct")) { + LOG(("WebSocket OnProxyAvailable [%p] Proxy found skip DNS lookup\n", + this)); + // call DNS callback directly without DNS resolver + OnLookupComplete(nullptr, nullptr, NS_ERROR_FAILURE); + } else { + LOG(("WebSocketChannel::OnProxyAvailable[%p] checking DNS resolution\n", + this)); + nsresult rv = DoAdmissionDNS(); + if (NS_FAILED(rv)) { + LOG(("WebSocket OnProxyAvailable [%p] DNS lookup failed\n", this)); + // call DNS callback directly without DNS resolver + OnLookupComplete(nullptr, nullptr, NS_ERROR_FAILURE); + } + } + + // notify listener of OnProxyAvailable + LOG(("WebSocketChannel::OnProxyAvailable Notifying Listener %p", + mListenerMT ? mListenerMT->mListener.get() : nullptr)); + nsresult rv; + nsCOMPtr<nsIProtocolProxyCallback> ppc( + do_QueryInterface(mListenerMT->mListener, &rv)); + if (NS_SUCCEEDED(rv)) { + rv = ppc->OnProxyAvailable(aRequest, aChannel, pi, status); + if (NS_FAILED(rv)) { + LOG( + ("WebSocketChannel::OnProxyAvailable notify" + " failed with error 0x%08" PRIx32, + static_cast<uint32_t>(rv))); + } + } + + return NS_OK; +} + +// nsIInterfaceRequestor + +NS_IMETHODIMP +WebSocketChannel::GetInterface(const nsIID& iid, void** result) { + LOG(("WebSocketChannel::GetInterface() %p\n", this)); + + if (iid.Equals(NS_GET_IID(nsIChannelEventSink))) { + return QueryInterface(iid, result); + } + + if (mCallbacks) return mCallbacks->GetInterface(iid, result); + + return NS_ERROR_NO_INTERFACE; +} + +// nsIChannelEventSink + +NS_IMETHODIMP +WebSocketChannel::AsyncOnChannelRedirect( + nsIChannel* oldChannel, nsIChannel* newChannel, uint32_t flags, + nsIAsyncVerifyRedirectCallback* callback) { + LOG(("WebSocketChannel::AsyncOnChannelRedirect() %p\n", this)); + + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + + nsresult rv; + + nsCOMPtr<nsIURI> newuri; + rv = newChannel->GetURI(getter_AddRefs(newuri)); + NS_ENSURE_SUCCESS(rv, rv); + + // newuri is expected to be http or https + bool newuriIsHttps = newuri->SchemeIs("https"); + + // allow insecure->secure redirects for HTTP Strict Transport Security (from + // ws://FOO to https://FOO (mapped to wss://FOO) + if (!(flags & (nsIChannelEventSink::REDIRECT_INTERNAL | + nsIChannelEventSink::REDIRECT_STS_UPGRADE))) { + nsAutoCString newSpec; + rv = newuri->GetSpec(newSpec); + NS_ENSURE_SUCCESS(rv, rv); + + LOG(("WebSocketChannel: Redirect to %s denied by configuration\n", + newSpec.get())); + return NS_ERROR_FAILURE; + } + + if (mEncrypted && !newuriIsHttps) { + nsAutoCString spec; + if (NS_SUCCEEDED(newuri->GetSpec(spec))) { + LOG(("WebSocketChannel: Redirect to %s violates encryption rule\n", + spec.get())); + } + return NS_ERROR_FAILURE; + } + + nsCOMPtr<nsIHttpChannel> newHttpChannel = do_QueryInterface(newChannel, &rv); + if (NS_FAILED(rv)) { + LOG(("WebSocketChannel: Redirect could not QI to HTTP\n")); + return rv; + } + + nsCOMPtr<nsIHttpChannelInternal> newUpgradeChannel = + do_QueryInterface(newChannel, &rv); + + if (NS_FAILED(rv)) { + LOG(("WebSocketChannel: Redirect could not QI to HTTP Upgrade\n")); + return rv; + } + + // The redirect is likely OK + + newChannel->SetNotificationCallbacks(this); + + mEncrypted = newuriIsHttps; + rv = NS_MutateURI(newuri) + .SetScheme(mEncrypted ? "wss"_ns : "ws"_ns) + .Finalize(mURI); + + if (NS_FAILED(rv)) { + LOG(("WebSocketChannel: Could not set the proper scheme\n")); + return rv; + } + + mHttpChannel = newHttpChannel; + mChannel = newUpgradeChannel; + rv = SetupRequest(); + if (NS_FAILED(rv)) { + LOG(("WebSocketChannel: Redirect could not SetupRequest()\n")); + return rv; + } + + // Redirected-to URI may need to be delayed by 1-connecting-per-host and + // delay-after-fail algorithms. So hold off calling OnRedirectVerifyCallback + // until BeginOpen, when we know it's OK to proceed with new channel. + mRedirectCallback = callback; + + // Mark old channel as successfully connected so we'll clear any FailDelay + // associated with the old URI. Note: no need to also call OnStopSession: + // it's a no-op for successful, already-connected channels. + nsWSAdmissionManager::OnConnected(this); + + // ApplyForAdmission as if we were starting from fresh... + mAddress.Truncate(); + mOpenedHttpChannel = false; + rv = ApplyForAdmission(); + if (NS_FAILED(rv)) { + LOG(("WebSocketChannel: Redirect failed due to DNS failure\n")); + mRedirectCallback = nullptr; + return rv; + } + + return NS_OK; +} + +// nsITimerCallback + +NS_IMETHODIMP +WebSocketChannel::Notify(nsITimer* timer) { + LOG(("WebSocketChannel::Notify() %p [%p]\n", this, timer)); + + if (timer == mCloseTimer) { + MOZ_ASSERT(mClientClosed, "Close Timeout without local close"); + MOZ_ASSERT(mIOThread->IsOnCurrentThread(), "not on right thread"); + + mCloseTimer = nullptr; + if (mStopped || mServerClosed) { /* no longer relevant */ + return NS_OK; + } + + LOG(("WebSocketChannel:: Expecting Server Close - Timed Out\n")); + AbortSession(NS_ERROR_NET_TIMEOUT_EXTERNAL); + } else if (timer == mOpenTimer) { + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + + mOpenTimer = nullptr; + LOG(("WebSocketChannel:: Connection Timed Out\n")); + if (mStopped || mServerClosed) { /* no longer relevant */ + return NS_OK; + } + + AbortSession(NS_ERROR_NET_TIMEOUT_EXTERNAL); + MOZ_PUSH_IGNORE_THREAD_SAFETY + // mReconnectDelayTimer is only modified on MainThread, we can read it + // without a lock, but ONLY if we're on MainThread! And if we're not + // on MainThread, it can't be mReconnectDelayTimer + } else if (NS_IsMainThread() && timer == mReconnectDelayTimer) { + MOZ_POP_THREAD_SAFETY + MOZ_ASSERT(mConnecting == CONNECTING_DELAYED, + "woke up from delay w/o being delayed?"); + + { + MutexAutoLock lock(mMutex); + mReconnectDelayTimer = nullptr; + } + LOG(("WebSocketChannel: connecting [this=%p] after reconnect delay", this)); + BeginOpen(false); + } else if (timer == mPingTimer) { + MOZ_ASSERT(mIOThread->IsOnCurrentThread(), "not on right thread"); + + if (mClientClosed || mServerClosed || mRequestedClose) { + // no point in worrying about ping now + mPingTimer = nullptr; + return NS_OK; + } + + if (!mPingOutstanding) { + // Ping interval must be non-null or PING was forced by OnNetworkChanged() + MOZ_ASSERT(mPingInterval || mPingForced); + LOG(("nsWebSocketChannel:: Generating Ping\n")); + mPingOutstanding = 1; + mPingForced = false; + mPingTimer->InitWithCallback(this, mPingResponseTimeout, + nsITimer::TYPE_ONE_SHOT); + GeneratePing(); + } else { + LOG(("nsWebSocketChannel:: Timed out Ping\n")); + mPingTimer = nullptr; + AbortSession(NS_ERROR_NET_TIMEOUT_EXTERNAL); + } + } else if (timer == mLingeringCloseTimer) { + LOG(("WebSocketChannel:: Lingering Close Timer")); + CleanupConnection(); + } else { + MOZ_ASSERT(0, "Unknown Timer"); + } + + return NS_OK; +} + +// nsINamed + +NS_IMETHODIMP +WebSocketChannel::GetName(nsACString& aName) { + aName.AssignLiteral("WebSocketChannel"); + return NS_OK; +} + +// nsIWebSocketChannel + +NS_IMETHODIMP +WebSocketChannel::GetSecurityInfo(nsITransportSecurityInfo** aSecurityInfo) { + LOG(("WebSocketChannel::GetSecurityInfo() %p\n", this)); + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + + *aSecurityInfo = nullptr; + + if (mConnection) { + nsresult rv = mConnection->GetSecurityInfo(aSecurityInfo); + if (NS_FAILED(rv)) { + return rv; + } + return NS_OK; + } + + if (mTransport) { + nsCOMPtr<nsITLSSocketControl> tlsSocketControl; + nsresult rv = + mTransport->GetTlsSocketControl(getter_AddRefs(tlsSocketControl)); + if (NS_FAILED(rv)) { + return rv; + } + nsCOMPtr<nsITransportSecurityInfo> securityInfo( + do_QueryInterface(tlsSocketControl)); + if (securityInfo) { + securityInfo.forget(aSecurityInfo); + } + } + return NS_OK; +} + +NS_IMETHODIMP +WebSocketChannel::AsyncOpen(nsIURI* aURI, const nsACString& aOrigin, + JS::Handle<JS::Value> aOriginAttributes, + uint64_t aInnerWindowID, + nsIWebSocketListener* aListener, + nsISupports* aContext, JSContext* aCx) { + OriginAttributes attrs; + if (!aOriginAttributes.isObject() || !attrs.Init(aCx, aOriginAttributes)) { + return NS_ERROR_INVALID_ARG; + } + return AsyncOpenNative(aURI, aOrigin, attrs, aInnerWindowID, aListener, + aContext); +} + +NS_IMETHODIMP +WebSocketChannel::AsyncOpenNative(nsIURI* aURI, const nsACString& aOrigin, + const OriginAttributes& aOriginAttributes, + uint64_t aInnerWindowID, + nsIWebSocketListener* aListener, + nsISupports* aContext) { + LOG(("WebSocketChannel::AsyncOpen() %p\n", this)); + + aOriginAttributes.CreateSuffix(mOriginSuffix); + + if (!NS_IsMainThread()) { + MOZ_ASSERT(false, "not main thread"); + LOG(("WebSocketChannel::AsyncOpen() called off the main thread")); + return NS_ERROR_UNEXPECTED; + } + + if ((!aURI && !mIsServerSide) || !aListener) { + LOG(("WebSocketChannel::AsyncOpen() Uri or Listener null")); + return NS_ERROR_UNEXPECTED; + } + + if (mListenerMT || mWasOpened) return NS_ERROR_ALREADY_OPENED; + + nsresult rv; + + // Ensure target thread is set if RetargetDeliveryTo isn't called + { + auto lock = mTargetThread.Lock(); + if (!lock.ref()) { + lock.ref() = GetMainThreadSerialEventTarget(); + } + } + + mIOThread = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv); + if (NS_FAILED(rv)) { + NS_WARNING("unable to continue without socket transport service"); + return rv; + } + + nsCOMPtr<nsIPrefBranch> prefService; + prefService = do_GetService(NS_PREFSERVICE_CONTRACTID); + + if (prefService) { + int32_t intpref; + bool boolpref; + rv = + prefService->GetIntPref("network.websocket.max-message-size", &intpref); + if (NS_SUCCEEDED(rv)) { + mMaxMessageSize = clamped(intpref, 1024, INT32_MAX); + } + rv = prefService->GetIntPref("network.websocket.timeout.close", &intpref); + if (NS_SUCCEEDED(rv)) { + mCloseTimeout = clamped(intpref, 1, 1800) * 1000; + } + rv = prefService->GetIntPref("network.websocket.timeout.open", &intpref); + if (NS_SUCCEEDED(rv)) { + mOpenTimeout = clamped(intpref, 1, 1800) * 1000; + } + rv = prefService->GetIntPref("network.websocket.timeout.ping.request", + &intpref); + if (NS_SUCCEEDED(rv) && !mClientSetPingInterval) { + mPingInterval = clamped(intpref, 0, 86400) * 1000; + } + rv = prefService->GetIntPref("network.websocket.timeout.ping.response", + &intpref); + if (NS_SUCCEEDED(rv) && !mClientSetPingTimeout) { + mPingResponseTimeout = clamped(intpref, 1, 3600) * 1000; + } + rv = prefService->GetBoolPref( + "network.websocket.extensions.permessage-deflate", &boolpref); + if (NS_SUCCEEDED(rv)) { + mAllowPMCE = boolpref ? 1 : 0; + } + rv = prefService->GetIntPref("network.websocket.max-connections", &intpref); + if (NS_SUCCEEDED(rv)) { + mMaxConcurrentConnections = clamped(intpref, 1, 0xffff); + } + } + + int32_t sessionCount = -1; + nsWSAdmissionManager::GetSessionCount(sessionCount); + if (sessionCount >= 0) { + LOG(("WebSocketChannel::AsyncOpen %p sessionCount=%d max=%d\n", this, + sessionCount, mMaxConcurrentConnections)); + } + + if (sessionCount >= mMaxConcurrentConnections) { + LOG(("WebSocketChannel: max concurrency %d exceeded (%d)", + mMaxConcurrentConnections, sessionCount)); + + // WebSocket connections are expected to be long lived, so return + // an error here instead of queueing + return NS_ERROR_SOCKET_CREATE_FAILED; + } + + mInnerWindowID = aInnerWindowID; + mOriginalURI = aURI; + mURI = mOriginalURI; + mOrigin = aOrigin; + + if (mIsServerSide) { + // IncrementSessionCount(); + mWasOpened = 1; + mListenerMT = new ListenerAndContextContainer(aListener, aContext); + rv = mServerTransportProvider->SetListener(this); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + mServerTransportProvider = nullptr; + + return NS_OK; + } + + mURI->GetHostPort(mHost); + + mRandomGenerator = + do_GetService("@mozilla.org/security/random-generator;1", &rv); + if (NS_FAILED(rv)) { + NS_WARNING("unable to continue without random number generator"); + return rv; + } + + nsCOMPtr<nsIURI> localURI; + nsCOMPtr<nsIChannel> localChannel; + + rv = NS_MutateURI(mURI) + .SetScheme(mEncrypted ? "https"_ns : "http"_ns) + .Finalize(localURI); + NS_ENSURE_SUCCESS(rv, rv); + + nsCOMPtr<nsIIOService> ioService; + ioService = do_GetService(NS_IOSERVICE_CONTRACTID, &rv); + if (NS_FAILED(rv)) { + NS_WARNING("unable to continue without io service"); + return rv; + } + + // Ideally we'd call newChannelFromURIWithLoadInfo here, but that doesn't + // allow setting proxy uri/flags + rv = ioService->NewChannelFromURIWithProxyFlags( + localURI, mURI, + nsIProtocolProxyService::RESOLVE_PREFER_SOCKS_PROXY | + nsIProtocolProxyService::RESOLVE_PREFER_HTTPS_PROXY | + nsIProtocolProxyService::RESOLVE_ALWAYS_TUNNEL, + mLoadInfo->LoadingNode(), mLoadInfo->GetLoadingPrincipal(), + mLoadInfo->TriggeringPrincipal(), mLoadInfo->GetSecurityFlags(), + mLoadInfo->InternalContentPolicyType(), getter_AddRefs(localChannel)); + NS_ENSURE_SUCCESS(rv, rv); + + // Please note that we still call SetLoadInfo on the channel because + // we want the same instance of the loadInfo to be set on the channel. + rv = localChannel->SetLoadInfo(mLoadInfo); + NS_ENSURE_SUCCESS(rv, rv); + + // Pass most GetInterface() requests through to our instantiator, but handle + // nsIChannelEventSink in this object in order to deal with redirects + localChannel->SetNotificationCallbacks(this); + + class MOZ_STACK_CLASS CleanUpOnFailure { + public: + explicit CleanUpOnFailure(WebSocketChannel* aWebSocketChannel) + : mWebSocketChannel(aWebSocketChannel) {} + + ~CleanUpOnFailure() { + if (!mWebSocketChannel->mWasOpened) { + mWebSocketChannel->mChannel = nullptr; + mWebSocketChannel->mHttpChannel = nullptr; + } + } + + WebSocketChannel* mWebSocketChannel; + }; + + CleanUpOnFailure cuof(this); + + mChannel = do_QueryInterface(localChannel, &rv); + NS_ENSURE_SUCCESS(rv, rv); + + mHttpChannel = do_QueryInterface(localChannel, &rv); + NS_ENSURE_SUCCESS(rv, rv); + + rv = SetupRequest(); + if (NS_FAILED(rv)) return rv; + + mPrivateBrowsing = NS_UsePrivateBrowsing(localChannel); + + if (mConnectionLogService && !mPrivateBrowsing) { + mConnectionLogService->AddHost(mHost, mSerial, + BaseWebSocketChannel::mEncrypted); + } + + rv = ApplyForAdmission(); + if (NS_FAILED(rv)) return rv; + + // Register for prefs change notifications + nsCOMPtr<nsIObserverService> observerService = + mozilla::services::GetObserverService(); + if (!observerService) { + NS_WARNING("failed to get observer service"); + return NS_ERROR_FAILURE; + } + + rv = observerService->AddObserver(this, NS_NETWORK_LINK_TOPIC, false); + if (NS_WARN_IF(NS_FAILED(rv))) { + return rv; + } + + // Only set these if the open was successful: + // + mWasOpened = 1; + mListenerMT = new ListenerAndContextContainer(aListener, aContext); + IncrementSessionCount(); + + return rv; +} + +NS_IMETHODIMP +WebSocketChannel::Close(uint16_t code, const nsACString& reason) { + LOG(("WebSocketChannel::Close() %p\n", this)); + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + + { + MutexAutoLock lock(mMutex); + + if (mRequestedClose) { + return NS_OK; + } + + if (mStopped) { + return NS_ERROR_NOT_AVAILABLE; + } + + // The API requires the UTF-8 string to be 123 or less bytes + if (reason.Length() > 123) return NS_ERROR_ILLEGAL_VALUE; + + mRequestedClose = true; + mScriptCloseReason = reason; + mScriptCloseCode = code; + + if (mDataStarted) { + return mIOThread->Dispatch( + new OutboundEnqueuer(this, + new OutboundMessage(kMsgTypeFin, VoidCString())), + nsIEventTarget::DISPATCH_NORMAL); + } + + mStopped = true; + } + + nsresult rv; + if (code == CLOSE_GOING_AWAY) { + // Not an error: for example, tab has closed or navigated away + LOG(("WebSocketChannel::Close() GOING_AWAY without transport.")); + rv = NS_OK; + } else { + LOG(("WebSocketChannel::Close() without transport - error.")); + rv = NS_ERROR_NOT_CONNECTED; + } + + DoStopSession(rv); + return rv; +} + +NS_IMETHODIMP +WebSocketChannel::SendMsg(const nsACString& aMsg) { + LOG(("WebSocketChannel::SendMsg() %p\n", this)); + + return SendMsgCommon(aMsg, false, aMsg.Length()); +} + +NS_IMETHODIMP +WebSocketChannel::SendBinaryMsg(const nsACString& aMsg) { + LOG(("WebSocketChannel::SendBinaryMsg() %p len=%zu\n", this, aMsg.Length())); + return SendMsgCommon(aMsg, true, aMsg.Length()); +} + +NS_IMETHODIMP +WebSocketChannel::SendBinaryStream(nsIInputStream* aStream, uint32_t aLength) { + LOG(("WebSocketChannel::SendBinaryStream() %p\n", this)); + + return SendMsgCommon(VoidCString(), true, aLength, aStream); +} + +nsresult WebSocketChannel::SendMsgCommon(const nsACString& aMsg, bool aIsBinary, + uint32_t aLength, + nsIInputStream* aStream) { + MOZ_ASSERT(IsOnTargetThread(), "not target thread"); + + if (!mDataStarted) { + LOG(("WebSocketChannel:: Error: data not started yet\n")); + return NS_ERROR_UNEXPECTED; + } + + if (mRequestedClose) { + LOG(("WebSocketChannel:: Error: send when closed\n")); + return NS_ERROR_UNEXPECTED; + } + + if (mStopped) { + LOG(("WebSocketChannel:: Error: send when stopped\n")); + return NS_ERROR_NOT_CONNECTED; + } + + MOZ_ASSERT(mMaxMessageSize >= 0, "max message size negative"); + if (aLength > static_cast<uint32_t>(mMaxMessageSize)) { + LOG(("WebSocketChannel:: Error: message too big\n")); + return NS_ERROR_FILE_TOO_BIG; + } + + if (mConnectionLogService && !mPrivateBrowsing) { + mConnectionLogService->NewMsgSent(mHost, mSerial, aLength); + LOG(("Added new msg sent for %s", mHost.get())); + } + + return mIOThread->Dispatch( + aStream + ? new OutboundEnqueuer(this, new OutboundMessage(aStream, aLength)) + : new OutboundEnqueuer( + this, + new OutboundMessage( + aIsBinary ? kMsgTypeBinaryString : kMsgTypeString, aMsg)), + nsIEventTarget::DISPATCH_NORMAL); +} + +// nsIHttpUpgradeListener + +NS_IMETHODIMP +WebSocketChannel::OnTransportAvailable(nsISocketTransport* aTransport, + nsIAsyncInputStream* aSocketIn, + nsIAsyncOutputStream* aSocketOut) { + if (!NS_IsMainThread()) { + return NS_DispatchToMainThread( + new CallOnTransportAvailable(this, aTransport, aSocketIn, aSocketOut)); + } + + LOG(("WebSocketChannel::OnTransportAvailable %p [%p %p %p] rcvdonstart=%d\n", + this, aTransport, aSocketIn, aSocketOut, mGotUpgradeOK)); + + if (mStopped) { + LOG(("WebSocketChannel::OnTransportAvailable: Already stopped")); + return NS_OK; + } + + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + MOZ_ASSERT(!mRecvdHttpUpgradeTransport, "OTA duplicated"); + MOZ_ASSERT(aSocketIn, "OTA with invalid socketIn"); + + mTransport = aTransport; + mSocketIn = aSocketIn; + mSocketOut = aSocketOut; + + nsresult rv; + rv = mTransport->SetEventSink(nullptr, nullptr); + if (NS_FAILED(rv)) return rv; + rv = mTransport->SetSecurityCallbacks(this); + if (NS_FAILED(rv)) return rv; + + return OnTransportAvailableInternal(); +} + +NS_IMETHODIMP +WebSocketChannel::OnWebSocketConnectionAvailable( + WebSocketConnectionBase* aConnection) { + if (!NS_IsMainThread()) { + RefPtr<WebSocketChannel> self = this; + RefPtr<WebSocketConnectionBase> connection = aConnection; + return NS_DispatchToMainThread(NS_NewRunnableFunction( + "WebSocketChannel::OnWebSocketConnectionAvailable", + [self, connection]() { + self->OnWebSocketConnectionAvailable(connection); + })); + } + + LOG( + ("WebSocketChannel::OnWebSocketConnectionAvailable %p [%p] " + "rcvdonstart=%d\n", + this, aConnection, mGotUpgradeOK)); + + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + MOZ_ASSERT(!mRecvdHttpUpgradeTransport, + "OnWebSocketConnectionAvailable duplicated"); + MOZ_ASSERT(aConnection); + + if (mStopped) { + LOG(("WebSocketChannel::OnWebSocketConnectionAvailable: Already stopped")); + aConnection->Close(); + return NS_OK; + } + + nsresult rv = aConnection->Init(this); + if (NS_FAILED(rv)) { + return rv; + } + + mConnection = aConnection; + // Note: mIOThread will be IPDL background thread. + mConnection->GetIoTarget(getter_AddRefs(mIOThread)); + return OnTransportAvailableInternal(); +} + +nsresult WebSocketChannel::OnTransportAvailableInternal() { + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + MOZ_ASSERT(!mRecvdHttpUpgradeTransport, + "OnWebSocketConnectionAvailable duplicated"); + MOZ_ASSERT(mSocketIn || mConnection); + + mRecvdHttpUpgradeTransport = 1; + if (mGotUpgradeOK) { + // We're now done CONNECTING, which means we can now open another, + // perhaps parallel, connection to the same host if one + // is pending + nsWSAdmissionManager::OnConnected(this); + + return CallStartWebsocketData(); + } + + if (mIsServerSide) { + if (!mNegotiatedExtensions.IsEmpty()) { + bool clientNoContextTakeover; + bool serverNoContextTakeover; + int32_t clientMaxWindowBits; + int32_t serverMaxWindowBits; + + nsresult rv = ParseWebSocketExtension( + mNegotiatedExtensions, eParseServerSide, clientNoContextTakeover, + serverNoContextTakeover, clientMaxWindowBits, serverMaxWindowBits); + MOZ_RELEASE_ASSERT(NS_SUCCEEDED(rv), "illegal value provided by server"); + + if (clientMaxWindowBits == -1) { + clientMaxWindowBits = 15; + } + if (serverMaxWindowBits == -1) { + serverMaxWindowBits = 15; + } + + MutexAutoLock lock(mCompressorMutex); + mPMCECompressor = MakeUnique<PMCECompression>( + serverNoContextTakeover, serverMaxWindowBits, clientMaxWindowBits); + if (mPMCECompressor->Active()) { + LOG( + ("WebSocketChannel::OnTransportAvailable: PMCE negotiated, %susing " + "context takeover, serverMaxWindowBits=%d, " + "clientMaxWindowBits=%d\n", + serverNoContextTakeover ? "NOT " : "", serverMaxWindowBits, + clientMaxWindowBits)); + + mNegotiatedExtensions = "permessage-deflate"; + } else { + LOG( + ("WebSocketChannel::OnTransportAvailable: Cannot init PMCE " + "compression object\n")); + mPMCECompressor = nullptr; + AbortSession(NS_ERROR_UNEXPECTED); + return NS_ERROR_UNEXPECTED; + } + } + + return CallStartWebsocketData(); + } + + return NS_OK; +} + +NS_IMETHODIMP +WebSocketChannel::OnUpgradeFailed(nsresult aErrorCode) { + // When socket process is enabled, this could be called on background thread. + + LOG(("WebSocketChannel::OnUpgradeFailed() %p [aErrorCode %" PRIx32 "]", this, + static_cast<uint32_t>(aErrorCode))); + + if (mStopped) { + LOG(("WebSocketChannel::OnUpgradeFailed: Already stopped")); + return NS_OK; + } + + MOZ_ASSERT(!mRecvdHttpUpgradeTransport, "OTA already called"); + + AbortSession(aErrorCode); + return NS_OK; +} + +// nsIRequestObserver (from nsIStreamListener) + +NS_IMETHODIMP +WebSocketChannel::OnStartRequest(nsIRequest* aRequest) { + LOG(("WebSocketChannel::OnStartRequest(): %p [%p %p] recvdhttpupgrade=%d\n", + this, aRequest, mHttpChannel.get(), mRecvdHttpUpgradeTransport)); + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + MOZ_ASSERT(!mGotUpgradeOK, "OTA duplicated"); + + if (mStopped) { + LOG(("WebSocketChannel::OnStartRequest: Channel Already Done\n")); + AbortSession(NS_ERROR_WEBSOCKET_CONNECTION_REFUSED); + return NS_ERROR_WEBSOCKET_CONNECTION_REFUSED; + } + + nsresult rv; + uint32_t status; + char *val, *token; + + rv = mHttpChannel->GetResponseStatus(&status); + if (NS_FAILED(rv)) { + nsresult httpStatus; + rv = NS_ERROR_WEBSOCKET_CONNECTION_REFUSED; + + // If we failed to connect due to unsuccessful TLS handshake, we must + // propagate a specific error to mozilla::dom::WebSocketImpl so it can set + // status code to 1015. Otherwise return + // NS_ERROR_WEBSOCKET_CONNECTION_REFUSED. + if (NS_SUCCEEDED(mHttpChannel->GetStatus(&httpStatus))) { + uint32_t errorClass; + nsCOMPtr<nsINSSErrorsService> errSvc = + do_GetService("@mozilla.org/nss_errors_service;1"); + // If GetErrorClass succeeds httpStatus is TLS related failure. + if (errSvc && + NS_SUCCEEDED(errSvc->GetErrorClass(httpStatus, &errorClass))) { + rv = NS_ERROR_NET_INADEQUATE_SECURITY; + } + } + + LOG(("WebSocketChannel::OnStartRequest: No HTTP Response\n")); + AbortSession(rv); + return rv; + } + + LOG(("WebSocketChannel::OnStartRequest: HTTP status %d\n", status)); + nsCOMPtr<nsIHttpChannelInternal> internalChannel = + do_QueryInterface(mHttpChannel); + uint32_t versionMajor, versionMinor; + rv = internalChannel->GetResponseVersion(&versionMajor, &versionMinor); + if (NS_FAILED(rv) || + !((versionMajor == 1 && versionMinor != 0) || versionMajor == 2) || + (versionMajor == 1 && status != 101) || + (versionMajor == 2 && status != 200)) { + AbortSession(NS_ERROR_WEBSOCKET_CONNECTION_REFUSED); + return NS_ERROR_WEBSOCKET_CONNECTION_REFUSED; + } + + if (versionMajor == 1) { + // These are only present on http/1.x websocket upgrades + nsAutoCString respUpgrade; + rv = mHttpChannel->GetResponseHeader("Upgrade"_ns, respUpgrade); + + if (NS_SUCCEEDED(rv)) { + rv = NS_ERROR_ILLEGAL_VALUE; + if (!respUpgrade.IsEmpty()) { + val = respUpgrade.BeginWriting(); + while ((token = nsCRT::strtok(val, ", \t", &val))) { + if (nsCRT::strcasecmp(token, "Websocket") == 0) { + rv = NS_OK; + break; + } + } + } + } + + if (NS_FAILED(rv)) { + LOG( + ("WebSocketChannel::OnStartRequest: " + "HTTP response header Upgrade: websocket not found\n")); + AbortSession(NS_ERROR_ILLEGAL_VALUE); + return rv; + } + + nsAutoCString respConnection; + rv = mHttpChannel->GetResponseHeader("Connection"_ns, respConnection); + + if (NS_SUCCEEDED(rv)) { + rv = NS_ERROR_ILLEGAL_VALUE; + if (!respConnection.IsEmpty()) { + val = respConnection.BeginWriting(); + while ((token = nsCRT::strtok(val, ", \t", &val))) { + if (nsCRT::strcasecmp(token, "Upgrade") == 0) { + rv = NS_OK; + break; + } + } + } + } + + if (NS_FAILED(rv)) { + LOG( + ("WebSocketChannel::OnStartRequest: " + "HTTP response header 'Connection: Upgrade' not found\n")); + AbortSession(NS_ERROR_ILLEGAL_VALUE); + return rv; + } + + nsAutoCString respAccept; + rv = mHttpChannel->GetResponseHeader("Sec-WebSocket-Accept"_ns, respAccept); + + if (NS_FAILED(rv) || respAccept.IsEmpty() || + !respAccept.Equals(mHashedSecret)) { + LOG( + ("WebSocketChannel::OnStartRequest: " + "HTTP response header Sec-WebSocket-Accept check failed\n")); + LOG(("WebSocketChannel::OnStartRequest: Expected %s received %s\n", + mHashedSecret.get(), respAccept.get())); +#ifdef FUZZING + if (NS_FAILED(rv) || respAccept.IsEmpty()) { +#endif + AbortSession(NS_ERROR_ILLEGAL_VALUE); + return NS_ERROR_ILLEGAL_VALUE; +#ifdef FUZZING + } +#endif + } + } + + // If we sent a sub protocol header, verify the response matches. + // If response contains protocol that was not in request, fail. + // If response contained no protocol header, set to "" so the protocol + // attribute of the WebSocket JS object reflects that + if (!mProtocol.IsEmpty()) { + nsAutoCString respProtocol; + rv = mHttpChannel->GetResponseHeader("Sec-WebSocket-Protocol"_ns, + respProtocol); + if (NS_SUCCEEDED(rv)) { + rv = NS_ERROR_ILLEGAL_VALUE; + val = mProtocol.BeginWriting(); + while ((token = nsCRT::strtok(val, ", \t", &val))) { + if (strcmp(token, respProtocol.get()) == 0) { + rv = NS_OK; + break; + } + } + + if (NS_SUCCEEDED(rv)) { + LOG(("WebsocketChannel::OnStartRequest: subprotocol %s confirmed", + respProtocol.get())); + mProtocol = respProtocol; + } else { + LOG( + ("WebsocketChannel::OnStartRequest: " + "Server replied with non-matching subprotocol [%s]: aborting", + respProtocol.get())); + mProtocol.Truncate(); + AbortSession(NS_ERROR_ILLEGAL_VALUE); + return NS_ERROR_ILLEGAL_VALUE; + } + } else { + LOG( + ("WebsocketChannel::OnStartRequest " + "subprotocol [%s] not found - none returned", + mProtocol.get())); + mProtocol.Truncate(); + } + } + + rv = HandleExtensions(); + if (NS_FAILED(rv)) return rv; + + // Update mEffectiveURL for off main thread URI access. + nsCOMPtr<nsIURI> uri = mURI ? mURI : mOriginalURI; + nsAutoCString spec; + rv = uri->GetSpec(spec); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + CopyUTF8toUTF16(spec, mEffectiveURL); + + mGotUpgradeOK = 1; + if (mRecvdHttpUpgradeTransport) { + // We're now done CONNECTING, which means we can now open another, + // perhaps parallel, connection to the same host if one + // is pending + nsWSAdmissionManager::OnConnected(this); + + return CallStartWebsocketData(); + } + + return NS_OK; +} + +NS_IMETHODIMP +WebSocketChannel::OnStopRequest(nsIRequest* aRequest, nsresult aStatusCode) { + LOG(("WebSocketChannel::OnStopRequest() %p [%p %p %" PRIx32 "]\n", this, + aRequest, mHttpChannel.get(), static_cast<uint32_t>(aStatusCode))); + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + + // OnTransportAvailable won't be called if the request is stopped with + // an error. Abort the session now instead of waiting for timeout. + if (NS_FAILED(aStatusCode) && !mRecvdHttpUpgradeTransport) { + AbortSession(aStatusCode); + } + + ReportConnectionTelemetry(aStatusCode); + + // This is the end of the HTTP upgrade transaction, the + // upgraded streams live on + + mChannel = nullptr; + mHttpChannel = nullptr; + mLoadGroup = nullptr; + mCallbacks = nullptr; + + return NS_OK; +} + +// nsIInputStreamCallback + +NS_IMETHODIMP +WebSocketChannel::OnInputStreamReady(nsIAsyncInputStream* aStream) { + LOG(("WebSocketChannel::OnInputStreamReady() %p\n", this)); + MOZ_DIAGNOSTIC_ASSERT(mIOThread->IsOnCurrentThread(), "not on right thread"); + + if (!mSocketIn) { // did we we clean up the socket after scheduling + // InputReady? + return NS_OK; + } + + // this is after the http upgrade - so we are speaking websockets + char buffer[2048]; + uint32_t count; + nsresult rv; + + do { + rv = mSocketIn->Read((char*)buffer, sizeof(buffer), &count); + LOG(("WebSocketChannel::OnInputStreamReady: read %u rv %" PRIx32 "\n", + count, static_cast<uint32_t>(rv))); + + if (rv == NS_BASE_STREAM_WOULD_BLOCK) { + mSocketIn->AsyncWait(this, 0, 0, mIOThread); + return NS_OK; + } + + if (NS_FAILED(rv)) { + AbortSession(rv); + return rv; + } + + if (count == 0) { + AbortSession(NS_BASE_STREAM_CLOSED); + return NS_OK; + } + + if (mStopped) { + continue; + } + + rv = ProcessInput((uint8_t*)buffer, count); + if (NS_FAILED(rv)) { + AbortSession(rv); + return rv; + } + } while (NS_SUCCEEDED(rv) && mSocketIn); + + return NS_OK; +} + +// nsIOutputStreamCallback + +NS_IMETHODIMP +WebSocketChannel::OnOutputStreamReady(nsIAsyncOutputStream* aStream) { + LOG(("WebSocketChannel::OnOutputStreamReady() %p\n", this)); + MOZ_DIAGNOSTIC_ASSERT(mIOThread->IsOnCurrentThread(), "not on right thread"); + nsresult rv; + + if (!mCurrentOut) PrimeNewOutgoingMessage(); + + while (mCurrentOut && mSocketOut) { + const char* sndBuf; + uint32_t toSend; + uint32_t amtSent; + + if (mHdrOut) { + sndBuf = (const char*)mHdrOut; + toSend = mHdrOutToSend; + LOG( + ("WebSocketChannel::OnOutputStreamReady: " + "Try to send %u of hdr/copybreak\n", + toSend)); + } else { + sndBuf = (char*)mCurrentOut->BeginReading() + mCurrentOutSent; + toSend = mCurrentOut->Length() - mCurrentOutSent; + if (toSend > 0) { + LOG( + ("WebSocketChannel::OnOutputStreamReady [%p]: " + "Try to send %u of data\n", + this, toSend)); + } + } + + if (toSend == 0) { + amtSent = 0; + } else { + rv = mSocketOut->Write(sndBuf, toSend, &amtSent); + LOG(("WebSocketChannel::OnOutputStreamReady [%p]: write %u rv %" PRIx32 + "\n", + this, amtSent, static_cast<uint32_t>(rv))); + + if (rv == NS_BASE_STREAM_WOULD_BLOCK) { + mSocketOut->AsyncWait(this, 0, 0, mIOThread); + return NS_OK; + } + + if (NS_FAILED(rv)) { + AbortSession(rv); + return NS_OK; + } + } + + if (mHdrOut) { + if (amtSent == toSend) { + mHdrOut = nullptr; + mHdrOutToSend = 0; + } else { + mHdrOut += amtSent; + mHdrOutToSend -= amtSent; + mSocketOut->AsyncWait(this, 0, 0, mIOThread); + } + } else { + if (amtSent == toSend) { + if (!mStopped) { + if (nsCOMPtr<nsIEventTarget> target = GetTargetThread()) { + target->Dispatch( + new CallAcknowledge(this, mCurrentOut->OrigLength()), + NS_DISPATCH_NORMAL); + } else { + return NS_ERROR_UNEXPECTED; + } + } + DeleteCurrentOutGoingMessage(); + PrimeNewOutgoingMessage(); + } else { + mCurrentOutSent += amtSent; + mSocketOut->AsyncWait(this, 0, 0, mIOThread); + } + } + } + + if (mReleaseOnTransmit) ReleaseSession(); + return NS_OK; +} + +// nsIStreamListener + +NS_IMETHODIMP +WebSocketChannel::OnDataAvailable(nsIRequest* aRequest, + nsIInputStream* aInputStream, + uint64_t aOffset, uint32_t aCount) { + LOG(("WebSocketChannel::OnDataAvailable() %p [%p %p %p %" PRIu64 " %u]\n", + this, aRequest, mHttpChannel.get(), aInputStream, aOffset, aCount)); + + // This is the HTTP OnDataAvailable Method, which means this is http data in + // response to the upgrade request and there should be no http response body + // if the upgrade succeeded. This generally should be caught by a non 101 + // response code in OnStartRequest().. so we can ignore the data here + + LOG(("WebSocketChannel::OnDataAvailable: HTTP data unexpected len>=%u\n", + aCount)); + + return NS_OK; +} + +void WebSocketChannel::DoEnqueueOutgoingMessage() { + LOG(("WebSocketChannel::DoEnqueueOutgoingMessage() %p\n", this)); + MOZ_ASSERT(mIOThread->IsOnCurrentThread(), "not on right thread"); + + if (!mCurrentOut) { + PrimeNewOutgoingMessage(); + } + + while (mCurrentOut && mConnection) { + nsresult rv = NS_OK; + if (mCurrentOut->Length() - mCurrentOutSent == 0) { + LOG( + ("WebSocketChannel::DoEnqueueOutgoingMessage: " + "Try to send %u of hdr/copybreak\n", + mHdrOutToSend)); + rv = mConnection->WriteOutputData(mOutHeader, mHdrOutToSend, nullptr, 0); + } else { + LOG( + ("WebSocketChannel::DoEnqueueOutgoingMessage: " + "Try to send %u of hdr and %u of data\n", + mHdrOutToSend, mCurrentOut->Length())); + rv = mConnection->WriteOutputData(mOutHeader, mHdrOutToSend, + (uint8_t*)mCurrentOut->BeginReading(), + mCurrentOut->Length()); + } + + LOG(("WebSocketChannel::DoEnqueueOutgoingMessage: rv %" PRIx32 "\n", + static_cast<uint32_t>(rv))); + if (NS_FAILED(rv)) { + AbortSession(rv); + return; + } + + if (!mStopped) { + // TODO: Currently, we assume that data is completely written to the + // socket after sending it to socket process, but it's not true. The data + // could be queued in socket process and waiting for the socket to be able + // to write. We should implement flow control for this in bug 1726552. + if (nsCOMPtr<nsIEventTarget> target = GetTargetThread()) { + target->Dispatch(new CallAcknowledge(this, mCurrentOut->OrigLength()), + NS_DISPATCH_NORMAL); + } else { + AbortSession(NS_ERROR_UNEXPECTED); + return; + } + } + DeleteCurrentOutGoingMessage(); + PrimeNewOutgoingMessage(); + } + + if (mReleaseOnTransmit) { + ReleaseSession(); + } +} + +void WebSocketChannel::OnError(nsresult aStatus) { AbortSession(aStatus); } + +void WebSocketChannel::OnTCPClosed() { mTCPClosed = true; } + +nsresult WebSocketChannel::OnDataReceived(uint8_t* aData, uint32_t aCount) { + return ProcessInput(aData, aCount); +} + +} // namespace mozilla::net + +#undef CLOSE_GOING_AWAY diff --git a/netwerk/protocol/websocket/WebSocketChannel.h b/netwerk/protocol/websocket/WebSocketChannel.h new file mode 100644 index 0000000000..83b1fc6cd2 --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketChannel.h @@ -0,0 +1,377 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et tw=80 : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_net_WebSocketChannel_h +#define mozilla_net_WebSocketChannel_h + +#include "nsISupports.h" +#include "nsIInterfaceRequestor.h" +#include "nsIStreamListener.h" +#include "nsIAsyncInputStream.h" +#include "nsIAsyncOutputStream.h" +#include "nsITimer.h" +#include "nsIDNSListener.h" +#include "nsINamed.h" +#include "nsIObserver.h" +#include "nsIProtocolProxyCallback.h" +#include "nsIChannelEventSink.h" +#include "nsIHttpChannelInternal.h" +#include "mozilla/net/WebSocketConnectionListener.h" +#include "mozilla/Mutex.h" +#include "BaseWebSocketChannel.h" + +#include "nsCOMPtr.h" +#include "nsString.h" +#include "nsDeque.h" +#include "mozilla/Atomics.h" + +class nsIAsyncVerifyRedirectCallback; +class nsIDashboardEventNotifier; +class nsIEventTarget; +class nsIHttpChannel; +class nsIRandomGenerator; +class nsISocketTransport; +class nsIURI; + +namespace mozilla { +namespace net { + +class OutboundMessage; +class OutboundEnqueuer; +class nsWSAdmissionManager; +class PMCECompression; +class CallOnMessageAvailable; +class CallOnStop; +class CallOnServerClose; +class CallAcknowledge; +class WebSocketEventService; +class WebSocketConnectionBase; + +[[nodiscard]] extern nsresult CalculateWebSocketHashedSecret( + const nsACString& aKey, nsACString& aHash); +extern void ProcessServerWebSocketExtensions(const nsACString& aExtensions, + nsACString& aNegotiatedExtensions); + +// Used to enforce "1 connecting websocket per host" rule, and reconnect delays +enum wsConnectingState { + NOT_CONNECTING = 0, // Not yet (or no longer) trying to open connection + CONNECTING_QUEUED, // Waiting for other ws to same host to finish opening + CONNECTING_DELAYED, // Delayed by "reconnect after failure" algorithm + CONNECTING_IN_PROGRESS // Started connection: waiting for result +}; + +class WebSocketChannel : public BaseWebSocketChannel, + public nsIHttpUpgradeListener, + public nsIStreamListener, + public nsIInputStreamCallback, + public nsIOutputStreamCallback, + public nsITimerCallback, + public nsIDNSListener, + public nsIObserver, + public nsIProtocolProxyCallback, + public nsIInterfaceRequestor, + public nsIChannelEventSink, + public nsINamed, + public WebSocketConnectionListener { + friend class WebSocketFrame; + + public: + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSIHTTPUPGRADELISTENER + NS_DECL_NSIREQUESTOBSERVER + NS_DECL_NSISTREAMLISTENER + NS_DECL_NSIINPUTSTREAMCALLBACK + NS_DECL_NSIOUTPUTSTREAMCALLBACK + NS_DECL_NSITIMERCALLBACK + NS_DECL_NSIDNSLISTENER + NS_DECL_NSIPROTOCOLPROXYCALLBACK + NS_DECL_NSIINTERFACEREQUESTOR + NS_DECL_NSICHANNELEVENTSINK + NS_DECL_NSIOBSERVER + NS_DECL_NSINAMED + + // nsIWebSocketChannel methods BaseWebSocketChannel didn't implement for us + // + NS_IMETHOD AsyncOpen(nsIURI* aURI, const nsACString& aOrigin, + JS::Handle<JS::Value> aOriginAttributes, + uint64_t aWindowID, nsIWebSocketListener* aListener, + nsISupports* aContext, JSContext* aCx) override; + NS_IMETHOD AsyncOpenNative(nsIURI* aURI, const nsACString& aOrigin, + const OriginAttributes& aOriginAttributes, + uint64_t aWindowID, + nsIWebSocketListener* aListener, + nsISupports* aContext) override; + NS_IMETHOD Close(uint16_t aCode, const nsACString& aReason) override; + NS_IMETHOD SendMsg(const nsACString& aMsg) override; + NS_IMETHOD SendBinaryMsg(const nsACString& aMsg) override; + NS_IMETHOD SendBinaryStream(nsIInputStream* aStream, + uint32_t length) override; + NS_IMETHOD GetSecurityInfo(nsITransportSecurityInfo** aSecurityInfo) override; + + WebSocketChannel(); + static void Shutdown(); + + // Off main thread URI access. + void GetEffectiveURL(nsAString& aEffectiveURL) const override; + bool IsEncrypted() const override; + + nsresult OnTransportAvailableInternal(); + void OnError(nsresult aStatus) override; + void OnTCPClosed() override; + nsresult OnDataReceived(uint8_t* aData, uint32_t aCount) override; + + const static uint32_t kControlFrameMask = 0x8; + + // First byte of the header + const static uint8_t kFinalFragBit = 0x80; + const static uint8_t kRsvBitsMask = 0x70; + const static uint8_t kRsv1Bit = 0x40; + const static uint8_t kRsv2Bit = 0x20; + const static uint8_t kRsv3Bit = 0x10; + const static uint8_t kOpcodeBitsMask = 0x0F; + + // Second byte of the header + const static uint8_t kMaskBit = 0x80; + const static uint8_t kPayloadLengthBitsMask = 0x7F; + + protected: + ~WebSocketChannel() override; + + private: + friend class OutboundEnqueuer; + friend class nsWSAdmissionManager; + friend class FailDelayManager; + friend class CallOnMessageAvailable; + friend class CallOnStop; + friend class CallOnServerClose; + friend class CallAcknowledge; + + // Common send code for binary + text msgs + [[nodiscard]] nsresult SendMsgCommon(const nsACString& aMsg, bool isBinary, + uint32_t length, + nsIInputStream* aStream = nullptr); + + void EnqueueOutgoingMessage(nsDeque<OutboundMessage>& aQueue, + OutboundMessage* aMsg); + void DoEnqueueOutgoingMessage(); + + void PrimeNewOutgoingMessage(); + void DeleteCurrentOutGoingMessage(); + void GeneratePong(uint8_t* payload, uint32_t len); + void GeneratePing(); + + [[nodiscard]] nsresult OnNetworkChanged(); + [[nodiscard]] nsresult StartPinging(); + + void BeginOpen(bool aCalledFromAdmissionManager); + void BeginOpenInternal(); + [[nodiscard]] nsresult HandleExtensions(); + [[nodiscard]] nsresult SetupRequest(); + [[nodiscard]] nsresult ApplyForAdmission(); + [[nodiscard]] nsresult DoAdmissionDNS(); + [[nodiscard]] nsresult CallStartWebsocketData(); + [[nodiscard]] nsresult StartWebsocketData(); + uint16_t ResultToCloseCode(nsresult resultCode); + void ReportConnectionTelemetry(nsresult aStatusCode); + + void StopSession(nsresult reason); + void DoStopSession(nsresult reason); + void AbortSession(nsresult reason); + void ReleaseSession(); + void CleanupConnection(); + void IncrementSessionCount(); + void DecrementSessionCount(); + + void EnsureHdrOut(uint32_t size); + + static void ApplyMask(uint32_t mask, uint8_t* data, uint64_t len); + + bool IsPersistentFramePtr(); + [[nodiscard]] nsresult ProcessInput(uint8_t* buffer, uint32_t count); + [[nodiscard]] bool UpdateReadBuffer(uint8_t* buffer, uint32_t count, + uint32_t accumulatedFragments, + uint32_t* available); + + inline void ResetPingTimer() { + mPingOutstanding = 0; + if (mPingTimer) { + if (!mPingInterval) { + // The timer was created by forced ping and regular pinging is disabled, + // so cancel and null out mPingTimer. + mPingTimer->Cancel(); + mPingTimer = nullptr; + } else { + mPingTimer->SetDelay(mPingInterval); + } + } + } + + void NotifyOnStart(); + + nsCOMPtr<nsIEventTarget> mIOThread; + // Set in AsyncOpenNative and AsyncOnChannelRedirect, modified in + // DoStopSession on IO thread (.forget()). Probably ok... + nsCOMPtr<nsIHttpChannelInternal> mChannel; + nsCOMPtr<nsIHttpChannel> mHttpChannel; + + nsCOMPtr<nsICancelable> mCancelable MOZ_GUARDED_BY(mMutex); + // Mainthread only + nsCOMPtr<nsIAsyncVerifyRedirectCallback> mRedirectCallback; + // Set on Mainthread during AsyncOpen, used on IO thread and Mainthread + nsCOMPtr<nsIRandomGenerator> mRandomGenerator; + + nsCString mHashedSecret; // MainThread only + + // Used as key for connection managment: Initially set to hostname from URI, + // then to IP address (unless we're leaving DNS resolution to a proxy server) + // MainThread only + nsCString mAddress; + nsCString mPath; + int32_t mPort; // WS server port + // Secondary key for the connection queue. Used by nsWSAdmissionManager. + nsCString mOriginSuffix; // MainThread only + + // Used for off main thread access to the URI string. + // Set on MainThread in AsyncOpenNative, used on TargetThread and IO thread + nsCString mHost; + nsString mEffectiveURL; + + // Set on MainThread before multithread use, used on IO thread, cleared on + // IOThread + nsCOMPtr<nsISocketTransport> mTransport; + nsCOMPtr<nsIAsyncInputStream> mSocketIn; + nsCOMPtr<nsIAsyncOutputStream> mSocketOut; + RefPtr<WebSocketConnectionBase> mConnection; + + // Only used on IO Thread (accessed when known-to-be-null in DoStopSession + // on MainThread before mDataStarted) + nsCOMPtr<nsITimer> mCloseTimer; + // set in AsyncOpenInternal on MainThread, used on IO thread. + // No multithread use before it's set, no changes after that. + uint32_t mCloseTimeout; /* milliseconds */ + + nsCOMPtr<nsITimer> mOpenTimer; /* Mainthread only */ + uint32_t mOpenTimeout; /* milliseconds, MainThread only */ + wsConnectingState mConnecting; /* 0 if not connecting, MainThread only */ + // Set on MainThread, deleted on MainThread, used on MainThread or + // IO Thread (in DoStopSession). Mutex required to access off-main-thread. + nsCOMPtr<nsITimer> mReconnectDelayTimer MOZ_GUARDED_BY(mMutex); + + // Only touched on IOThread (DoStopSession reads it on MainThread if + // we haven't connected yet (mDataStarted==false), and it's always null + // until mDataStarted=true) + nsCOMPtr<nsITimer> mPingTimer; + + // Created in DoStopSession on IO thread (mDataStarted=true), accessed + // only from IO Thread + nsCOMPtr<nsITimer> mLingeringCloseTimer; + const static int32_t kLingeringCloseTimeout = 1000; + const static int32_t kLingeringCloseThreshold = 50; + + RefPtr<WebSocketEventService> mService; // effectively const + + int32_t + mMaxConcurrentConnections; // only used in AsyncOpenNative on MainThread + + // Set on MainThread in AsyncOpenNative; then used on IO thread + uint64_t mInnerWindowID; + + // following members are accessed only on the main thread + uint32_t mGotUpgradeOK : 1; + uint32_t mRecvdHttpUpgradeTransport : 1; + uint32_t mAllowPMCE : 1; + uint32_t : 0; // ensure these aren't mixed with the next set + + // following members are accessed only on the IO thread + uint32_t mPingOutstanding : 1; + uint32_t mReleaseOnTransmit : 1; + uint32_t : 0; + + Atomic<bool> mDataStarted; + // All changes to mRequestedClose happen under mutex, but since it's atomic, + // it can be read anywhere without a lock + Atomic<bool> mRequestedClose; + // mServer/ClientClosed are only modified on IOThread + Atomic<bool> mClientClosed; + Atomic<bool> mServerClosed; + // All changes to mStopped happen under mutex, but since it's atomic, it + // can be read anywhere without a lock + Atomic<bool> mStopped; + Atomic<bool> mCalledOnStop; + Atomic<bool> mTCPClosed; + Atomic<bool> mOpenedHttpChannel; + Atomic<bool> mIncrementedSessionCount; + Atomic<bool> mDecrementedSessionCount; + + int32_t mMaxMessageSize; // set on MainThread in AsyncOpenNative, read on IO + // thread + // Set on IOThread, or on MainThread before mDataStarted. Used on IO Thread + // (after mDataStarted) + nsresult mStopOnClose; + uint16_t mServerCloseCode; // only used on IO thread + nsCString mServerCloseReason; // only used on IO thread + uint16_t mScriptCloseCode MOZ_GUARDED_BY(mMutex); + nsCString mScriptCloseReason MOZ_GUARDED_BY(mMutex); + + // These are for the read buffers + const static uint32_t kIncomingBufferInitialSize = 16 * 1024; + // We're ok with keeping a buffer this size or smaller around for the life of + // the websocket. If a particular message needs bigger than this we'll + // increase the buffer temporarily, then drop back down to this size. + const static uint32_t kIncomingBufferStableSize = 128 * 1024; + + // Set at creation, used/modified only on IO thread + uint8_t* mFramePtr; + uint8_t* mBuffer; + uint8_t mFragmentOpcode; + uint32_t mFragmentAccumulator; + uint32_t mBuffered; + uint32_t mBufferSize; + + // These are for the send buffers + const static int32_t kCopyBreak = 1000; + + // Only used on IO thread + OutboundMessage* mCurrentOut; + uint32_t mCurrentOutSent; + nsDeque<OutboundMessage> mOutgoingMessages; + nsDeque<OutboundMessage> mOutgoingPingMessages; + nsDeque<OutboundMessage> mOutgoingPongMessages; + uint32_t mHdrOutToSend; + uint8_t* mHdrOut; + uint8_t mOutHeader[kCopyBreak + 16]{0}; + + // Set on MainThread in OnStartRequest (before mDataStarted), or in + // HandleExtensions() or OnTransportAvailableInternal(),used on IO Thread + // (after mDataStarted), cleared in DoStopSession on IOThread or on + // MainThread (if mDataStarted == false). + Mutex mCompressorMutex; + UniquePtr<PMCECompression> mPMCECompressor MOZ_GUARDED_BY(mCompressorMutex); + + // Used by EnsureHdrOut, which isn't called anywhere + uint32_t mDynamicOutputSize; + uint8_t* mDynamicOutput; + // Set on creation and AsyncOpen, read on both threads + Atomic<bool> mPrivateBrowsing; + + nsCOMPtr<nsIDashboardEventNotifier> + mConnectionLogService; // effectively const + + mozilla::Mutex mMutex; +}; + +class WebSocketSSLChannel : public WebSocketChannel { + public: + WebSocketSSLChannel() { BaseWebSocketChannel::mEncrypted = true; } + + protected: + virtual ~WebSocketSSLChannel() = default; +}; + +} // namespace net +} // namespace mozilla + +#endif // mozilla_net_WebSocketChannel_h diff --git a/netwerk/protocol/websocket/WebSocketChannelChild.cpp b/netwerk/protocol/websocket/WebSocketChannelChild.cpp new file mode 100644 index 0000000000..2931b47065 --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketChannelChild.cpp @@ -0,0 +1,732 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et tw=80 : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "WebSocketLog.h" +#include "base/compiler_specific.h" +#include "mozilla/dom/BrowserChild.h" +#include "mozilla/net/NeckoChild.h" +#include "WebSocketChannelChild.h" +#include "nsContentUtils.h" +#include "nsIBrowserChild.h" +#include "nsNetUtil.h" +#include "mozilla/ipc/IPCStreamUtils.h" +#include "mozilla/ipc/URIUtils.h" +#include "mozilla/ipc/BackgroundUtils.h" +#include "mozilla/net/ChannelEventQueue.h" +#include "SerializedLoadContext.h" +#include "mozilla/dom/ContentChild.h" +#include "nsITransportProvider.h" + +using namespace mozilla::ipc; +using mozilla::dom::ContentChild; + +namespace mozilla { +namespace net { + +NS_IMPL_ADDREF(WebSocketChannelChild) + +NS_IMETHODIMP_(MozExternalRefCountType) WebSocketChannelChild::Release() { + MOZ_ASSERT(0 != mRefCnt, "dup release"); + --mRefCnt; + NS_LOG_RELEASE(this, mRefCnt, "WebSocketChannelChild"); + + if (mRefCnt == 1) { + MaybeReleaseIPCObject(); + return mRefCnt; + } + + if (mRefCnt == 0) { + mRefCnt = 1; /* stabilize */ + delete this; + return 0; + } + return mRefCnt; +} + +NS_INTERFACE_MAP_BEGIN(WebSocketChannelChild) + NS_INTERFACE_MAP_ENTRY(nsIWebSocketChannel) + NS_INTERFACE_MAP_ENTRY(nsIProtocolHandler) + NS_INTERFACE_MAP_ENTRY_AMBIGUOUS(nsISupports, nsIWebSocketChannel) + NS_INTERFACE_MAP_ENTRY(nsIThreadRetargetableRequest) +NS_INTERFACE_MAP_END + +WebSocketChannelChild::WebSocketChannelChild(bool aEncrypted) + : NeckoTargetHolder(nullptr), + mIPCState(Closed), + mMutex("WebSocketChannelChild::mMutex") { + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + + LOG(("WebSocketChannelChild::WebSocketChannelChild() %p\n", this)); + mEncrypted = aEncrypted; + mEventQ = new ChannelEventQueue(static_cast<nsIWebSocketChannel*>(this)); +} + +WebSocketChannelChild::~WebSocketChannelChild() { + LOG(("WebSocketChannelChild::~WebSocketChannelChild() %p\n", this)); + mEventQ->NotifyReleasingOwner(); +} + +void WebSocketChannelChild::AddIPDLReference() { + MOZ_ASSERT(NS_IsMainThread()); + + { + MutexAutoLock lock(mMutex); + MOZ_ASSERT(mIPCState == Closed, + "Attempt to retain more than one IPDL reference"); + mIPCState = Opened; + } + + AddRef(); +} + +void WebSocketChannelChild::ReleaseIPDLReference() { + MOZ_ASSERT(NS_IsMainThread()); + + { + MutexAutoLock lock(mMutex); + MOZ_ASSERT(mIPCState != Closed, + "Attempt to release nonexistent IPDL reference"); + mIPCState = Closed; + } + + Release(); +} + +void WebSocketChannelChild::MaybeReleaseIPCObject() { + { + MutexAutoLock lock(mMutex); + if (mIPCState != Opened) { + return; + } + + mIPCState = Closing; + } + + if (!NS_IsMainThread()) { + nsCOMPtr<nsIEventTarget> target = GetNeckoTarget(); + MOZ_ALWAYS_SUCCEEDS(target->Dispatch( + NewRunnableMethod("WebSocketChannelChild::MaybeReleaseIPCObject", this, + &WebSocketChannelChild::MaybeReleaseIPCObject), + NS_DISPATCH_NORMAL)); + return; + } + + SendDeleteSelf(); +} + +void WebSocketChannelChild::GetEffectiveURL(nsAString& aEffectiveURL) const { + aEffectiveURL = mEffectiveURL; +} + +bool WebSocketChannelChild::IsEncrypted() const { return mEncrypted; } + +class WebSocketEvent { + public: + MOZ_COUNTED_DEFAULT_CTOR(WebSocketEvent) + MOZ_COUNTED_DTOR_VIRTUAL(WebSocketEvent) + virtual void Run(WebSocketChannelChild* aChild) = 0; +}; + +class WrappedWebSocketEvent : public Runnable { + public: + WrappedWebSocketEvent(WebSocketChannelChild* aChild, + UniquePtr<WebSocketEvent>&& aWebSocketEvent) + : Runnable("net::WrappedWebSocketEvent"), + mChild(aChild), + mWebSocketEvent(std::move(aWebSocketEvent)) { + MOZ_RELEASE_ASSERT(!!mWebSocketEvent); + } + NS_IMETHOD Run() override { + mWebSocketEvent->Run(mChild); + return NS_OK; + } + + private: + RefPtr<WebSocketChannelChild> mChild; + UniquePtr<WebSocketEvent> mWebSocketEvent; +}; + +class EventTargetDispatcher : public ChannelEvent { + public: + EventTargetDispatcher(WebSocketChannelChild* aChild, + WebSocketEvent* aWebSocketEvent) + : mChild(aChild), + mWebSocketEvent(aWebSocketEvent), + mEventTarget(mChild->GetTargetThread()) {} + + void Run() override { + if (mEventTarget) { + mEventTarget->Dispatch( + new WrappedWebSocketEvent(mChild, std::move(mWebSocketEvent)), + NS_DISPATCH_NORMAL); + return; + } + } + + already_AddRefed<nsIEventTarget> GetEventTarget() override { + nsCOMPtr<nsIEventTarget> target = mEventTarget; + if (!target) { + target = GetMainThreadSerialEventTarget(); + } + return target.forget(); + } + + private: + // The lifetime of the child is ensured by ChannelEventQueue. + WebSocketChannelChild* mChild; + UniquePtr<WebSocketEvent> mWebSocketEvent; + nsCOMPtr<nsIEventTarget> mEventTarget; +}; + +class StartEvent : public WebSocketEvent { + public: + StartEvent(const nsACString& aProtocol, const nsACString& aExtensions, + const nsAString& aEffectiveURL, bool aEncrypted, + uint64_t aHttpChannelId) + : mProtocol(aProtocol), + mExtensions(aExtensions), + mEffectiveURL(aEffectiveURL), + mEncrypted(aEncrypted), + mHttpChannelId(aHttpChannelId) {} + + void Run(WebSocketChannelChild* aChild) override { + aChild->OnStart(mProtocol, mExtensions, mEffectiveURL, mEncrypted, + mHttpChannelId); + } + + private: + nsCString mProtocol; + nsCString mExtensions; + nsString mEffectiveURL; + bool mEncrypted; + uint64_t mHttpChannelId; +}; + +mozilla::ipc::IPCResult WebSocketChannelChild::RecvOnStart( + const nsACString& aProtocol, const nsACString& aExtensions, + const nsAString& aEffectiveURL, const bool& aEncrypted, + const uint64_t& aHttpChannelId) { + mEventQ->RunOrEnqueue(new EventTargetDispatcher( + this, new StartEvent(aProtocol, aExtensions, aEffectiveURL, aEncrypted, + aHttpChannelId))); + + return IPC_OK(); +} + +void WebSocketChannelChild::OnStart(const nsACString& aProtocol, + const nsACString& aExtensions, + const nsAString& aEffectiveURL, + const bool& aEncrypted, + const uint64_t& aHttpChannelId) { + LOG(("WebSocketChannelChild::RecvOnStart() %p\n", this)); + SetProtocol(aProtocol); + mNegotiatedExtensions = aExtensions; + mEffectiveURL = aEffectiveURL; + mEncrypted = aEncrypted; + mHttpChannelId = aHttpChannelId; + + if (mListenerMT) { + AutoEventEnqueuer ensureSerialDispatch(mEventQ); + nsresult rv = mListenerMT->mListener->OnStart(mListenerMT->mContext); + if (NS_FAILED(rv)) { + LOG( + ("WebSocketChannelChild::OnStart " + "mListenerMT->mListener->OnStart() failed with error 0x%08" PRIx32, + static_cast<uint32_t>(rv))); + } + } +} + +class StopEvent : public WebSocketEvent { + public: + explicit StopEvent(const nsresult& aStatusCode) : mStatusCode(aStatusCode) {} + + void Run(WebSocketChannelChild* aChild) override { + aChild->OnStop(mStatusCode); + } + + private: + nsresult mStatusCode; +}; + +mozilla::ipc::IPCResult WebSocketChannelChild::RecvOnStop( + const nsresult& aStatusCode) { + mEventQ->RunOrEnqueue( + new EventTargetDispatcher(this, new StopEvent(aStatusCode))); + + return IPC_OK(); +} + +void WebSocketChannelChild::OnStop(const nsresult& aStatusCode) { + LOG(("WebSocketChannelChild::RecvOnStop() %p\n", this)); + if (mListenerMT) { + AutoEventEnqueuer ensureSerialDispatch(mEventQ); + nsresult rv = + mListenerMT->mListener->OnStop(mListenerMT->mContext, aStatusCode); + if (NS_FAILED(rv)) { + LOG( + ("WebSocketChannel::OnStop " + "mListenerMT->mListener->OnStop() failed with error 0x%08" PRIx32, + static_cast<uint32_t>(rv))); + } + } +} + +class MessageEvent : public WebSocketEvent { + public: + MessageEvent(const nsACString& aMessage, bool aBinary) + : mMessage(aMessage), mBinary(aBinary) {} + + void Run(WebSocketChannelChild* aChild) override { + if (!mBinary) { + aChild->OnMessageAvailable(mMessage); + } else { + aChild->OnBinaryMessageAvailable(mMessage); + } + } + + private: + nsCString mMessage; + bool mBinary; +}; + +bool WebSocketChannelChild::RecvOnMessageAvailableInternal( + const nsACString& aMsg, bool aMoreData, bool aBinary) { + if (aMoreData) { + return mReceivedMsgBuffer.Append(aMsg, fallible); + } + + if (!mReceivedMsgBuffer.Append(aMsg, fallible)) { + return false; + } + + mEventQ->RunOrEnqueue(new EventTargetDispatcher( + this, new MessageEvent(mReceivedMsgBuffer, aBinary))); + mReceivedMsgBuffer.Truncate(); + return true; +} + +class OnErrorEvent : public WebSocketEvent { + public: + OnErrorEvent() = default; + + void Run(WebSocketChannelChild* aChild) override { aChild->OnError(); } +}; + +void WebSocketChannelChild::OnError() { + LOG(("WebSocketChannelChild::OnError() %p", this)); + if (mListenerMT) { + AutoEventEnqueuer ensureSerialDispatch(mEventQ); + Unused << mListenerMT->mListener->OnError(); + } +} + +mozilla::ipc::IPCResult WebSocketChannelChild::RecvOnMessageAvailable( + const nsACString& aMsg, const bool& aMoreData) { + if (!RecvOnMessageAvailableInternal(aMsg, aMoreData, false)) { + LOG(("WebSocketChannelChild %p append message failed", this)); + mEventQ->RunOrEnqueue(new EventTargetDispatcher(this, new OnErrorEvent())); + } + return IPC_OK(); +} + +void WebSocketChannelChild::OnMessageAvailable(const nsACString& aMsg) { + LOG(("WebSocketChannelChild::RecvOnMessageAvailable() %p\n", this)); + if (mListenerMT) { + AutoEventEnqueuer ensureSerialDispatch(mEventQ); + nsresult rv = + mListenerMT->mListener->OnMessageAvailable(mListenerMT->mContext, aMsg); + if (NS_FAILED(rv)) { + LOG( + ("WebSocketChannelChild::OnMessageAvailable " + "mListenerMT->mListener->OnMessageAvailable() " + "failed with error 0x%08" PRIx32, + static_cast<uint32_t>(rv))); + } + } +} + +mozilla::ipc::IPCResult WebSocketChannelChild::RecvOnBinaryMessageAvailable( + const nsACString& aMsg, const bool& aMoreData) { + if (!RecvOnMessageAvailableInternal(aMsg, aMoreData, true)) { + LOG(("WebSocketChannelChild %p append message failed", this)); + mEventQ->RunOrEnqueue(new EventTargetDispatcher(this, new OnErrorEvent())); + } + return IPC_OK(); +} + +void WebSocketChannelChild::OnBinaryMessageAvailable(const nsACString& aMsg) { + LOG(("WebSocketChannelChild::RecvOnBinaryMessageAvailable() %p\n", this)); + if (mListenerMT) { + AutoEventEnqueuer ensureSerialDispatch(mEventQ); + nsresult rv = mListenerMT->mListener->OnBinaryMessageAvailable( + mListenerMT->mContext, aMsg); + if (NS_FAILED(rv)) { + LOG( + ("WebSocketChannelChild::OnBinaryMessageAvailable " + "mListenerMT->mListener->OnBinaryMessageAvailable() " + "failed with error 0x%08" PRIx32, + static_cast<uint32_t>(rv))); + } + } +} + +class AcknowledgeEvent : public WebSocketEvent { + public: + explicit AcknowledgeEvent(const uint32_t& aSize) : mSize(aSize) {} + + void Run(WebSocketChannelChild* aChild) override { + aChild->OnAcknowledge(mSize); + } + + private: + uint32_t mSize; +}; + +mozilla::ipc::IPCResult WebSocketChannelChild::RecvOnAcknowledge( + const uint32_t& aSize) { + mEventQ->RunOrEnqueue( + new EventTargetDispatcher(this, new AcknowledgeEvent(aSize))); + + return IPC_OK(); +} + +void WebSocketChannelChild::OnAcknowledge(const uint32_t& aSize) { + LOG(("WebSocketChannelChild::RecvOnAcknowledge() %p\n", this)); + if (mListenerMT) { + AutoEventEnqueuer ensureSerialDispatch(mEventQ); + nsresult rv = + mListenerMT->mListener->OnAcknowledge(mListenerMT->mContext, aSize); + if (NS_FAILED(rv)) { + LOG( + ("WebSocketChannel::OnAcknowledge " + "mListenerMT->mListener->OnAcknowledge() " + "failed with error 0x%08" PRIx32, + static_cast<uint32_t>(rv))); + } + } +} + +class ServerCloseEvent : public WebSocketEvent { + public: + ServerCloseEvent(const uint16_t aCode, const nsACString& aReason) + : mCode(aCode), mReason(aReason) {} + + void Run(WebSocketChannelChild* aChild) override { + aChild->OnServerClose(mCode, mReason); + } + + private: + uint16_t mCode; + nsCString mReason; +}; + +mozilla::ipc::IPCResult WebSocketChannelChild::RecvOnServerClose( + const uint16_t& aCode, const nsACString& aReason) { + mEventQ->RunOrEnqueue( + new EventTargetDispatcher(this, new ServerCloseEvent(aCode, aReason))); + + return IPC_OK(); +} + +void WebSocketChannelChild::OnServerClose(const uint16_t& aCode, + const nsACString& aReason) { + LOG(("WebSocketChannelChild::RecvOnServerClose() %p\n", this)); + if (mListenerMT) { + AutoEventEnqueuer ensureSerialDispatch(mEventQ); + DebugOnly<nsresult> rv = mListenerMT->mListener->OnServerClose( + mListenerMT->mContext, aCode, aReason); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + } +} + +void WebSocketChannelChild::SetupNeckoTarget() { + mNeckoTarget = GetMainThreadSerialEventTarget(); +} + +NS_IMETHODIMP +WebSocketChannelChild::AsyncOpen(nsIURI* aURI, const nsACString& aOrigin, + JS::Handle<JS::Value> aOriginAttributes, + uint64_t aInnerWindowID, + nsIWebSocketListener* aListener, + nsISupports* aContext, JSContext* aCx) { + OriginAttributes attrs; + if (!aOriginAttributes.isObject() || !attrs.Init(aCx, aOriginAttributes)) { + return NS_ERROR_INVALID_ARG; + } + return AsyncOpenNative(aURI, aOrigin, attrs, aInnerWindowID, aListener, + aContext); +} + +NS_IMETHODIMP +WebSocketChannelChild::AsyncOpenNative( + nsIURI* aURI, const nsACString& aOrigin, + const OriginAttributes& aOriginAttributes, uint64_t aInnerWindowID, + nsIWebSocketListener* aListener, nsISupports* aContext) { + LOG(("WebSocketChannelChild::AsyncOpen() %p\n", this)); + + MOZ_ASSERT(NS_IsMainThread(), "not main thread"); + MOZ_ASSERT((aURI && !mIsServerSide) || (!aURI && mIsServerSide), + "Invalid aURI for WebSocketChannelChild::AsyncOpen"); + MOZ_ASSERT(aListener && !mListenerMT, + "Invalid state for WebSocketChannelChild::AsyncOpen"); + + mozilla::dom::BrowserChild* browserChild = nullptr; + nsCOMPtr<nsIBrowserChild> iBrowserChild; + NS_QueryNotificationCallbacks(mCallbacks, mLoadGroup, + NS_GET_IID(nsIBrowserChild), + getter_AddRefs(iBrowserChild)); + if (iBrowserChild) { + browserChild = + static_cast<mozilla::dom::BrowserChild*>(iBrowserChild.get()); + } + + ContentChild* cc = static_cast<ContentChild*>(gNeckoChild->Manager()); + if (cc->IsShuttingDown()) { + return NS_ERROR_FAILURE; + } + + // Corresponding release in DeallocPWebSocket + AddIPDLReference(); + + nsCOMPtr<nsIURI> uri; + LoadInfoArgs loadInfoArgs; + Maybe<NotNull<PTransportProviderChild*>> transportProvider; + + if (!mIsServerSide) { + uri = aURI; + nsresult rv = LoadInfoToLoadInfoArgs(mLoadInfo, &loadInfoArgs); + NS_ENSURE_SUCCESS(rv, rv); + + transportProvider = Nothing(); + } else { + MOZ_ASSERT(mServerTransportProvider); + PTransportProviderChild* ipcChild; + nsresult rv = mServerTransportProvider->GetIPCChild(&ipcChild); + NS_ENSURE_SUCCESS(rv, rv); + + transportProvider = Some(WrapNotNull(ipcChild)); + } + + // This must be called before sending constructor message. + SetupNeckoTarget(); + + if (!gNeckoChild->SendPWebSocketConstructor( + this, browserChild, IPC::SerializedLoadContext(this), mSerial)) { + return NS_ERROR_UNEXPECTED; + } + if (!SendAsyncOpen(uri, aOrigin, aOriginAttributes, aInnerWindowID, mProtocol, + mEncrypted, mPingInterval, mClientSetPingInterval, + mPingResponseTimeout, mClientSetPingTimeout, loadInfoArgs, + transportProvider, mNegotiatedExtensions)) { + return NS_ERROR_UNEXPECTED; + } + + if (mIsServerSide) { + mServerTransportProvider = nullptr; + } + + mOriginalURI = aURI; + mURI = mOriginalURI; + mListenerMT = new ListenerAndContextContainer(aListener, aContext); + mOrigin = aOrigin; + mWasOpened = 1; + + return NS_OK; +} + +class CloseEvent : public Runnable { + public: + CloseEvent(WebSocketChannelChild* aChild, uint16_t aCode, + const nsACString& aReason) + : Runnable("net::CloseEvent"), + mChild(aChild), + mCode(aCode), + mReason(aReason) { + MOZ_RELEASE_ASSERT(!NS_IsMainThread()); + MOZ_ASSERT(aChild); + } + NS_IMETHOD Run() override { + MOZ_RELEASE_ASSERT(NS_IsMainThread()); + mChild->Close(mCode, mReason); + return NS_OK; + } + + private: + RefPtr<WebSocketChannelChild> mChild; + uint16_t mCode; + nsCString mReason; +}; + +NS_IMETHODIMP +WebSocketChannelChild::Close(uint16_t code, const nsACString& reason) { + if (!NS_IsMainThread()) { + MOZ_RELEASE_ASSERT(IsOnTargetThread()); + nsCOMPtr<nsIEventTarget> target = GetNeckoTarget(); + return target->Dispatch(new CloseEvent(this, code, reason), + NS_DISPATCH_NORMAL); + } + LOG(("WebSocketChannelChild::Close() %p\n", this)); + + { + MutexAutoLock lock(mMutex); + if (mIPCState != Opened) { + return NS_ERROR_UNEXPECTED; + } + } + + if (!SendClose(code, reason)) { + return NS_ERROR_UNEXPECTED; + } + + return NS_OK; +} + +class MsgEvent : public Runnable { + public: + MsgEvent(WebSocketChannelChild* aChild, const nsACString& aMsg, + bool aBinaryMsg) + : Runnable("net::MsgEvent"), + mChild(aChild), + mMsg(aMsg), + mBinaryMsg(aBinaryMsg) { + MOZ_RELEASE_ASSERT(!NS_IsMainThread()); + MOZ_ASSERT(aChild); + } + NS_IMETHOD Run() override { + MOZ_RELEASE_ASSERT(NS_IsMainThread()); + if (mBinaryMsg) { + mChild->SendBinaryMsg(mMsg); + } else { + mChild->SendMsg(mMsg); + } + return NS_OK; + } + + private: + RefPtr<WebSocketChannelChild> mChild; + nsCString mMsg; + bool mBinaryMsg; +}; + +NS_IMETHODIMP +WebSocketChannelChild::SendMsg(const nsACString& aMsg) { + if (!NS_IsMainThread()) { + MOZ_RELEASE_ASSERT(IsOnTargetThread()); + nsCOMPtr<nsIEventTarget> target = GetNeckoTarget(); + return target->Dispatch(new MsgEvent(this, aMsg, false), + NS_DISPATCH_NORMAL); + } + LOG(("WebSocketChannelChild::SendMsg() %p\n", this)); + + { + MutexAutoLock lock(mMutex); + if (mIPCState != Opened) { + return NS_ERROR_UNEXPECTED; + } + } + + if (!SendSendMsg(aMsg)) { + return NS_ERROR_UNEXPECTED; + } + + return NS_OK; +} + +NS_IMETHODIMP +WebSocketChannelChild::SendBinaryMsg(const nsACString& aMsg) { + if (!NS_IsMainThread()) { + MOZ_RELEASE_ASSERT(IsOnTargetThread()); + nsCOMPtr<nsIEventTarget> target = GetNeckoTarget(); + return target->Dispatch(new MsgEvent(this, aMsg, true), NS_DISPATCH_NORMAL); + } + LOG(("WebSocketChannelChild::SendBinaryMsg() %p\n", this)); + + { + MutexAutoLock lock(mMutex); + if (mIPCState != Opened) { + return NS_ERROR_UNEXPECTED; + } + } + + if (!SendSendBinaryMsg(aMsg)) { + return NS_ERROR_UNEXPECTED; + } + + return NS_OK; +} + +class BinaryStreamEvent : public Runnable { + public: + BinaryStreamEvent(WebSocketChannelChild* aChild, nsIInputStream* aStream, + uint32_t aLength) + : Runnable("net::BinaryStreamEvent"), + mChild(aChild), + mStream(aStream), + mLength(aLength) { + MOZ_RELEASE_ASSERT(!NS_IsMainThread()); + MOZ_ASSERT(aChild); + } + NS_IMETHOD Run() override { + MOZ_ASSERT(NS_IsMainThread()); + nsresult rv = mChild->SendBinaryStream(mStream, mLength); + if (NS_FAILED(rv)) { + LOG( + ("WebSocketChannelChild::BinaryStreamEvent %p " + "SendBinaryStream failed (%08" PRIx32 ")\n", + this, static_cast<uint32_t>(rv))); + } + return NS_OK; + } + + private: + RefPtr<WebSocketChannelChild> mChild; + nsCOMPtr<nsIInputStream> mStream; + uint32_t mLength; +}; + +NS_IMETHODIMP +WebSocketChannelChild::SendBinaryStream(nsIInputStream* aStream, + uint32_t aLength) { + if (!NS_IsMainThread()) { + MOZ_RELEASE_ASSERT(IsOnTargetThread()); + nsCOMPtr<nsIEventTarget> target = GetNeckoTarget(); + return target->Dispatch(new BinaryStreamEvent(this, aStream, aLength), + NS_DISPATCH_NORMAL); + } + + LOG(("WebSocketChannelChild::SendBinaryStream() %p\n", this)); + + IPCStream ipcStream; + if (NS_WARN_IF(!mozilla::ipc::SerializeIPCStream(do_AddRef(aStream), + ipcStream, + /* aAllowLazy */ false))) { + return NS_ERROR_UNEXPECTED; + } + + { + MutexAutoLock lock(mMutex); + if (mIPCState != Opened) { + return NS_ERROR_UNEXPECTED; + } + } + + if (!SendSendBinaryStream(ipcStream, aLength)) { + return NS_ERROR_UNEXPECTED; + } + + return NS_OK; +} + +NS_IMETHODIMP +WebSocketChannelChild::GetSecurityInfo( + nsITransportSecurityInfo** aSecurityInfo) { + LOG(("WebSocketChannelChild::GetSecurityInfo() %p\n", this)); + return NS_ERROR_NOT_AVAILABLE; +} + +} // namespace net +} // namespace mozilla diff --git a/netwerk/protocol/websocket/WebSocketChannelChild.h b/netwerk/protocol/websocket/WebSocketChannelChild.h new file mode 100644 index 0000000000..0221b9fdde --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketChannelChild.h @@ -0,0 +1,116 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et tw=80 : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_net_WebSocketChannelChild_h +#define mozilla_net_WebSocketChannelChild_h + +#include "mozilla/net/NeckoTargetHolder.h" +#include "mozilla/net/PWebSocketChild.h" +#include "mozilla/net/BaseWebSocketChannel.h" +#include "nsString.h" + +namespace mozilla { + +namespace net { + +class ChannelEvent; +class ChannelEventQueue; +class MessageEvent; + +class WebSocketChannelChild final : public BaseWebSocketChannel, + public PWebSocketChild, + public NeckoTargetHolder { + friend class PWebSocketChild; + + public: + explicit WebSocketChannelChild(bool aEncrypted); + + NS_DECL_THREADSAFE_ISUPPORTS + + // nsIWebSocketChannel methods BaseWebSocketChannel didn't implement for us + // + NS_IMETHOD AsyncOpen(nsIURI* aURI, const nsACString& aOrigin, + JS::Handle<JS::Value> aOriginAttributes, + uint64_t aInnerWindowID, nsIWebSocketListener* aListener, + nsISupports* aContext, JSContext* aCx) override; + NS_IMETHOD AsyncOpenNative(nsIURI* aURI, const nsACString& aOrigin, + const OriginAttributes& aOriginAttributes, + uint64_t aInnerWindowID, + nsIWebSocketListener* aListener, + nsISupports* aContext) override; + NS_IMETHOD Close(uint16_t code, const nsACString& reason) override; + NS_IMETHOD SendMsg(const nsACString& aMsg) override; + NS_IMETHOD SendBinaryMsg(const nsACString& aMsg) override; + NS_IMETHOD SendBinaryStream(nsIInputStream* aStream, + uint32_t aLength) override; + NS_IMETHOD GetSecurityInfo(nsITransportSecurityInfo** aSecurityInfo) override; + + void AddIPDLReference(); + void ReleaseIPDLReference(); + + // Off main thread URI access. + void GetEffectiveURL(nsAString& aEffectiveURL) const override; + bool IsEncrypted() const override; + + private: + ~WebSocketChannelChild(); + + mozilla::ipc::IPCResult RecvOnStart(const nsACString& aProtocol, + const nsACString& aExtensions, + const nsAString& aEffectiveURL, + const bool& aEncrypted, + const uint64_t& aHttpChannelId); + mozilla::ipc::IPCResult RecvOnStop(const nsresult& aStatusCode); + mozilla::ipc::IPCResult RecvOnMessageAvailable(const nsACString& aMsg, + const bool& aMoreData); + mozilla::ipc::IPCResult RecvOnBinaryMessageAvailable(const nsACString& aMsg, + const bool& aMoreData); + mozilla::ipc::IPCResult RecvOnAcknowledge(const uint32_t& aSize); + mozilla::ipc::IPCResult RecvOnServerClose(const uint16_t& aCode, + const nsACString& aReason); + + void OnStart(const nsACString& aProtocol, const nsACString& aExtensions, + const nsAString& aEffectiveURL, const bool& aEncrypted, + const uint64_t& aHttpChannelId); + void OnStop(const nsresult& aStatusCode); + void OnMessageAvailable(const nsACString& aMsg); + void OnBinaryMessageAvailable(const nsACString& aMsg); + void OnAcknowledge(const uint32_t& aSize); + void OnServerClose(const uint16_t& aCode, const nsACString& aReason); + void AsyncOpenFailed(); + + void MaybeReleaseIPCObject(); + + // This function tries to get a labeled event target for |mNeckoTarget|. + void SetupNeckoTarget(); + + bool RecvOnMessageAvailableInternal(const nsACString& aMsg, bool aMoreData, + bool aBinary); + + void OnError(); + + RefPtr<ChannelEventQueue> mEventQ; + nsString mEffectiveURL; + nsCString mReceivedMsgBuffer; + + // This variable is protected by mutex. + enum { Opened, Closing, Closed } mIPCState; + + mozilla::Mutex mMutex MOZ_UNANNOTATED; + + friend class StartEvent; + friend class StopEvent; + friend class MessageEvent; + friend class AcknowledgeEvent; + friend class ServerCloseEvent; + friend class AsyncOpenFailedEvent; + friend class OnErrorEvent; +}; + +} // namespace net +} // namespace mozilla + +#endif // mozilla_net_WebSocketChannelChild_h diff --git a/netwerk/protocol/websocket/WebSocketChannelParent.cpp b/netwerk/protocol/websocket/WebSocketChannelParent.cpp new file mode 100644 index 0000000000..b690f3ac27 --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketChannelParent.cpp @@ -0,0 +1,360 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et tw=80 : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "WebSocketLog.h" +#include "WebSocketChannelParent.h" +#include "nsIAuthPromptProvider.h" +#include "mozilla/dom/ContentParent.h" +#include "mozilla/ipc/InputStreamUtils.h" +#include "mozilla/ipc/URIUtils.h" +#include "mozilla/ipc/BackgroundUtils.h" +#include "SerializedLoadContext.h" +#include "mozilla/net/NeckoCommon.h" +#include "mozilla/net/WebSocketChannel.h" +#include "nsComponentManagerUtils.h" +#include "IPCTransportProvider.h" +#include "mozilla/net/ChannelEventQueue.h" + +using namespace mozilla::ipc; + +namespace mozilla { +namespace net { + +NS_IMPL_ISUPPORTS(WebSocketChannelParent, nsIWebSocketListener, + nsIInterfaceRequestor) + +WebSocketChannelParent::WebSocketChannelParent( + nsIAuthPromptProvider* aAuthProvider, nsILoadContext* aLoadContext, + PBOverrideStatus aOverrideStatus, uint32_t aSerial) + : mAuthProvider(aAuthProvider), + mLoadContext(aLoadContext), + mSerial(aSerial) { + // Websocket channels can't have a private browsing override + MOZ_ASSERT_IF(!aLoadContext, aOverrideStatus == kPBOverride_Unset); +} +//----------------------------------------------------------------------------- +// WebSocketChannelParent::PWebSocketChannelParent +//----------------------------------------------------------------------------- + +mozilla::ipc::IPCResult WebSocketChannelParent::RecvDeleteSelf() { + LOG(("WebSocketChannelParent::RecvDeleteSelf() %p\n", this)); + mChannel = nullptr; + mAuthProvider = nullptr; + IProtocol* mgr = Manager(); + if (CanRecv() && !Send__delete__(this)) { + return IPC_FAIL_NO_REASON(mgr); + } + return IPC_OK(); +} + +mozilla::ipc::IPCResult WebSocketChannelParent::RecvAsyncOpen( + nsIURI* aURI, const nsCString& aOrigin, + const OriginAttributes& aOriginAttributes, const uint64_t& aInnerWindowID, + const nsCString& aProtocol, const bool& aSecure, + const uint32_t& aPingInterval, const bool& aClientSetPingInterval, + const uint32_t& aPingTimeout, const bool& aClientSetPingTimeout, + const LoadInfoArgs& aLoadInfoArgs, + const Maybe<PTransportProviderParent*>& aTransportProvider, + const nsCString& aNegotiatedExtensions) { + LOG(("WebSocketChannelParent::RecvAsyncOpen() %p\n", this)); + + nsresult rv; + nsCOMPtr<nsILoadInfo> loadInfo; + nsCOMPtr<nsIURI> uri; + + rv = LoadInfoArgsToLoadInfo( + aLoadInfoArgs, + mozilla::dom::ContentParent::Cast(Manager()->Manager())->GetRemoteType(), + getter_AddRefs(loadInfo)); + if (NS_FAILED(rv)) { + goto fail; + } + + if (aSecure) { + mChannel = + do_CreateInstance("@mozilla.org/network/protocol;1?name=wss", &rv); + } else { + mChannel = + do_CreateInstance("@mozilla.org/network/protocol;1?name=ws", &rv); + } + if (NS_FAILED(rv)) goto fail; + + rv = mChannel->SetSerial(mSerial); + if (NS_WARN_IF(NS_FAILED(rv))) { + goto fail; + } + + rv = mChannel->SetLoadInfo(loadInfo); + if (NS_FAILED(rv)) { + goto fail; + } + + rv = mChannel->SetNotificationCallbacks(this); + if (NS_FAILED(rv)) goto fail; + + rv = mChannel->SetProtocol(aProtocol); + if (NS_FAILED(rv)) goto fail; + + if (aTransportProvider.isSome()) { + RefPtr<TransportProviderParent> provider = + static_cast<TransportProviderParent*>(aTransportProvider.value()); + rv = mChannel->SetServerParameters(provider, aNegotiatedExtensions); + if (NS_FAILED(rv)) { + goto fail; + } + } else { + uri = aURI; + if (!uri) { + rv = NS_ERROR_FAILURE; + goto fail; + } + } + + // only use ping values from child if they were overridden by client code. + if (aClientSetPingInterval) { + // IDL allows setting in seconds, so must be multiple of 1000 ms + MOZ_ASSERT(aPingInterval >= 1000 && !(aPingInterval % 1000)); + DebugOnly<nsresult> rv = mChannel->SetPingInterval(aPingInterval / 1000); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + } + if (aClientSetPingTimeout) { + MOZ_ASSERT(aPingTimeout >= 1000 && !(aPingTimeout % 1000)); + DebugOnly<nsresult> rv = mChannel->SetPingTimeout(aPingTimeout / 1000); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + } + + rv = mChannel->AsyncOpenNative(uri, aOrigin, aOriginAttributes, + aInnerWindowID, this, nullptr); + if (NS_FAILED(rv)) goto fail; + + return IPC_OK(); + +fail: + mChannel = nullptr; + if (!SendOnStop(rv)) { + return IPC_FAIL_NO_REASON(this); + } + return IPC_OK(); +} + +mozilla::ipc::IPCResult WebSocketChannelParent::RecvClose( + const uint16_t& code, const nsCString& reason) { + LOG(("WebSocketChannelParent::RecvClose() %p\n", this)); + if (mChannel) { + nsresult rv = mChannel->Close(code, reason); + NS_ENSURE_SUCCESS(rv, IPC_OK()); + } + + return IPC_OK(); +} + +mozilla::ipc::IPCResult WebSocketChannelParent::RecvSendMsg( + const nsCString& aMsg) { + LOG(("WebSocketChannelParent::RecvSendMsg() %p\n", this)); + if (mChannel) { + nsresult rv = mChannel->SendMsg(aMsg); + NS_ENSURE_SUCCESS(rv, IPC_OK()); + } + return IPC_OK(); +} + +mozilla::ipc::IPCResult WebSocketChannelParent::RecvSendBinaryMsg( + const nsCString& aMsg) { + LOG(("WebSocketChannelParent::RecvSendBinaryMsg() %p\n", this)); + if (mChannel) { + nsresult rv = mChannel->SendBinaryMsg(aMsg); + NS_ENSURE_SUCCESS(rv, IPC_OK()); + } + return IPC_OK(); +} + +mozilla::ipc::IPCResult WebSocketChannelParent::RecvSendBinaryStream( + const IPCStream& aStream, const uint32_t& aLength) { + LOG(("WebSocketChannelParent::RecvSendBinaryStream() %p\n", this)); + if (mChannel) { + nsCOMPtr<nsIInputStream> stream = DeserializeIPCStream(aStream); + if (!stream) { + return IPC_FAIL_NO_REASON(this); + } + nsresult rv = mChannel->SendBinaryStream(stream, aLength); + NS_ENSURE_SUCCESS(rv, IPC_OK()); + } + return IPC_OK(); +} + +//----------------------------------------------------------------------------- +// WebSocketChannelParent::nsIRequestObserver +//----------------------------------------------------------------------------- + +NS_IMETHODIMP +WebSocketChannelParent::OnStart(nsISupports* aContext) { + LOG(("WebSocketChannelParent::OnStart() %p\n", this)); + nsAutoCString protocol, extensions; + nsString effectiveURL; + bool encrypted = false; + uint64_t httpChannelId = 0; + if (mChannel) { + DebugOnly<nsresult> rv = mChannel->GetProtocol(protocol); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + rv = mChannel->GetExtensions(extensions); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + + RefPtr<WebSocketChannel> channel; + channel = static_cast<WebSocketChannel*>(mChannel.get()); + MOZ_ASSERT(channel); + + channel->GetEffectiveURL(effectiveURL); + encrypted = channel->IsEncrypted(); + httpChannelId = channel->HttpChannelId(); + } + if (!CanRecv() || !SendOnStart(protocol, extensions, effectiveURL, encrypted, + httpChannelId)) { + return NS_ERROR_FAILURE; + } + return NS_OK; +} + +NS_IMETHODIMP +WebSocketChannelParent::OnStop(nsISupports* aContext, nsresult aStatusCode) { + LOG(("WebSocketChannelParent::OnStop() %p\n", this)); + if (!CanRecv() || !SendOnStop(aStatusCode)) { + return NS_ERROR_FAILURE; + } + return NS_OK; +} + +static bool SendOnMessageAvailableHelper( + const nsACString& aMsg, + const std::function<bool(const nsDependentCSubstring&, bool)>& aSendFunc) { + // To avoid the crash caused by too large IPC message, we have to split the + // data in small chunks and send them to child process. Note that the chunk + // size used here is the same as what we used for PHttpChannel. + static uint32_t const kCopyChunkSize = 128 * 1024; + uint32_t count = aMsg.Length(); + if (count <= kCopyChunkSize) { + return aSendFunc(nsDependentCSubstring(aMsg), false); + } + + uint32_t start = 0; + uint32_t toRead = std::min<uint32_t>(count, kCopyChunkSize); + while (count) { + nsDependentCSubstring data(Substring(aMsg, start, toRead)); + + if (!aSendFunc(data, count > kCopyChunkSize)) { + return false; + } + + start += toRead; + count -= toRead; + toRead = std::min<uint32_t>(count, kCopyChunkSize); + } + + return true; +} + +NS_IMETHODIMP +WebSocketChannelParent::OnMessageAvailable(nsISupports* aContext, + const nsACString& aMsg) { + LOG(("WebSocketChannelParent::OnMessageAvailable() %p\n", this)); + + if (!CanRecv()) { + return NS_ERROR_FAILURE; + } + + auto sendFunc = [self = UnsafePtr<WebSocketChannelParent>(this)]( + const nsDependentCSubstring& aMsg, bool aMoreData) { + return self->SendOnMessageAvailable(aMsg, aMoreData); + }; + + if (!SendOnMessageAvailableHelper(aMsg, sendFunc)) { + return NS_ERROR_FAILURE; + } + + return NS_OK; +} + +NS_IMETHODIMP +WebSocketChannelParent::OnBinaryMessageAvailable(nsISupports* aContext, + const nsACString& aMsg) { + LOG(("WebSocketChannelParent::OnBinaryMessageAvailable() %p\n", this)); + + if (!CanRecv()) { + return NS_ERROR_FAILURE; + } + + auto sendFunc = [self = UnsafePtr<WebSocketChannelParent>(this)]( + const nsDependentCSubstring& aMsg, bool aMoreData) { + return self->SendOnBinaryMessageAvailable(aMsg, aMoreData); + }; + + if (!SendOnMessageAvailableHelper(aMsg, sendFunc)) { + return NS_ERROR_FAILURE; + } + + return NS_OK; +} + +NS_IMETHODIMP +WebSocketChannelParent::OnAcknowledge(nsISupports* aContext, uint32_t aSize) { + LOG(("WebSocketChannelParent::OnAcknowledge() %p\n", this)); + if (!CanRecv() || !SendOnAcknowledge(aSize)) { + return NS_ERROR_FAILURE; + } + return NS_OK; +} + +NS_IMETHODIMP +WebSocketChannelParent::OnServerClose(nsISupports* aContext, uint16_t code, + const nsACString& reason) { + LOG(("WebSocketChannelParent::OnServerClose() %p\n", this)); + if (!CanRecv() || !SendOnServerClose(code, reason)) { + return NS_ERROR_FAILURE; + } + return NS_OK; +} + +NS_IMETHODIMP +WebSocketChannelParent::OnError() { return NS_OK; } + +void WebSocketChannelParent::ActorDestroy(ActorDestroyReason why) { + LOG(("WebSocketChannelParent::ActorDestroy() %p\n", this)); + + // Make sure we close the channel if the content process dies without going + // through a clean shutdown. + if (mChannel) { + Unused << mChannel->Close(nsIWebSocketChannel::CLOSE_GOING_AWAY, + "Child was killed"_ns); + } +} + +//----------------------------------------------------------------------------- +// WebSocketChannelParent::nsIInterfaceRequestor +//----------------------------------------------------------------------------- + +NS_IMETHODIMP +WebSocketChannelParent::GetInterface(const nsIID& iid, void** result) { + LOG(("WebSocketChannelParent::GetInterface() %p\n", this)); + if (mAuthProvider && iid.Equals(NS_GET_IID(nsIAuthPromptProvider))) { + nsresult rv = mAuthProvider->GetAuthPrompt( + nsIAuthPromptProvider::PROMPT_NORMAL, iid, result); + if (NS_FAILED(rv)) { + return NS_ERROR_NO_INTERFACE; + } + return NS_OK; + } + + // Only support nsILoadContext if child channel's callbacks did too + if (iid.Equals(NS_GET_IID(nsILoadContext)) && mLoadContext) { + nsCOMPtr<nsILoadContext> copy = mLoadContext; + copy.forget(result); + return NS_OK; + } + + return QueryInterface(iid, result); +} + +} // namespace net +} // namespace mozilla diff --git a/netwerk/protocol/websocket/WebSocketChannelParent.h b/netwerk/protocol/websocket/WebSocketChannelParent.h new file mode 100644 index 0000000000..534a737dcb --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketChannelParent.h @@ -0,0 +1,70 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et tw=80 : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_net_WebSocketChannelParent_h +#define mozilla_net_WebSocketChannelParent_h + +#include "mozilla/net/PWebSocketParent.h" +#include "mozilla/net/NeckoParent.h" +#include "nsIInterfaceRequestor.h" +#include "nsIWebSocketListener.h" +#include "nsIWebSocketChannel.h" +#include "nsILoadContext.h" +#include "nsCOMPtr.h" +#include "nsString.h" + +class nsIAuthPromptProvider; + +namespace mozilla { +namespace net { + +class WebSocketChannelParent : public PWebSocketParent, + public nsIWebSocketListener, + public nsIInterfaceRequestor { + friend class PWebSocketParent; + + ~WebSocketChannelParent() = default; + + public: + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSIWEBSOCKETLISTENER + NS_DECL_NSIINTERFACEREQUESTOR + + WebSocketChannelParent(nsIAuthPromptProvider* aAuthProvider, + nsILoadContext* aLoadContext, + PBOverrideStatus aOverrideStatus, uint32_t aSerial); + + private: + mozilla::ipc::IPCResult RecvAsyncOpen( + nsIURI* aURI, const nsCString& aOrigin, + const OriginAttributes& aOriginAttributes, const uint64_t& aInnerWindowID, + const nsCString& aProtocol, const bool& aSecure, + const uint32_t& aPingInterval, const bool& aClientSetPingInterval, + const uint32_t& aPingTimeout, const bool& aClientSetPingTimeout, + const LoadInfoArgs& aLoadInfoArgs, + const Maybe<PTransportProviderParent*>& aTransportProvider, + const nsCString& aNegotiatedExtensions); + mozilla::ipc::IPCResult RecvClose(const uint16_t& code, + const nsCString& reason); + mozilla::ipc::IPCResult RecvSendMsg(const nsCString& aMsg); + mozilla::ipc::IPCResult RecvSendBinaryMsg(const nsCString& aMsg); + mozilla::ipc::IPCResult RecvSendBinaryStream(const IPCStream& aStream, + const uint32_t& aLength); + mozilla::ipc::IPCResult RecvDeleteSelf(); + + void ActorDestroy(ActorDestroyReason why) override; + + nsCOMPtr<nsIAuthPromptProvider> mAuthProvider; + nsCOMPtr<nsIWebSocketChannel> mChannel; + nsCOMPtr<nsILoadContext> mLoadContext; + + uint32_t mSerial; +}; + +} // namespace net +} // namespace mozilla + +#endif // mozilla_net_WebSocketChannelParent_h diff --git a/netwerk/protocol/websocket/WebSocketConnection.cpp b/netwerk/protocol/websocket/WebSocketConnection.cpp new file mode 100644 index 0000000000..b57f4fc115 --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketConnection.cpp @@ -0,0 +1,261 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et tw=80 : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "nsIInterfaceRequestor.h" +#include "WebSocketConnection.h" + +#include "WebSocketLog.h" +#include "mozilla/net/WebSocketConnectionListener.h" +#include "nsIOService.h" +#include "nsITLSSocketControl.h" +#include "nsISocketTransport.h" +#include "nsITransportSecurityInfo.h" +#include "nsSocketTransportService2.h" + +namespace mozilla::net { + +NS_IMPL_ISUPPORTS(WebSocketConnection, nsIInputStreamCallback, + nsIOutputStreamCallback) + +WebSocketConnection::WebSocketConnection(nsISocketTransport* aTransport, + nsIAsyncInputStream* aInputStream, + nsIAsyncOutputStream* aOutputStream) + : mTransport(aTransport), + mSocketIn(aInputStream), + mSocketOut(aOutputStream) { + LOG(("WebSocketConnection ctor %p\n", this)); +} + +WebSocketConnection::~WebSocketConnection() { + LOG(("WebSocketConnection dtor %p\n", this)); +} + +nsresult WebSocketConnection::Init(WebSocketConnectionListener* aListener) { + NS_ENSURE_ARG_POINTER(aListener); + + mListener = aListener; + nsresult rv; + mSocketThread = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv); + if (NS_FAILED(rv)) { + return rv; + } + + if (!mTransport) { + return NS_ERROR_FAILURE; + } + + if (XRE_IsParentProcess()) { + nsCOMPtr<nsIInterfaceRequestor> callbacks = do_QueryInterface(mListener); + mTransport->SetSecurityCallbacks(callbacks); + } else { + // NOTE: we don't use security callbacks in socket process. + mTransport->SetSecurityCallbacks(nullptr); + } + return mTransport->SetEventSink(nullptr, nullptr); +} + +void WebSocketConnection::GetIoTarget(nsIEventTarget** aTarget) { + nsCOMPtr<nsIEventTarget> target = mSocketThread; + return target.forget(aTarget); +} + +void WebSocketConnection::Close() { + LOG(("WebSocketConnection::Close %p\n", this)); + MOZ_ASSERT(OnSocketThread()); + + if (mTransport) { + mTransport->SetSecurityCallbacks(nullptr); + mTransport->SetEventSink(nullptr, nullptr); + mTransport->Close(NS_BASE_STREAM_CLOSED); + mTransport = nullptr; + } + + if (mSocketIn) { + if (mStartReadingCalled) { + mSocketIn->AsyncWait(nullptr, 0, 0, nullptr); + } + mSocketIn = nullptr; + } + + if (mSocketOut) { + mSocketOut->AsyncWait(nullptr, 0, 0, nullptr); + mSocketOut = nullptr; + } +} + +nsresult WebSocketConnection::WriteOutputData(nsTArray<uint8_t>&& aData) { + MOZ_ASSERT(OnSocketThread()); + + if (!mSocketOut) { + return NS_ERROR_NOT_AVAILABLE; + } + + mOutputQueue.emplace_back(std::move(aData)); + return OnOutputStreamReady(mSocketOut); +} + +nsresult WebSocketConnection::WriteOutputData(const uint8_t* aHdrBuf, + uint32_t aHdrBufLength, + const uint8_t* aPayloadBuf, + uint32_t aPayloadBufLength) { + MOZ_ASSERT_UNREACHABLE("Should not be called"); + return NS_ERROR_NOT_IMPLEMENTED; +} + +nsresult WebSocketConnection::StartReading() { + MOZ_ASSERT(OnSocketThread()); + + if (!mSocketIn) { + return NS_ERROR_NOT_AVAILABLE; + } + + MOZ_ASSERT(!mStartReadingCalled, "StartReading twice"); + mStartReadingCalled = true; + return mSocketIn->AsyncWait(this, 0, 0, mSocketThread); +} + +void WebSocketConnection::DrainSocketData() { + MOZ_ASSERT(OnSocketThread()); + + if (!mSocketIn || !mListener) { + return; + } + + // If we leave any data unconsumed (including the tcp fin) a RST will be + // generated The right thing to do here is shutdown(SHUT_WR) and then wait a + // little while to see if any data comes in.. but there is no reason to delay + // things for that when the websocket handshake is supposed to guarantee a + // quiet connection except for that fin. + char buffer[512]; + uint32_t count = 0; + uint32_t total = 0; + nsresult rv; + do { + total += count; + rv = mSocketIn->Read(buffer, 512, &count); + if (rv != NS_BASE_STREAM_WOULD_BLOCK && (NS_FAILED(rv) || count == 0)) { + mListener->OnTCPClosed(); + } + } while (NS_SUCCEEDED(rv) && count > 0 && total < 32000); +} + +nsresult WebSocketConnection::GetSecurityInfo( + nsITransportSecurityInfo** aSecurityInfo) { + LOG(("WebSocketConnection::GetSecurityInfo() %p\n", this)); + MOZ_ASSERT(OnSocketThread()); + *aSecurityInfo = nullptr; + + if (mTransport) { + nsCOMPtr<nsITLSSocketControl> tlsSocketControl; + nsresult rv = + mTransport->GetTlsSocketControl(getter_AddRefs(tlsSocketControl)); + if (NS_FAILED(rv)) { + return rv; + } + nsCOMPtr<nsITransportSecurityInfo> securityInfo( + do_QueryInterface(tlsSocketControl)); + if (securityInfo) { + securityInfo.forget(aSecurityInfo); + } + } + return NS_OK; +} + +NS_IMETHODIMP +WebSocketConnection::OnInputStreamReady(nsIAsyncInputStream* aStream) { + LOG(("WebSocketConnection::OnInputStreamReady() %p\n", this)); + MOZ_ASSERT(OnSocketThread()); + MOZ_ASSERT(mListener); + + // did we we clean up the socket after scheduling InputReady? + if (!mSocketIn) { + return NS_OK; + } + + // this is after the http upgrade - so we are speaking websockets + uint8_t buffer[2048]; + uint32_t count; + nsresult rv; + + do { + rv = mSocketIn->Read((char*)buffer, 2048, &count); + LOG(("WebSocketConnection::OnInputStreamReady: read %u rv %" PRIx32 "\n", + count, static_cast<uint32_t>(rv))); + + if (rv == NS_BASE_STREAM_WOULD_BLOCK) { + mSocketIn->AsyncWait(this, 0, 0, mSocketThread); + return NS_OK; + } + + if (NS_FAILED(rv)) { + mListener->OnError(rv); + return rv; + } + + if (count == 0) { + mListener->OnError(NS_BASE_STREAM_CLOSED); + return NS_OK; + } + + rv = mListener->OnDataReceived(buffer, count); + if (NS_FAILED(rv)) { + mListener->OnError(rv); + return rv; + } + } while (NS_SUCCEEDED(rv) && mSocketIn && mListener); + + return NS_OK; +} + +NS_IMETHODIMP +WebSocketConnection::OnOutputStreamReady(nsIAsyncOutputStream* aStream) { + LOG(("WebSocketConnection::OnOutputStreamReady() %p\n", this)); + MOZ_ASSERT(OnSocketThread()); + MOZ_ASSERT(mListener); + + if (!mSocketOut) { + return NS_OK; + } + + while (!mOutputQueue.empty()) { + const OutputData& data = mOutputQueue.front(); + + char* buffer = reinterpret_cast<char*>( + const_cast<uint8_t*>(data.GetData().Elements())) + + mWriteOffset; + uint32_t toWrite = data.GetData().Length() - mWriteOffset; + + uint32_t wrote = 0; + nsresult rv = mSocketOut->Write(buffer, toWrite, &wrote); + LOG(("WebSocketConnection::OnOutputStreamReady: write %u rv %" PRIx32, + wrote, static_cast<uint32_t>(rv))); + + if (rv == NS_BASE_STREAM_WOULD_BLOCK) { + mSocketOut->AsyncWait(this, 0, 0, mSocketThread); + return rv; + } + + if (NS_FAILED(rv)) { + LOG(("WebSocketConnection::OnOutputStreamReady %p failed %u\n", this, + static_cast<uint32_t>(rv))); + mListener->OnError(rv); + return NS_OK; + } + + mWriteOffset += wrote; + + if (toWrite == wrote) { + mWriteOffset = 0; + mOutputQueue.pop_front(); + } else { + mSocketOut->AsyncWait(this, 0, 0, mSocketThread); + } + } + + return NS_OK; +} + +} // namespace mozilla::net diff --git a/netwerk/protocol/websocket/WebSocketConnection.h b/netwerk/protocol/websocket/WebSocketConnection.h new file mode 100644 index 0000000000..4e6a53b013 --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketConnection.h @@ -0,0 +1,79 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et tw=80 : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_net_WebSocketConnection_h +#define mozilla_net_WebSocketConnection_h + +#include <list> + +#include "nsIStreamListener.h" +#include "nsIAsyncInputStream.h" +#include "nsIAsyncOutputStream.h" +#include "mozilla/net/WebSocketConnectionBase.h" +#include "nsTArray.h" +#include "nsISocketTransport.h" + +class nsISocketTransport; + +namespace mozilla { +namespace net { + +class WebSocketConnectionListener; + +class WebSocketConnection : public nsIInputStreamCallback, + public nsIOutputStreamCallback, + public WebSocketConnectionBase { + public: + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSIINPUTSTREAMCALLBACK + NS_DECL_NSIOUTPUTSTREAMCALLBACK + + explicit WebSocketConnection(nsISocketTransport* aTransport, + nsIAsyncInputStream* aInputStream, + nsIAsyncOutputStream* aOutputStream); + + nsresult Init(WebSocketConnectionListener* aListener) override; + void GetIoTarget(nsIEventTarget** aTarget) override; + void Close() override; + nsresult WriteOutputData(const uint8_t* aHdrBuf, uint32_t aHdrBufLength, + const uint8_t* aPayloadBuf, + uint32_t aPayloadBufLength) override; + nsresult WriteOutputData(nsTArray<uint8_t>&& aData); + nsresult StartReading() override; + void DrainSocketData() override; + nsresult GetSecurityInfo(nsITransportSecurityInfo** aSecurityInfo) override; + + private: + virtual ~WebSocketConnection(); + + class OutputData { + public: + explicit OutputData(nsTArray<uint8_t>&& aData) : mData(std::move(aData)) { + MOZ_COUNT_CTOR(OutputData); + } + + ~OutputData() { MOZ_COUNT_DTOR(OutputData); } + + const nsTArray<uint8_t>& GetData() const { return mData; } + + private: + nsTArray<uint8_t> mData; + }; + + RefPtr<WebSocketConnectionListener> mListener; + nsCOMPtr<nsISocketTransport> mTransport; + nsCOMPtr<nsIAsyncInputStream> mSocketIn; + nsCOMPtr<nsIAsyncOutputStream> mSocketOut; + nsCOMPtr<nsIEventTarget> mSocketThread; + size_t mWriteOffset{0}; + std::list<OutputData> mOutputQueue; + bool mStartReadingCalled{false}; +}; + +} // namespace net +} // namespace mozilla + +#endif // mozilla_net_WebSocketConnection_h diff --git a/netwerk/protocol/websocket/WebSocketConnectionBase.h b/netwerk/protocol/websocket/WebSocketConnectionBase.h new file mode 100644 index 0000000000..a462b8a346 --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketConnectionBase.h @@ -0,0 +1,86 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et ft=cpp : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_net_WebSocketConnectionBase_h +#define mozilla_net_WebSocketConnectionBase_h + +// Architecture view: +// Parent Process Socket Process +// ┌─────────────────┐ IPC ┌────────────────────────┐ +// │ WebSocketChannel│ ┌─────►WebSocketConnectionChild│ +// └─────────┬───────┘ │ └─────────┬──────────────┘ +// │ │ │ +// ┌──────────▼───────────────┐ │ ┌─────────▼─────────┐ +// │ WebSocketConnectionParent├──┘ │WebSocketConnection│ +// └──────────────────────────┘ └─────────┬─────────┘ +// │ +// ┌─────────▼─────────┐ +// │ nsSocketTransport │ +// └───────────────────┘ +// The main reason that we need WebSocketConnectionBase is that we need to +// encapsulate nsSockTransport, nsSocketOutoutStream, and nsSocketInputStream. +// These three objects only live in socket process, so we provide the necessary +// interfaces in WebSocketConnectionBase and WebSocketConnectionListener for +// reading/writing the real socket. +// +// The output path when WebSocketChannel wants to write data to socket: +// - WebSocketConnectionParent::WriteOutputData +// - WebSocketConnectionParent::SendWriteOutputData +// - WebSocketConnectionChild::RecvWriteOutputData +// - WebSocketConnection::WriteOutputData +// - WebSocketConnection::OnOutputStreamReady (writes data to the real socket) +// +// The input path when data is able to read from the socket +// - WebSocketConnection::OnInputStreamReady +// - WebSocketConnectionChild::OnDataReceived +// - WebSocketConnectionChild::SendOnDataReceived +// - WebSocketConnectionParent::RecvOnDataReceived +// - WebSocketChannel::OnDataReceived +// +// The path that WebSocketConnection is constructed. +// - nsHttpChannel::OnStopRequest +// - HttpConnectionMgrShell::CompleteUpgrade +// - HttpConnectionMgrParent::CompleteUpgrade (we store the +// nsIHttpUpgradeListener in a table and send an id to socket process) +// - HttpConnectionMgrParent::SendStartWebSocketConnection +// - HttpConnectionMgrChild::RecvStartWebSocketConnection (the listener id is +// saved in WebSocketConnectionChild and will be used when the socket +// transport is available) +// - WebSocketConnectionChild::Init (an IPC channel between socket thread in +// socket process and background thread in parent process is created) +// - nsHttpConnectionMgr::CompleteUpgrade +// - WebSocketConnectionChild::OnTransportAvailable (WebSocketConnection is +// created) +// - WebSocketConnectionChild::SendOnTransportAvailable +// - WebSocketConnectionParent::RecvOnTransportAvailable +// - WebSocketChannel::OnWebSocketConnectionAvailable + +class nsITransportSecurityInfo; + +namespace mozilla { +namespace net { + +class WebSocketConnectionListener; + +class WebSocketConnectionBase : public nsISupports { + public: + virtual nsresult Init(WebSocketConnectionListener* aListener) = 0; + virtual void GetIoTarget(nsIEventTarget** aTarget) = 0; + virtual void Close() = 0; + virtual nsresult WriteOutputData(const uint8_t* aHdrBuf, + uint32_t aHdrBufLength, + const uint8_t* aPayloadBuf, + uint32_t aPayloadBufLength) = 0; + virtual nsresult StartReading() = 0; + virtual void DrainSocketData() = 0; + virtual nsresult GetSecurityInfo( + nsITransportSecurityInfo** aSecurityInfo) = 0; +}; + +} // namespace net +} // namespace mozilla + +#endif // mozilla_net_WebSocketConnectionBase_h diff --git a/netwerk/protocol/websocket/WebSocketConnectionChild.cpp b/netwerk/protocol/websocket/WebSocketConnectionChild.cpp new file mode 100644 index 0000000000..045c3763b7 --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketConnectionChild.cpp @@ -0,0 +1,221 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et ft=cpp : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "WebSocketLog.h" +#include "WebSocketConnectionChild.h" + +#include "WebSocketConnection.h" +#include "mozilla/ipc/Endpoint.h" +#include "mozilla/net/SocketProcessBackgroundChild.h" +#include "nsISerializable.h" +#include "nsITLSSocketControl.h" +#include "nsITransportSecurityInfo.h" +#include "nsNetCID.h" +#include "nsSerializationHelper.h" +#include "nsSocketTransportService2.h" +#include "nsThreadUtils.h" + +namespace mozilla { +namespace net { + +NS_IMPL_ISUPPORTS(WebSocketConnectionChild, nsIHttpUpgradeListener) + +WebSocketConnectionChild::WebSocketConnectionChild() { + LOG(("WebSocketConnectionChild ctor %p\n", this)); +} + +WebSocketConnectionChild::~WebSocketConnectionChild() { + LOG(("WebSocketConnectionChild dtor %p\n", this)); +} + +void WebSocketConnectionChild::Init(uint32_t aListenerId) { + MOZ_ASSERT(NS_IsMainThread()); + + nsresult rv; + mSocketThread = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + if (!mSocketThread) { + return; + } + + ipc::Endpoint<PWebSocketConnectionParent> parentEndpoint; + ipc::Endpoint<PWebSocketConnectionChild> childEndpoint; + PWebSocketConnection::CreateEndpoints(&parentEndpoint, &childEndpoint); + + if (NS_FAILED(SocketProcessBackgroundChild::WithActor( + "SendInitWebSocketConnection", + [aListenerId, endpoint = std::move(parentEndpoint)]( + SocketProcessBackgroundChild* aActor) mutable { + Unused << aActor->SendInitWebSocketConnection(std::move(endpoint), + aListenerId); + }))) { + return; + } + + mSocketThread->Dispatch(NS_NewRunnableFunction( + "BindWebSocketConnectionChild", + [self = RefPtr{this}, endpoint = std::move(childEndpoint)]() mutable { + endpoint.Bind(self); + })); +} + +// nsIHttpUpgradeListener +NS_IMETHODIMP +WebSocketConnectionChild::OnTransportAvailable( + nsISocketTransport* aTransport, nsIAsyncInputStream* aSocketIn, + nsIAsyncOutputStream* aSocketOut) { + LOG(("WebSocketConnectionChild::OnTransportAvailable %p\n", this)); + if (!OnSocketThread()) { + nsCOMPtr<nsISocketTransport> transport = aTransport; + nsCOMPtr<nsIAsyncInputStream> inputStream = aSocketIn; + nsCOMPtr<nsIAsyncOutputStream> outputStream = aSocketOut; + RefPtr<WebSocketConnectionChild> self = this; + return mSocketThread->Dispatch( + NS_NewRunnableFunction("WebSocketConnectionChild::OnTransportAvailable", + [self, transport, inputStream, outputStream]() { + Unused << self->OnTransportAvailable( + transport, inputStream, outputStream); + }), + NS_DISPATCH_NORMAL); + } + + LOG(("WebSocketConnectionChild::OnTransportAvailable %p\n", this)); + MOZ_ASSERT(OnSocketThread()); + MOZ_ASSERT(!mConnection, "already called"); + MOZ_ASSERT(aTransport); + + if (!CanSend()) { + return NS_ERROR_NOT_AVAILABLE; + } + + nsCOMPtr<nsITLSSocketControl> tlsSocketControl; + aTransport->GetTlsSocketControl(getter_AddRefs(tlsSocketControl)); + nsCOMPtr<nsITransportSecurityInfo> securityInfo( + do_QueryInterface(tlsSocketControl)); + + RefPtr<WebSocketConnection> connection = + new WebSocketConnection(aTransport, aSocketIn, aSocketOut); + nsresult rv = connection->Init(this); + if (NS_FAILED(rv)) { + Unused << OnUpgradeFailed(rv); + return NS_OK; + } + + mConnection = std::move(connection); + + Unused << SendOnTransportAvailable(securityInfo); + return NS_OK; +} + +NS_IMETHODIMP +WebSocketConnectionChild::OnUpgradeFailed(nsresult aReason) { + if (!OnSocketThread()) { + return mSocketThread->Dispatch(NewRunnableMethod<nsresult>( + "WebSocketConnectionChild::OnUpgradeFailed", this, + &WebSocketConnectionChild::OnUpgradeFailed, aReason)); + } + + if (CanSend()) { + Unused << SendOnUpgradeFailed(aReason); + } + return NS_OK; +} + +NS_IMETHODIMP +WebSocketConnectionChild::OnWebSocketConnectionAvailable( + WebSocketConnectionBase* aConnection) { + return NS_ERROR_NOT_IMPLEMENTED; +} + +mozilla::ipc::IPCResult WebSocketConnectionChild::RecvWriteOutputData( + nsTArray<uint8_t>&& aData) { + LOG(("WebSocketConnectionChild::RecvWriteOutputData %p\n", this)); + + if (!mConnection) { + OnError(NS_ERROR_NOT_AVAILABLE); + return IPC_OK(); + } + + mConnection->WriteOutputData(std::move(aData)); + return IPC_OK(); +} + +mozilla::ipc::IPCResult WebSocketConnectionChild::RecvStartReading() { + LOG(("WebSocketConnectionChild::RecvStartReading %p\n", this)); + + if (!mConnection) { + OnError(NS_ERROR_NOT_AVAILABLE); + return IPC_OK(); + } + + mConnection->StartReading(); + return IPC_OK(); +} + +mozilla::ipc::IPCResult WebSocketConnectionChild::RecvDrainSocketData() { + LOG(("WebSocketConnectionChild::RecvDrainSocketData %p\n", this)); + + if (!mConnection) { + OnError(NS_ERROR_NOT_AVAILABLE); + return IPC_OK(); + } + + mConnection->DrainSocketData(); + return IPC_OK(); +} + +mozilla::ipc::IPCResult WebSocketConnectionChild::Recv__delete__() { + LOG(("WebSocketConnectionChild::Recv__delete__ %p\n", this)); + + if (!mConnection) { + OnError(NS_ERROR_NOT_AVAILABLE); + return IPC_OK(); + } + + mConnection->Close(); + mConnection = nullptr; + return IPC_OK(); +} + +void WebSocketConnectionChild::OnError(nsresult aStatus) { + LOG(("WebSocketConnectionChild::OnError %p\n", this)); + + if (CanSend()) { + Unused << SendOnError(aStatus); + } +} + +void WebSocketConnectionChild::OnTCPClosed() { + LOG(("WebSocketConnectionChild::OnTCPClosed %p\n", this)); + + if (CanSend()) { + Unused << SendOnTCPClosed(); + } +} + +nsresult WebSocketConnectionChild::OnDataReceived(uint8_t* aData, + uint32_t aCount) { + LOG(("WebSocketConnectionChild::OnDataReceived %p\n", this)); + + if (CanSend()) { + nsTArray<uint8_t> data; + data.AppendElements(aData, aCount); + Unused << SendOnDataReceived(data); + } + + return NS_OK; +} + +void WebSocketConnectionChild::ActorDestroy(ActorDestroyReason aWhy) { + LOG(("WebSocketConnectionChild::ActorDestroy %p\n", this)); + if (mConnection) { + mConnection->Close(); + mConnection = nullptr; + } +} + +} // namespace net +} // namespace mozilla diff --git a/netwerk/protocol/websocket/WebSocketConnectionChild.h b/netwerk/protocol/websocket/WebSocketConnectionChild.h new file mode 100644 index 0000000000..a0f16b7fa8 --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketConnectionChild.h @@ -0,0 +1,54 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et ft=cpp : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_net_WebSocketConnectionChild_h +#define mozilla_net_WebSocketConnectionChild_h + +#include "mozilla/net/PWebSocketConnectionChild.h" +#include "mozilla/net/WebSocketConnectionListener.h" +#include "nsIHttpChannelInternal.h" + +namespace mozilla { +namespace net { + +class WebSocketConnection; + +// WebSocketConnectionChild only lives in socket process and uses +// WebSocketConnection to send/read data from socket. Only IPDL holds a strong +// reference to WebSocketConnectionChild, so the life time of +// WebSocketConnectionChild is bound to the IPC actor. + +class WebSocketConnectionChild final : public PWebSocketConnectionChild, + public nsIHttpUpgradeListener, + public WebSocketConnectionListener { + public: + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSIHTTPUPGRADELISTENER + + WebSocketConnectionChild(); + void Init(uint32_t aListenerId); + mozilla::ipc::IPCResult RecvWriteOutputData(nsTArray<uint8_t>&& aData); + mozilla::ipc::IPCResult RecvStartReading(); + mozilla::ipc::IPCResult RecvDrainSocketData(); + mozilla::ipc::IPCResult Recv__delete__() override; + + void ActorDestroy(ActorDestroyReason aWhy) override; + + void OnError(nsresult aStatus) override; + void OnTCPClosed() override; + nsresult OnDataReceived(uint8_t* aData, uint32_t aCount) override; + + private: + virtual ~WebSocketConnectionChild(); + + RefPtr<WebSocketConnection> mConnection; + nsCOMPtr<nsIEventTarget> mSocketThread; +}; + +} // namespace net +} // namespace mozilla + +#endif // mozilla_net_WebSocketConnectionChild_h diff --git a/netwerk/protocol/websocket/WebSocketConnectionListener.h b/netwerk/protocol/websocket/WebSocketConnectionListener.h new file mode 100644 index 0000000000..31663f17f6 --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketConnectionListener.h @@ -0,0 +1,23 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et tw=80 : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_net_WebSocketConnectionListener_h +#define mozilla_net_WebSocketConnectionListener_h + +namespace mozilla { +namespace net { + +class WebSocketConnectionListener : public nsISupports { + public: + virtual void OnError(nsresult aStatus) = 0; + virtual void OnTCPClosed() = 0; + virtual nsresult OnDataReceived(uint8_t* aData, uint32_t aCount) = 0; +}; + +} // namespace net +} // namespace mozilla + +#endif // mozilla_net_WebSocketConnectionListener_h diff --git a/netwerk/protocol/websocket/WebSocketConnectionParent.cpp b/netwerk/protocol/websocket/WebSocketConnectionParent.cpp new file mode 100644 index 0000000000..3fdf85337e --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketConnectionParent.cpp @@ -0,0 +1,210 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et ft=cpp : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "WebSocketLog.h" +#include "WebSocketConnectionParent.h" + +#include "nsIHttpChannelInternal.h" +#include "nsITransportSecurityInfo.h" +#include "nsSerializationHelper.h" +#include "nsThreadUtils.h" +#include "WebSocketConnectionListener.h" + +namespace mozilla { +namespace net { + +NS_IMPL_ISUPPORTS0(WebSocketConnectionParent) + +WebSocketConnectionParent::WebSocketConnectionParent( + nsIHttpUpgradeListener* aListener) + : mUpgradeListener(aListener), + mBackgroundThread(GetCurrentSerialEventTarget()) { + LOG(("WebSocketConnectionParent ctor %p\n", this)); + MOZ_ASSERT(mUpgradeListener); +} + +WebSocketConnectionParent::~WebSocketConnectionParent() { + LOG(("WebSocketConnectionParent dtor %p\n", this)); +} + +mozilla::ipc::IPCResult WebSocketConnectionParent::RecvOnTransportAvailable( + nsITransportSecurityInfo* aSecurityInfo) { + LOG(("WebSocketConnectionParent::RecvOnTransportAvailable %p\n", this)); + MOZ_ASSERT(mBackgroundThread->IsOnCurrentThread()); + + if (aSecurityInfo) { + MutexAutoLock lock(mMutex); + mSecurityInfo = aSecurityInfo; + } + + if (mUpgradeListener) { + Unused << mUpgradeListener->OnWebSocketConnectionAvailable(this); + mUpgradeListener = nullptr; + } + return IPC_OK(); +} + +mozilla::ipc::IPCResult WebSocketConnectionParent::RecvOnError( + const nsresult& aStatus) { + LOG(("WebSocketConnectionParent::RecvOnError %p\n", this)); + MOZ_ASSERT(mBackgroundThread->IsOnCurrentThread()); + + MOZ_ASSERT(mListener); + mListener->OnError(aStatus); + return IPC_OK(); +} + +mozilla::ipc::IPCResult WebSocketConnectionParent::RecvOnUpgradeFailed( + const nsresult& aReason) { + MOZ_ASSERT(mBackgroundThread->IsOnCurrentThread()); + + if (mUpgradeListener) { + Unused << mUpgradeListener->OnUpgradeFailed(aReason); + mUpgradeListener = nullptr; + } + return IPC_OK(); +} + +mozilla::ipc::IPCResult WebSocketConnectionParent::RecvOnTCPClosed() { + LOG(("WebSocketConnectionParent::RecvOnTCPClosed %p\n", this)); + MOZ_ASSERT(mBackgroundThread->IsOnCurrentThread()); + + MOZ_ASSERT(mListener); + mListener->OnTCPClosed(); + return IPC_OK(); +} + +mozilla::ipc::IPCResult WebSocketConnectionParent::RecvOnDataReceived( + nsTArray<uint8_t>&& aData) { + LOG(("WebSocketConnectionParent::RecvOnDataReceived %p\n", this)); + MOZ_ASSERT(mBackgroundThread->IsOnCurrentThread()); + + MOZ_ASSERT(mListener); + uint8_t* buffer = const_cast<uint8_t*>(aData.Elements()); + nsresult rv = mListener->OnDataReceived(buffer, aData.Length()); + if (NS_FAILED(rv)) { + mListener->OnError(rv); + } + + return IPC_OK(); +} + +void WebSocketConnectionParent::ActorDestroy(ActorDestroyReason aWhy) { + LOG(("WebSocketConnectionParent::ActorDestroy %p aWhy=%d\n", this, aWhy)); + if (!mClosed) { + // Treat this as an error when IPC is closed before + // WebSocketConnectionParent::Close() is called. + RefPtr<WebSocketConnectionListener> listener; + listener.swap(mListener); + if (listener) { + listener->OnError(NS_ERROR_FAILURE); + } + } + mBackgroundThread->Dispatch(NS_NewRunnableFunction( + "WebSocketConnectionParent::DefereredDestroy", [self = RefPtr{this}]() { + LOG(("WebSocketConnectionParent::DefereredDestroy")); + })); +}; + +nsresult WebSocketConnectionParent::Init( + WebSocketConnectionListener* aListener) { + NS_ENSURE_ARG_POINTER(aListener); + + mListener = aListener; + return NS_OK; +} + +void WebSocketConnectionParent::GetIoTarget(nsIEventTarget** aTarget) { + nsCOMPtr<nsIEventTarget> target = mBackgroundThread; + return target.forget(aTarget); +} + +void WebSocketConnectionParent::Close() { + LOG(("WebSocketConnectionParent::Close %p\n", this)); + + mClosed = true; + + auto task = [self = RefPtr{this}]() { + self->PWebSocketConnectionParent::Close(); + }; + + if (mBackgroundThread->IsOnCurrentThread()) { + task(); + } else { + mBackgroundThread->Dispatch(NS_NewRunnableFunction( + "WebSocketConnectionParent::Close", std::move(task))); + } +} + +nsresult WebSocketConnectionParent::WriteOutputData( + const uint8_t* aHdrBuf, uint32_t aHdrBufLength, const uint8_t* aPayloadBuf, + uint32_t aPayloadBufLength) { + LOG(("WebSocketConnectionParent::WriteOutputData %p", this)); + MOZ_ASSERT(mBackgroundThread->IsOnCurrentThread()); + + if (!CanSend()) { + return NS_ERROR_NOT_AVAILABLE; + } + + nsTArray<uint8_t> data; + data.AppendElements(aHdrBuf, aHdrBufLength); + data.AppendElements(aPayloadBuf, aPayloadBufLength); + return SendWriteOutputData(data) ? NS_OK : NS_ERROR_FAILURE; +} + +nsresult WebSocketConnectionParent::StartReading() { + LOG(("WebSocketConnectionParent::StartReading %p\n", this)); + + RefPtr<WebSocketConnectionParent> self = this; + auto task = [self{std::move(self)}]() { + if (!self->CanSend()) { + if (self->mListener) { + self->mListener->OnError(NS_ERROR_NOT_AVAILABLE); + } + return; + } + + Unused << self->SendStartReading(); + }; + + if (mBackgroundThread->IsOnCurrentThread()) { + task(); + } else { + mBackgroundThread->Dispatch(NS_NewRunnableFunction( + "WebSocketConnectionParent::SendStartReading", std::move(task))); + } + + return NS_OK; +} + +void WebSocketConnectionParent::DrainSocketData() { + LOG(("WebSocketConnectionParent::DrainSocketData %p\n", this)); + MOZ_ASSERT(mBackgroundThread->IsOnCurrentThread()); + + if (!CanSend()) { + MOZ_ASSERT(mListener); + mListener->OnError(NS_ERROR_NOT_AVAILABLE); + + return; + } + + Unused << SendDrainSocketData(); +} + +nsresult WebSocketConnectionParent::GetSecurityInfo( + nsITransportSecurityInfo** aSecurityInfo) { + LOG(("WebSocketConnectionParent::GetSecurityInfo() %p\n", this)); + MOZ_ASSERT(NS_IsMainThread()); + + NS_ENSURE_ARG_POINTER(aSecurityInfo); + + MutexAutoLock lock(mMutex); + NS_IF_ADDREF(*aSecurityInfo = mSecurityInfo); + return NS_OK; +} + +} // namespace net +} // namespace mozilla diff --git a/netwerk/protocol/websocket/WebSocketConnectionParent.h b/netwerk/protocol/websocket/WebSocketConnectionParent.h new file mode 100644 index 0000000000..45b031e9cb --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketConnectionParent.h @@ -0,0 +1,68 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et ft=cpp : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_net_WebSocketConnectionParent_h +#define mozilla_net_WebSocketConnectionParent_h + +#include "mozilla/net/PWebSocketConnectionParent.h" +#include "mozilla/net/WebSocketConnectionBase.h" +#include "nsISupportsImpl.h" +#include "WebSocketConnectionBase.h" + +class nsIHttpUpgradeListener; + +namespace mozilla { +namespace net { + +class WebSocketConnectionListener; + +// WebSocketConnectionParent implements WebSocketConnectionBase and provides +// interface for WebSocketChannel to send/receive data. The ownership model for +// WebSocketConnectionParent is that IPDL and WebSocketChannel hold strong +// reference of this object. When Close() is called, a __delete__ message will +// be sent and the IPC actor will be deallocated as well. + +class WebSocketConnectionParent final : public PWebSocketConnectionParent, + public WebSocketConnectionBase { + public: + NS_DECL_THREADSAFE_ISUPPORTS + + explicit WebSocketConnectionParent(nsIHttpUpgradeListener* aListener); + + mozilla::ipc::IPCResult RecvOnTransportAvailable( + nsITransportSecurityInfo* aSecurityInfo); + mozilla::ipc::IPCResult RecvOnError(const nsresult& aStatus); + mozilla::ipc::IPCResult RecvOnTCPClosed(); + mozilla::ipc::IPCResult RecvOnDataReceived(nsTArray<uint8_t>&& aData); + mozilla::ipc::IPCResult RecvOnUpgradeFailed(const nsresult& aReason); + + void ActorDestroy(ActorDestroyReason aWhy) override; + + nsresult Init(WebSocketConnectionListener* aListener) override; + void GetIoTarget(nsIEventTarget** aTarget) override; + void Close() override; + nsresult WriteOutputData(const uint8_t* aHdrBuf, uint32_t aHdrBufLength, + const uint8_t* aPayloadBuf, + uint32_t aPayloadBufLength) override; + nsresult StartReading() override; + void DrainSocketData() override; + nsresult GetSecurityInfo(nsITransportSecurityInfo** aSecurityInfo) override; + + private: + virtual ~WebSocketConnectionParent(); + + nsCOMPtr<nsIHttpUpgradeListener> mUpgradeListener; + RefPtr<WebSocketConnectionListener> mListener; + nsCOMPtr<nsISerialEventTarget> mBackgroundThread; + nsCOMPtr<nsITransportSecurityInfo> mSecurityInfo; + Atomic<bool> mClosed{false}; + Mutex mMutex MOZ_UNANNOTATED{"WebSocketConnectionParent::mMutex"}; +}; + +} // namespace net +} // namespace mozilla + +#endif // mozilla_net_WebSocketConnectionParent_h diff --git a/netwerk/protocol/websocket/WebSocketEventListenerChild.cpp b/netwerk/protocol/websocket/WebSocketEventListenerChild.cpp new file mode 100644 index 0000000000..2232475382 --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketEventListenerChild.cpp @@ -0,0 +1,109 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "WebSocketEventListenerChild.h" + +#include "WebSocketEventService.h" +#include "WebSocketFrame.h" + +namespace mozilla { +namespace net { + +WebSocketEventListenerChild::WebSocketEventListenerChild( + uint64_t aInnerWindowID, nsISerialEventTarget* aTarget) + : NeckoTargetHolder(aTarget), + mService(WebSocketEventService::GetOrCreate()), + mInnerWindowID(aInnerWindowID) {} + +WebSocketEventListenerChild::~WebSocketEventListenerChild() { + MOZ_ASSERT(!mService); +} + +mozilla::ipc::IPCResult WebSocketEventListenerChild::RecvWebSocketCreated( + const uint32_t& aWebSocketSerialID, const nsString& aURI, + const nsCString& aProtocols) { + if (mService) { + nsCOMPtr<nsIEventTarget> target = GetNeckoTarget(); + mService->WebSocketCreated(aWebSocketSerialID, mInnerWindowID, aURI, + aProtocols, target); + } + + return IPC_OK(); +} + +mozilla::ipc::IPCResult WebSocketEventListenerChild::RecvWebSocketOpened( + const uint32_t& aWebSocketSerialID, const nsString& aEffectiveURI, + const nsCString& aProtocols, const nsCString& aExtensions, + const uint64_t& aHttpChannelId) { + if (mService) { + nsCOMPtr<nsIEventTarget> target = GetNeckoTarget(); + mService->WebSocketOpened(aWebSocketSerialID, mInnerWindowID, aEffectiveURI, + aProtocols, aExtensions, aHttpChannelId, target); + } + + return IPC_OK(); +} + +mozilla::ipc::IPCResult +WebSocketEventListenerChild::RecvWebSocketMessageAvailable( + const uint32_t& aWebSocketSerialID, const nsCString& aData, + const uint16_t& aMessageType) { + if (mService) { + nsCOMPtr<nsIEventTarget> target = GetNeckoTarget(); + mService->WebSocketMessageAvailable(aWebSocketSerialID, mInnerWindowID, + aData, aMessageType, target); + } + + return IPC_OK(); +} + +mozilla::ipc::IPCResult WebSocketEventListenerChild::RecvWebSocketClosed( + const uint32_t& aWebSocketSerialID, const bool& aWasClean, + const uint16_t& aCode, const nsString& aReason) { + if (mService) { + nsCOMPtr<nsIEventTarget> target = GetNeckoTarget(); + mService->WebSocketClosed(aWebSocketSerialID, mInnerWindowID, aWasClean, + aCode, aReason, target); + } + + return IPC_OK(); +} + +mozilla::ipc::IPCResult WebSocketEventListenerChild::RecvFrameReceived( + const uint32_t& aWebSocketSerialID, const WebSocketFrameData& aFrameData) { + if (mService) { + nsCOMPtr<nsIEventTarget> target = GetNeckoTarget(); + RefPtr<WebSocketFrame> frame = new WebSocketFrame(aFrameData); + mService->FrameReceived(aWebSocketSerialID, mInnerWindowID, frame.forget(), + target); + } + + return IPC_OK(); +} + +mozilla::ipc::IPCResult WebSocketEventListenerChild::RecvFrameSent( + const uint32_t& aWebSocketSerialID, const WebSocketFrameData& aFrameData) { + if (mService) { + nsCOMPtr<nsIEventTarget> target = GetNeckoTarget(); + RefPtr<WebSocketFrame> frame = new WebSocketFrame(aFrameData); + mService->FrameSent(aWebSocketSerialID, mInnerWindowID, frame.forget(), + target); + } + + return IPC_OK(); +} + +void WebSocketEventListenerChild::Close() { + mService = nullptr; + SendClose(); +} + +void WebSocketEventListenerChild::ActorDestroy(ActorDestroyReason aWhy) { + mService = nullptr; +} + +} // namespace net +} // namespace mozilla diff --git a/netwerk/protocol/websocket/WebSocketEventListenerChild.h b/netwerk/protocol/websocket/WebSocketEventListenerChild.h new file mode 100644 index 0000000000..88dddb79a5 --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketEventListenerChild.h @@ -0,0 +1,63 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_net_WebSocketEventListenerChild_h +#define mozilla_net_WebSocketEventListenerChild_h + +#include "mozilla/net/NeckoTargetHolder.h" +#include "mozilla/net/PWebSocketEventListenerChild.h" + +namespace mozilla { +namespace net { + +class WebSocketEventService; + +class WebSocketEventListenerChild final : public PWebSocketEventListenerChild, + public NeckoTargetHolder { + public: + NS_INLINE_DECL_REFCOUNTING(WebSocketEventListenerChild) + + explicit WebSocketEventListenerChild(uint64_t aInnerWindowID, + nsISerialEventTarget* aTarget); + + mozilla::ipc::IPCResult RecvWebSocketCreated( + const uint32_t& aWebSocketSerialID, const nsString& aURI, + const nsCString& aProtocols); + + mozilla::ipc::IPCResult RecvWebSocketOpened( + const uint32_t& aWebSocketSerialID, const nsString& aEffectiveURI, + const nsCString& aProtocols, const nsCString& aExtensions, + const uint64_t& aHttpChannelId); + + mozilla::ipc::IPCResult RecvWebSocketMessageAvailable( + const uint32_t& aWebSocketSerialID, const nsCString& aData, + const uint16_t& aMessageType); + + mozilla::ipc::IPCResult RecvWebSocketClosed( + const uint32_t& aWebSocketSerialID, const bool& aWasClean, + const uint16_t& aCode, const nsString& aReason); + + mozilla::ipc::IPCResult RecvFrameReceived( + const uint32_t& aWebSocketSerialID, const WebSocketFrameData& aFrameData); + + mozilla::ipc::IPCResult RecvFrameSent(const uint32_t& aWebSocketSerialID, + const WebSocketFrameData& aFrameData); + + void Close(); + + private: + ~WebSocketEventListenerChild(); + + virtual void ActorDestroy(ActorDestroyReason aWhy) override; + + RefPtr<WebSocketEventService> mService; + uint64_t mInnerWindowID; +}; + +} // namespace net +} // namespace mozilla + +#endif // mozilla_net_WebSocketEventListenerChild_h diff --git a/netwerk/protocol/websocket/WebSocketEventListenerParent.cpp b/netwerk/protocol/websocket/WebSocketEventListenerParent.cpp new file mode 100644 index 0000000000..5b0ea16838 --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketEventListenerParent.cpp @@ -0,0 +1,117 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "WebSocketEventService.h" +#include "WebSocketEventListenerParent.h" +#include "mozilla/Unused.h" +#include "WebSocketFrame.h" + +namespace mozilla { +namespace net { + +NS_INTERFACE_MAP_BEGIN(WebSocketEventListenerParent) + NS_INTERFACE_MAP_ENTRY_AMBIGUOUS(nsISupports, nsIWebSocketEventListener) + NS_INTERFACE_MAP_ENTRY(nsIWebSocketEventListener) +NS_INTERFACE_MAP_END + +NS_IMPL_ADDREF(WebSocketEventListenerParent) +NS_IMPL_RELEASE(WebSocketEventListenerParent) + +WebSocketEventListenerParent::WebSocketEventListenerParent( + uint64_t aInnerWindowID) + : mService(WebSocketEventService::GetOrCreate()), + mInnerWindowID(aInnerWindowID) { + DebugOnly<nsresult> rv = mService->AddListener(mInnerWindowID, this); + MOZ_ASSERT(NS_SUCCEEDED(rv)); +} + +WebSocketEventListenerParent::~WebSocketEventListenerParent() { + MOZ_ASSERT(!mService); +} + +mozilla::ipc::IPCResult WebSocketEventListenerParent::RecvClose() { + if (mService) { + UnregisterListener(); + Unused << Send__delete__(this); + } + + return IPC_OK(); +} + +void WebSocketEventListenerParent::ActorDestroy(ActorDestroyReason aWhy) { + UnregisterListener(); +} + +void WebSocketEventListenerParent::UnregisterListener() { + if (mService) { + DebugOnly<nsresult> rv = mService->RemoveListener(mInnerWindowID, this); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + mService = nullptr; + } +} + +NS_IMETHODIMP +WebSocketEventListenerParent::WebSocketCreated(uint32_t aWebSocketSerialID, + const nsAString& aURI, + const nsACString& aProtocols) { + Unused << SendWebSocketCreated(aWebSocketSerialID, aURI, aProtocols); + return NS_OK; +} + +NS_IMETHODIMP +WebSocketEventListenerParent::WebSocketOpened(uint32_t aWebSocketSerialID, + const nsAString& aEffectiveURI, + const nsACString& aProtocols, + const nsACString& aExtensions, + uint64_t aHttpChannelId) { + Unused << SendWebSocketOpened(aWebSocketSerialID, aEffectiveURI, aProtocols, + aExtensions, aHttpChannelId); + return NS_OK; +} + +NS_IMETHODIMP +WebSocketEventListenerParent::WebSocketClosed(uint32_t aWebSocketSerialID, + bool aWasClean, uint16_t aCode, + const nsAString& aReason) { + Unused << SendWebSocketClosed(aWebSocketSerialID, aWasClean, aCode, aReason); + return NS_OK; +} + +NS_IMETHODIMP +WebSocketEventListenerParent::WebSocketMessageAvailable( + uint32_t aWebSocketSerialID, const nsACString& aData, + uint16_t aMessageType) { + Unused << SendWebSocketMessageAvailable(aWebSocketSerialID, aData, + aMessageType); + return NS_OK; +} + +NS_IMETHODIMP +WebSocketEventListenerParent::FrameReceived(uint32_t aWebSocketSerialID, + nsIWebSocketFrame* aFrame) { + if (!aFrame) { + return NS_ERROR_FAILURE; + } + + WebSocketFrame* frame = static_cast<WebSocketFrame*>(aFrame); + Unused << SendFrameReceived(aWebSocketSerialID, frame->Data()); + return NS_OK; +} + +NS_IMETHODIMP +WebSocketEventListenerParent::FrameSent(uint32_t aWebSocketSerialID, + nsIWebSocketFrame* aFrame) { + if (!aFrame) { + return NS_ERROR_FAILURE; + } + + WebSocketFrame* frame = static_cast<WebSocketFrame*>(aFrame); + Unused << SendFrameSent(aWebSocketSerialID, frame->Data()); + return NS_OK; +} + +} // namespace net +} // namespace mozilla diff --git a/netwerk/protocol/websocket/WebSocketEventListenerParent.h b/netwerk/protocol/websocket/WebSocketEventListenerParent.h new file mode 100644 index 0000000000..eb5e02679b --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketEventListenerParent.h @@ -0,0 +1,44 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_net_WebSocketEventListenerParent_h +#define mozilla_net_WebSocketEventListenerParent_h + +#include "mozilla/net/PWebSocketEventListenerParent.h" +#include "nsIWebSocketEventService.h" + +namespace mozilla { +namespace net { + +class WebSocketEventService; + +class WebSocketEventListenerParent final : public PWebSocketEventListenerParent, + public nsIWebSocketEventListener { + friend class PWebSocketEventListenerParent; + + public: + NS_DECL_ISUPPORTS + NS_DECL_NSIWEBSOCKETEVENTLISTENER + + explicit WebSocketEventListenerParent(uint64_t aInnerWindowID); + + private: + ~WebSocketEventListenerParent(); + + mozilla::ipc::IPCResult RecvClose(); + + virtual void ActorDestroy(ActorDestroyReason aWhy) override; + + void UnregisterListener(); + + RefPtr<WebSocketEventService> mService; + uint64_t mInnerWindowID; +}; + +} // namespace net +} // namespace mozilla + +#endif // mozilla_net_WebSocketEventListenerParent_h diff --git a/netwerk/protocol/websocket/WebSocketEventService.cpp b/netwerk/protocol/websocket/WebSocketEventService.cpp new file mode 100644 index 0000000000..6e9a89a005 --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketEventService.cpp @@ -0,0 +1,566 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "WebSocketEventListenerChild.h" +#include "WebSocketEventService.h" +#include "WebSocketFrame.h" + +#include "mozilla/net/NeckoChild.h" +#include "mozilla/StaticPtr.h" +#include "nsISupportsPrimitives.h" +#include "nsIObserverService.h" +#include "nsXULAppAPI.h" +#include "nsSocketTransportService2.h" +#include "nsThreadUtils.h" +#include "mozilla/Services.h" +#include "nsIWebSocketImpl.h" + +namespace mozilla { +namespace net { + +namespace { + +StaticRefPtr<WebSocketEventService> gWebSocketEventService; + +bool IsChildProcess() { + return XRE_GetProcessType() != GeckoProcessType_Default; +} + +} // anonymous namespace + +class WebSocketBaseRunnable : public Runnable { + public: + WebSocketBaseRunnable(uint32_t aWebSocketSerialID, uint64_t aInnerWindowID) + : Runnable("net::WebSocketBaseRunnable"), + mWebSocketSerialID(aWebSocketSerialID), + mInnerWindowID(aInnerWindowID) {} + + NS_IMETHOD Run() override { + MOZ_ASSERT(NS_IsMainThread()); + + RefPtr<WebSocketEventService> service = + WebSocketEventService::GetOrCreate(); + MOZ_ASSERT(service); + + WebSocketEventService::WindowListeners listeners; + service->GetListeners(mInnerWindowID, listeners); + + for (uint32_t i = 0; i < listeners.Length(); ++i) { + DoWork(listeners[i]); + } + + return NS_OK; + } + + protected: + ~WebSocketBaseRunnable() = default; + + virtual void DoWork(nsIWebSocketEventListener* aListener) = 0; + + uint32_t mWebSocketSerialID; + uint64_t mInnerWindowID; +}; + +class WebSocketFrameRunnable final : public WebSocketBaseRunnable { + public: + WebSocketFrameRunnable(uint32_t aWebSocketSerialID, uint64_t aInnerWindowID, + already_AddRefed<WebSocketFrame> aFrame, + bool aFrameSent) + : WebSocketBaseRunnable(aWebSocketSerialID, aInnerWindowID), + mFrame(std::move(aFrame)), + mFrameSent(aFrameSent) {} + + private: + virtual void DoWork(nsIWebSocketEventListener* aListener) override { + DebugOnly<nsresult> rv{}; + if (mFrameSent) { + rv = aListener->FrameSent(mWebSocketSerialID, mFrame); + } else { + rv = aListener->FrameReceived(mWebSocketSerialID, mFrame); + } + + NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "Frame op failed"); + } + + RefPtr<WebSocketFrame> mFrame; + bool mFrameSent; +}; + +class WebSocketCreatedRunnable final : public WebSocketBaseRunnable { + public: + WebSocketCreatedRunnable(uint32_t aWebSocketSerialID, uint64_t aInnerWindowID, + const nsAString& aURI, const nsACString& aProtocols) + : WebSocketBaseRunnable(aWebSocketSerialID, aInnerWindowID), + mURI(aURI), + mProtocols(aProtocols) {} + + private: + virtual void DoWork(nsIWebSocketEventListener* aListener) override { + DebugOnly<nsresult> rv = + aListener->WebSocketCreated(mWebSocketSerialID, mURI, mProtocols); + NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "WebSocketCreated failed"); + } + + const nsString mURI; + const nsCString mProtocols; +}; + +class WebSocketOpenedRunnable final : public WebSocketBaseRunnable { + public: + WebSocketOpenedRunnable(uint32_t aWebSocketSerialID, uint64_t aInnerWindowID, + const nsAString& aEffectiveURI, + const nsACString& aProtocols, + const nsACString& aExtensions, + uint64_t aHttpChannelId) + : WebSocketBaseRunnable(aWebSocketSerialID, aInnerWindowID), + mEffectiveURI(aEffectiveURI), + mProtocols(aProtocols), + mExtensions(aExtensions), + mHttpChannelId(aHttpChannelId) {} + + private: + virtual void DoWork(nsIWebSocketEventListener* aListener) override { + DebugOnly<nsresult> rv = + aListener->WebSocketOpened(mWebSocketSerialID, mEffectiveURI, + mProtocols, mExtensions, mHttpChannelId); + NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "WebSocketOpened failed"); + } + + const nsString mEffectiveURI; + const nsCString mProtocols; + const nsCString mExtensions; + uint64_t mHttpChannelId; +}; + +class WebSocketMessageAvailableRunnable final : public WebSocketBaseRunnable { + public: + WebSocketMessageAvailableRunnable(uint32_t aWebSocketSerialID, + uint64_t aInnerWindowID, + const nsACString& aData, + uint16_t aMessageType) + : WebSocketBaseRunnable(aWebSocketSerialID, aInnerWindowID), + mData(aData), + mMessageType(aMessageType) {} + + private: + virtual void DoWork(nsIWebSocketEventListener* aListener) override { + DebugOnly<nsresult> rv = aListener->WebSocketMessageAvailable( + mWebSocketSerialID, mData, mMessageType); + NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "WebSocketMessageAvailable failed"); + } + + const nsCString mData; + uint16_t mMessageType; +}; + +class WebSocketClosedRunnable final : public WebSocketBaseRunnable { + public: + WebSocketClosedRunnable(uint32_t aWebSocketSerialID, uint64_t aInnerWindowID, + bool aWasClean, uint16_t aCode, + const nsAString& aReason) + : WebSocketBaseRunnable(aWebSocketSerialID, aInnerWindowID), + mWasClean(aWasClean), + mCode(aCode), + mReason(aReason) {} + + private: + virtual void DoWork(nsIWebSocketEventListener* aListener) override { + DebugOnly<nsresult> rv = aListener->WebSocketClosed( + mWebSocketSerialID, mWasClean, mCode, mReason); + NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "WebSocketClosed failed"); + } + + bool mWasClean; + uint16_t mCode; + const nsString mReason; +}; + +/* static */ +already_AddRefed<WebSocketEventService> WebSocketEventService::Get() { + MOZ_ASSERT(NS_IsMainThread()); + RefPtr<WebSocketEventService> service = gWebSocketEventService.get(); + return service.forget(); +} + +/* static */ +already_AddRefed<WebSocketEventService> WebSocketEventService::GetOrCreate() { + MOZ_ASSERT(NS_IsMainThread()); + + if (!gWebSocketEventService) { + gWebSocketEventService = new WebSocketEventService(); + } + + RefPtr<WebSocketEventService> service = gWebSocketEventService.get(); + return service.forget(); +} + +NS_INTERFACE_MAP_BEGIN(WebSocketEventService) + NS_INTERFACE_MAP_ENTRY_AMBIGUOUS(nsISupports, nsIWebSocketEventService) + NS_INTERFACE_MAP_ENTRY(nsIObserver) + NS_INTERFACE_MAP_ENTRY(nsIWebSocketEventService) +NS_INTERFACE_MAP_END + +NS_IMPL_ADDREF(WebSocketEventService) +NS_IMPL_RELEASE(WebSocketEventService) + +WebSocketEventService::WebSocketEventService() : mCountListeners(0) { + MOZ_ASSERT(NS_IsMainThread()); + + nsCOMPtr<nsIObserverService> obs = mozilla::services::GetObserverService(); + if (obs) { + obs->AddObserver(this, "xpcom-shutdown", false); + obs->AddObserver(this, "inner-window-destroyed", false); + } +} + +WebSocketEventService::~WebSocketEventService() { + MOZ_ASSERT(NS_IsMainThread()); +} + +void WebSocketEventService::WebSocketCreated(uint32_t aWebSocketSerialID, + uint64_t aInnerWindowID, + const nsAString& aURI, + const nsACString& aProtocols, + nsIEventTarget* aTarget) { + // Let's continue only if we have some listeners. + if (!HasListeners()) { + return; + } + + RefPtr<WebSocketCreatedRunnable> runnable = new WebSocketCreatedRunnable( + aWebSocketSerialID, aInnerWindowID, aURI, aProtocols); + DebugOnly<nsresult> rv = aTarget + ? aTarget->Dispatch(runnable, NS_DISPATCH_NORMAL) + : NS_DispatchToMainThread(runnable); + NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "NS_DispatchToMainThread failed"); +} + +void WebSocketEventService::WebSocketOpened(uint32_t aWebSocketSerialID, + uint64_t aInnerWindowID, + const nsAString& aEffectiveURI, + const nsACString& aProtocols, + const nsACString& aExtensions, + uint64_t aHttpChannelId, + nsIEventTarget* aTarget) { + // Let's continue only if we have some listeners. + if (!HasListeners()) { + return; + } + + RefPtr<WebSocketOpenedRunnable> runnable = new WebSocketOpenedRunnable( + aWebSocketSerialID, aInnerWindowID, aEffectiveURI, aProtocols, + aExtensions, aHttpChannelId); + DebugOnly<nsresult> rv = aTarget + ? aTarget->Dispatch(runnable, NS_DISPATCH_NORMAL) + : NS_DispatchToMainThread(runnable); + NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "NS_DispatchToMainThread failed"); +} + +void WebSocketEventService::WebSocketMessageAvailable( + uint32_t aWebSocketSerialID, uint64_t aInnerWindowID, + const nsACString& aData, uint16_t aMessageType, nsIEventTarget* aTarget) { + // Let's continue only if we have some listeners. + if (!HasListeners()) { + return; + } + + RefPtr<WebSocketMessageAvailableRunnable> runnable = + new WebSocketMessageAvailableRunnable(aWebSocketSerialID, aInnerWindowID, + aData, aMessageType); + DebugOnly<nsresult> rv = aTarget + ? aTarget->Dispatch(runnable, NS_DISPATCH_NORMAL) + : NS_DispatchToMainThread(runnable); + NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "NS_DispatchToMainThread failed"); +} + +void WebSocketEventService::WebSocketClosed(uint32_t aWebSocketSerialID, + uint64_t aInnerWindowID, + bool aWasClean, uint16_t aCode, + const nsAString& aReason, + nsIEventTarget* aTarget) { + // Let's continue only if we have some listeners. + if (!HasListeners()) { + return; + } + + RefPtr<WebSocketClosedRunnable> runnable = new WebSocketClosedRunnable( + aWebSocketSerialID, aInnerWindowID, aWasClean, aCode, aReason); + DebugOnly<nsresult> rv = aTarget + ? aTarget->Dispatch(runnable, NS_DISPATCH_NORMAL) + : NS_DispatchToMainThread(runnable); + NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "NS_DispatchToMainThread failed"); +} + +void WebSocketEventService::FrameReceived( + uint32_t aWebSocketSerialID, uint64_t aInnerWindowID, + already_AddRefed<WebSocketFrame> aFrame, nsIEventTarget* aTarget) { + RefPtr<WebSocketFrame> frame(std::move(aFrame)); + MOZ_ASSERT(frame); + + // Let's continue only if we have some listeners. + if (!HasListeners()) { + return; + } + + RefPtr<WebSocketFrameRunnable> runnable = + new WebSocketFrameRunnable(aWebSocketSerialID, aInnerWindowID, + frame.forget(), false /* frameSent */); + DebugOnly<nsresult> rv = aTarget + ? aTarget->Dispatch(runnable, NS_DISPATCH_NORMAL) + : NS_DispatchToMainThread(runnable); + NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "NS_DispatchToMainThread failed"); +} + +void WebSocketEventService::FrameSent(uint32_t aWebSocketSerialID, + uint64_t aInnerWindowID, + already_AddRefed<WebSocketFrame> aFrame, + nsIEventTarget* aTarget) { + RefPtr<WebSocketFrame> frame(std::move(aFrame)); + MOZ_ASSERT(frame); + + // Let's continue only if we have some listeners. + if (!HasListeners()) { + return; + } + + RefPtr<WebSocketFrameRunnable> runnable = new WebSocketFrameRunnable( + aWebSocketSerialID, aInnerWindowID, frame.forget(), true /* frameSent */); + + DebugOnly<nsresult> rv = aTarget + ? aTarget->Dispatch(runnable, NS_DISPATCH_NORMAL) + : NS_DispatchToMainThread(runnable); + NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "NS_DispatchToMainThread failed"); +} + +void WebSocketEventService::AssociateWebSocketImplWithSerialID( + nsIWebSocketImpl* aWebSocketImpl, uint32_t aWebSocketSerialID) { + MOZ_ASSERT(NS_IsMainThread()); + + mWebSocketImplMap.InsertOrUpdate(aWebSocketSerialID, + do_GetWeakReference(aWebSocketImpl)); +} + +NS_IMETHODIMP +WebSocketEventService::SendMessage(uint32_t aWebSocketSerialID, + const nsAString& aMessage) { + MOZ_ASSERT(NS_IsMainThread()); + + nsWeakPtr weakPtr = mWebSocketImplMap.Get(aWebSocketSerialID); + nsCOMPtr<nsIWebSocketImpl> webSocketImpl = do_QueryReferent(weakPtr); + + if (!webSocketImpl) { + return NS_ERROR_NOT_AVAILABLE; + } + + return webSocketImpl->SendMessage(aMessage); +} + +NS_IMETHODIMP +WebSocketEventService::AddListener(uint64_t aInnerWindowID, + nsIWebSocketEventListener* aListener) { + MOZ_ASSERT(NS_IsMainThread()); + + if (!aListener) { + return NS_ERROR_FAILURE; + } + + ++mCountListeners; + + mWindows + .LookupOrInsertWith( + aInnerWindowID, + [&] { + auto listener = MakeUnique<WindowListener>(); + + if (IsChildProcess()) { + PWebSocketEventListenerChild* actor = + gNeckoChild->SendPWebSocketEventListenerConstructor( + aInnerWindowID); + + listener->mActor = + static_cast<WebSocketEventListenerChild*>(actor); + MOZ_ASSERT(listener->mActor); + } + + return listener; + }) + ->mListeners.AppendElement(aListener); + + return NS_OK; +} + +NS_IMETHODIMP +WebSocketEventService::RemoveListener(uint64_t aInnerWindowID, + nsIWebSocketEventListener* aListener) { + MOZ_ASSERT(NS_IsMainThread()); + + if (!aListener) { + return NS_ERROR_FAILURE; + } + + WindowListener* listener = mWindows.Get(aInnerWindowID); + if (!listener) { + return NS_ERROR_FAILURE; + } + + if (!listener->mListeners.RemoveElement(aListener)) { + return NS_ERROR_FAILURE; + } + + // The last listener for this window. + if (listener->mListeners.IsEmpty()) { + if (IsChildProcess()) { + ShutdownActorListener(listener); + } + + mWindows.Remove(aInnerWindowID); + } + + MOZ_ASSERT(mCountListeners); + --mCountListeners; + + return NS_OK; +} + +NS_IMETHODIMP +WebSocketEventService::HasListenerFor(uint64_t aInnerWindowID, bool* aResult) { + MOZ_ASSERT(NS_IsMainThread()); + + *aResult = mWindows.Get(aInnerWindowID); + + return NS_OK; +} + +NS_IMETHODIMP +WebSocketEventService::Observe(nsISupports* aSubject, const char* aTopic, + const char16_t* aData) { + MOZ_ASSERT(NS_IsMainThread()); + + if (!strcmp(aTopic, "xpcom-shutdown")) { + Shutdown(); + return NS_OK; + } + + if (!strcmp(aTopic, "inner-window-destroyed") && HasListeners()) { + nsCOMPtr<nsISupportsPRUint64> wrapper = do_QueryInterface(aSubject); + NS_ENSURE_TRUE(wrapper, NS_ERROR_FAILURE); + + uint64_t innerID; + nsresult rv = wrapper->GetData(&innerID); + NS_ENSURE_SUCCESS(rv, rv); + + WindowListener* listener = mWindows.Get(innerID); + if (!listener) { + return NS_OK; + } + + MOZ_ASSERT(mCountListeners >= listener->mListeners.Length()); + mCountListeners -= listener->mListeners.Length(); + + if (IsChildProcess()) { + ShutdownActorListener(listener); + } + + mWindows.Remove(innerID); + } + + // This should not happen. + return NS_ERROR_FAILURE; +} + +void WebSocketEventService::Shutdown() { + MOZ_ASSERT(NS_IsMainThread()); + + if (gWebSocketEventService) { + nsCOMPtr<nsIObserverService> obs = mozilla::services::GetObserverService(); + if (obs) { + obs->RemoveObserver(gWebSocketEventService, "xpcom-shutdown"); + obs->RemoveObserver(gWebSocketEventService, "inner-window-destroyed"); + } + + mWindows.Clear(); + gWebSocketEventService = nullptr; + } +} + +bool WebSocketEventService::HasListeners() const { return !!mCountListeners; } + +void WebSocketEventService::GetListeners( + uint64_t aInnerWindowID, + WebSocketEventService::WindowListeners& aListeners) const { + aListeners.Clear(); + + WindowListener* listener = mWindows.Get(aInnerWindowID); + if (!listener) { + return; + } + + aListeners.AppendElements(listener->mListeners); +} + +void WebSocketEventService::ShutdownActorListener(WindowListener* aListener) { + MOZ_ASSERT(aListener); + MOZ_ASSERT(aListener->mActor); + aListener->mActor->Close(); + aListener->mActor = nullptr; +} + +already_AddRefed<WebSocketFrame> WebSocketEventService::CreateFrameIfNeeded( + bool aFinBit, bool aRsvBit1, bool aRsvBit2, bool aRsvBit3, uint8_t aOpCode, + bool aMaskBit, uint32_t aMask, const nsCString& aPayload) { + if (!HasListeners()) { + return nullptr; + } + + return MakeAndAddRef<WebSocketFrame>(aFinBit, aRsvBit1, aRsvBit2, aRsvBit3, + aOpCode, aMaskBit, aMask, aPayload); +} + +already_AddRefed<WebSocketFrame> WebSocketEventService::CreateFrameIfNeeded( + bool aFinBit, bool aRsvBit1, bool aRsvBit2, bool aRsvBit3, uint8_t aOpCode, + bool aMaskBit, uint32_t aMask, uint8_t* aPayload, uint32_t aPayloadLength) { + if (!HasListeners()) { + return nullptr; + } + + nsAutoCString payloadStr; + if (NS_WARN_IF(!(payloadStr.Assign((const char*)aPayload, aPayloadLength, + mozilla::fallible)))) { + return nullptr; + } + + return MakeAndAddRef<WebSocketFrame>(aFinBit, aRsvBit1, aRsvBit2, aRsvBit3, + aOpCode, aMaskBit, aMask, payloadStr); +} + +already_AddRefed<WebSocketFrame> WebSocketEventService::CreateFrameIfNeeded( + bool aFinBit, bool aRsvBit1, bool aRsvBit2, bool aRsvBit3, uint8_t aOpCode, + bool aMaskBit, uint32_t aMask, uint8_t* aPayloadInHdr, + uint32_t aPayloadInHdrLength, uint8_t* aPayload, uint32_t aPayloadLength) { + if (!HasListeners()) { + return nullptr; + } + + uint32_t payloadLength = aPayloadLength + aPayloadInHdrLength; + + nsAutoCString payload; + if (NS_WARN_IF(!payload.SetLength(payloadLength, fallible))) { + return nullptr; + } + + char* payloadPtr = payload.BeginWriting(); + if (aPayloadInHdrLength) { + memcpy(payloadPtr, aPayloadInHdr, aPayloadInHdrLength); + } + + memcpy(payloadPtr + aPayloadInHdrLength, aPayload, aPayloadLength); + + return MakeAndAddRef<WebSocketFrame>(aFinBit, aRsvBit1, aRsvBit2, aRsvBit3, + aOpCode, aMaskBit, aMask, payload); +} + +} // namespace net +} // namespace mozilla diff --git a/netwerk/protocol/websocket/WebSocketEventService.h b/netwerk/protocol/websocket/WebSocketEventService.h new file mode 100644 index 0000000000..0f15e5058b --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketEventService.h @@ -0,0 +1,125 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_net_WebSocketEventService_h +#define mozilla_net_WebSocketEventService_h + +#include "mozilla/AlreadyAddRefed.h" +#include "mozilla/Atomics.h" +#include "nsIWebSocketEventService.h" +#include "nsCOMPtr.h" +#include "nsClassHashtable.h" +#include "nsHashKeys.h" +#include "nsIObserver.h" +#include "nsISupportsImpl.h" +#include "nsTArray.h" +#include "nsIWeakReferenceUtils.h" +#include "nsTHashMap.h" + +class nsIWebSocketImpl; +class nsIEventTarget; + +namespace mozilla { +namespace net { + +class WebSocketFrame; +class WebSocketEventListenerChild; + +class WebSocketEventService final : public nsIWebSocketEventService, + public nsIObserver { + friend class WebSocketBaseRunnable; + + public: + NS_DECL_ISUPPORTS + NS_DECL_NSIOBSERVER + NS_DECL_NSIWEBSOCKETEVENTSERVICE + + static already_AddRefed<WebSocketEventService> Get(); + static already_AddRefed<WebSocketEventService> GetOrCreate(); + + void WebSocketCreated(uint32_t aWebSocketSerialID, uint64_t aInnerWindowID, + const nsAString& aURI, const nsACString& aProtocols, + nsIEventTarget* aTarget = nullptr); + + void WebSocketOpened(uint32_t aWebSocketSerialID, uint64_t aInnerWindowID, + const nsAString& aEffectiveURI, + const nsACString& aProtocols, + const nsACString& aExtensions, uint64_t aHttpChannelId, + nsIEventTarget* aTarget = nullptr); + + void WebSocketMessageAvailable(uint32_t aWebSocketSerialID, + uint64_t aInnerWindowID, + const nsACString& aData, uint16_t aMessageType, + nsIEventTarget* aTarget = nullptr); + + void WebSocketClosed(uint32_t aWebSocketSerialID, uint64_t aInnerWindowID, + bool aWasClean, uint16_t aCode, const nsAString& aReason, + nsIEventTarget* aTarget = nullptr); + + void FrameReceived(uint32_t aWebSocketSerialID, uint64_t aInnerWindowID, + already_AddRefed<WebSocketFrame> aFrame, + nsIEventTarget* aTarget = nullptr); + + void FrameSent(uint32_t aWebSocketSerialID, uint64_t aInnerWindowID, + already_AddRefed<WebSocketFrame> aFrame, + nsIEventTarget* aTarget = nullptr); + + void AssociateWebSocketImplWithSerialID(nsIWebSocketImpl* aWebSocketImpl, + uint32_t aWebSocketSerialID); + + already_AddRefed<WebSocketFrame> CreateFrameIfNeeded( + bool aFinBit, bool aRsvBit1, bool aRsvBit2, bool aRsvBit3, + uint8_t aOpCode, bool aMaskBit, uint32_t aMask, + const nsCString& aPayload); + + already_AddRefed<WebSocketFrame> CreateFrameIfNeeded( + bool aFinBit, bool aRsvBit1, bool aRsvBit2, bool aRsvBit3, + uint8_t aOpCode, bool aMaskBit, uint32_t aMask, uint8_t* aPayload, + uint32_t aPayloadLength); + + already_AddRefed<WebSocketFrame> CreateFrameIfNeeded( + bool aFinBit, bool aRsvBit1, bool aRsvBit2, bool aRsvBit3, + uint8_t aOpCode, bool aMaskBit, uint32_t aMask, uint8_t* aPayloadInHdr, + uint32_t aPayloadInHdrLength, uint8_t* aPayload, uint32_t aPayloadLength); + + private: + WebSocketEventService(); + ~WebSocketEventService(); + + bool HasListeners() const; + void Shutdown(); + + using WindowListeners = nsTArray<nsCOMPtr<nsIWebSocketEventListener>>; + + nsTHashMap<nsUint32HashKey, nsWeakPtr> mWebSocketImplMap; + + struct WindowListener { + WindowListeners mListeners; + RefPtr<WebSocketEventListenerChild> mActor; + }; + + void GetListeners(uint64_t aInnerWindowID, WindowListeners& aListeners) const; + + void ShutdownActorListener(WindowListener* aListener); + + // Used only on the main-thread. + nsClassHashtable<nsUint64HashKey, WindowListener> mWindows; + + Atomic<uint64_t> mCountListeners; +}; + +} // namespace net +} // namespace mozilla + +/** + * Casting WebSocketEventService to nsISupports is ambiguous. + * This method handles that. + */ +inline nsISupports* ToSupports(mozilla::net::WebSocketEventService* p) { + return NS_ISUPPORTS_CAST(nsIWebSocketEventService*, p); +} + +#endif // mozilla_net_WebSocketEventService_h diff --git a/netwerk/protocol/websocket/WebSocketFrame.cpp b/netwerk/protocol/websocket/WebSocketFrame.cpp new file mode 100644 index 0000000000..745b598ba9 --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketFrame.cpp @@ -0,0 +1,133 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "WebSocketFrame.h" + +#include "ErrorList.h" +#include "MainThreadUtils.h" +#include "chrome/common/ipc_message_utils.h" +#include "mozilla/Assertions.h" +#include "nsDOMNavigationTiming.h" +#include "nsSocketTransportService2.h" +#include "nscore.h" +#include "prtime.h" + +namespace mozilla { +namespace net { + +NS_INTERFACE_MAP_BEGIN(WebSocketFrame) + NS_INTERFACE_MAP_ENTRY_AMBIGUOUS(nsISupports, nsIWebSocketFrame) + NS_INTERFACE_MAP_ENTRY(nsIWebSocketFrame) +NS_INTERFACE_MAP_END + +NS_IMPL_ADDREF(WebSocketFrame) +NS_IMPL_RELEASE(WebSocketFrame) + +WebSocketFrame::WebSocketFrame(const WebSocketFrameData& aData) + : mData(aData) {} + +WebSocketFrame::WebSocketFrame(bool aFinBit, bool aRsvBit1, bool aRsvBit2, + bool aRsvBit3, uint8_t aOpCode, bool aMaskBit, + uint32_t aMask, const nsCString& aPayload) + : mData(PR_Now(), aFinBit, aRsvBit1, aRsvBit2, aRsvBit3, aOpCode, aMaskBit, + aMask, aPayload) { + // This could be called on the background thread when socket process is + // enabled. + mData.mTimeStamp = PR_Now(); +} + +#define WSF_GETTER(method, value, type) \ + NS_IMETHODIMP \ + WebSocketFrame::method(type* aValue) { \ + MOZ_ASSERT(NS_IsMainThread()); \ + if (!aValue) { \ + return NS_ERROR_FAILURE; \ + } \ + *aValue = value; \ + return NS_OK; \ + } + +WSF_GETTER(GetTimeStamp, mData.mTimeStamp, DOMHighResTimeStamp); +WSF_GETTER(GetFinBit, mData.mFinBit, bool); +WSF_GETTER(GetRsvBit1, mData.mRsvBit1, bool); +WSF_GETTER(GetRsvBit2, mData.mRsvBit2, bool); +WSF_GETTER(GetRsvBit3, mData.mRsvBit3, bool); +WSF_GETTER(GetOpCode, mData.mOpCode, uint16_t); +WSF_GETTER(GetMaskBit, mData.mMaskBit, bool); +WSF_GETTER(GetMask, mData.mMask, uint32_t); + +#undef WSF_GETTER + +NS_IMETHODIMP +WebSocketFrame::GetPayload(nsACString& aValue) { + MOZ_ASSERT(NS_IsMainThread()); + aValue = mData.mPayload; + return NS_OK; +} + +WebSocketFrameData::WebSocketFrameData() + : mFinBit(false), + mRsvBit1(false), + mRsvBit2(false), + mRsvBit3(false), + mMaskBit(false) {} + +WebSocketFrameData::WebSocketFrameData(DOMHighResTimeStamp aTimeStamp, + bool aFinBit, bool aRsvBit1, + bool aRsvBit2, bool aRsvBit3, + uint8_t aOpCode, bool aMaskBit, + uint32_t aMask, + const nsCString& aPayload) + : mTimeStamp(aTimeStamp), + mFinBit(aFinBit), + mRsvBit1(aRsvBit1), + mRsvBit2(aRsvBit2), + mRsvBit3(aRsvBit3), + mMaskBit(aMaskBit), + mOpCode(aOpCode), + mMask(aMask), + mPayload(aPayload) {} + +void WebSocketFrameData::WriteIPCParams(IPC::MessageWriter* aWriter) const { + WriteParam(aWriter, mTimeStamp); + WriteParam(aWriter, mFinBit); + WriteParam(aWriter, mRsvBit1); + WriteParam(aWriter, mRsvBit2); + WriteParam(aWriter, mRsvBit3); + WriteParam(aWriter, mMaskBit); + WriteParam(aWriter, mOpCode); + WriteParam(aWriter, mMask); + WriteParam(aWriter, mPayload); +} + +bool WebSocketFrameData::ReadIPCParams(IPC::MessageReader* aReader) { + if (!ReadParam(aReader, &mTimeStamp)) { + return false; + } + +#define ReadParamHelper(x) \ + { \ + bool bit; \ + if (!ReadParam(aReader, &bit)) { \ + return false; \ + } \ + (x) = bit; \ + } + + ReadParamHelper(mFinBit); + ReadParamHelper(mRsvBit1); + ReadParamHelper(mRsvBit2); + ReadParamHelper(mRsvBit3); + ReadParamHelper(mMaskBit); + +#undef ReadParamHelper + + return ReadParam(aReader, &mOpCode) && ReadParam(aReader, &mMask) && + ReadParam(aReader, &mPayload); +} + +} // namespace net +} // namespace mozilla diff --git a/netwerk/protocol/websocket/WebSocketFrame.h b/netwerk/protocol/websocket/WebSocketFrame.h new file mode 100644 index 0000000000..8bcb672692 --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketFrame.h @@ -0,0 +1,104 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_net_WebSocketFrame_h +#define mozilla_net_WebSocketFrame_h + +#include <cstdint> +#include "nsISupports.h" +#include "nsIWebSocketEventService.h" +#include "nsString.h" + +class PickleIterator; + +// Avoid including nsDOMNavigationTiming.h here, where the canonical definition +// of DOMHighResTimeStamp resides. +using DOMHighResTimeStamp = double; + +namespace IPC { +class Message; +class MessageReader; +class MessageWriter; +template <class P> +struct ParamTraits; +} // namespace IPC + +namespace mozilla { +namespace net { + +class WebSocketFrameData final { + public: + WebSocketFrameData(); + + explicit WebSocketFrameData(const WebSocketFrameData&) = default; + WebSocketFrameData(WebSocketFrameData&&) = default; + WebSocketFrameData& operator=(WebSocketFrameData&&) = default; + WebSocketFrameData& operator=(const WebSocketFrameData&) = default; + + WebSocketFrameData(DOMHighResTimeStamp aTimeStamp, bool aFinBit, + bool aRsvBit1, bool aRsvBit2, bool aRsvBit3, + uint8_t aOpCode, bool aMaskBit, uint32_t aMask, + const nsCString& aPayload); + + ~WebSocketFrameData() = default; + + // For IPC serialization + void WriteIPCParams(IPC::MessageWriter* aWriter) const; + bool ReadIPCParams(IPC::MessageReader* aReader); + + DOMHighResTimeStamp mTimeStamp{0}; + + bool mFinBit : 1; + bool mRsvBit1 : 1; + bool mRsvBit2 : 1; + bool mRsvBit3 : 1; + bool mMaskBit : 1; + uint8_t mOpCode{0}; + + uint32_t mMask{0}; + + nsCString mPayload; +}; + +class WebSocketFrame final : public nsIWebSocketFrame { + public: + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSIWEBSOCKETFRAME + + explicit WebSocketFrame(const WebSocketFrameData& aData); + + WebSocketFrame(bool aFinBit, bool aRsvBit1, bool aRsvBit2, bool aRsvBit3, + uint8_t aOpCode, bool aMaskBit, uint32_t aMask, + const nsCString& aPayload); + + const WebSocketFrameData& Data() const { return mData; } + + private: + ~WebSocketFrame() = default; + + WebSocketFrameData mData; +}; + +} // namespace net +} // namespace mozilla + +namespace IPC { +template <> +struct ParamTraits<mozilla::net::WebSocketFrameData> { + using paramType = mozilla::net::WebSocketFrameData; + + static void Write(MessageWriter* aWriter, const paramType& aParam) { + aParam.WriteIPCParams(aWriter); + } + + static bool Read(MessageReader* aReader, paramType* aResult) { + return aResult->ReadIPCParams(aReader); + } +}; + +} // namespace IPC + +#endif // mozilla_net_WebSocketFrame_h diff --git a/netwerk/protocol/websocket/WebSocketLog.h b/netwerk/protocol/websocket/WebSocketLog.h new file mode 100644 index 0000000000..80122249c2 --- /dev/null +++ b/netwerk/protocol/websocket/WebSocketLog.h @@ -0,0 +1,24 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set sw=2 ts=8 et tw=80 : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef WebSocketLog_h +#define WebSocketLog_h + +#include "base/basictypes.h" +#include "mozilla/Logging.h" +#include "mozilla/net/NeckoChild.h" + +namespace mozilla { +namespace net { +extern LazyLogModule webSocketLog; +} +} // namespace mozilla + +#undef LOG +#define LOG(args) \ + MOZ_LOG(mozilla::net::webSocketLog, mozilla::LogLevel::Debug, args) + +#endif diff --git a/netwerk/protocol/websocket/moz.build b/netwerk/protocol/websocket/moz.build new file mode 100644 index 0000000000..123d364b70 --- /dev/null +++ b/netwerk/protocol/websocket/moz.build @@ -0,0 +1,67 @@ +# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*- +# vim: set filetype=python: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +with Files("**"): + BUG_COMPONENT = ("Core", "Networking: WebSockets") + +XPIDL_SOURCES += [ + "nsITransportProvider.idl", + "nsIWebSocketChannel.idl", + "nsIWebSocketEventService.idl", + "nsIWebSocketImpl.idl", + "nsIWebSocketListener.idl", +] + +XPIDL_MODULE = "necko_websocket" + +EXPORTS.mozilla.net += [ + "BaseWebSocketChannel.h", + "IPCTransportProvider.h", + "WebSocketChannel.h", + "WebSocketChannelChild.h", + "WebSocketChannelParent.h", + "WebSocketConnectionBase.h", + "WebSocketConnectionChild.h", + "WebSocketConnectionListener.h", + "WebSocketConnectionParent.h", + "WebSocketEventListenerChild.h", + "WebSocketEventListenerParent.h", + "WebSocketEventService.h", + "WebSocketFrame.h", +] + +UNIFIED_SOURCES += [ + "BaseWebSocketChannel.cpp", + "IPCTransportProvider.cpp", + "WebSocketChannel.cpp", + "WebSocketChannelChild.cpp", + "WebSocketChannelParent.cpp", + "WebSocketConnection.cpp", + "WebSocketConnectionChild.cpp", + "WebSocketConnectionParent.cpp", + "WebSocketEventListenerChild.cpp", + "WebSocketEventListenerParent.cpp", + "WebSocketEventService.cpp", + "WebSocketFrame.cpp", +] + +IPDL_SOURCES += [ + "PTransportProvider.ipdl", + "PWebSocket.ipdl", + "PWebSocketConnection.ipdl", + "PWebSocketEventListener.ipdl", +] + +include("/ipc/chromium/chromium-config.mozbuild") + +FINAL_LIBRARY = "xul" + +LOCAL_INCLUDES += [ + "/dom/base", + "/netwerk/base", +] + +include("/tools/fuzzing/libfuzzer-config.mozbuild") diff --git a/netwerk/protocol/websocket/nsITransportProvider.idl b/netwerk/protocol/websocket/nsITransportProvider.idl new file mode 100644 index 0000000000..8b60eabd03 --- /dev/null +++ b/netwerk/protocol/websocket/nsITransportProvider.idl @@ -0,0 +1,36 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 4 -*- */ +/* vim: set sw=4 ts=4 et tw=80 : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +interface nsIHttpUpgradeListener; + +#include "nsISupports.idl" + +%{C++ +namespace mozilla { +namespace net { +class PTransportProviderChild; +} +} +%} + +[ptr] native PTransportProviderChild(mozilla::net::PTransportProviderChild); + +/** + * An interface which can be used to asynchronously request a nsITransport + * together with the input and output streams that go together with it. + */ +[scriptable, uuid(6fcec704-cfd2-46ef-a394-a64d5cb1475c)] +interface nsITransportProvider : nsISupports +{ + // This must not be called in a child process since transport + // objects are not accessible there. Call getIPCChild instead. + [must_use] void setListener(in nsIHttpUpgradeListener listener); + + // This must be implemented by nsITransportProvider objects running + // in the child process. It must return null when called in the parent + // process. + [noscript, must_use] PTransportProviderChild getIPCChild(); +}; diff --git a/netwerk/protocol/websocket/nsIWebSocketChannel.idl b/netwerk/protocol/websocket/nsIWebSocketChannel.idl new file mode 100644 index 0000000000..7092c3c6b1 --- /dev/null +++ b/netwerk/protocol/websocket/nsIWebSocketChannel.idl @@ -0,0 +1,268 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 4 -*- */ +/* vim: set sw=4 ts=4 et tw=80 : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +interface nsICookieJarSettings; +interface nsIInputStream; +interface nsIInterfaceRequestor; +interface nsILoadGroup; +interface nsILoadInfo; +interface nsIPrincipal; +interface nsITransportProvider; +interface nsITransportSecurityInfo; +interface nsIURI; +interface nsIWebSocketListener; + +webidl Node; + +#include "nsISupports.idl" +#include "nsIContentPolicy.idl" + + + +[ref] native OriginAttributes(const mozilla::OriginAttributes); + +/** + * Low-level websocket API: handles network protocol. + * + * This is primarly intended for use by the higher-level nsIWebSocket.idl. + * We are also making it scriptable for now, but this may change once we have + * WebSockets for Workers. + */ +[scriptable, builtinclass, uuid(ce71d028-322a-4105-a947-a894689b52bf)] +interface nsIWebSocketChannel : nsISupports +{ + /** + * The original URI used to construct the protocol connection. This is used + * in the case of a redirect or URI "resolution" (e.g. resolving a + * resource: URI to a file: URI) so that the original pre-redirect + * URI can still be obtained. This is never null. + */ + [must_use] readonly attribute nsIURI originalURI; + + /** + * The readonly URI corresponding to the protocol connection after any + * redirections are completed. + */ + [must_use] readonly attribute nsIURI URI; + + /** + * The notification callbacks for authorization, etc.. + */ + [must_use] attribute nsIInterfaceRequestor notificationCallbacks; + + /** + * Transport-level security information (if any) + */ + [must_use] readonly attribute nsITransportSecurityInfo securityInfo; + + /** + * The load group of of the websocket + */ + [must_use] attribute nsILoadGroup loadGroup; + + /** + * The load info of the websocket + */ + [must_use] attribute nsILoadInfo loadInfo; + + /** + * Sec-Websocket-Protocol value + */ + [must_use] attribute ACString protocol; + + /** + * Sec-Websocket-Extensions response header value + */ + [must_use] readonly attribute ACString extensions; + + /** + * The channelId of the underlying http channel. + * It's available only after nsIWebSocketListener::onStart + */ + [must_use] readonly attribute uint64_t httpChannelId; + + /** + * Init the WebSocketChannel with LoadInfo arguments. + * @param aLoadingNode + * @param aLoadingPrincipal + * @param aTriggeringPrincipal + * @param aCookieJarSettings + * @param aSecurityFlags + * @param aContentPolicyType + * These will be used as values for the nsILoadInfo object on the + * created channel. For details, see nsILoadInfo in nsILoadInfo.idl + * @return reference to the new nsIChannel object + * + * Keep in mind that URIs coming from a webpage should *never* use the + * systemPrincipal as the loadingPrincipal. + * + * Please note, if you provide both a loadingNode and a loadingPrincipal, + * then loadingPrincipal must be equal to loadingNode->NodePrincipal(). + * But less error prone is to just supply a loadingNode. + */ + [notxpcom] nsresult initLoadInfoNative(in Node aLoadingNode, + in nsIPrincipal aLoadingPrincipal, + in nsIPrincipal aTriggeringPrincipal, + in nsICookieJarSettings aCookieJarSettings, + in unsigned long aSecurityFlags, + in nsContentPolicyType aContentPolicyType, + in unsigned long aSandboxFlags); + + /** + * Similar to the previous one but without nsICookieJarSettings. + * This method is used by JS code where nsICookieJarSettings is not exposed. + */ + [must_use] void initLoadInfo(in Node aLoadingNode, + in nsIPrincipal aLoadingPrincipal, + in nsIPrincipal aTriggeringPrincipal, + in unsigned long aSecurityFlags, + in nsContentPolicyType aContentPolicyType); + + /** + * Asynchronously open the websocket connection. Received messages are fed + * to the socket listener as they arrive. The socket listener's methods + * are called on the thread that calls asyncOpen and are not called until + * after asyncOpen returns. If asyncOpen returns successfully, the + * protocol implementation promises to call at least onStop on the listener. + * + * NOTE: Implementations should throw NS_ERROR_ALREADY_OPENED if the + * websocket connection is reopened. + * + * @param aURI the uri of the websocket protocol - may be redirected + * @param aOrigin the uri of the originating resource + * @param aOriginAttributes attributes of the originating resource. + * @param aInnerWindowID the inner window ID + * @param aListener the nsIWebSocketListener implementation + * @param aContext an opaque parameter forwarded to aListener's methods + */ + [implicit_jscontext] + void asyncOpen(in nsIURI aURI, + in ACString aOrigin, + in jsval aOriginAttributes, + in unsigned long long aInnerWindowID, + in nsIWebSocketListener aListener, + in nsISupports aContext); + + [must_use] + void asyncOpenNative(in nsIURI aURI, + in ACString aOrigin, + in OriginAttributes aOriginAttributes, + in unsigned long long aInnerWindowID, + in nsIWebSocketListener aListener, + in nsISupports aContext); + + /* + * Close the websocket connection for writing - no more calls to sendMsg + * or sendBinaryMsg should be made after calling this. The listener object + * may receive more messages if a server close has not yet been received. + * + * @param aCode the websocket closing handshake close code. Set to 0 if + * you are not providing a code. + * @param aReason the websocket closing handshake close reason + */ + [must_use] void close(in unsigned short aCode, in AUTF8String aReason); + + // section 7.4.1 defines these close codes + const unsigned short CLOSE_NORMAL = 1000; + const unsigned short CLOSE_GOING_AWAY = 1001; + const unsigned short CLOSE_PROTOCOL_ERROR = 1002; + const unsigned short CLOSE_UNSUPPORTED_DATATYPE = 1003; + // code 1004 is reserved + const unsigned short CLOSE_NO_STATUS = 1005; + const unsigned short CLOSE_ABNORMAL = 1006; + const unsigned short CLOSE_INVALID_PAYLOAD = 1007; + const unsigned short CLOSE_POLICY_VIOLATION = 1008; + const unsigned short CLOSE_TOO_LARGE = 1009; + const unsigned short CLOSE_EXTENSION_MISSING = 1010; + // Initially used just for server-side internal errors: adopted later for + // client-side errors too (not clear if will make into spec: see + // http://www.ietf.org/mail-archive/web/hybi/current/msg09372.html + const unsigned short CLOSE_INTERNAL_ERROR = 1011; + // MUST NOT be set as a status code in Close control frame by an endpoint: + // To be used if TLS handshake failed (ex: server certificate unverifiable) + const unsigned short CLOSE_TLS_FAILED = 1015; + + /** + * Use to send text message down the connection to WebSocket peer. + * + * @param aMsg the utf8 string to send + */ + [must_use] void sendMsg(in AUTF8String aMsg); + + /** + * Use to send binary message down the connection to WebSocket peer. + * + * @param aMsg the data to send + */ + [must_use] void sendBinaryMsg(in ACString aMsg); + + /** + * Use to send a binary stream (Blob) to Websocket peer. + * + * @param aStream The input stream to be sent. + */ + [must_use] void sendBinaryStream(in nsIInputStream aStream, + in unsigned long length); + + /** + * This value determines how often (in seconds) websocket keepalive + * pings are sent. If set to 0 (the default), no pings are ever sent. + * + * This value can currently only be set before asyncOpen is called, else + * NS_ERROR_IN_PROGRESS is thrown. + * + * Be careful using this setting: ping traffic can consume lots of power and + * bandwidth over time. + */ + [must_use] attribute unsigned long pingInterval; + + /** + * This value determines how long (in seconds) the websocket waits for + * the server to reply to a ping that has been sent before considering the + * connection broken. + * + * This value can currently only be set before asyncOpen is called, else + * NS_ERROR_IN_PROGRESS is thrown. + */ + [must_use] attribute unsigned long pingTimeout; + + /** + * Unique ID for this channel. It's not readonly because when the channel is + * created via IPC, the serial number is received from the child process. + */ + [must_use] attribute unsigned long serial; + + /** + * Set a nsITransportProvider and negotated extensions to be used by this + * channel. Calling this function also means that this channel will + * implement the server-side part of a websocket connection rather than the + * client-side part. + */ + [must_use] void setServerParameters(in nsITransportProvider aProvider, + in ACString aNegotiatedExtensions); + +%{C++ + inline uint32_t Serial() + { + uint32_t serial; + nsresult rv = GetSerial(&serial); + if (NS_WARN_IF(NS_FAILED(rv))) { + return 0; + } + return serial; + } + + inline uint64_t HttpChannelId() + { + uint64_t httpChannelId; + nsresult rv = GetHttpChannelId(&httpChannelId); + if (NS_WARN_IF(NS_FAILED(rv))) { + return 0; + } + return httpChannelId; + } +%} +}; diff --git a/netwerk/protocol/websocket/nsIWebSocketEventService.idl b/netwerk/protocol/websocket/nsIWebSocketEventService.idl new file mode 100644 index 0000000000..9763850609 --- /dev/null +++ b/netwerk/protocol/websocket/nsIWebSocketEventService.idl @@ -0,0 +1,87 @@ +/* -*- Mode: IDL; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* This Source Code Form is subject to the terms of the Mozilla Public +* License, v. 2.0. If a copy of the MPL was not distributed with this +* file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "domstubs.idl" +#include "nsISupports.idl" + +interface nsIWeakReference; + +[scriptable, builtinclass, uuid(6714a6be-2265-4f73-a988-d78a12416037)] +interface nsIWebSocketFrame : nsISupports +{ + [must_use] readonly attribute DOMHighResTimeStamp timeStamp; + + [must_use] readonly attribute boolean finBit; + + [must_use] readonly attribute boolean rsvBit1; + [must_use] readonly attribute boolean rsvBit2; + [must_use] readonly attribute boolean rsvBit3; + + [must_use] readonly attribute unsigned short opCode; + + [must_use] readonly attribute boolean maskBit; + + [must_use] readonly attribute unsigned long mask; + + [must_use] readonly attribute ACString payload; + + // Non-Control opCode values: + const unsigned short OPCODE_CONTINUATION = 0x0; + const unsigned short OPCODE_TEXT = 0x1; + const unsigned short OPCODE_BINARY = 0x2; + + // Control opCode values: + const unsigned short OPCODE_CLOSE = 0x8; + const unsigned short OPCODE_PING = 0x9; + const unsigned short OPCODE_PONG = 0xA; +}; + +[scriptable, uuid(e7c005ab-e694-489b-b741-96db43ffb16f)] +interface nsIWebSocketEventListener : nsISupports +{ + [must_use] void webSocketCreated(in unsigned long aWebSocketSerialID, + in AString aURI, + in ACString aProtocols); + + [must_use] void webSocketOpened(in unsigned long aWebSocketSerialID, + in AString aEffectiveURI, + in ACString aProtocols, + in ACString aExtensions, + in uint64_t aHttpChannelId); + + const unsigned short TYPE_STRING = 0x0; + const unsigned short TYPE_BLOB = 0x1; + const unsigned short TYPE_ARRAYBUFFER = 0x2; + + [must_use] void webSocketMessageAvailable(in unsigned long aWebSocketSerialID, + in ACString aMessage, + in unsigned short aType); + + [must_use] void webSocketClosed(in unsigned long aWebSocketSerialID, + in boolean aWasClean, + in unsigned short aCode, + in AString aReason); + + [must_use] void frameReceived(in unsigned long aWebSocketSerialID, + in nsIWebSocketFrame aFrame); + + [must_use] void frameSent(in unsigned long aWebSocketSerialID, + in nsIWebSocketFrame aFrame); +}; + +[scriptable, builtinclass, uuid(b89d1b90-2cf3-4d8f-ac21-5aedfb25c760)] +interface nsIWebSocketEventService : nsISupports +{ + [must_use] void sendMessage(in unsigned long aWebSocketSerialID, + in AString aMessage); + + [must_use] void addListener(in unsigned long long aInnerWindowID, + in nsIWebSocketEventListener aListener); + + [must_use] void removeListener(in unsigned long long aInnerWindowID, + in nsIWebSocketEventListener aListener); + + [must_use] bool hasListenerFor(in unsigned long long aInnerWindowID); +}; diff --git a/netwerk/protocol/websocket/nsIWebSocketImpl.idl b/netwerk/protocol/websocket/nsIWebSocketImpl.idl new file mode 100644 index 0000000000..ab4177a7e6 --- /dev/null +++ b/netwerk/protocol/websocket/nsIWebSocketImpl.idl @@ -0,0 +1,18 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 4 -*- */ +/* vim: set sw=4 ts=4 et tw=80 : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "nsISupports.idl" + +[scriptable, uuid(db1f4e2b-3cff-4615-a03c-341fda66c53d)] +interface nsIWebSocketImpl : nsISupports +{ + /** + * Called to send message of type string through web socket + * + * @param aMessage the message to send + */ + [must_use] void sendMessage(in AString aMessage); +}; diff --git a/netwerk/protocol/websocket/nsIWebSocketListener.idl b/netwerk/protocol/websocket/nsIWebSocketListener.idl new file mode 100644 index 0000000000..31c588630b --- /dev/null +++ b/netwerk/protocol/websocket/nsIWebSocketListener.idl @@ -0,0 +1,94 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 4 -*- */ +/* vim: set sw=4 ts=4 et tw=80 : */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "nsISupports.idl" + +/** + * nsIWebSocketListener: passed to nsIWebSocketChannel::AsyncOpen. Receives + * websocket traffic events as they arrive. + */ +[scriptable, uuid(d74c96b2-65b3-4e39-9e39-c577de5d7a73)] +interface nsIWebSocketListener : nsISupports +{ + /** + * Called to signify the establishment of the message stream. + * + * Unlike most other networking channels (which use nsIRequestObserver + * instead of this class), we do not guarantee that OnStart is always + * called: OnStop is called without calling this function if errors occur + * during connection setup. If the websocket connection is successful, + * OnStart will be called before any other calls to this API. + * + * @param aContext user defined context + */ + [must_use] void onStart(in nsISupports aContext); + + /** + * Called to signify the completion of the message stream. + * OnStop is the final notification the listener will receive and it + * completes the WebSocket connection: after it returns the + * nsIWebSocketChannel will release its reference to the listener. + * + * Note: this event can be received in error cases even if + * nsIWebSocketChannel::Close() has not been called. + * + * @param aContext user defined context + * @param aStatusCode reason for stopping (NS_OK if completed successfully) + */ + [must_use] void onStop(in nsISupports aContext, + in nsresult aStatusCode); + + /** + * Called to deliver text message. + * + * @param aContext user defined context + * @param aMsg the message data + */ + [must_use] void onMessageAvailable(in nsISupports aContext, + in AUTF8String aMsg); + + /** + * Called to deliver binary message. + * + * @param aContext user defined context + * @param aMsg the message data + */ + [must_use] void onBinaryMessageAvailable(in nsISupports aContext, + in ACString aMsg); + + /** + * Called to acknowledge message sent via sendMsg() or sendBinaryMsg. + * + * @param aContext user defined context + * @param aSize number of bytes placed in OS send buffer + */ + [must_use] void onAcknowledge(in nsISupports aContext, in uint32_t aSize); + + /** + * Called to inform receipt of WebSocket Close message from server. + * In the case of errors onStop() can be called without ever + * receiving server close. + * + * No additional messages through onMessageAvailable(), + * onBinaryMessageAvailable() or onAcknowledge() will be delievered + * to the listener after onServerClose(), though outgoing messages can still + * be sent through the nsIWebSocketChannel connection. + * + * @param aContext user defined context + * @param aCode the websocket closing handshake close code. + * @param aReason the websocket closing handshake close reason + */ + [must_use] void onServerClose(in nsISupports aContext, + in unsigned short aCode, + in AUTF8String aReason); + + /** + * Called to inform an error is happened. The connection will be closed + * when this is called. + */ + [must_use] void OnError(); + +}; |