diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 14:29:10 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 14:29:10 +0000 |
commit | 2aa4a82499d4becd2284cdb482213d541b8804dd (patch) | |
tree | b80bf8bf13c3766139fbacc530efd0dd9d54394c /toolkit/components/extensions/webrequest | |
parent | Initial commit. (diff) | |
download | firefox-upstream.tar.xz firefox-upstream.zip |
Adding upstream version 86.0.1.upstream/86.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'toolkit/components/extensions/webrequest')
18 files changed, 6018 insertions, 0 deletions
diff --git a/toolkit/components/extensions/webrequest/ChannelWrapper.cpp b/toolkit/components/extensions/webrequest/ChannelWrapper.cpp new file mode 100644 index 0000000000..f313a3aec8 --- /dev/null +++ b/toolkit/components/extensions/webrequest/ChannelWrapper.cpp @@ -0,0 +1,1205 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ChannelWrapper.h" + +#include "jsapi.h" +#include "xpcpublic.h" + +#include "mozilla/BasePrincipal.h" +#include "mozilla/SystemPrincipal.h" + +#include "NSSErrorsService.h" +#include "nsITransportSecurityInfo.h" + +#include "mozilla/AddonManagerWebAPI.h" +#include "mozilla/ClearOnShutdown.h" +#include "mozilla/ErrorNames.h" +#include "mozilla/ResultExtensions.h" +#include "mozilla/Unused.h" +#include "mozilla/dom/Element.h" +#include "mozilla/dom/Event.h" +#include "mozilla/dom/EventBinding.h" +#include "mozilla/dom/BrowserHost.h" +#include "mozIThirdPartyUtil.h" +#include "nsContentUtils.h" +#include "nsIContentPolicy.h" +#include "nsIClassifiedChannel.h" +#include "nsIHttpChannelInternal.h" +#include "nsIHttpHeaderVisitor.h" +#include "nsIInterfaceRequestor.h" +#include "nsIInterfaceRequestorUtils.h" +#include "nsILoadContext.h" +#include "nsIProxiedChannel.h" +#include "nsIProxyInfo.h" +#include "nsITraceableChannel.h" +#include "nsIWritablePropertyBag.h" +#include "nsIWritablePropertyBag2.h" +#include "nsNetUtil.h" +#include "nsProxyRelease.h" +#include "nsPrintfCString.h" + +using namespace mozilla::dom; +using namespace JS; + +namespace mozilla { +namespace extensions { + +#define CHANNELWRAPPER_PROP_KEY u"ChannelWrapper::CachedInstance"_ns + +using CF = nsIClassifiedChannel::ClassificationFlags; +using MUC = MozUrlClassificationFlags; + +struct ClassificationStruct { + uint32_t mFlag; + MozUrlClassificationFlags mValue; +}; +static const ClassificationStruct classificationArray[] = { + {CF::CLASSIFIED_FINGERPRINTING, MUC::Fingerprinting}, + {CF::CLASSIFIED_FINGERPRINTING_CONTENT, MUC::Fingerprinting_content}, + {CF::CLASSIFIED_CRYPTOMINING, MUC::Cryptomining}, + {CF::CLASSIFIED_CRYPTOMINING_CONTENT, MUC::Cryptomining_content}, + {CF::CLASSIFIED_TRACKING, MUC::Tracking}, + {CF::CLASSIFIED_TRACKING_AD, MUC::Tracking_ad}, + {CF::CLASSIFIED_TRACKING_ANALYTICS, MUC::Tracking_analytics}, + {CF::CLASSIFIED_TRACKING_SOCIAL, MUC::Tracking_social}, + {CF::CLASSIFIED_TRACKING_CONTENT, MUC::Tracking_content}, + {CF::CLASSIFIED_SOCIALTRACKING, MUC::Socialtracking}, + {CF::CLASSIFIED_SOCIALTRACKING_FACEBOOK, MUC::Socialtracking_facebook}, + {CF::CLASSIFIED_SOCIALTRACKING_LINKEDIN, MUC::Socialtracking_linkedin}, + {CF::CLASSIFIED_SOCIALTRACKING_TWITTER, MUC::Socialtracking_twitter}, + {CF::CLASSIFIED_ANY_BASIC_TRACKING, MUC::Any_basic_tracking}, + {CF::CLASSIFIED_ANY_STRICT_TRACKING, MUC::Any_strict_tracking}, + {CF::CLASSIFIED_ANY_SOCIAL_TRACKING, MUC::Any_social_tracking}}; + +/***************************************************************************** + * Lifetimes + *****************************************************************************/ + +namespace { +class ChannelListHolder : public LinkedList<ChannelWrapper> { + public: + ChannelListHolder() : LinkedList<ChannelWrapper>() {} + + ~ChannelListHolder(); +}; + +} // anonymous namespace + +ChannelListHolder::~ChannelListHolder() { + while (ChannelWrapper* wrapper = popFirst()) { + wrapper->Die(); + } +} + +static LinkedList<ChannelWrapper>& ChannelList() { + static UniquePtr<ChannelListHolder> sChannelList; + if (!sChannelList) { + sChannelList.reset(new ChannelListHolder()); + ClearOnShutdown(&sChannelList, ShutdownPhase::Shutdown); + } + return *sChannelList; +} + +NS_IMPL_CYCLE_COLLECTING_ADDREF(ChannelWrapper::ChannelWrapperStub) +NS_IMPL_CYCLE_COLLECTING_RELEASE(ChannelWrapper::ChannelWrapperStub) + +NS_IMPL_CYCLE_COLLECTION(ChannelWrapper::ChannelWrapperStub, mChannelWrapper) + +NS_INTERFACE_MAP_BEGIN_CYCLE_COLLECTION(ChannelWrapper::ChannelWrapperStub) + NS_INTERFACE_MAP_ENTRY_TEAROFF(ChannelWrapper, mChannelWrapper) + NS_INTERFACE_MAP_ENTRY(nsISupports) +NS_INTERFACE_MAP_END + +/***************************************************************************** + * Initialization + *****************************************************************************/ + +ChannelWrapper::ChannelWrapper(nsISupports* aParent, nsIChannel* aChannel) + : ChannelHolder(aChannel), mParent(aParent) { + mStub = new ChannelWrapperStub(this); + + ChannelList().insertBack(this); +} + +ChannelWrapper::~ChannelWrapper() { + if (LinkedListElement<ChannelWrapper>::isInList()) { + LinkedListElement<ChannelWrapper>::remove(); + } +} + +void ChannelWrapper::Die() { + if (mStub) { + mStub->mChannelWrapper = nullptr; + } +} + +/* static */ +already_AddRefed<ChannelWrapper> ChannelWrapper::Get(const GlobalObject& global, + nsIChannel* channel) { + RefPtr<ChannelWrapper> wrapper; + + nsCOMPtr<nsIWritablePropertyBag2> props = do_QueryInterface(channel); + if (props) { + Unused << props->GetPropertyAsInterface(CHANNELWRAPPER_PROP_KEY, + NS_GET_IID(ChannelWrapper), + getter_AddRefs(wrapper)); + + if (wrapper) { + // Assume cached attributes may have changed at this point. + wrapper->ClearCachedAttributes(); + } + } + + if (!wrapper) { + wrapper = new ChannelWrapper(global.GetAsSupports(), channel); + if (props) { + Unused << props->SetPropertyAsInterface(CHANNELWRAPPER_PROP_KEY, + wrapper->mStub); + } + } + + return wrapper.forget(); +} + +already_AddRefed<ChannelWrapper> ChannelWrapper::GetRegisteredChannel( + const GlobalObject& global, uint64_t aChannelId, + const WebExtensionPolicy& aAddon, nsIRemoteTab* aRemoteTab) { + ContentParent* contentParent = nullptr; + if (BrowserHost* host = BrowserHost::GetFrom(aRemoteTab)) { + contentParent = host->GetActor()->Manager(); + } + + auto& webreq = WebRequestService::GetSingleton(); + + nsCOMPtr<nsITraceableChannel> channel = + webreq.GetTraceableChannel(aChannelId, aAddon.Id(), contentParent); + if (!channel) { + return nullptr; + } + nsCOMPtr<nsIChannel> chan(do_QueryInterface(channel)); + return ChannelWrapper::Get(global, chan); +} + +void ChannelWrapper::SetChannel(nsIChannel* aChannel) { + detail::ChannelHolder::SetChannel(aChannel); + ClearCachedAttributes(); + ChannelWrapper_Binding::ClearCachedFinalURIValue(this); + ChannelWrapper_Binding::ClearCachedFinalURLValue(this); + mFinalURLInfo.reset(); + ChannelWrapper_Binding::ClearCachedProxyInfoValue(this); +} + +void ChannelWrapper::ClearCachedAttributes() { + ChannelWrapper_Binding::ClearCachedRemoteAddressValue(this); + ChannelWrapper_Binding::ClearCachedStatusCodeValue(this); + ChannelWrapper_Binding::ClearCachedStatusLineValue(this); + ChannelWrapper_Binding::ClearCachedUrlClassificationValue(this); + if (!mFiredErrorEvent) { + ChannelWrapper_Binding::ClearCachedErrorStringValue(this); + } + + ChannelWrapper_Binding::ClearCachedRequestSizeValue(this); + ChannelWrapper_Binding::ClearCachedResponseSizeValue(this); +} + +/***************************************************************************** + * ... + *****************************************************************************/ + +void ChannelWrapper::Cancel(uint32_t aResult, uint32_t aReason, + ErrorResult& aRv) { + nsresult rv = NS_ERROR_UNEXPECTED; + if (nsCOMPtr<nsIChannel> chan = MaybeChannel()) { + nsCOMPtr<nsILoadInfo> loadInfo = GetLoadInfo(); + if (aReason > 0 && loadInfo) { + loadInfo->SetRequestBlockingReason(aReason); + } + rv = chan->Cancel(nsresult(aResult)); + ErrorCheck(); + } + if (NS_FAILED(rv)) { + aRv.Throw(rv); + } +} + +void ChannelWrapper::RedirectTo(nsIURI* aURI, ErrorResult& aRv) { + nsresult rv = NS_ERROR_UNEXPECTED; + if (nsCOMPtr<nsIHttpChannel> chan = MaybeHttpChannel()) { + rv = chan->RedirectTo(aURI); + } + if (NS_FAILED(rv)) { + aRv.Throw(rv); + } +} + +void ChannelWrapper::UpgradeToSecure(ErrorResult& aRv) { + nsresult rv = NS_ERROR_UNEXPECTED; + if (nsCOMPtr<nsIHttpChannel> chan = MaybeHttpChannel()) { + rv = chan->UpgradeToSecure(); + } + if (NS_FAILED(rv)) { + aRv.Throw(rv); + } +} + +void ChannelWrapper::Suspend(ErrorResult& aRv) { + if (!mSuspended) { + nsresult rv = NS_ERROR_UNEXPECTED; + if (nsCOMPtr<nsIChannel> chan = MaybeChannel()) { + mSuspendTime = mozilla::TimeStamp::NowUnfuzzed(); + rv = chan->Suspend(); + } + if (NS_FAILED(rv)) { + aRv.Throw(rv); + } else { + mSuspended = true; + } + } +} + +void ChannelWrapper::Resume(const nsCString& aText, ErrorResult& aRv) { + if (mSuspended) { + nsresult rv = NS_ERROR_UNEXPECTED; + if (nsCOMPtr<nsIChannel> chan = MaybeChannel()) { + rv = chan->Resume(); + + PROFILER_MARKER_TEXT("Extension Suspend", NETWORK, + MarkerTiming::IntervalUntilNowFrom(mSuspendTime), + aText); + } + if (NS_FAILED(rv)) { + aRv.Throw(rv); + } else { + mSuspended = false; + } + } +} + +void ChannelWrapper::GetContentType(nsCString& aContentType) const { + if (nsCOMPtr<nsIHttpChannel> chan = MaybeHttpChannel()) { + Unused << chan->GetContentType(aContentType); + } +} + +void ChannelWrapper::SetContentType(const nsACString& aContentType) { + if (nsCOMPtr<nsIHttpChannel> chan = MaybeHttpChannel()) { + Unused << chan->SetContentType(aContentType); + } +} + +/***************************************************************************** + * Headers + *****************************************************************************/ + +namespace { + +class MOZ_STACK_CLASS HeaderVisitor final : public nsIHttpHeaderVisitor { + public: + NS_DECL_NSIHTTPHEADERVISITOR + + explicit HeaderVisitor(nsTArray<dom::MozHTTPHeader>& aHeaders) + : mHeaders(aHeaders) {} + + HeaderVisitor(nsTArray<dom::MozHTTPHeader>& aHeaders, + const nsCString& aContentTypeHdr) + : mHeaders(aHeaders), mContentTypeHdr(aContentTypeHdr) {} + + void VisitRequestHeaders(nsIHttpChannel* aChannel, ErrorResult& aRv) { + CheckResult(aChannel->VisitRequestHeaders(this), aRv); + } + + void VisitResponseHeaders(nsIHttpChannel* aChannel, ErrorResult& aRv) { + CheckResult(aChannel->VisitResponseHeaders(this), aRv); + } + + NS_IMETHOD QueryInterface(REFNSIID aIID, void** aInstancePtr) override; + + // Stub AddRef/Release since this is a stack class. + NS_IMETHOD_(MozExternalRefCountType) AddRef(void) override { + return ++mRefCnt; + } + + NS_IMETHOD_(MozExternalRefCountType) Release(void) override { + return --mRefCnt; + } + + virtual ~HeaderVisitor() { MOZ_DIAGNOSTIC_ASSERT(mRefCnt == 0); } + + private: + bool CheckResult(nsresult aNSRv, ErrorResult& aRv) { + if (NS_FAILED(aNSRv)) { + aRv.Throw(aNSRv); + return false; + } + return true; + } + + nsTArray<dom::MozHTTPHeader>& mHeaders; + nsCString mContentTypeHdr = VoidCString(); + + nsrefcnt mRefCnt = 0; +}; + +NS_IMETHODIMP +HeaderVisitor::VisitHeader(const nsACString& aHeader, + const nsACString& aValue) { + auto dict = mHeaders.AppendElement(fallible); + if (!dict) { + return NS_ERROR_OUT_OF_MEMORY; + } + dict->mName = aHeader; + + if (!mContentTypeHdr.IsVoid() && + aHeader.LowerCaseEqualsLiteral("content-type")) { + dict->mValue = mContentTypeHdr; + } else { + dict->mValue = aValue; + } + + return NS_OK; +} + +NS_IMPL_QUERY_INTERFACE(HeaderVisitor, nsIHttpHeaderVisitor) + +} // anonymous namespace + +void ChannelWrapper::GetRequestHeaders(nsTArray<dom::MozHTTPHeader>& aRetVal, + ErrorResult& aRv) const { + if (nsCOMPtr<nsIHttpChannel> chan = MaybeHttpChannel()) { + HeaderVisitor visitor(aRetVal); + visitor.VisitRequestHeaders(chan, aRv); + } else { + aRv.Throw(NS_ERROR_UNEXPECTED); + } +} + +void ChannelWrapper::GetRequestHeader(const nsCString& aHeader, + nsCString& aResult, + ErrorResult& aRv) const { + aResult.SetIsVoid(true); + if (nsCOMPtr<nsIHttpChannel> chan = MaybeHttpChannel()) { + Unused << chan->GetRequestHeader(aHeader, aResult); + } else { + aRv.Throw(NS_ERROR_UNEXPECTED); + } +} + +void ChannelWrapper::GetResponseHeaders(nsTArray<dom::MozHTTPHeader>& aRetVal, + ErrorResult& aRv) const { + if (nsCOMPtr<nsIHttpChannel> chan = MaybeHttpChannel()) { + HeaderVisitor visitor(aRetVal, mContentTypeHdr); + visitor.VisitResponseHeaders(chan, aRv); + } else { + aRv.Throw(NS_ERROR_UNEXPECTED); + } +} + +void ChannelWrapper::SetRequestHeader(const nsCString& aHeader, + const nsCString& aValue, bool aMerge, + ErrorResult& aRv) { + nsresult rv = NS_ERROR_UNEXPECTED; + if (nsCOMPtr<nsIHttpChannel> chan = MaybeHttpChannel()) { + rv = chan->SetRequestHeader(aHeader, aValue, aMerge); + } + if (NS_FAILED(rv)) { + aRv.Throw(rv); + } +} + +void ChannelWrapper::SetResponseHeader(const nsCString& aHeader, + const nsCString& aValue, bool aMerge, + ErrorResult& aRv) { + nsresult rv = NS_ERROR_UNEXPECTED; + if (nsCOMPtr<nsIHttpChannel> chan = MaybeHttpChannel()) { + if (aHeader.LowerCaseEqualsLiteral("content-type")) { + rv = chan->SetContentType(aValue); + if (NS_SUCCEEDED(rv)) { + mContentTypeHdr = aValue; + } + } else { + rv = chan->SetResponseHeader(aHeader, aValue, aMerge); + } + } + if (NS_FAILED(rv)) { + aRv.Throw(rv); + } +} + +/***************************************************************************** + * LoadInfo + *****************************************************************************/ + +already_AddRefed<nsILoadContext> ChannelWrapper::GetLoadContext() const { + if (nsCOMPtr<nsIChannel> chan = MaybeChannel()) { + nsCOMPtr<nsILoadContext> ctxt; + NS_QueryNotificationCallbacks(chan, ctxt); + return ctxt.forget(); + } + return nullptr; +} + +already_AddRefed<Element> ChannelWrapper::GetBrowserElement() const { + if (nsCOMPtr<nsILoadContext> ctxt = GetLoadContext()) { + RefPtr<Element> elem; + if (NS_SUCCEEDED(ctxt->GetTopFrameElement(getter_AddRefs(elem)))) { + return elem.forget(); + } + } + return nullptr; +} + +static inline bool IsSystemPrincipal(nsIPrincipal* aPrincipal) { + return BasePrincipal::Cast(aPrincipal)->Is<SystemPrincipal>(); +} + +bool ChannelWrapper::IsSystemLoad() const { + if (nsCOMPtr<nsILoadInfo> loadInfo = GetLoadInfo()) { + if (nsIPrincipal* prin = loadInfo->GetLoadingPrincipal()) { + return IsSystemPrincipal(prin); + } + + if (RefPtr<BrowsingContext> bc = loadInfo->GetBrowsingContext(); + !bc || bc->IsTop()) { + return false; + } + + if (nsIPrincipal* prin = loadInfo->PrincipalToInherit()) { + return IsSystemPrincipal(prin); + } + if (nsIPrincipal* prin = loadInfo->TriggeringPrincipal()) { + return IsSystemPrincipal(prin); + } + } + return false; +} + +bool ChannelWrapper::CanModify() const { + if (WebExtensionPolicy::IsRestrictedURI(FinalURLInfo())) { + return false; + } + + if (nsCOMPtr<nsILoadInfo> loadInfo = GetLoadInfo()) { + if (nsIPrincipal* prin = loadInfo->GetLoadingPrincipal()) { + if (IsSystemPrincipal(prin)) { + return false; + } + + auto* docURI = DocumentURLInfo(); + if (docURI && WebExtensionPolicy::IsRestrictedURI(*docURI)) { + return false; + } + } + } + return true; +} + +already_AddRefed<nsIURI> ChannelWrapper::GetOriginURI() const { + nsCOMPtr<nsIURI> uri; + if (nsCOMPtr<nsILoadInfo> loadInfo = GetLoadInfo()) { + if (nsIPrincipal* prin = loadInfo->TriggeringPrincipal()) { + if (prin->GetIsContentPrincipal()) { + auto* basePrin = BasePrincipal::Cast(prin); + Unused << basePrin->GetURI(getter_AddRefs(uri)); + } + } + } + return uri.forget(); +} + +already_AddRefed<nsIURI> ChannelWrapper::GetDocumentURI() const { + nsCOMPtr<nsIURI> uri; + if (nsCOMPtr<nsILoadInfo> loadInfo = GetLoadInfo()) { + if (nsIPrincipal* prin = loadInfo->GetLoadingPrincipal()) { + if (prin->GetIsContentPrincipal()) { + auto* basePrin = BasePrincipal::Cast(prin); + Unused << basePrin->GetURI(getter_AddRefs(uri)); + } + } + } + return uri.forget(); +} + +void ChannelWrapper::GetOriginURL(nsCString& aRetVal) const { + if (nsCOMPtr<nsIURI> uri = GetOriginURI()) { + Unused << uri->GetSpec(aRetVal); + } +} + +void ChannelWrapper::GetDocumentURL(nsCString& aRetVal) const { + if (nsCOMPtr<nsIURI> uri = GetDocumentURI()) { + Unused << uri->GetSpec(aRetVal); + } +} + +const URLInfo& ChannelWrapper::FinalURLInfo() const { + if (mFinalURLInfo.isNothing()) { + ErrorResult rv; + nsCOMPtr<nsIURI> uri = FinalURI(); + MOZ_ASSERT(uri); + + // If this is a view-source scheme, get the nested uri. + while (uri && uri->SchemeIs("view-source")) { + nsCOMPtr<nsINestedURI> nested = do_QueryInterface(uri); + if (!nested) { + break; + } + nested->GetInnerURI(getter_AddRefs(uri)); + } + mFinalURLInfo.emplace(uri.get(), true); + + // If this is a WebSocket request, mangle the URL so that the scheme is + // ws: or wss:, as appropriate. + auto& url = mFinalURLInfo.ref(); + if (Type() == MozContentPolicyType::Websocket && + (url.Scheme() == nsGkAtoms::http || url.Scheme() == nsGkAtoms::https)) { + nsAutoCString spec(url.CSpec()); + spec.Replace(0, 4, "ws"_ns); + + Unused << NS_NewURI(getter_AddRefs(uri), spec); + MOZ_RELEASE_ASSERT(uri); + mFinalURLInfo.reset(); + mFinalURLInfo.emplace(uri.get(), true); + } + } + return mFinalURLInfo.ref(); +} + +const URLInfo* ChannelWrapper::DocumentURLInfo() const { + if (mDocumentURLInfo.isNothing()) { + nsCOMPtr<nsIURI> uri = GetDocumentURI(); + if (!uri) { + return nullptr; + } + mDocumentURLInfo.emplace(uri.get(), true); + } + return &mDocumentURLInfo.ref(); +} + +bool ChannelWrapper::Matches( + const dom::MozRequestFilter& aFilter, const WebExtensionPolicy* aExtension, + const dom::MozRequestMatchOptions& aOptions) const { + if (!HaveChannel()) { + return false; + } + + if (!aFilter.mTypes.IsNull() && !aFilter.mTypes.Value().Contains(Type())) { + return false; + } + + auto& urlInfo = FinalURLInfo(); + if (aFilter.mUrls && !aFilter.mUrls->Matches(urlInfo)) { + return false; + } + + nsCOMPtr<nsILoadInfo> loadInfo = GetLoadInfo(); + bool isPrivate = + loadInfo && loadInfo->GetOriginAttributes().mPrivateBrowsingId > 0; + if (!aFilter.mIncognito.IsNull() && aFilter.mIncognito.Value() != isPrivate) { + return false; + } + + if (aExtension) { + // Verify extension access to private requests + if (isPrivate && !aExtension->PrivateBrowsingAllowed()) { + return false; + } + + bool isProxy = + aOptions.mIsProxy && aExtension->HasPermission(nsGkAtoms::proxy); + // Proxies are allowed access to all urls, including restricted urls. + if (!aExtension->CanAccessURI(urlInfo, false, !isProxy, true)) { + return false; + } + + // If this isn't the proxy phase of the request, check that the extension + // has origin permissions for origin that originated the request. + if (!isProxy) { + if (IsSystemLoad()) { + return false; + } + + auto origin = DocumentURLInfo(); + // Extensions with the file:-permission may observe requests from file: + // origins, because such documents can already be modified by content + // scripts anyway. + if (origin && !aExtension->CanAccessURI(*origin, false, true, true)) { + return false; + } + } + } + + return true; +} + +int64_t NormalizeFrameID(nsILoadInfo* aLoadInfo, uint64_t bcID) { + if (RefPtr<BrowsingContext> bc = aLoadInfo->GetBrowsingContext(); + !bc || bcID == bc->Top()->Id()) { + return 0; + } + return bcID; +} + +uint64_t ChannelWrapper::BrowsingContextId(nsILoadInfo* aLoadInfo) const { + auto frameID = aLoadInfo->GetFrameBrowsingContextID(); + if (!frameID) { + frameID = aLoadInfo->GetBrowsingContextID(); + } + return frameID; +} + +int64_t ChannelWrapper::FrameId() const { + if (nsCOMPtr<nsILoadInfo> loadInfo = GetLoadInfo()) { + return NormalizeFrameID(loadInfo, BrowsingContextId(loadInfo)); + } + return 0; +} + +int64_t ChannelWrapper::ParentFrameId() const { + if (nsCOMPtr<nsILoadInfo> loadInfo = GetLoadInfo()) { + if (RefPtr<BrowsingContext> bc = loadInfo->GetBrowsingContext()) { + if (BrowsingContextId(loadInfo) == bc->Top()->Id()) { + return -1; + } + + uint64_t parentID = -1; + if (loadInfo->GetFrameBrowsingContextID()) { + parentID = loadInfo->GetBrowsingContextID(); + } else if (bc->GetParent()) { + parentID = bc->GetParent()->Id(); + } + return NormalizeFrameID(loadInfo, parentID); + } + } + return -1; +} + +void ChannelWrapper::GetFrameAncestors( + dom::Nullable<nsTArray<dom::MozFrameAncestorInfo>>& aFrameAncestors, + ErrorResult& aRv) const { + nsCOMPtr<nsILoadInfo> loadInfo = GetLoadInfo(); + if (!loadInfo || BrowsingContextId(loadInfo) == 0) { + aFrameAncestors.SetNull(); + return; + } + + nsresult rv = GetFrameAncestors(loadInfo, aFrameAncestors.SetValue()); + if (NS_FAILED(rv)) { + aRv.Throw(rv); + } +} + +nsresult ChannelWrapper::GetFrameAncestors( + nsILoadInfo* aLoadInfo, + nsTArray<dom::MozFrameAncestorInfo>& aFrameAncestors) const { + const nsTArray<nsCOMPtr<nsIPrincipal>>& ancestorPrincipals = + aLoadInfo->AncestorPrincipals(); + const nsTArray<uint64_t>& ancestorBrowsingContextIDs = + aLoadInfo->AncestorBrowsingContextIDs(); + uint32_t size = ancestorPrincipals.Length(); + MOZ_DIAGNOSTIC_ASSERT(size == ancestorBrowsingContextIDs.Length()); + if (size != ancestorBrowsingContextIDs.Length()) { + return NS_ERROR_UNEXPECTED; + } + + bool subFrame = aLoadInfo->GetExternalContentPolicyType() == + ExtContentPolicy::TYPE_SUBDOCUMENT; + if (!aFrameAncestors.SetCapacity(subFrame ? size : size + 1, fallible)) { + return NS_ERROR_OUT_OF_MEMORY; + } + + // The immediate parent is always the first element in the ancestor arrays, + // however SUBDOCUMENTs do not have their immediate parent included, so we + // inject it here. This will force wrapper.parentBrowsingContextId == + // wrapper.frameAncestors[0].frameId to always be true. All ather requests + // already match this way. + if (subFrame) { + auto ancestor = aFrameAncestors.AppendElement(); + GetDocumentURL(ancestor->mUrl); + ancestor->mFrameId = ParentFrameId(); + } + + for (uint32_t i = 0; i < size; ++i) { + auto ancestor = aFrameAncestors.AppendElement(); + MOZ_TRY(ancestorPrincipals[i]->GetAsciiSpec(ancestor->mUrl)); + ancestor->mFrameId = + NormalizeFrameID(aLoadInfo, ancestorBrowsingContextIDs[i]); + } + return NS_OK; +} + +/***************************************************************************** + * Response filtering + *****************************************************************************/ + +void ChannelWrapper::RegisterTraceableChannel(const WebExtensionPolicy& aAddon, + nsIRemoteTab* aBrowserParent) { + // We can't attach new listeners after the response has started, so don't + // bother registering anything. + if (mResponseStarted || !CanModify()) { + return; + } + + mAddonEntries.Put(aAddon.Id(), aBrowserParent); + if (!mChannelEntry) { + mChannelEntry = WebRequestService::GetSingleton().RegisterChannel(this); + CheckEventListeners(); + } +} + +already_AddRefed<nsITraceableChannel> ChannelWrapper::GetTraceableChannel( + nsAtom* aAddonId, dom::ContentParent* aContentParent) const { + nsCOMPtr<nsIRemoteTab> remoteTab; + if (mAddonEntries.Get(aAddonId, getter_AddRefs(remoteTab))) { + ContentParent* contentParent = nullptr; + if (remoteTab) { + contentParent = + BrowserHost::GetFrom(remoteTab.get())->GetActor()->Manager(); + } + + if (contentParent == aContentParent) { + nsCOMPtr<nsITraceableChannel> chan = QueryChannel(); + return chan.forget(); + } + } + return nullptr; +} + +/***************************************************************************** + * ... + *****************************************************************************/ + +MozContentPolicyType GetContentPolicyType(ExtContentPolicyType aType) { + // Note: Please keep this function in sync with the external types in + // nsIContentPolicy.idl + switch (aType) { + case ExtContentPolicy::TYPE_DOCUMENT: + return MozContentPolicyType::Main_frame; + case ExtContentPolicy::TYPE_SUBDOCUMENT: + return MozContentPolicyType::Sub_frame; + case ExtContentPolicy::TYPE_STYLESHEET: + return MozContentPolicyType::Stylesheet; + case ExtContentPolicy::TYPE_SCRIPT: + return MozContentPolicyType::Script; + case ExtContentPolicy::TYPE_IMAGE: + return MozContentPolicyType::Image; + case ExtContentPolicy::TYPE_OBJECT: + return MozContentPolicyType::Object; + case ExtContentPolicy::TYPE_OBJECT_SUBREQUEST: + return MozContentPolicyType::Object_subrequest; + case ExtContentPolicy::TYPE_XMLHTTPREQUEST: + return MozContentPolicyType::Xmlhttprequest; + // TYPE_FETCH returns xmlhttprequest for cross-browser compatibility. + case ExtContentPolicy::TYPE_FETCH: + return MozContentPolicyType::Xmlhttprequest; + case ExtContentPolicy::TYPE_XSLT: + return MozContentPolicyType::Xslt; + case ExtContentPolicy::TYPE_PING: + return MozContentPolicyType::Ping; + case ExtContentPolicy::TYPE_BEACON: + return MozContentPolicyType::Beacon; + case ExtContentPolicy::TYPE_DTD: + return MozContentPolicyType::Xml_dtd; + case ExtContentPolicy::TYPE_FONT: + return MozContentPolicyType::Font; + case ExtContentPolicy::TYPE_MEDIA: + return MozContentPolicyType::Media; + case ExtContentPolicy::TYPE_WEBSOCKET: + return MozContentPolicyType::Websocket; + case ExtContentPolicy::TYPE_CSP_REPORT: + return MozContentPolicyType::Csp_report; + case ExtContentPolicy::TYPE_IMAGESET: + return MozContentPolicyType::Imageset; + case ExtContentPolicy::TYPE_WEB_MANIFEST: + return MozContentPolicyType::Web_manifest; + case ExtContentPolicy::TYPE_SPECULATIVE: + return MozContentPolicyType::Speculative; + case ExtContentPolicy::TYPE_INVALID: + case ExtContentPolicy::TYPE_OTHER: + case ExtContentPolicy::TYPE_SAVEAS_DOWNLOAD: + break; + // Do not add default: so that compilers can catch the missing case. + } + return MozContentPolicyType::Other; +} + +MozContentPolicyType ChannelWrapper::Type() const { + if (nsCOMPtr<nsILoadInfo> loadInfo = GetLoadInfo()) { + return GetContentPolicyType(loadInfo->GetExternalContentPolicyType()); + } + return MozContentPolicyType::Other; +} + +void ChannelWrapper::GetMethod(nsCString& aMethod) const { + if (nsCOMPtr<nsIHttpChannel> chan = MaybeHttpChannel()) { + Unused << chan->GetRequestMethod(aMethod); + } +} + +/***************************************************************************** + * ... + *****************************************************************************/ + +uint32_t ChannelWrapper::StatusCode() const { + uint32_t result = 0; + if (nsCOMPtr<nsIHttpChannel> chan = MaybeHttpChannel()) { + Unused << chan->GetResponseStatus(&result); + } + return result; +} + +void ChannelWrapper::GetStatusLine(nsCString& aRetVal) const { + nsCOMPtr<nsIHttpChannel> chan = MaybeHttpChannel(); + nsCOMPtr<nsIHttpChannelInternal> internal = do_QueryInterface(chan); + + if (internal) { + nsAutoCString statusText; + uint32_t major, minor, status; + if (NS_FAILED(chan->GetResponseStatus(&status)) || + NS_FAILED(chan->GetResponseStatusText(statusText)) || + NS_FAILED(internal->GetResponseVersion(&major, &minor))) { + return; + } + + aRetVal = nsPrintfCString("HTTP/%u.%u %u %s", major, minor, status, + statusText.get()); + } +} + +uint64_t ChannelWrapper::ResponseSize() const { + uint64_t result = 0; + if (nsCOMPtr<nsIHttpChannel> chan = MaybeHttpChannel()) { + Unused << chan->GetTransferSize(&result); + } + return result; +} + +uint64_t ChannelWrapper::RequestSize() const { + uint64_t result = 0; + if (nsCOMPtr<nsIHttpChannel> chan = MaybeHttpChannel()) { + Unused << chan->GetRequestSize(&result); + } + return result; +} + +/***************************************************************************** + * ... + *****************************************************************************/ + +already_AddRefed<nsIURI> ChannelWrapper::FinalURI() const { + nsCOMPtr<nsIURI> uri; + if (nsCOMPtr<nsIChannel> chan = MaybeChannel()) { + NS_GetFinalChannelURI(chan, getter_AddRefs(uri)); + } + return uri.forget(); +} + +void ChannelWrapper::GetFinalURL(nsString& aRetVal) const { + if (HaveChannel()) { + aRetVal = FinalURLInfo().Spec(); + } +} + +/***************************************************************************** + * ... + *****************************************************************************/ + +nsresult FillProxyInfo(MozProxyInfo& aDict, nsIProxyInfo* aProxyInfo) { + MOZ_TRY(aProxyInfo->GetHost(aDict.mHost)); + MOZ_TRY(aProxyInfo->GetPort(&aDict.mPort)); + MOZ_TRY(aProxyInfo->GetType(aDict.mType)); + MOZ_TRY(aProxyInfo->GetUsername(aDict.mUsername)); + MOZ_TRY( + aProxyInfo->GetProxyAuthorizationHeader(aDict.mProxyAuthorizationHeader)); + MOZ_TRY(aProxyInfo->GetConnectionIsolationKey(aDict.mConnectionIsolationKey)); + MOZ_TRY(aProxyInfo->GetFailoverTimeout(&aDict.mFailoverTimeout.Construct())); + + uint32_t flags; + MOZ_TRY(aProxyInfo->GetFlags(&flags)); + aDict.mProxyDNS = flags & nsIProxyInfo::TRANSPARENT_PROXY_RESOLVES_HOST; + + return NS_OK; +} + +void ChannelWrapper::GetProxyInfo(dom::Nullable<MozProxyInfo>& aRetVal, + ErrorResult& aRv) const { + nsCOMPtr<nsIProxyInfo> proxyInfo; + if (nsCOMPtr<nsIProxiedChannel> proxied = QueryChannel()) { + Unused << proxied->GetProxyInfo(getter_AddRefs(proxyInfo)); + } + if (proxyInfo) { + MozProxyInfo result; + + nsresult rv = FillProxyInfo(result, proxyInfo); + if (NS_FAILED(rv)) { + aRv.Throw(rv); + } else { + aRetVal.SetValue(std::move(result)); + } + } +} + +void ChannelWrapper::GetRemoteAddress(nsCString& aRetVal) const { + aRetVal.SetIsVoid(true); + if (nsCOMPtr<nsIHttpChannelInternal> internal = QueryChannel()) { + Unused << internal->GetRemoteAddress(aRetVal); + } +} + +void FillClassification( + Sequence<mozilla::dom::MozUrlClassificationFlags>& classifications, + uint32_t classificationFlags, ErrorResult& aRv) { + if (classificationFlags == 0) { + return; + } + for (const auto& entry : classificationArray) { + if (classificationFlags & entry.mFlag) { + if (!classifications.AppendElement(entry.mValue, mozilla::fallible)) { + aRv.Throw(NS_ERROR_OUT_OF_MEMORY); + return; + } + } + } +} + +void ChannelWrapper::GetUrlClassification( + dom::Nullable<dom::MozUrlClassification>& aRetVal, ErrorResult& aRv) const { + MozUrlClassification classification; + if (nsCOMPtr<nsIHttpChannel> chan = MaybeHttpChannel()) { + nsCOMPtr<nsIClassifiedChannel> classified = do_QueryInterface(chan); + MOZ_DIAGNOSTIC_ASSERT( + classified, + "Must be an object inheriting from both nsIHttpChannel and " + "nsIClassifiedChannel"); + uint32_t classificationFlags; + classified->GetFirstPartyClassificationFlags(&classificationFlags); + FillClassification(classification.mFirstParty, classificationFlags, aRv); + if (aRv.Failed()) { + return; + } + classified->GetThirdPartyClassificationFlags(&classificationFlags); + FillClassification(classification.mThirdParty, classificationFlags, aRv); + } + aRetVal.SetValue(std::move(classification)); +} + +bool ChannelWrapper::ThirdParty() const { + nsCOMPtr<mozIThirdPartyUtil> thirdPartyUtil = services::GetThirdPartyUtil(); + if (NS_WARN_IF(!thirdPartyUtil)) { + return true; + } + + nsCOMPtr<nsIHttpChannel> chan = MaybeHttpChannel(); + if (!chan) { + return false; + } + + bool thirdParty = false; + nsresult rv = thirdPartyUtil->IsThirdPartyChannel(chan, nullptr, &thirdParty); + if (NS_WARN_IF(NS_FAILED(rv))) { + return true; + } + + return thirdParty; +} + +/***************************************************************************** + * Error handling + *****************************************************************************/ + +void ChannelWrapper::GetErrorString(nsString& aRetVal) const { + if (nsCOMPtr<nsIChannel> chan = MaybeChannel()) { + nsCOMPtr<nsISupports> securityInfo; + Unused << chan->GetSecurityInfo(getter_AddRefs(securityInfo)); + if (nsCOMPtr<nsITransportSecurityInfo> tsi = + do_QueryInterface(securityInfo)) { + int32_t errorCode = 0; + tsi->GetErrorCode(&errorCode); + if (psm::IsNSSErrorCode(errorCode)) { + nsCOMPtr<nsINSSErrorsService> nsserr = + do_GetService(NS_NSS_ERRORS_SERVICE_CONTRACTID); + + nsresult rv = psm::GetXPCOMFromNSSError(errorCode); + if (nsserr && NS_SUCCEEDED(nsserr->GetErrorMessage(rv, aRetVal))) { + return; + } + } + } + + nsresult status; + if (NS_SUCCEEDED(chan->GetStatus(&status)) && NS_FAILED(status)) { + nsAutoCString name; + GetErrorName(status, name); + AppendUTF8toUTF16(name, aRetVal); + } else { + aRetVal.SetIsVoid(true); + } + } else { + aRetVal.AssignLiteral("NS_ERROR_UNEXPECTED"); + } +} + +void ChannelWrapper::ErrorCheck() { + if (!mFiredErrorEvent) { + nsAutoString error; + GetErrorString(error); + if (error.Length()) { + mChannelEntry = nullptr; + mFiredErrorEvent = true; + ChannelWrapper_Binding::ClearCachedErrorStringValue(this); + FireEvent(u"error"_ns); + } + } +} + +/***************************************************************************** + * nsIWebRequestListener + *****************************************************************************/ + +NS_IMPL_ISUPPORTS(ChannelWrapper::RequestListener, nsIStreamListener, + nsIMultiPartChannelListener, nsIRequestObserver, + nsIThreadRetargetableStreamListener) + +ChannelWrapper::RequestListener::~RequestListener() { + NS_ReleaseOnMainThread("RequestListener::mChannelWrapper", + mChannelWrapper.forget()); +} + +nsresult ChannelWrapper::RequestListener::Init() { + if (nsCOMPtr<nsITraceableChannel> chan = mChannelWrapper->QueryChannel()) { + return chan->SetNewListener(this, false, + getter_AddRefs(mOrigStreamListener)); + } + return NS_ERROR_UNEXPECTED; +} + +NS_IMETHODIMP +ChannelWrapper::RequestListener::OnStartRequest(nsIRequest* request) { + MOZ_ASSERT(mOrigStreamListener, "Should have mOrigStreamListener"); + + mChannelWrapper->mChannelEntry = nullptr; + mChannelWrapper->mResponseStarted = true; + mChannelWrapper->ErrorCheck(); + mChannelWrapper->FireEvent(u"start"_ns); + + return mOrigStreamListener->OnStartRequest(request); +} + +NS_IMETHODIMP +ChannelWrapper::RequestListener::OnStopRequest(nsIRequest* request, + nsresult aStatus) { + MOZ_ASSERT(mOrigStreamListener, "Should have mOrigStreamListener"); + + mChannelWrapper->mChannelEntry = nullptr; + mChannelWrapper->ErrorCheck(); + mChannelWrapper->FireEvent(u"stop"_ns); + + return mOrigStreamListener->OnStopRequest(request, aStatus); +} + +NS_IMETHODIMP +ChannelWrapper::RequestListener::OnDataAvailable(nsIRequest* request, + nsIInputStream* inStr, + uint64_t sourceOffset, + uint32_t count) { + MOZ_ASSERT(mOrigStreamListener, "Should have mOrigStreamListener"); + return mOrigStreamListener->OnDataAvailable(request, inStr, sourceOffset, + count); +} + +NS_IMETHODIMP +ChannelWrapper::RequestListener::OnAfterLastPart(nsresult aStatus) { + MOZ_ASSERT(mOrigStreamListener, "Should have mOrigStreamListener"); + if (nsCOMPtr<nsIMultiPartChannelListener> listener = + do_QueryInterface(mOrigStreamListener)) { + return listener->OnAfterLastPart(aStatus); + } + return NS_OK; +} + +NS_IMETHODIMP +ChannelWrapper::RequestListener::CheckListenerChain() { + MOZ_ASSERT(NS_IsMainThread(), "Should be on main thread!"); + nsresult rv; + nsCOMPtr<nsIThreadRetargetableStreamListener> retargetableListener = + do_QueryInterface(mOrigStreamListener, &rv); + if (retargetableListener) { + return retargetableListener->CheckListenerChain(); + } + return rv; +} + +/***************************************************************************** + * Event dispatching + *****************************************************************************/ + +void ChannelWrapper::FireEvent(const nsAString& aType) { + EventInit init; + init.mBubbles = false; + init.mCancelable = false; + + RefPtr<Event> event = Event::Constructor(this, aType, init); + event->SetTrusted(true); + + DispatchEvent(*event); +} + +void ChannelWrapper::CheckEventListeners() { + if (!mAddedStreamListener && + (HasListenersFor(nsGkAtoms::onerror) || + HasListenersFor(nsGkAtoms::onstart) || + HasListenersFor(nsGkAtoms::onstop) || mChannelEntry)) { + auto listener = MakeRefPtr<RequestListener>(this); + if (!NS_WARN_IF(NS_FAILED(listener->Init()))) { + mAddedStreamListener = true; + } + } +} + +void ChannelWrapper::EventListenerAdded(nsAtom* aType) { + CheckEventListeners(); +} + +void ChannelWrapper::EventListenerRemoved(nsAtom* aType) { + CheckEventListeners(); +} + +/***************************************************************************** + * Glue + *****************************************************************************/ + +JSObject* ChannelWrapper::WrapObject(JSContext* aCx, HandleObject aGivenProto) { + return ChannelWrapper_Binding::Wrap(aCx, this, aGivenProto); +} + +NS_IMPL_CYCLE_COLLECTION_CLASS(ChannelWrapper) + +NS_INTERFACE_MAP_BEGIN_CYCLE_COLLECTION(ChannelWrapper) + NS_INTERFACE_MAP_ENTRY(ChannelWrapper) +NS_INTERFACE_MAP_END_INHERITING(DOMEventTargetHelper) + +NS_IMPL_CYCLE_COLLECTION_UNLINK_BEGIN_INHERITED(ChannelWrapper, + DOMEventTargetHelper) + NS_IMPL_CYCLE_COLLECTION_UNLINK(mParent) + NS_IMPL_CYCLE_COLLECTION_UNLINK(mStub) + NS_IMPL_CYCLE_COLLECTION_UNLINK_WEAK_PTR +NS_IMPL_CYCLE_COLLECTION_UNLINK_END + +NS_IMPL_CYCLE_COLLECTION_TRAVERSE_BEGIN_INHERITED(ChannelWrapper, + DOMEventTargetHelper) + NS_IMPL_CYCLE_COLLECTION_TRAVERSE(mParent) + NS_IMPL_CYCLE_COLLECTION_TRAVERSE(mStub) +NS_IMPL_CYCLE_COLLECTION_TRAVERSE_END + +NS_IMPL_CYCLE_COLLECTION_TRACE_BEGIN_INHERITED(ChannelWrapper, + DOMEventTargetHelper) +NS_IMPL_CYCLE_COLLECTION_TRACE_END + +NS_IMPL_ADDREF_INHERITED(ChannelWrapper, DOMEventTargetHelper) +NS_IMPL_RELEASE_INHERITED(ChannelWrapper, DOMEventTargetHelper) + +} // namespace extensions +} // namespace mozilla diff --git a/toolkit/components/extensions/webrequest/ChannelWrapper.h b/toolkit/components/extensions/webrequest/ChannelWrapper.h new file mode 100644 index 0000000000..3f2299c93c --- /dev/null +++ b/toolkit/components/extensions/webrequest/ChannelWrapper.h @@ -0,0 +1,352 @@ +/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2; -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_extensions_ChannelWrapper_h +#define mozilla_extensions_ChannelWrapper_h + +#include "mozilla/dom/BindingDeclarations.h" +#include "mozilla/dom/ChannelWrapperBinding.h" + +#include "mozilla/WebRequestService.h" +#include "mozilla/extensions/MatchPattern.h" +#include "mozilla/extensions/WebExtensionPolicy.h" + +#include "mozilla/Attributes.h" +#include "mozilla/LinkedList.h" +#include "mozilla/Maybe.h" +#include "mozilla/UniquePtr.h" +#include "mozilla/WeakPtr.h" + +#include "mozilla/DOMEventTargetHelper.h" +#include "nsCOMPtr.h" +#include "nsCycleCollectionParticipant.h" +#include "nsIChannel.h" +#include "nsIHttpChannel.h" +#include "nsIMultiPartChannel.h" +#include "nsIStreamListener.h" +#include "nsIRemoteTab.h" +#include "nsIThreadRetargetableStreamListener.h" +#include "nsPointerHashKeys.h" +#include "nsInterfaceHashtable.h" +#include "nsIWeakReferenceUtils.h" +#include "nsWrapperCache.h" + +#define NS_CHANNELWRAPPER_IID \ + { \ + 0xc06162d2, 0xb803, 0x43b4, { \ + 0xaa, 0x31, 0xcf, 0x69, 0x7f, 0x93, 0x68, 0x1c \ + } \ + } + +class nsILoadContext; +class nsITraceableChannel; + +namespace mozilla { +namespace dom { +class ContentParent; +class Element; +} // namespace dom +namespace extensions { + +namespace detail { + +// We need to store our wrapped channel as a weak reference, since channels +// are not cycle collected, and we're going to be hanging this wrapper +// instance off the channel in order to ensure the same channel always has +// the same wrapper. +// +// But since performance matters here, and we don't want to have to +// QueryInterface the channel every time we touch it, we store separate +// nsIChannel and nsIHttpChannel weak references, and check that the WeakPtr +// is alive before returning it. +// +// This holder class prevents us from accidentally touching the weak pointer +// members directly from our ChannelWrapper class. +struct ChannelHolder { + explicit ChannelHolder(nsIChannel* aChannel) + : mChannel(do_GetWeakReference(aChannel)), mWeakChannel(aChannel) {} + + bool HaveChannel() const { return mChannel && mChannel->IsAlive(); } + + void SetChannel(nsIChannel* aChannel) { + mChannel = do_GetWeakReference(aChannel); + mWeakChannel = aChannel; + mWeakHttpChannel.reset(); + } + + already_AddRefed<nsIChannel> MaybeChannel() const { + if (!HaveChannel()) { + mWeakChannel = nullptr; + } + return do_AddRef(mWeakChannel); + } + + already_AddRefed<nsIHttpChannel> MaybeHttpChannel() const { + if (mWeakHttpChannel.isNothing()) { + nsCOMPtr<nsIHttpChannel> chan = QueryChannel(); + mWeakHttpChannel.emplace(chan.get()); + } + + if (!HaveChannel()) { + mWeakHttpChannel.ref() = nullptr; + } + return do_AddRef(mWeakHttpChannel.value()); + } + + const nsQueryReferent QueryChannel() const { + return do_QueryReferent(mChannel); + } + + private: + nsWeakPtr mChannel; + + mutable nsIChannel* MOZ_NON_OWNING_REF mWeakChannel; + mutable Maybe<nsIHttpChannel*> MOZ_NON_OWNING_REF mWeakHttpChannel; +}; +} // namespace detail + +class WebRequestChannelEntry; + +class ChannelWrapper final : public DOMEventTargetHelper, + public SupportsWeakPtr, + public LinkedListElement<ChannelWrapper>, + private detail::ChannelHolder { + public: + NS_DECL_ISUPPORTS_INHERITED + NS_DECL_CYCLE_COLLECTION_SCRIPT_HOLDER_CLASS_INHERITED(ChannelWrapper, + DOMEventTargetHelper) + + NS_DECLARE_STATIC_IID_ACCESSOR(NS_CHANNELWRAPPER_IID) + + void Die(); + + static already_AddRefed<extensions::ChannelWrapper> Get( + const dom::GlobalObject& global, nsIChannel* channel); + static already_AddRefed<extensions::ChannelWrapper> GetRegisteredChannel( + const dom::GlobalObject& global, uint64_t aChannelId, + const WebExtensionPolicy& aAddon, nsIRemoteTab* aBrowserParent); + + uint64_t Id() const { return mId; } + + already_AddRefed<nsIChannel> GetChannel() const { return MaybeChannel(); } + + void SetChannel(nsIChannel* aChannel); + + void Cancel(uint32_t result, uint32_t reason, ErrorResult& aRv); + + void RedirectTo(nsIURI* uri, ErrorResult& aRv); + void UpgradeToSecure(ErrorResult& aRv); + + bool Suspended() const { return mSuspended; } + void Suspend(ErrorResult& aRv); + void Resume(const nsCString& aText, ErrorResult& aRv); + + void GetContentType(nsCString& aContentType) const; + void SetContentType(const nsACString& aContentType); + + void RegisterTraceableChannel(const WebExtensionPolicy& aAddon, + nsIRemoteTab* aBrowserParent); + + already_AddRefed<nsITraceableChannel> GetTraceableChannel( + nsAtom* aAddonId, dom::ContentParent* aContentParent) const; + + void GetMethod(nsCString& aRetVal) const; + + dom::MozContentPolicyType Type() const; + + uint32_t StatusCode() const; + + uint64_t ResponseSize() const; + + uint64_t RequestSize() const; + + void GetStatusLine(nsCString& aRetVal) const; + + void GetErrorString(nsString& aRetVal) const; + + void ErrorCheck(); + + IMPL_EVENT_HANDLER(error); + IMPL_EVENT_HANDLER(start); + IMPL_EVENT_HANDLER(stop); + + already_AddRefed<nsIURI> FinalURI() const; + + void GetFinalURL(nsString& aRetVal) const; + + bool Matches(const dom::MozRequestFilter& aFilter, + const WebExtensionPolicy* aExtension, + const dom::MozRequestMatchOptions& aOptions) const; + + already_AddRefed<nsILoadInfo> GetLoadInfo() const { + nsCOMPtr<nsIChannel> chan = MaybeChannel(); + if (chan) { + return chan->LoadInfo(); + } + return nullptr; + } + + int64_t FrameId() const; + + int64_t ParentFrameId() const; + + void GetFrameAncestors( + dom::Nullable<nsTArray<dom::MozFrameAncestorInfo>>& aFrameAncestors, + ErrorResult& aRv) const; + + bool IsSystemLoad() const; + + void GetOriginURL(nsCString& aRetVal) const; + + void GetDocumentURL(nsCString& aRetVal) const; + + already_AddRefed<nsIURI> GetOriginURI() const; + + already_AddRefed<nsIURI> GetDocumentURI() const; + + already_AddRefed<nsILoadContext> GetLoadContext() const; + + already_AddRefed<dom::Element> GetBrowserElement() const; + + bool CanModify() const; + bool GetCanModify(ErrorResult& aRv) const { return CanModify(); } + + void GetProxyInfo(dom::Nullable<dom::MozProxyInfo>& aRetVal, + ErrorResult& aRv) const; + + void GetRemoteAddress(nsCString& aRetVal) const; + + void GetRequestHeaders(nsTArray<dom::MozHTTPHeader>& aRetVal, + ErrorResult& aRv) const; + void GetRequestHeader(const nsCString& aHeader, nsCString& aResult, + ErrorResult& aRv) const; + + void GetResponseHeaders(nsTArray<dom::MozHTTPHeader>& aRetVal, + ErrorResult& aRv) const; + + void SetRequestHeader(const nsCString& header, const nsCString& value, + bool merge, ErrorResult& aRv); + + void SetResponseHeader(const nsCString& header, const nsCString& value, + bool merge, ErrorResult& aRv); + + void GetUrlClassification(dom::Nullable<dom::MozUrlClassification>& aRetVal, + ErrorResult& aRv) const; + + bool ThirdParty() const; + + using EventTarget::EventListenerAdded; + using EventTarget::EventListenerRemoved; + virtual void EventListenerAdded(nsAtom* aType) override; + virtual void EventListenerRemoved(nsAtom* aType) override; + + nsISupports* GetParentObject() const { return mParent; } + + JSObject* WrapObject(JSContext* aCx, JS::HandleObject aGivenProto) override; + + protected: + ~ChannelWrapper(); + + private: + ChannelWrapper(nsISupports* aParent, nsIChannel* aChannel); + + void ClearCachedAttributes(); + + bool CheckAlive(ErrorResult& aRv) const { + if (!HaveChannel()) { + aRv.Throw(NS_ERROR_UNEXPECTED); + return false; + } + return true; + } + + void FireEvent(const nsAString& aType); + + const URLInfo& FinalURLInfo() const; + const URLInfo* DocumentURLInfo() const; + + uint64_t BrowsingContextId(nsILoadInfo* aLoadInfo) const; + + nsresult GetFrameAncestors( + nsILoadInfo* aLoadInfo, + nsTArray<dom::MozFrameAncestorInfo>& aFrameAncestors) const; + + static uint64_t GetNextId() { + static uint64_t sNextId = 1; + return ++sNextId; + } + + void CheckEventListeners(); + + class ChannelWrapperStub final : public nsISupports { + public: + NS_DECL_CYCLE_COLLECTING_ISUPPORTS + NS_DECL_CYCLE_COLLECTION_CLASS(ChannelWrapperStub) + + explicit ChannelWrapperStub(ChannelWrapper* aChannelWrapper) + : mChannelWrapper(aChannelWrapper) {} + + private: + friend class ChannelWrapper; + + RefPtr<ChannelWrapper> mChannelWrapper; + + protected: + ~ChannelWrapperStub() = default; + }; + + RefPtr<ChannelWrapperStub> mStub; + + mutable Maybe<URLInfo> mFinalURLInfo; + mutable Maybe<URLInfo> mDocumentURLInfo; + + UniquePtr<WebRequestChannelEntry> mChannelEntry; + + // The overridden Content-Type header value. + nsCString mContentTypeHdr = VoidCString(); + + const uint64_t mId = GetNextId(); + nsCOMPtr<nsISupports> mParent; + + bool mAddedStreamListener = false; + bool mFiredErrorEvent = false; + bool mSuspended = false; + bool mResponseStarted = false; + + nsInterfaceHashtable<nsPtrHashKey<const nsAtom>, nsIRemoteTab> mAddonEntries; + + mozilla::TimeStamp mSuspendTime; + + class RequestListener final : public nsIStreamListener, + public nsIMultiPartChannelListener, + public nsIThreadRetargetableStreamListener { + public: + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSIREQUESTOBSERVER + NS_DECL_NSISTREAMLISTENER + NS_DECL_NSIMULTIPARTCHANNELLISTENER + NS_DECL_NSITHREADRETARGETABLESTREAMLISTENER + + explicit RequestListener(ChannelWrapper* aWrapper) + : mChannelWrapper(aWrapper) {} + + nsresult Init(); + + protected: + virtual ~RequestListener(); + + private: + RefPtr<ChannelWrapper> mChannelWrapper; + nsCOMPtr<nsIStreamListener> mOrigStreamListener; + }; +}; + +NS_DEFINE_STATIC_IID_ACCESSOR(ChannelWrapper, NS_CHANNELWRAPPER_IID) + +} // namespace extensions +} // namespace mozilla + +#endif // mozilla_extensions_ChannelWrapper_h diff --git a/toolkit/components/extensions/webrequest/PStreamFilter.ipdl b/toolkit/components/extensions/webrequest/PStreamFilter.ipdl new file mode 100644 index 0000000000..80e03cc8cf --- /dev/null +++ b/toolkit/components/extensions/webrequest/PStreamFilter.ipdl @@ -0,0 +1,38 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +include protocol PBackground; + +namespace mozilla { +namespace extensions { + +async protocol PStreamFilter +{ +parent: + async Write(uint8_t[] data); + + async FlushedData(); + + async Suspend(); + async Resume(); + async Close(); + async Disconnect(); + async Destroy(); + +child: + async Resumed(); + async Suspended(); + async Closed(); + async Error(nsCString error); + + async FlushData(); + + async StartRequest(); + async Data(uint8_t[] data); + async StopRequest(nsresult aStatus); +}; + +} // namespace extensions +} // namespace mozilla + diff --git a/toolkit/components/extensions/webrequest/SecurityInfo.jsm b/toolkit/components/extensions/webrequest/SecurityInfo.jsm new file mode 100644 index 0000000000..4652aa28da --- /dev/null +++ b/toolkit/components/extensions/webrequest/SecurityInfo.jsm @@ -0,0 +1,328 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +"use strict"; + +const EXPORTED_SYMBOLS = ["SecurityInfo"]; + +const { XPCOMUtils } = ChromeUtils.import( + "resource://gre/modules/XPCOMUtils.jsm" +); + +const wpl = Ci.nsIWebProgressListener; +XPCOMUtils.defineLazyServiceGetter( + this, + "NSSErrorsService", + "@mozilla.org/nss_errors_service;1", + "nsINSSErrorsService" +); +XPCOMUtils.defineLazyServiceGetter( + this, + "sss", + "@mozilla.org/ssservice;1", + "nsISiteSecurityService" +); + +// NOTE: SecurityInfo is largely reworked from the devtools NetworkHelper with changes +// to better support the WebRequest api. The objects returned are formatted specifically +// to pass through as part of a response to webRequest listeners. + +const SecurityInfo = { + /** + * Extracts security information from nsIChannel.securityInfo. + * + * @param {nsIChannel} channel + * If null channel is assumed to be insecure. + * @param {Object} options + * + * @returns {Object} + * Returns an object containing following members: + * - state: The security of the connection used to fetch this + * request. Has one of following string values: + * * "insecure": the connection was not secure (only http) + * * "weak": the connection has minor security issues + * * "broken": secure connection failed (e.g. expired cert) + * * "secure": the connection was properly secured. + * If state == broken: + * - errorMessage: full error message from + * nsITransportSecurityInfo. + * If state == secure: + * - protocolVersion: one of TLSv1, TLSv1.1, TLSv1.2, TLSv1.3. + * - cipherSuite: the cipher suite used in this connection. + * - cert: information about certificate used in this connection. + * See parseCertificateInfo for the contents. + * - hsts: true if host uses Strict Transport Security, + * false otherwise + * - hpkp: true if host uses Public Key Pinning, false otherwise + * If state == weak: Same as state == secure and + * - weaknessReasons: list of reasons that cause the request to be + * considered weak. See getReasonsForWeakness. + */ + getSecurityInfo(channel, options = {}) { + const info = { + state: "insecure", + }; + + /** + * Different scenarios to consider here and how they are handled: + * - request is HTTP, the connection is not secure + * => securityInfo is null + * => state === "insecure" + * + * - request is HTTPS, the connection is secure + * => .securityState has STATE_IS_SECURE flag + * => state === "secure" + * + * - request is HTTPS, the connection has security issues + * => .securityState has STATE_IS_INSECURE flag + * => .errorCode is an NSS error code. + * => state === "broken" + * + * - request is HTTPS, the connection was terminated before the security + * could be validated + * => .securityState has STATE_IS_INSECURE flag + * => .errorCode is NOT an NSS error code. + * => .errorMessage is not available. + * => state === "insecure" + * + * - request is HTTPS but it uses a weak cipher or old protocol, see + * https://hg.mozilla.org/mozilla-central/annotate/def6ed9d1c1a/ + * security/manager/ssl/nsNSSCallbacks.cpp#l1233 + * - request is mixed content (which makes no sense whatsoever) + * => .securityState has STATE_IS_BROKEN flag + * => .errorCode is NOT an NSS error code + * => .errorMessage is not available + * => state === "weak" + */ + + let securityInfo = channel.securityInfo; + if (!securityInfo) { + return info; + } + + securityInfo.QueryInterface(Ci.nsITransportSecurityInfo); + + if (NSSErrorsService.isNSSErrorCode(securityInfo.errorCode)) { + // The connection failed. + info.state = "broken"; + info.errorMessage = securityInfo.errorMessage; + if (options.certificateChain && securityInfo.failedCertChain) { + info.certificates = this.getCertificateChain( + securityInfo.failedCertChain, + options + ); + } + return info; + } + + const state = securityInfo.securityState; + + let uri = channel.URI; + if (uri && !uri.schemeIs("https") && !uri.schemeIs("wss")) { + // it is not enough to look at the transport security info - + // schemes other than https and wss are subject to + // downgrade/etc at the scheme level and should always be + // considered insecure. + // Leave info.state = "insecure"; + } else if (state & wpl.STATE_IS_SECURE) { + // The connection is secure if the scheme is sufficient + info.state = "secure"; + } else if (state & wpl.STATE_IS_BROKEN) { + // The connection is not secure, there was no error but there's some + // minor security issues. + info.state = "weak"; + info.weaknessReasons = this.getReasonsForWeakness(state); + } else if (state & wpl.STATE_IS_INSECURE) { + // This was most likely an https request that was aborted before + // validation. Return info as info.state = insecure. + return info; + } else { + // No known STATE_IS_* flags. + return info; + } + + // Cipher suite. + info.cipherSuite = securityInfo.cipherName; + + // Key exchange group name. + if (securityInfo.keaGroupName !== "none") { + info.keaGroupName = securityInfo.keaGroupName; + } + + // Certificate signature scheme. + if (securityInfo.signatureSchemeName !== "none") { + info.signatureSchemeName = securityInfo.signatureSchemeName; + } + + info.isDomainMismatch = securityInfo.isDomainMismatch; + info.isExtendedValidation = securityInfo.isExtendedValidation; + info.isNotValidAtThisTime = securityInfo.isNotValidAtThisTime; + info.isUntrusted = securityInfo.isUntrusted; + + info.certificateTransparencyStatus = this.getTransparencyStatus( + securityInfo.certificateTransparencyStatus + ); + + // Protocol version. + info.protocolVersion = this.formatSecurityProtocol( + securityInfo.protocolVersion + ); + + if (options.certificateChain && securityInfo.succeededCertChain) { + info.certificates = this.getCertificateChain( + securityInfo.succeededCertChain, + options + ); + } else { + info.certificates = [ + this.parseCertificateInfo(securityInfo.serverCert, options), + ]; + } + + // HSTS and static pinning if available. + if (uri && uri.host) { + // SiteSecurityService uses different storage if the channel is + // private. Thus we must give isSecureURI correct flags or we + // might get incorrect results. + let flags = 0; + if ( + channel instanceof Ci.nsIPrivateBrowsingChannel && + channel.isChannelPrivate + ) { + flags = Ci.nsISocketProvider.NO_PERMANENT_STORAGE; + } + + info.hsts = sss.isSecureURI(sss.HEADER_HSTS, uri, flags); + info.hpkp = sss.isSecureURI(sss.STATIC_PINNING, uri, flags); + } else { + info.hsts = false; + info.hpkp = false; + } + + return info; + }, + + getCertificateChain(certChain, options = {}) { + let certificates = []; + for (let cert of certChain) { + certificates.push(this.parseCertificateInfo(cert, options)); + } + return certificates; + }, + + /** + * Takes an nsIX509Cert and returns an object with certificate information. + * + * @param {nsIX509Cert} cert + * The certificate to extract the information from. + * @param {Object} options + * @returns {Object} + * An object with following format: + * { + * subject: subjectName, + * issuer: issuerName, + * validity: { start, end }, + * fingerprint: { sha1, sha256 } + * } + */ + parseCertificateInfo(cert, options = {}) { + if (!cert) { + return {}; + } + + let certData = { + subject: cert.subjectName, + issuer: cert.issuerName, + validity: { + start: cert.validity.notBefore + ? Math.trunc(cert.validity.notBefore / 1000) + : 0, + end: cert.validity.notAfter + ? Math.trunc(cert.validity.notAfter / 1000) + : 0, + }, + fingerprint: { + sha1: cert.sha1Fingerprint, + sha256: cert.sha256Fingerprint, + }, + serialNumber: cert.serialNumber, + isBuiltInRoot: cert.isBuiltInRoot, + subjectPublicKeyInfoDigest: { + sha256: cert.sha256SubjectPublicKeyInfoDigest, + }, + }; + if (options.rawDER) { + certData.rawDER = cert.getRawDER(); + } + return certData; + }, + + // Bug 1355903 Transparency is currently disabled using security.pki.certificate_transparency.mode + getTransparencyStatus(status) { + switch (status) { + case Ci.nsITransportSecurityInfo.CERTIFICATE_TRANSPARENCY_NOT_APPLICABLE: + return "not_applicable"; + case Ci.nsITransportSecurityInfo + .CERTIFICATE_TRANSPARENCY_POLICY_COMPLIANT: + return "policy_compliant"; + case Ci.nsITransportSecurityInfo + .CERTIFICATE_TRANSPARENCY_POLICY_NOT_ENOUGH_SCTS: + return "policy_not_enough_scts"; + case Ci.nsITransportSecurityInfo + .CERTIFICATE_TRANSPARENCY_POLICY_NOT_DIVERSE_SCTS: + return "policy_not_diverse_scts"; + } + return "unknown"; + }, + + /** + * Takes protocolVersion of TransportSecurityInfo object and returns human readable + * description. + * + * @param {number} version + * One of nsITransportSecurityInfo version constants. + * @returns {string} + * One of TLSv1, TLSv1.1, TLSv1.2, TLSv1.3 if version + * is valid, Unknown otherwise. + */ + formatSecurityProtocol(version) { + switch (version) { + case Ci.nsITransportSecurityInfo.TLS_VERSION_1: + return "TLSv1"; + case Ci.nsITransportSecurityInfo.TLS_VERSION_1_1: + return "TLSv1.1"; + case Ci.nsITransportSecurityInfo.TLS_VERSION_1_2: + return "TLSv1.2"; + case Ci.nsITransportSecurityInfo.TLS_VERSION_1_3: + return "TLSv1.3"; + } + return "unknown"; + }, + + /** + * Takes the securityState bitfield and returns reasons for weak connection + * as an array of strings. + * + * @param {number} state + * nsITransportSecurityInfo.securityState. + * + * @returns {array<string>} + * List of weakness reasons. A subset of { cipher } where + * * cipher: The cipher suite is consireded to be weak (RC4). + */ + getReasonsForWeakness(state) { + // If there's non-fatal security issues the request has STATE_IS_BROKEN + // flag set. See https://hg.mozilla.org/mozilla-central/file/44344099d119 + // /security/manager/ssl/nsNSSCallbacks.cpp#l1233 + let reasons = []; + + if (state & wpl.STATE_IS_BROKEN) { + if (state & wpl.STATE_USES_WEAK_CRYPTO) { + reasons.push("cipher"); + } + } + + return reasons; + }, +}; diff --git a/toolkit/components/extensions/webrequest/StreamFilter.cpp b/toolkit/components/extensions/webrequest/StreamFilter.cpp new file mode 100644 index 0000000000..178689f333 --- /dev/null +++ b/toolkit/components/extensions/webrequest/StreamFilter.cpp @@ -0,0 +1,286 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "StreamFilter.h" + +#include "jsapi.h" +#include "jsfriendapi.h" +#include "xpcpublic.h" + +#include "mozilla/AbstractThread.h" +#include "mozilla/HoldDropJSObjects.h" +#include "mozilla/extensions/StreamFilterChild.h" +#include "mozilla/extensions/StreamFilterEvents.h" +#include "mozilla/extensions/StreamFilterParent.h" +#include "mozilla/dom/ContentChild.h" +#include "mozilla/ipc/Endpoint.h" +#include "nsContentUtils.h" +#include "nsCycleCollectionParticipant.h" +#include "nsLiteralString.h" +#include "nsThreadUtils.h" +#include "nsTArray.h" + +using namespace JS; +using namespace mozilla::dom; + +namespace mozilla { +namespace extensions { + +/***************************************************************************** + * Initialization + *****************************************************************************/ + +StreamFilter::StreamFilter(nsIGlobalObject* aParent, uint64_t aRequestId, + const nsAString& aAddonId) + : mParent(aParent), mChannelId(aRequestId), mAddonId(NS_Atomize(aAddonId)) { + MOZ_ASSERT(aParent); + + Connect(); +}; + +StreamFilter::~StreamFilter() { ForgetActor(); } + +void StreamFilter::ForgetActor() { + if (mActor) { + mActor->Cleanup(); + mActor->SetStreamFilter(nullptr); + } +} + +/* static */ +already_AddRefed<StreamFilter> StreamFilter::Create(GlobalObject& aGlobal, + uint64_t aRequestId, + const nsAString& aAddonId) { + nsCOMPtr<nsIGlobalObject> global = do_QueryInterface(aGlobal.GetAsSupports()); + MOZ_ASSERT(global); + + RefPtr<StreamFilter> filter = new StreamFilter(global, aRequestId, aAddonId); + return filter.forget(); +} + +/***************************************************************************** + * Actor allocation + *****************************************************************************/ + +void StreamFilter::Connect() { + MOZ_ASSERT(!mActor); + + mActor = new StreamFilterChild(); + mActor->SetStreamFilter(this); + + nsAutoString addonId; + mAddonId->ToString(addonId); + + ContentChild* cc = ContentChild::GetSingleton(); + RefPtr<StreamFilter> self(this); + if (cc) { + cc->SendInitStreamFilter(mChannelId, addonId) + ->Then( + GetCurrentSerialEventTarget(), __func__, + [self](mozilla::ipc::Endpoint<PStreamFilterChild>&& aEndpoint) { + self->FinishConnect(std::move(aEndpoint)); + }, + [self](mozilla::ipc::ResponseRejectReason&& aReason) { + self->mActor->RecvInitialized(false); + }); + } else { + StreamFilterParent::Create(nullptr, mChannelId, addonId) + ->Then( + GetCurrentSerialEventTarget(), __func__, + [self](mozilla::ipc::Endpoint<PStreamFilterChild>&& aEndpoint) { + self->FinishConnect(std::move(aEndpoint)); + }, + [self](bool aDummy) { self->mActor->RecvInitialized(false); }); + } +} + +void StreamFilter::FinishConnect( + mozilla::ipc::Endpoint<PStreamFilterChild>&& aEndpoint) { + if (aEndpoint.IsValid()) { + MOZ_RELEASE_ASSERT(aEndpoint.Bind(mActor)); + mActor->RecvInitialized(true); + + // IPC now owns this reference. + Unused << do_AddRef(mActor); + } else { + mActor->RecvInitialized(false); + } +} + +bool StreamFilter::CheckAlive() { + // Check whether the global that owns this StreamFitler is still scriptable + // and, if not, disconnect the actor so that it can be cleaned up. + JSObject* wrapper = GetWrapperPreserveColor(); + if (!wrapper || !xpc::Scriptability::Get(wrapper).Allowed()) { + ForgetActor(); + return false; + } + return true; +} + +/***************************************************************************** + * Binding methods + *****************************************************************************/ + +template <typename T> +static inline bool ReadTypedArrayData(nsTArray<uint8_t>& aData, const T& aArray, + ErrorResult& aRv) { + aArray.ComputeState(); + if (!aData.SetLength(aArray.Length(), fallible)) { + aRv.Throw(NS_ERROR_OUT_OF_MEMORY); + return false; + } + memcpy(aData.Elements(), aArray.Data(), aArray.Length()); + return true; +} + +void StreamFilter::Write(const ArrayBufferOrUint8Array& aData, + ErrorResult& aRv) { + if (!mActor) { + aRv.Throw(NS_ERROR_NOT_INITIALIZED); + return; + } + + nsTArray<uint8_t> data; + + bool ok; + if (aData.IsArrayBuffer()) { + ok = ReadTypedArrayData(data, aData.GetAsArrayBuffer(), aRv); + } else if (aData.IsUint8Array()) { + ok = ReadTypedArrayData(data, aData.GetAsUint8Array(), aRv); + } else { + MOZ_ASSERT_UNREACHABLE("Argument should be ArrayBuffer or Uint8Array"); + return; + } + + if (ok) { + mActor->Write(std::move(data), aRv); + } +} + +StreamFilterStatus StreamFilter::Status() const { + if (!mActor) { + return StreamFilterStatus::Uninitialized; + } + return mActor->Status(); +} + +void StreamFilter::Suspend(ErrorResult& aRv) { + if (mActor) { + mActor->Suspend(aRv); + } else { + aRv.Throw(NS_ERROR_NOT_INITIALIZED); + } +} + +void StreamFilter::Resume(ErrorResult& aRv) { + if (mActor) { + mActor->Resume(aRv); + } else { + aRv.Throw(NS_ERROR_NOT_INITIALIZED); + } +} + +void StreamFilter::Disconnect(ErrorResult& aRv) { + if (mActor) { + mActor->Disconnect(aRv); + } else { + aRv.Throw(NS_ERROR_NOT_INITIALIZED); + } +} + +void StreamFilter::Close(ErrorResult& aRv) { + if (mActor) { + mActor->Close(aRv); + } else { + aRv.Throw(NS_ERROR_NOT_INITIALIZED); + } +} + +/***************************************************************************** + * Event emitters + *****************************************************************************/ + +void StreamFilter::FireEvent(const nsAString& aType) { + EventInit init; + init.mBubbles = false; + init.mCancelable = false; + + RefPtr<Event> event = Event::Constructor(this, aType, init); + event->SetTrusted(true); + + DispatchEvent(*event); +} + +void StreamFilter::FireDataEvent(const nsTArray<uint8_t>& aData) { + AutoEntryScript aes(mParent, "StreamFilter data event"); + JSContext* cx = aes.cx(); + + RootedDictionary<StreamFilterDataEventInit> init(cx); + init.mBubbles = false; + init.mCancelable = false; + + auto buffer = ArrayBuffer::Create(cx, aData.Length(), aData.Elements()); + if (!buffer) { + // TODO: There is no way to recover from this. This chunk of data is lost. + FireErrorEvent(u"Out of memory"_ns); + return; + } + + init.mData.Init(buffer); + + RefPtr<StreamFilterDataEvent> event = + StreamFilterDataEvent::Constructor(this, u"data"_ns, init); + event->SetTrusted(true); + + DispatchEvent(*event); +} + +void StreamFilter::FireErrorEvent(const nsAString& aError) { + MOZ_ASSERT(mError.IsEmpty()); + + mError = aError; + FireEvent(u"error"_ns); +} + +/***************************************************************************** + * Glue + *****************************************************************************/ + +/* static */ +bool StreamFilter::IsAllowedInContext(JSContext* aCx, JSObject* /* unused */) { + return nsContentUtils::CallerHasPermission(aCx, + nsGkAtoms::webRequestBlocking); +} + +JSObject* StreamFilter::WrapObject(JSContext* aCx, HandleObject aGivenProto) { + return StreamFilter_Binding::Wrap(aCx, this, aGivenProto); +} + +NS_IMPL_CYCLE_COLLECTION_CLASS(StreamFilter) + +NS_INTERFACE_MAP_BEGIN_CYCLE_COLLECTION(StreamFilter) +NS_INTERFACE_MAP_END_INHERITING(DOMEventTargetHelper) + +NS_IMPL_CYCLE_COLLECTION_UNLINK_BEGIN_INHERITED(StreamFilter, + DOMEventTargetHelper) + NS_IMPL_CYCLE_COLLECTION_UNLINK(mParent) +NS_IMPL_CYCLE_COLLECTION_UNLINK_END + +NS_IMPL_CYCLE_COLLECTION_TRAVERSE_BEGIN_INHERITED(StreamFilter, + DOMEventTargetHelper) + NS_IMPL_CYCLE_COLLECTION_TRAVERSE(mParent) +NS_IMPL_CYCLE_COLLECTION_TRAVERSE_END + +NS_IMPL_CYCLE_COLLECTION_TRACE_BEGIN_INHERITED(StreamFilter, + DOMEventTargetHelper) +NS_IMPL_CYCLE_COLLECTION_TRACE_END + +NS_IMPL_ADDREF_INHERITED(StreamFilter, DOMEventTargetHelper) +NS_IMPL_RELEASE_INHERITED(StreamFilter, DOMEventTargetHelper) + +} // namespace extensions +} // namespace mozilla diff --git a/toolkit/components/extensions/webrequest/StreamFilter.h b/toolkit/components/extensions/webrequest/StreamFilter.h new file mode 100644 index 0000000000..4c29dc19ac --- /dev/null +++ b/toolkit/components/extensions/webrequest/StreamFilter.h @@ -0,0 +1,97 @@ +/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2; -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_extensions_StreamFilter_h +#define mozilla_extensions_StreamFilter_h + +#include "mozilla/dom/BindingDeclarations.h" +#include "mozilla/dom/StreamFilterBinding.h" + +#include "mozilla/DOMEventTargetHelper.h" +#include "nsCOMPtr.h" +#include "nsCycleCollectionParticipant.h" +#include "nsAtom.h" + +namespace mozilla { +namespace ipc { +template <class T> +class Endpoint; +} + +namespace extensions { + +class PStreamFilterChild; +class StreamFilterChild; + +using namespace mozilla::dom; + +class StreamFilter : public DOMEventTargetHelper { + friend class StreamFilterChild; + + NS_DECL_ISUPPORTS_INHERITED + NS_DECL_CYCLE_COLLECTION_SCRIPT_HOLDER_CLASS_INHERITED(StreamFilter, + DOMEventTargetHelper) + + static already_AddRefed<StreamFilter> Create(GlobalObject& global, + uint64_t aRequestId, + const nsAString& aAddonId); + + explicit StreamFilter(nsIGlobalObject* aParent, uint64_t aRequestId, + const nsAString& aAddonId); + + IMPL_EVENT_HANDLER(start); + IMPL_EVENT_HANDLER(stop); + IMPL_EVENT_HANDLER(data); + IMPL_EVENT_HANDLER(error); + + void Write(const ArrayBufferOrUint8Array& aData, ErrorResult& aRv); + + void GetError(nsAString& aError) { aError = mError; } + + StreamFilterStatus Status() const; + void Suspend(ErrorResult& aRv); + void Resume(ErrorResult& aRv); + void Disconnect(ErrorResult& aRv); + void Close(ErrorResult& aRv); + + nsISupports* GetParentObject() const { return mParent; } + + virtual JSObject* WrapObject(JSContext* aCx, + JS::Handle<JSObject*> aGivenProto) override; + + static bool IsAllowedInContext(JSContext* aCx, JSObject* aObj); + + protected: + virtual ~StreamFilter(); + + void FireEvent(const nsAString& aType); + + void FireDataEvent(const nsTArray<uint8_t>& aData); + + void FireErrorEvent(const nsAString& aError); + + bool CheckAlive(); + + private: + void Connect(); + + void FinishConnect(mozilla::ipc::Endpoint<PStreamFilterChild>&& aEndpoint); + + void ForgetActor(); + + nsCOMPtr<nsIGlobalObject> mParent; + RefPtr<StreamFilterChild> mActor; + + nsString mError; + + const uint64_t mChannelId; + const RefPtr<nsAtom> mAddonId; +}; + +} // namespace extensions +} // namespace mozilla + +#endif // mozilla_extensions_StreamFilter_h diff --git a/toolkit/components/extensions/webrequest/StreamFilterBase.h b/toolkit/components/extensions/webrequest/StreamFilterBase.h new file mode 100644 index 0000000000..4f413835ef --- /dev/null +++ b/toolkit/components/extensions/webrequest/StreamFilterBase.h @@ -0,0 +1,38 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_extensions_StreamFilterBase_h +#define mozilla_extensions_StreamFilterBase_h + +#include "mozilla/LinkedList.h" +#include "nsTArray.h" + +namespace mozilla { +namespace extensions { + +class StreamFilterBase { + public: + typedef nsTArray<uint8_t> Data; + + protected: + class BufferedData : public LinkedListElement<BufferedData> { + public: + explicit BufferedData(Data&& aData) : mData(std::move(aData)) {} + + Data mData; + }; + + LinkedList<BufferedData> mBufferedData; + + inline void BufferData(Data&& aData) { + mBufferedData.insertBack(new BufferedData(std::move(aData))); + }; +}; + +} // namespace extensions +} // namespace mozilla + +#endif // mozilla_extensions_StreamFilterBase_h diff --git a/toolkit/components/extensions/webrequest/StreamFilterChild.cpp b/toolkit/components/extensions/webrequest/StreamFilterChild.cpp new file mode 100644 index 0000000000..90a6e11ad2 --- /dev/null +++ b/toolkit/components/extensions/webrequest/StreamFilterChild.cpp @@ -0,0 +1,520 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "StreamFilterChild.h" +#include "StreamFilter.h" + +#include "mozilla/Assertions.h" +#include "mozilla/UniquePtr.h" + +namespace mozilla { +namespace extensions { + +using mozilla::dom::StreamFilterStatus; +using mozilla::ipc::IPCResult; + +/***************************************************************************** + * Initialization and cleanup + *****************************************************************************/ + +void StreamFilterChild::Cleanup() { + switch (mState) { + case State::Closing: + case State::Closed: + case State::Error: + case State::Disconnecting: + case State::Disconnected: + break; + + default: + ErrorResult rv; + Disconnect(rv); + break; + } +} + +/***************************************************************************** + * State change methods + *****************************************************************************/ + +void StreamFilterChild::Suspend(ErrorResult& aRv) { + switch (mState) { + case State::TransferringData: + mState = State::Suspending; + mNextState = State::Suspended; + + SendSuspend(); + break; + + case State::Suspending: + switch (mNextState) { + case State::Suspended: + case State::Resuming: + mNextState = State::Suspended; + break; + + default: + aRv.Throw(NS_ERROR_FAILURE); + return; + } + break; + + case State::Resuming: + switch (mNextState) { + case State::TransferringData: + case State::Suspending: + mNextState = State::Suspending; + break; + + default: + aRv.Throw(NS_ERROR_FAILURE); + return; + } + break; + + case State::Suspended: + break; + + default: + aRv.Throw(NS_ERROR_FAILURE); + break; + } +} + +void StreamFilterChild::Resume(ErrorResult& aRv) { + switch (mState) { + case State::Suspended: + mState = State::Resuming; + mNextState = State::TransferringData; + + SendResume(); + break; + + case State::Suspending: + switch (mNextState) { + case State::Suspended: + case State::Resuming: + mNextState = State::Resuming; + break; + + default: + aRv.Throw(NS_ERROR_FAILURE); + return; + } + break; + + case State::Resuming: + case State::TransferringData: + break; + + default: + aRv.Throw(NS_ERROR_FAILURE); + return; + } + + FlushBufferedData(); +} + +void StreamFilterChild::Disconnect(ErrorResult& aRv) { + switch (mState) { + case State::Suspended: + case State::TransferringData: + case State::FinishedTransferringData: + mState = State::Disconnecting; + mNextState = State::Disconnected; + + WriteBufferedData(); + SendDisconnect(); + break; + + case State::Suspending: + case State::Resuming: + switch (mNextState) { + case State::Suspended: + case State::Resuming: + case State::Disconnecting: + mNextState = State::Disconnecting; + break; + + default: + aRv.Throw(NS_ERROR_FAILURE); + return; + } + break; + + case State::Disconnecting: + case State::Disconnected: + break; + + default: + aRv.Throw(NS_ERROR_FAILURE); + return; + } +} + +void StreamFilterChild::Close(ErrorResult& aRv) { + switch (mState) { + case State::Suspended: + case State::TransferringData: + case State::FinishedTransferringData: + mState = State::Closing; + mNextState = State::Closed; + + SendClose(); + break; + + case State::Suspending: + case State::Resuming: + mNextState = State::Closing; + break; + + case State::Closing: + MOZ_DIAGNOSTIC_ASSERT(mNextState == State::Closed); + break; + + case State::Closed: + break; + + default: + aRv.Throw(NS_ERROR_FAILURE); + return; + } + + mBufferedData.clear(); +} + +/***************************************************************************** + * Internal state management + *****************************************************************************/ + +void StreamFilterChild::SetNextState() { + mState = mNextState; + + switch (mNextState) { + case State::Suspending: + mNextState = State::Suspended; + SendSuspend(); + break; + + case State::Resuming: + mNextState = State::TransferringData; + SendResume(); + break; + + case State::Closing: + mNextState = State::Closed; + SendClose(); + break; + + case State::Disconnecting: + mNextState = State::Disconnected; + SendDisconnect(); + break; + + case State::FinishedTransferringData: + if (mStreamFilter) { + mStreamFilter->FireEvent(u"stop"_ns); + // We don't need access to the stream filter after this point, so break + // our reference cycle, so that it can be collected if we're the last + // reference. + mStreamFilter = nullptr; + } + break; + + case State::TransferringData: + FlushBufferedData(); + break; + + case State::Closed: + case State::Disconnected: + case State::Error: + mStreamFilter = nullptr; + break; + + default: + break; + } +} + +void StreamFilterChild::MaybeStopRequest() { + if (!mReceivedOnStop || !mBufferedData.isEmpty()) { + return; + } + + if (mStreamFilter) { + Unused << mStreamFilter->CheckAlive(); + } + + switch (mState) { + case State::Suspending: + case State::Resuming: + mNextState = State::FinishedTransferringData; + return; + + case State::Disconnecting: + case State::Closing: + case State::Closed: + break; + + default: + mState = State::FinishedTransferringData; + if (mStreamFilter) { + mStreamFilter->FireEvent(u"stop"_ns); + // We don't need access to the stream filter after this point, so break + // our reference cycle, so that it can be collected if we're the last + // reference. + mStreamFilter = nullptr; + } + break; + } +} + +/***************************************************************************** + * State change acknowledgment callbacks + *****************************************************************************/ + +void StreamFilterChild::RecvInitialized(bool aSuccess) { + MOZ_ASSERT(mState == State::Uninitialized); + + if (aSuccess) { + mState = State::Initialized; + } else { + mState = State::Error; + if (mStreamFilter) { + mStreamFilter->FireErrorEvent(u"Invalid request ID"_ns); + mStreamFilter = nullptr; + } + } +} + +IPCResult StreamFilterChild::RecvError(const nsCString& aError) { + mState = State::Error; + if (mStreamFilter) { + mStreamFilter->FireErrorEvent(NS_ConvertUTF8toUTF16(aError)); + mStreamFilter = nullptr; + } + SendDestroy(); + return IPC_OK(); +} + +IPCResult StreamFilterChild::RecvClosed() { + MOZ_DIAGNOSTIC_ASSERT(mState == State::Closing); + + SetNextState(); + return IPC_OK(); +} + +IPCResult StreamFilterChild::RecvSuspended() { + MOZ_DIAGNOSTIC_ASSERT(mState == State::Suspending); + + SetNextState(); + return IPC_OK(); +} + +IPCResult StreamFilterChild::RecvResumed() { + MOZ_DIAGNOSTIC_ASSERT(mState == State::Resuming); + + SetNextState(); + return IPC_OK(); +} + +IPCResult StreamFilterChild::RecvFlushData() { + MOZ_DIAGNOSTIC_ASSERT(mState == State::Disconnecting); + + SendFlushedData(); + SetNextState(); + return IPC_OK(); +} + +/***************************************************************************** + * Other binding methods + *****************************************************************************/ + +void StreamFilterChild::Write(Data&& aData, ErrorResult& aRv) { + switch (mState) { + case State::Suspending: + case State::Resuming: + switch (mNextState) { + case State::Suspended: + case State::TransferringData: + break; + + default: + aRv.Throw(NS_ERROR_FAILURE); + return; + } + break; + + case State::Suspended: + case State::TransferringData: + case State::FinishedTransferringData: + break; + + default: + aRv.Throw(NS_ERROR_FAILURE); + return; + } + + SendWrite(std::move(aData)); +} + +StreamFilterStatus StreamFilterChild::Status() const { + switch (mState) { + case State::Uninitialized: + case State::Initialized: + return StreamFilterStatus::Uninitialized; + + case State::TransferringData: + return StreamFilterStatus::Transferringdata; + + case State::Suspended: + return StreamFilterStatus::Suspended; + + case State::FinishedTransferringData: + return StreamFilterStatus::Finishedtransferringdata; + + case State::Resuming: + case State::Suspending: + switch (mNextState) { + case State::TransferringData: + case State::Resuming: + return StreamFilterStatus::Transferringdata; + + case State::Suspended: + case State::Suspending: + return StreamFilterStatus::Suspended; + + case State::Closing: + return StreamFilterStatus::Closed; + + case State::Disconnecting: + return StreamFilterStatus::Disconnected; + + default: + MOZ_ASSERT_UNREACHABLE("Unexpected next state"); + return StreamFilterStatus::Suspended; + } + break; + + case State::Closing: + case State::Closed: + return StreamFilterStatus::Closed; + + case State::Disconnecting: + case State::Disconnected: + return StreamFilterStatus::Disconnected; + + case State::Error: + return StreamFilterStatus::Failed; + }; + + MOZ_ASSERT_UNREACHABLE("Not reached"); + return StreamFilterStatus::Failed; +} + +/***************************************************************************** + * Request state notifications + *****************************************************************************/ + +IPCResult StreamFilterChild::RecvStartRequest() { + MOZ_ASSERT(mState == State::Initialized); + + mState = State::TransferringData; + + if (mStreamFilter) { + mStreamFilter->FireEvent(u"start"_ns); + Unused << mStreamFilter->CheckAlive(); + } + return IPC_OK(); +} + +IPCResult StreamFilterChild::RecvStopRequest(const nsresult& aStatus) { + mReceivedOnStop = true; + MaybeStopRequest(); + return IPC_OK(); +} + +/***************************************************************************** + * Incoming request data handling + *****************************************************************************/ + +void StreamFilterChild::EmitData(const Data& aData) { + MOZ_ASSERT(CanFlushData()); + if (mStreamFilter) { + mStreamFilter->FireDataEvent(aData); + } + + MaybeStopRequest(); +} + +void StreamFilterChild::FlushBufferedData() { + while (!mBufferedData.isEmpty() && CanFlushData()) { + UniquePtr<BufferedData> data(mBufferedData.popFirst()); + + EmitData(data->mData); + } +} + +void StreamFilterChild::WriteBufferedData() { + while (!mBufferedData.isEmpty()) { + UniquePtr<BufferedData> data(mBufferedData.popFirst()); + + SendWrite(data->mData); + } +} + +IPCResult StreamFilterChild::RecvData(Data&& aData) { + MOZ_ASSERT(!mReceivedOnStop); + + if (mStreamFilter) { + Unused << mStreamFilter->CheckAlive(); + } + + switch (mState) { + case State::TransferringData: + case State::Resuming: + EmitData(aData); + break; + + case State::FinishedTransferringData: + MOZ_ASSERT_UNREACHABLE("Received data in unexpected state"); + EmitData(aData); + break; + + case State::Suspending: + case State::Suspended: + BufferData(std::move(aData)); + break; + + case State::Disconnecting: + SendWrite(std::move(aData)); + break; + + case State::Closing: + break; + + default: + MOZ_ASSERT_UNREACHABLE("Received data in unexpected state"); + return IPC_FAIL_NO_REASON(this); + } + + return IPC_OK(); +} + +/***************************************************************************** + * Glue + *****************************************************************************/ + +void StreamFilterChild::ActorDestroy(ActorDestroyReason aWhy) { + mStreamFilter = nullptr; +} + +void StreamFilterChild::ActorDealloc() { + RefPtr<StreamFilterChild> self = dont_AddRef(this); +} + +} // namespace extensions +} // namespace mozilla diff --git a/toolkit/components/extensions/webrequest/StreamFilterChild.h b/toolkit/components/extensions/webrequest/StreamFilterChild.h new file mode 100644 index 0000000000..9cc6e04cce --- /dev/null +++ b/toolkit/components/extensions/webrequest/StreamFilterChild.h @@ -0,0 +1,137 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_extensions_StreamFilterChild_h +#define mozilla_extensions_StreamFilterChild_h + +#include "StreamFilterBase.h" +#include "mozilla/extensions/PStreamFilterChild.h" +#include "mozilla/extensions/StreamFilter.h" + +#include "mozilla/LinkedList.h" +#include "mozilla/dom/StreamFilterBinding.h" +#include "nsISupportsImpl.h" + +namespace mozilla { +class ErrorResult; + +namespace extensions { + +using mozilla::dom::StreamFilterStatus; +using mozilla::ipc::IPCResult; + +class StreamFilter; + +class StreamFilterChild final : public PStreamFilterChild, + public StreamFilterBase { + friend class StreamFilter; + friend class PStreamFilterChild; + + public: + NS_INLINE_DECL_REFCOUNTING(StreamFilterChild) + + StreamFilterChild() : mState(State::Uninitialized), mReceivedOnStop(false) {} + + enum class State { + // Uninitialized, waiting for constructor response from parent. + Uninitialized, + // Initialized, but channel has not begun transferring data. + Initialized, + // The stream's OnStartRequest event has been dispatched, and the channel is + // transferring data. + TransferringData, + // The channel's OnStopRequest event has been dispatched, and the channel is + // no longer transferring data. Data may still be written to the output + // stream listener. + FinishedTransferringData, + // The channel is being suspended, and we're waiting for confirmation of + // suspension from the parent. + Suspending, + // The channel has been suspended in the parent. Data may still be written + // to the output stream listener in this state. + Suspended, + // The channel is suspended. Resume has been called, and we are waiting for + // confirmation of resumption from the parent. + Resuming, + // The close() method has been called, and no further output may be written. + // We are waiting for confirmation from the parent. + Closing, + // The close() method has been called, and we have been disconnected from + // our parent. + Closed, + // The channel is being disconnected from the parent, and all further events + // and data will pass unfiltered. Data received by the child in this state + // will be automatically written to the output stream listener. No data may + // be explicitly written. + Disconnecting, + // The channel has been disconnected from the parent, and all further data + // and events will be transparently passed to the output stream listener + // without passing through the child. + Disconnected, + // An error has occurred and the child is disconnected from the parent. + Error, + }; + + void Suspend(ErrorResult& aRv); + void Resume(ErrorResult& aRv); + void Disconnect(ErrorResult& aRv); + void Close(ErrorResult& aRv); + void Cleanup(); + + void Write(Data&& aData, ErrorResult& aRv); + + State GetState() const { return mState; } + + StreamFilterStatus Status() const; + + void RecvInitialized(bool aSuccess); + + protected: + IPCResult RecvStartRequest(); + IPCResult RecvData(Data&& data); + IPCResult RecvStopRequest(const nsresult& aStatus); + IPCResult RecvError(const nsCString& aError); + + IPCResult RecvClosed(); + IPCResult RecvSuspended(); + IPCResult RecvResumed(); + IPCResult RecvFlushData(); + + virtual void ActorDealloc() override; + + void SetStreamFilter(StreamFilter* aStreamFilter) { + mStreamFilter = aStreamFilter; + } + + private: + ~StreamFilterChild() = default; + + void SetNextState(); + + void MaybeStopRequest(); + + void EmitData(const Data& aData); + + bool CanFlushData() { + return (mState == State::TransferringData || mState == State::Resuming); + } + + void FlushBufferedData(); + void WriteBufferedData(); + + virtual void ActorDestroy(ActorDestroyReason aWhy) override; + + State mState; + State mNextState; + bool mReceivedOnStop; + + RefPtr<StreamFilter> mStreamFilter; +}; + +} // namespace extensions +} // namespace mozilla + +#endif // mozilla_extensions_StreamFilterChild_h diff --git a/toolkit/components/extensions/webrequest/StreamFilterEvents.cpp b/toolkit/components/extensions/webrequest/StreamFilterEvents.cpp new file mode 100644 index 0000000000..980e155d42 --- /dev/null +++ b/toolkit/components/extensions/webrequest/StreamFilterEvents.cpp @@ -0,0 +1,53 @@ +/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2; -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "mozilla/extensions/StreamFilterEvents.h" + +namespace mozilla { +namespace extensions { + +NS_IMPL_CYCLE_COLLECTION_CLASS(StreamFilterDataEvent) + +NS_IMPL_ADDREF_INHERITED(StreamFilterDataEvent, Event) +NS_IMPL_RELEASE_INHERITED(StreamFilterDataEvent, Event) + +NS_IMPL_CYCLE_COLLECTION_TRAVERSE_BEGIN_INHERITED(StreamFilterDataEvent, Event) +NS_IMPL_CYCLE_COLLECTION_TRAVERSE_END + +NS_IMPL_CYCLE_COLLECTION_TRACE_BEGIN_INHERITED(StreamFilterDataEvent, Event) + NS_IMPL_CYCLE_COLLECTION_TRACE_JS_MEMBER_CALLBACK(mData) +NS_IMPL_CYCLE_COLLECTION_TRACE_END + +NS_IMPL_CYCLE_COLLECTION_UNLINK_BEGIN_INHERITED(StreamFilterDataEvent, Event) + tmp->mData = nullptr; +NS_IMPL_CYCLE_COLLECTION_UNLINK_END + +NS_INTERFACE_MAP_BEGIN_CYCLE_COLLECTION(StreamFilterDataEvent) +NS_INTERFACE_MAP_END_INHERITING(Event) + +/* static */ +already_AddRefed<StreamFilterDataEvent> StreamFilterDataEvent::Constructor( + EventTarget* aEventTarget, const nsAString& aType, + const StreamFilterDataEventInit& aParam) { + RefPtr<StreamFilterDataEvent> event = new StreamFilterDataEvent(aEventTarget); + + bool trusted = event->Init(aEventTarget); + event->InitEvent(aType, aParam.mBubbles, aParam.mCancelable); + event->SetTrusted(trusted); + event->SetComposed(aParam.mComposed); + + event->SetData(aParam.mData); + + return event.forget(); +} + +JSObject* StreamFilterDataEvent::WrapObjectInternal( + JSContext* aCx, JS::Handle<JSObject*> aGivenProto) { + return StreamFilterDataEvent_Binding::Wrap(aCx, this, aGivenProto); +} + +} // namespace extensions +} // namespace mozilla diff --git a/toolkit/components/extensions/webrequest/StreamFilterEvents.h b/toolkit/components/extensions/webrequest/StreamFilterEvents.h new file mode 100644 index 0000000000..c058fa1910 --- /dev/null +++ b/toolkit/components/extensions/webrequest/StreamFilterEvents.h @@ -0,0 +1,68 @@ +/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2; -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_extensions_StreamFilterEvents_h +#define mozilla_extensions_StreamFilterEvents_h + +#include "mozilla/dom/BindingDeclarations.h" +#include "mozilla/dom/StreamFilterDataEventBinding.h" +#include "mozilla/extensions/StreamFilter.h" + +#include "js/RootingAPI.h" +#include "js/TypeDecls.h" + +#include "mozilla/HoldDropJSObjects.h" +#include "mozilla/dom/Event.h" +#include "nsCOMPtr.h" +#include "nsCycleCollectionParticipant.h" + +namespace mozilla { +namespace extensions { + +using namespace JS; +using namespace mozilla::dom; + +class StreamFilterDataEvent : public Event { + NS_DECL_ISUPPORTS_INHERITED + NS_DECL_CYCLE_COLLECTION_SCRIPT_HOLDER_CLASS_INHERITED(StreamFilterDataEvent, + Event) + + explicit StreamFilterDataEvent(EventTarget* aEventTarget) + : Event(aEventTarget, nullptr, nullptr) { + mozilla::HoldJSObjects(this); + } + + static already_AddRefed<StreamFilterDataEvent> Constructor( + EventTarget* aEventTarget, const nsAString& aType, + const StreamFilterDataEventInit& aParam); + + static already_AddRefed<StreamFilterDataEvent> Constructor( + GlobalObject& aGlobal, const nsAString& aType, + const StreamFilterDataEventInit& aParam) { + nsCOMPtr<EventTarget> target = do_QueryInterface(aGlobal.GetAsSupports()); + return Constructor(target, aType, aParam); + } + + void GetData(JSContext* aCx, JS::MutableHandleObject aResult) { + aResult.set(mData); + } + + virtual JSObject* WrapObjectInternal( + JSContext* aCx, JS::Handle<JSObject*> aGivenProto) override; + + protected: + virtual ~StreamFilterDataEvent() { mozilla::DropJSObjects(this); } + + private: + JS::Heap<JSObject*> mData; + + void SetData(const ArrayBuffer& aData) { mData = aData.Obj(); } +}; + +} // namespace extensions +} // namespace mozilla + +#endif // mozilla_extensions_StreamFilterEvents_h diff --git a/toolkit/components/extensions/webrequest/StreamFilterParent.cpp b/toolkit/components/extensions/webrequest/StreamFilterParent.cpp new file mode 100644 index 0000000000..e71e23ec33 --- /dev/null +++ b/toolkit/components/extensions/webrequest/StreamFilterParent.cpp @@ -0,0 +1,777 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "StreamFilterParent.h" + +#include "mozilla/Unused.h" +#include "mozilla/dom/ContentParent.h" +#include "mozilla/net/ChannelEventQueue.h" +#include "nsHttpChannel.h" +#include "nsIChannel.h" +#include "nsIInputStream.h" +#include "nsITraceableChannel.h" +#include "nsProxyRelease.h" +#include "nsQueryObject.h" +#include "nsSocketTransportService2.h" +#include "nsStringStream.h" +#include "mozilla/net/DocumentChannelChild.h" +#include "nsIViewSourceChannel.h" + +namespace mozilla { +namespace extensions { + +/***************************************************************************** + * Event queueing helpers + *****************************************************************************/ + +using net::ChannelEvent; +using net::ChannelEventQueue; + +namespace { + +// Define some simple ChannelEvent sub-classes that store the appropriate +// EventTarget and delegate their Run methods to a wrapped Runnable or lambda +// function. + +class ChannelEventWrapper : public ChannelEvent { + public: + ChannelEventWrapper(nsIEventTarget* aTarget) : mTarget(aTarget) {} + + already_AddRefed<nsIEventTarget> GetEventTarget() override { + return do_AddRef(mTarget); + } + + protected: + ~ChannelEventWrapper() override = default; + + private: + nsCOMPtr<nsIEventTarget> mTarget; +}; + +class ChannelEventFunction final : public ChannelEventWrapper { + public: + ChannelEventFunction(nsIEventTarget* aTarget, std::function<void()>&& aFunc) + : ChannelEventWrapper(aTarget), mFunc(std::move(aFunc)) {} + + void Run() override { mFunc(); } + + protected: + ~ChannelEventFunction() override = default; + + private: + std::function<void()> mFunc; +}; + +class ChannelEventRunnable final : public ChannelEventWrapper { + public: + ChannelEventRunnable(nsIEventTarget* aTarget, + already_AddRefed<Runnable> aRunnable) + : ChannelEventWrapper(aTarget), mRunnable(aRunnable) {} + + void Run() override { + nsresult rv = mRunnable->Run(); + Unused << NS_WARN_IF(NS_FAILED(rv)); + } + + protected: + ~ChannelEventRunnable() override = default; + + private: + RefPtr<Runnable> mRunnable; +}; + +} // anonymous namespace + +/***************************************************************************** + * Initialization + *****************************************************************************/ + +StreamFilterParent::StreamFilterParent() + : mMainThread(GetCurrentEventTarget()), + mIOThread(mMainThread), + mQueue(new ChannelEventQueue(static_cast<nsIStreamListener*>(this))), + mBufferMutex("StreamFilter buffer mutex"), + mReceivedStop(false), + mSentStop(false), + mContext(nullptr), + mOffset(0), + mState(State::Uninitialized) {} + +StreamFilterParent::~StreamFilterParent() { + NS_ReleaseOnMainThread("StreamFilterParent::mChannel", mChannel.forget()); + NS_ReleaseOnMainThread("StreamFilterParent::mLoadGroup", mLoadGroup.forget()); + NS_ReleaseOnMainThread("StreamFilterParent::mOrigListener", + mOrigListener.forget()); + NS_ReleaseOnMainThread("StreamFilterParent::mContext", mContext.forget()); +} + +auto StreamFilterParent::Create(dom::ContentParent* aContentParent, + uint64_t aChannelId, const nsAString& aAddonId) + -> RefPtr<ChildEndpointPromise> { + AssertIsMainThread(); + + auto& webreq = WebRequestService::GetSingleton(); + + RefPtr<nsAtom> addonId = NS_Atomize(aAddonId); + nsCOMPtr<nsITraceableChannel> channel = + webreq.GetTraceableChannel(aChannelId, addonId, aContentParent); + + RefPtr<mozilla::net::nsHttpChannel> chan = do_QueryObject(channel); + if (!chan) { + return ChildEndpointPromise::CreateAndReject(false, __func__); + } + + // Disable alt-data for extension stream listeners. + nsCOMPtr<nsIHttpChannelInternal> internal(do_QueryObject(channel)); + internal->DisableAltDataCache(); + + return chan->AttachStreamFilter(aContentParent ? aContentParent->OtherPid() + : base::GetCurrentProcId()); +} + +/* static */ +void StreamFilterParent::Attach(nsIChannel* aChannel, + ParentEndpoint&& aEndpoint) { + auto self = MakeRefPtr<StreamFilterParent>(); + + self->ActorThread()->Dispatch( + NewRunnableMethod<ParentEndpoint&&>("StreamFilterParent::Bind", self, + &StreamFilterParent::Bind, + std::move(aEndpoint)), + NS_DISPATCH_NORMAL); + + self->Init(aChannel); + + // IPC owns this reference now. + Unused << self.forget(); +} + +void StreamFilterParent::Bind(ParentEndpoint&& aEndpoint) { + aEndpoint.Bind(this); +} + +void StreamFilterParent::Init(nsIChannel* aChannel) { + mChannel = aChannel; + + nsCOMPtr<nsITraceableChannel> traceable = do_QueryInterface(aChannel); + if (MOZ_UNLIKELY(!traceable)) { + // nsViewSourceChannel is not nsITraceableChannel, but wraps one. Unwrap it. + nsCOMPtr<nsIViewSourceChannel> vsc = do_QueryInterface(aChannel); + if (vsc) { + traceable = do_QueryObject(vsc->GetInnerChannel()); + // OnStartRequest etc. is passed the unwrapped channel, so update mChannel + // to prevent OnStartRequest from mistaking it for a redirect, which would + // close the filter. + mChannel = do_QueryObject(traceable); + } + // TODO bug 1683403: Replace assertion; Close StreamFilter instead. + MOZ_RELEASE_ASSERT(traceable); + } + + nsresult rv = + traceable->SetNewListener(this, /* aMustApplyContentConversion = */ true, + getter_AddRefs(mOrigListener)); + MOZ_RELEASE_ASSERT(NS_SUCCEEDED(rv)); +} + +/***************************************************************************** + * nsIThreadRetargetableStreamListener + *****************************************************************************/ + +NS_IMETHODIMP +StreamFilterParent::CheckListenerChain() { + AssertIsMainThread(); + + nsCOMPtr<nsIThreadRetargetableStreamListener> trsl = + do_QueryInterface(mOrigListener); + if (trsl) { + return trsl->CheckListenerChain(); + } + return NS_ERROR_FAILURE; +} + +/***************************************************************************** + * Error handling + *****************************************************************************/ + +void StreamFilterParent::Broken() { + AssertIsActorThread(); + + switch (mState) { + case State::Initialized: + case State::TransferringData: + case State::Suspended: { + mState = State::Disconnecting; + RefPtr<StreamFilterParent> self(this); + RunOnMainThread(FUNC, [=] { + if (self->mChannel) { + self->mChannel->Cancel(NS_ERROR_FAILURE); + } + }); + + FinishDisconnect(); + } break; + + default: + break; + } +} + +/***************************************************************************** + * State change requests + *****************************************************************************/ + +IPCResult StreamFilterParent::RecvClose() { + AssertIsActorThread(); + + mState = State::Closed; + + if (!mSentStop) { + RefPtr<StreamFilterParent> self(this); + RunOnMainThread(FUNC, [=] { + nsresult rv = self->EmitStopRequest(NS_OK); + Unused << NS_WARN_IF(NS_FAILED(rv)); + }); + } + + Unused << SendClosed(); + Destroy(); + return IPC_OK(); +} + +void StreamFilterParent::Destroy() { + // Close the channel asynchronously so the actor is never destroyed before + // this message is fully processed. + ActorThread()->Dispatch(NewRunnableMethod("StreamFilterParent::Close", this, + &StreamFilterParent::Close), + NS_DISPATCH_NORMAL); +} + +IPCResult StreamFilterParent::RecvDestroy() { + AssertIsActorThread(); + Destroy(); + return IPC_OK(); +} + +IPCResult StreamFilterParent::RecvSuspend() { + AssertIsActorThread(); + + if (mState == State::TransferringData) { + RefPtr<StreamFilterParent> self(this); + RunOnMainThread(FUNC, [=] { + self->mChannel->Suspend(); + + RunOnActorThread(FUNC, [=] { + if (self->IPCActive()) { + self->mState = State::Suspended; + self->CheckResult(self->SendSuspended()); + } + }); + }); + } + return IPC_OK(); +} + +IPCResult StreamFilterParent::RecvResume() { + AssertIsActorThread(); + + if (mState == State::Suspended) { + // Change state before resuming so incoming data is handled correctly + // immediately after resuming. + mState = State::TransferringData; + + RefPtr<StreamFilterParent> self(this); + RunOnMainThread(FUNC, [=] { + self->mChannel->Resume(); + + RunOnActorThread(FUNC, [=] { + if (self->IPCActive()) { + self->CheckResult(self->SendResumed()); + } + }); + }); + } + return IPC_OK(); +} +IPCResult StreamFilterParent::RecvDisconnect() { + AssertIsActorThread(); + + if (mState == State::Suspended) { + RefPtr<StreamFilterParent> self(this); + RunOnMainThread(FUNC, [=] { self->mChannel->Resume(); }); + } else if (mState != State::TransferringData) { + return IPC_OK(); + } + + mState = State::Disconnecting; + CheckResult(SendFlushData()); + return IPC_OK(); +} + +IPCResult StreamFilterParent::RecvFlushedData() { + AssertIsActorThread(); + + MOZ_ASSERT(mState == State::Disconnecting); + + Destroy(); + + FinishDisconnect(); + return IPC_OK(); +} + +void StreamFilterParent::FinishDisconnect() { + RefPtr<StreamFilterParent> self(this); + RunOnIOThread(FUNC, [=] { + self->FlushBufferedData(); + + RunOnMainThread(FUNC, [=] { + if (self->mReceivedStop && !self->mSentStop) { + nsresult rv = self->EmitStopRequest(NS_OK); + Unused << NS_WARN_IF(NS_FAILED(rv)); + } else if (self->mLoadGroup && !self->mDisconnected) { + Unused << self->mLoadGroup->RemoveRequest(self, nullptr, NS_OK); + } + self->mDisconnected = true; + }); + + RunOnActorThread(FUNC, [=] { + if (self->mState != State::Closed) { + self->mState = State::Disconnected; + } + }); + }); +} + +/***************************************************************************** + * Data output + *****************************************************************************/ + +IPCResult StreamFilterParent::RecvWrite(Data&& aData) { + AssertIsActorThread(); + + RunOnIOThread(NewRunnableMethod<Data&&>("StreamFilterParent::WriteMove", this, + &StreamFilterParent::WriteMove, + std::move(aData))); + return IPC_OK(); +} + +void StreamFilterParent::WriteMove(Data&& aData) { + nsresult rv = Write(aData); + Unused << NS_WARN_IF(NS_FAILED(rv)); +} + +nsresult StreamFilterParent::Write(Data& aData) { + AssertIsIOThread(); + + nsCOMPtr<nsIInputStream> stream; + nsresult rv = NS_NewByteInputStream( + getter_AddRefs(stream), + Span(reinterpret_cast<char*>(aData.Elements()), aData.Length()), + NS_ASSIGNMENT_DEPEND); + NS_ENSURE_SUCCESS(rv, rv); + + rv = + mOrigListener->OnDataAvailable(mChannel, stream, mOffset, aData.Length()); + NS_ENSURE_SUCCESS(rv, rv); + + mOffset += aData.Length(); + return NS_OK; +} + +/***************************************************************************** + * nsIRequest + *****************************************************************************/ + +NS_IMETHODIMP +StreamFilterParent::GetName(nsACString& aName) { + AssertIsMainThread(); + MOZ_ASSERT(mChannel); + return mChannel->GetName(aName); +} + +NS_IMETHODIMP +StreamFilterParent::GetStatus(nsresult* aStatus) { + AssertIsMainThread(); + MOZ_ASSERT(mChannel); + return mChannel->GetStatus(aStatus); +} + +NS_IMETHODIMP +StreamFilterParent::IsPending(bool* aIsPending) { + switch (mState) { + case State::Initialized: + case State::TransferringData: + case State::Suspended: + *aIsPending = true; + break; + default: + *aIsPending = false; + } + return NS_OK; +} + +NS_IMETHODIMP +StreamFilterParent::Cancel(nsresult aResult) { + AssertIsMainThread(); + MOZ_ASSERT(mChannel); + return mChannel->Cancel(aResult); +} + +NS_IMETHODIMP +StreamFilterParent::Suspend() { + AssertIsMainThread(); + MOZ_ASSERT(mChannel); + return mChannel->Suspend(); +} + +NS_IMETHODIMP +StreamFilterParent::Resume() { + AssertIsMainThread(); + MOZ_ASSERT(mChannel); + return mChannel->Resume(); +} + +NS_IMETHODIMP +StreamFilterParent::GetLoadGroup(nsILoadGroup** aLoadGroup) { + *aLoadGroup = mLoadGroup; + return NS_OK; +} + +NS_IMETHODIMP +StreamFilterParent::SetLoadGroup(nsILoadGroup* aLoadGroup) { + return NS_ERROR_NOT_IMPLEMENTED; +} + +NS_IMETHODIMP +StreamFilterParent::GetLoadFlags(nsLoadFlags* aLoadFlags) { + AssertIsMainThread(); + MOZ_ASSERT(mChannel); + MOZ_TRY(mChannel->GetLoadFlags(aLoadFlags)); + *aLoadFlags &= ~nsIChannel::LOAD_DOCUMENT_URI; + return NS_OK; +} + +NS_IMETHODIMP +StreamFilterParent::SetLoadFlags(nsLoadFlags aLoadFlags) { + AssertIsMainThread(); + MOZ_ASSERT(mChannel); + return mChannel->SetLoadFlags(aLoadFlags); +} + +NS_IMETHODIMP +StreamFilterParent::GetTRRMode(nsIRequest::TRRMode* aTRRMode) { + return GetTRRModeImpl(aTRRMode); +} + +NS_IMETHODIMP +StreamFilterParent::SetTRRMode(nsIRequest::TRRMode aTRRMode) { + return SetTRRModeImpl(aTRRMode); +} + +/***************************************************************************** + * nsIStreamListener + *****************************************************************************/ + +NS_IMETHODIMP +StreamFilterParent::OnStartRequest(nsIRequest* aRequest) { + AssertIsMainThread(); + + // Always reset mChannel if aRequest is different. Various calls in + // StreamFilterParent will use mChannel, but aRequest is *always* the + // right channel to use at this point. + // + // For ALL redirections, we will disconnect this listener. Extensions + // will create a new filter if they need it. + if (aRequest != mChannel) { + nsCOMPtr<nsIChannel> channel = do_QueryInterface(aRequest); + nsCOMPtr<nsILoadInfo> loadInfo = channel ? channel->LoadInfo() : nullptr; + mChannel = channel; + + if (!(loadInfo && + loadInfo->RedirectChainIncludingInternalRedirects().IsEmpty())) { + mDisconnected = true; + mDisconnectedByOnStartRequest = true; + + RefPtr<StreamFilterParent> self(this); + RunOnActorThread(FUNC, [=] { + if (self->IPCActive()) { + self->mState = State::Disconnected; + CheckResult(self->SendError("Channel redirected"_ns)); + } + }); + } + } + + // Check if alterate cached data is being sent, if so we receive un-decoded + // data and we must disconnect the filter and send an error to the extension. + if (!mDisconnected) { + RefPtr<net::HttpBaseChannel> chan = do_QueryObject(aRequest); + if (chan && chan->IsDeliveringAltData()) { + mDisconnected = true; + mDisconnectedByOnStartRequest = true; + + RefPtr<StreamFilterParent> self(this); + RunOnActorThread(FUNC, [=] { + if (self->IPCActive()) { + self->mState = State::Disconnected; + CheckResult( + self->SendError("Channel is delivering cached alt-data"_ns)); + } + }); + } + } + + if (!mDisconnected) { + Unused << mChannel->GetLoadGroup(getter_AddRefs(mLoadGroup)); + if (mLoadGroup) { + Unused << mLoadGroup->AddRequest(this, nullptr); + } + } + + nsresult rv = mOrigListener->OnStartRequest(aRequest); + + // Important: Do this only *after* running the next listener in the chain, so + // that we get the final delivery target after any retargeting that it may do. + if (nsCOMPtr<nsIThreadRetargetableRequest> req = + do_QueryInterface(aRequest)) { + nsCOMPtr<nsIEventTarget> thread; + Unused << req->GetDeliveryTarget(getter_AddRefs(thread)); + if (thread) { + mIOThread = std::move(thread); + } + } + + // Important: Do this *after* we have set the thread delivery target, or it is + // possible in rare circumstances for an extension to attempt to write data + // before the thread has been set up, even though there are several layers of + // asynchrony involved. + if (!mDisconnected) { + RefPtr<StreamFilterParent> self(this); + RunOnActorThread(FUNC, [=] { + if (self->IPCActive()) { + self->mState = State::TransferringData; + self->CheckResult(self->SendStartRequest()); + } + }); + } + + return rv; +} + +NS_IMETHODIMP +StreamFilterParent::OnStopRequest(nsIRequest* aRequest, nsresult aStatusCode) { + AssertIsMainThread(); + MOZ_ASSERT(aRequest == mChannel); + + mReceivedStop = true; + if (mDisconnected) { + return EmitStopRequest(aStatusCode); + } + + RefPtr<StreamFilterParent> self(this); + RunOnActorThread(FUNC, [=] { + if (self->IPCActive()) { + self->CheckResult(self->SendStopRequest(aStatusCode)); + } else if (self->mState != State::Disconnecting) { + // If we're currently disconnecting, then we'll emit a stop + // request at the end of that process. Otherwise we need to + // manually emit one here, since we won't be getting a response + // from the child. + RunOnMainThread(FUNC, [=] { + if (!self->mSentStop) { + self->EmitStopRequest(aStatusCode); + } + }); + } + }); + return NS_OK; +} + +nsresult StreamFilterParent::EmitStopRequest(nsresult aStatusCode) { + AssertIsMainThread(); + MOZ_ASSERT(!mSentStop); + + mSentStop = true; + nsresult rv = mOrigListener->OnStopRequest(mChannel, aStatusCode); + + if (mLoadGroup && !mDisconnected) { + Unused << mLoadGroup->RemoveRequest(this, nullptr, aStatusCode); + } + + return rv; +} + +/***************************************************************************** + * Incoming data handling + *****************************************************************************/ + +void StreamFilterParent::DoSendData(Data&& aData) { + AssertIsActorThread(); + + if (mState == State::TransferringData) { + CheckResult(SendData(aData)); + } +} + +NS_IMETHODIMP +StreamFilterParent::OnDataAvailable(nsIRequest* aRequest, + nsIInputStream* aInputStream, + uint64_t aOffset, uint32_t aCount) { + AssertIsIOThread(); + + if (mDisconnectedByOnStartRequest || mState == State::Disconnected) { + // If we're offloading data in a thread pool, it's possible that we'll + // have buffered some additional data while waiting for the buffer to + // flush. So, if there's any buffered data left, flush that before we + // flush this incoming data. + // + // Note: When in the eDisconnected state, the buffer list is guaranteed + // never to be accessed by another thread during an OnDataAvailable call. + if (!mBufferedData.isEmpty()) { + FlushBufferedData(); + } + + mOffset += aCount; + return mOrigListener->OnDataAvailable(aRequest, aInputStream, + mOffset - aCount, aCount); + } + + Data data; + data.SetLength(aCount); + + uint32_t count; + nsresult rv = aInputStream->Read(reinterpret_cast<char*>(data.Elements()), + aCount, &count); + NS_ENSURE_SUCCESS(rv, rv); + NS_ENSURE_TRUE(count == aCount, NS_ERROR_UNEXPECTED); + + if (mState == State::Disconnecting) { + MutexAutoLock al(mBufferMutex); + BufferData(std::move(data)); + } else if (mState == State::Closed) { + return NS_ERROR_FAILURE; + } else { + ActorThread()->Dispatch( + NewRunnableMethod<Data&&>("StreamFilterParent::DoSendData", this, + &StreamFilterParent::DoSendData, + std::move(data)), + NS_DISPATCH_NORMAL); + } + return NS_OK; +} + +nsresult StreamFilterParent::FlushBufferedData() { + AssertIsIOThread(); + + // When offloading data to a thread pool, OnDataAvailable isn't guaranteed + // to always run in the same thread, so it's possible for this function to + // run in parallel with OnDataAvailable. + MutexAutoLock al(mBufferMutex); + + while (!mBufferedData.isEmpty()) { + UniquePtr<BufferedData> data(mBufferedData.popFirst()); + + nsresult rv = Write(data->mData); + NS_ENSURE_SUCCESS(rv, rv); + } + + return NS_OK; +} + +/***************************************************************************** + * Thread helpers + *****************************************************************************/ + +nsIEventTarget* StreamFilterParent::ActorThread() { + return net::gSocketTransportService; +} + +bool StreamFilterParent::IsActorThread() { + return ActorThread()->IsOnCurrentThread(); +} + +void StreamFilterParent::AssertIsActorThread() { MOZ_ASSERT(IsActorThread()); } + +nsIEventTarget* StreamFilterParent::IOThread() { return mIOThread; } + +bool StreamFilterParent::IsIOThread() { return mIOThread->IsOnCurrentThread(); } + +void StreamFilterParent::AssertIsIOThread() { MOZ_ASSERT(IsIOThread()); } + +template <typename Function> +void StreamFilterParent::RunOnMainThread(const char* aName, Function&& aFunc) { + mQueue->RunOrEnqueue(new ChannelEventFunction(mMainThread, std::move(aFunc))); +} + +void StreamFilterParent::RunOnMainThread(already_AddRefed<Runnable> aRunnable) { + mQueue->RunOrEnqueue( + new ChannelEventRunnable(mMainThread, std::move(aRunnable))); +} + +template <typename Function> +void StreamFilterParent::RunOnIOThread(const char* aName, Function&& aFunc) { + mQueue->RunOrEnqueue(new ChannelEventFunction(mIOThread, std::move(aFunc))); +} + +void StreamFilterParent::RunOnIOThread(already_AddRefed<Runnable> aRunnable) { + mQueue->RunOrEnqueue( + new ChannelEventRunnable(mIOThread, std::move(aRunnable))); +} + +template <typename Function> +void StreamFilterParent::RunOnActorThread(const char* aName, Function&& aFunc) { + // We don't use mQueue for dispatch to the actor thread. + // + // The main thread and IO thread are used for dispatching events to the + // wrapped stream listener, and those events need to be processed + // consistently, in the order they were dispatched. An event dispatched to the + // main thread can't be run before events that were dispatched to the IO + // thread before it. + // + // Additionally, the IO thread is likely to be a thread pool, which means that + // without thread-safe queuing, it's possible for multiple events dispatched + // to it to be processed in parallel, or out of order. + // + // The actor thread, however, is always a serial event target. Its events are + // always processed in order, and events dispatched to the actor thread are + // independent of the events in the output event queue. + if (IsActorThread()) { + aFunc(); + } else { + ActorThread()->Dispatch(std::move(NS_NewRunnableFunction(aName, aFunc)), + NS_DISPATCH_NORMAL); + } +} + +/***************************************************************************** + * Glue + *****************************************************************************/ + +void StreamFilterParent::ActorDestroy(ActorDestroyReason aWhy) { + AssertIsActorThread(); + + if (mState != State::Disconnected && mState != State::Closed) { + Broken(); + } +} + +void StreamFilterParent::ActorDealloc() { + RefPtr<StreamFilterParent> self = dont_AddRef(this); +} + +NS_INTERFACE_MAP_BEGIN(StreamFilterParent) + NS_INTERFACE_MAP_ENTRY(nsIStreamListener) + NS_INTERFACE_MAP_ENTRY(nsIRequestObserver) + NS_INTERFACE_MAP_ENTRY(nsIRequest) + NS_INTERFACE_MAP_ENTRY(nsIThreadRetargetableStreamListener) + NS_INTERFACE_MAP_ENTRY_AMBIGUOUS(nsISupports, nsIStreamListener) +NS_INTERFACE_MAP_END + +NS_IMPL_ADDREF(StreamFilterParent) +NS_IMPL_RELEASE(StreamFilterParent) + +} // namespace extensions +} // namespace mozilla diff --git a/toolkit/components/extensions/webrequest/StreamFilterParent.h b/toolkit/components/extensions/webrequest/StreamFilterParent.h new file mode 100644 index 0000000000..6a7b6cd0d0 --- /dev/null +++ b/toolkit/components/extensions/webrequest/StreamFilterParent.h @@ -0,0 +1,195 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_extensions_StreamFilterParent_h +#define mozilla_extensions_StreamFilterParent_h + +#include "StreamFilterBase.h" +#include "mozilla/extensions/PStreamFilterParent.h" + +#include "mozilla/LinkedList.h" +#include "mozilla/Mutex.h" +#include "mozilla/WebRequestService.h" +#include "nsIStreamListener.h" +#include "nsIThread.h" +#include "nsIThreadRetargetableStreamListener.h" +#include "nsThreadUtils.h" + +#if defined(_MSC_VER) +# define FUNC __FUNCSIG__ +#else +# define FUNC __PRETTY_FUNCTION__ +#endif + +namespace mozilla { +namespace dom { +class ContentParent; +} +namespace net { +class ChannelEventQueue; +class nsHttpChannel; +} // namespace net + +namespace extensions { + +using namespace mozilla::dom; +using mozilla::ipc::IPCResult; + +class StreamFilterParent final : public PStreamFilterParent, + public nsIStreamListener, + public nsIThreadRetargetableStreamListener, + public nsIRequest, + public StreamFilterBase { + friend class PStreamFilterParent; + + public: + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSISTREAMLISTENER + NS_DECL_NSIREQUEST + NS_DECL_NSIREQUESTOBSERVER + NS_DECL_NSITHREADRETARGETABLESTREAMLISTENER + + StreamFilterParent(); + + using ParentEndpoint = mozilla::ipc::Endpoint<PStreamFilterParent>; + using ChildEndpoint = mozilla::ipc::Endpoint<PStreamFilterChild>; + + using ChildEndpointPromise = MozPromise<ChildEndpoint, bool, true>; + + [[nodiscard]] static RefPtr<ChildEndpointPromise> Create( + ContentParent* aContentParent, uint64_t aChannelId, + const nsAString& aAddonId); + + static void Attach(nsIChannel* aChannel, ParentEndpoint&& aEndpoint); + + enum class State { + // The parent has been created, but not yet constructed by the child. + Uninitialized, + // The parent has been successfully constructed. + Initialized, + // The OnRequestStarted event has been received, and data is being + // transferred to the child. + TransferringData, + // The channel is suspended. + Suspended, + // The channel has been closed by the child, and will send or receive data. + Closed, + // The channel is being disconnected from the child, so that all further + // data and events pass unfiltered to the output listener. Any data + // currnetly in transit to, or buffered by, the child will be written to the + // output listener before we enter the Disconnected atate. + Disconnecting, + // The channel has been disconnected from the child, and all further data + // and events will be passed directly to the output listener. + Disconnected, + }; + + protected: + virtual ~StreamFilterParent(); + + IPCResult RecvWrite(Data&& aData); + IPCResult RecvFlushedData(); + IPCResult RecvSuspend(); + IPCResult RecvResume(); + IPCResult RecvClose(); + IPCResult RecvDisconnect(); + IPCResult RecvDestroy(); + + virtual void ActorDealloc() override; + + private: + bool IPCActive() { + return (mState != State::Closed && mState != State::Disconnecting && + mState != State::Disconnected); + } + + void Init(nsIChannel* aChannel); + + void Bind(ParentEndpoint&& aEndpoint); + + void Destroy(); + + nsresult FlushBufferedData(); + + nsresult Write(Data& aData); + + void WriteMove(Data&& aData); + + void DoSendData(Data&& aData); + + nsresult EmitStopRequest(nsresult aStatusCode); + + virtual void ActorDestroy(ActorDestroyReason aWhy) override; + + void Broken(); + void FinishDisconnect(); + + void CheckResult(bool aResult) { + if (NS_WARN_IF(!aResult)) { + Broken(); + } + } + + inline nsIEventTarget* ActorThread(); + + inline nsIEventTarget* IOThread(); + + inline bool IsIOThread(); + + inline bool IsActorThread(); + + inline void AssertIsActorThread(); + + inline void AssertIsIOThread(); + + static void AssertIsMainThread() { MOZ_ASSERT(NS_IsMainThread()); } + + template <typename Function> + void RunOnMainThread(const char* aName, Function&& aFunc); + + void RunOnMainThread(already_AddRefed<Runnable> aRunnable); + + template <typename Function> + void RunOnActorThread(const char* aName, Function&& aFunc); + + template <typename Function> + void RunOnIOThread(const char* aName, Function&& aFunc); + + void RunOnIOThread(already_AddRefed<Runnable>); + + nsCOMPtr<nsIChannel> mChannel; + nsCOMPtr<nsILoadGroup> mLoadGroup; + nsCOMPtr<nsIStreamListener> mOrigListener; + + nsCOMPtr<nsIEventTarget> mMainThread; + nsCOMPtr<nsIEventTarget> mIOThread; + + RefPtr<net::ChannelEventQueue> mQueue; + + Mutex mBufferMutex; + + bool mReceivedStop; + bool mSentStop; + bool mDisconnected = false; + + // If redirection happens or alterate cached data is being sent, the stream + // filter is disconnected in OnStartRequest and the following ODA would not + // be filtered. Using mDisconnected causes race condition. mState is possible + // to late to be set, which leads out of sync. + bool mDisconnectedByOnStartRequest = false; + + nsCOMPtr<nsISupports> mContext; + uint64_t mOffset; + + // Use Release-Acquire ordering to ensure the OMT ODA is not sent while + // the channel is disconnecting or closed. + Atomic<State, ReleaseAcquire> mState; +}; + +} // namespace extensions +} // namespace mozilla + +#endif // mozilla_extensions_StreamFilterParent_h diff --git a/toolkit/components/extensions/webrequest/WebRequest.jsm b/toolkit/components/extensions/webrequest/WebRequest.jsm new file mode 100644 index 0000000000..7385178ce8 --- /dev/null +++ b/toolkit/components/extensions/webrequest/WebRequest.jsm @@ -0,0 +1,1187 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +"use strict"; + +const EXPORTED_SYMBOLS = ["WebRequest"]; + +/* exported WebRequest */ + +/* globals ChannelWrapper */ + +const { nsIHttpActivityObserver, nsISocketTransport } = Ci; + +const { Services } = ChromeUtils.import("resource://gre/modules/Services.jsm"); +const { XPCOMUtils } = ChromeUtils.import( + "resource://gre/modules/XPCOMUtils.jsm" +); + +XPCOMUtils.defineLazyModuleGetters(this, { + ExtensionParent: "resource://gre/modules/ExtensionParent.jsm", + ExtensionUtils: "resource://gre/modules/ExtensionUtils.jsm", + WebRequestUpload: "resource://gre/modules/WebRequestUpload.jsm", + SecurityInfo: "resource://gre/modules/SecurityInfo.jsm", +}); + +// WebRequest.jsm's only consumer is ext-webRequest.js, so we can depend on +// the apiManager.global being initialized. +XPCOMUtils.defineLazyGetter(this, "tabTracker", () => { + return ExtensionParent.apiManager.global.tabTracker; +}); +XPCOMUtils.defineLazyGetter(this, "getCookieStoreIdForOriginAttributes", () => { + return ExtensionParent.apiManager.global.getCookieStoreIdForOriginAttributes; +}); + +function runLater(job) { + Services.tm.dispatchToMainThread(job); +} + +function parseFilter(filter) { + if (!filter) { + filter = {}; + } + + return { + urls: filter.urls || null, + types: filter.types || null, + tabId: filter.tabId ?? null, + windowId: filter.windowId ?? null, + incognito: filter.incognito ?? null, + }; +} + +function parseExtra(extra, allowed = [], optionsObj = {}) { + if (extra) { + for (let ex of extra) { + if (!allowed.includes(ex)) { + throw new ExtensionUtils.ExtensionError(`Invalid option ${ex}`); + } + } + } + + let result = Object.assign({}, optionsObj); + for (let al of allowed) { + if (extra && extra.includes(al)) { + result[al] = true; + } + } + return result; +} + +function isThenable(value) { + return value && typeof value === "object" && typeof value.then === "function"; +} + +class HeaderChanger { + constructor(channel) { + this.channel = channel; + + this.array = this.readHeaders(); + } + + getMap() { + if (!this.map) { + this.map = new Map(); + for (let header of this.array) { + this.map.set(header.name.toLowerCase(), header); + } + } + return this.map; + } + + toArray() { + return this.array; + } + + validateHeaders(headers) { + // We should probably use schema validation for this. + + if (!Array.isArray(headers)) { + return false; + } + + return headers.every(header => { + if (typeof header !== "object" || header === null) { + return false; + } + + if (typeof header.name !== "string") { + return false; + } + + return ( + typeof header.value === "string" || Array.isArray(header.binaryValue) + ); + }); + } + + applyChanges(headers, opts = {}) { + if (!this.validateHeaders(headers)) { + /* globals uneval */ + Cu.reportError(`Invalid header array: ${uneval(headers)}`); + return; + } + + let newHeaders = new Set(headers.map(({ name }) => name.toLowerCase())); + + // Remove missing headers. + let origHeaders = this.getMap(); + for (let name of origHeaders.keys()) { + if (!newHeaders.has(name)) { + this.setHeader(name, "", false, opts, name); + } + } + + // Set new or changed headers. If there are multiple headers with the same + // name (e.g. Set-Cookie), merge them, instead of having new values + // overwrite previous ones. + // + // When the new value of a header is equal the existing value of the header + // (e.g. the initial response set "Set-Cookie: examplename=examplevalue", + // and an extension also added the header + // "Set-Cookie: examplename=examplevalue") then the header value is not + // re-set, but subsequent headers of the same type will be merged in. + // + // Multiple addons will be able to provide modifications to any headers + // listed in the default set. + let headersAlreadySet = new Set(); + for (let { name, value, binaryValue } of headers) { + if (binaryValue) { + value = String.fromCharCode(...binaryValue); + } + + let lowerCaseName = name.toLowerCase(); + let original = origHeaders.get(lowerCaseName); + + if (!original || value !== original.value) { + let shouldMerge = headersAlreadySet.has(lowerCaseName); + this.setHeader(name, value, shouldMerge, opts, lowerCaseName); + } + + headersAlreadySet.add(lowerCaseName); + } + } +} + +const checkRestrictedHeaderValue = (value, opts = {}) => { + let uri = Services.io.newURI(`https://${value}/`); + let { policy } = opts; + + if (policy && !policy.allowedOrigins.matches(uri)) { + throw new Error(`Unable to set host header, url missing from permissions.`); + } + + if (WebExtensionPolicy.isRestrictedURI(uri)) { + throw new Error(`Unable to set host header to restricted url.`); + } +}; + +class RequestHeaderChanger extends HeaderChanger { + setHeader(name, value, merge, opts, lowerCaseName) { + try { + if (value && lowerCaseName === "host") { + checkRestrictedHeaderValue(value, opts); + } + this.channel.setRequestHeader(name, value, merge); + } catch (e) { + Cu.reportError(new Error(`Error setting request header ${name}: ${e}`)); + } + } + + readHeaders() { + return this.channel.getRequestHeaders(); + } +} + +class ResponseHeaderChanger extends HeaderChanger { + didModifyCSP = false; + + setHeader(name, value, merge, opts, lowerCaseName) { + if (lowerCaseName === "content-security-policy") { + // When multiple add-ons change the CSP, enforce the combined (strictest) + // policy - see bug 1462989 for motivation. + // When value is unset, don't force the header to be merged, to allow + // add-ons to clear the header if wanted. + if (value) { + merge = merge || this.didModifyCSP; + } + this.didModifyCSP = true; + } + try { + this.channel.setResponseHeader(name, value, merge); + } catch (e) { + Cu.reportError(new Error(`Error setting response header ${name}: ${e}`)); + } + } + + readHeaders() { + return this.channel.getResponseHeaders(); + } +} + +const MAYBE_CACHED_EVENTS = new Set([ + "onResponseStarted", + "onHeadersReceived", + "onBeforeRedirect", + "onCompleted", + "onErrorOccurred", +]); + +const OPTIONAL_PROPERTIES = [ + "requestHeaders", + "responseHeaders", + "statusCode", + "statusLine", + "error", + "redirectUrl", + "requestBody", + "scheme", + "realm", + "isProxy", + "challenger", + "proxyInfo", + "ip", + "frameAncestors", + "urlClassification", + "requestSize", + "responseSize", +]; + +function serializeRequestData(eventName) { + let data = { + requestId: this.requestId, + url: this.url, + originUrl: this.originUrl, + documentUrl: this.documentUrl, + method: this.method, + type: this.type, + timeStamp: Date.now(), + tabId: this.tabId, + frameId: this.frameId, + parentFrameId: this.parentFrameId, + incognito: this.incognito, + thirdParty: this.thirdParty, + cookieStoreId: this.cookieStoreId, + }; + + if (MAYBE_CACHED_EVENTS.has(eventName)) { + data.fromCache = !!this.fromCache; + } + + for (let opt of OPTIONAL_PROPERTIES) { + if (typeof this[opt] !== "undefined") { + data[opt] = this[opt]; + } + } + + if (this.urlClassification) { + data.urlClassification = { + firstParty: this.urlClassification.firstParty.filter( + c => !c.startsWith("socialtracking_") + ), + thirdParty: this.urlClassification.thirdParty.filter( + c => !c.startsWith("socialtracking_") + ), + }; + } + + return data; +} + +var HttpObserverManager; + +var ChannelEventSink = { + _classDescription: "WebRequest channel event sink", + _classID: Components.ID("115062f8-92f1-11e5-8b7f-080027b0f7ec"), + _contractID: "@mozilla.org/webrequest/channel-event-sink;1", + + QueryInterface: ChromeUtils.generateQI(["nsIChannelEventSink", "nsIFactory"]), + + init() { + Components.manager + .QueryInterface(Ci.nsIComponentRegistrar) + .registerFactory( + this._classID, + this._classDescription, + this._contractID, + this + ); + }, + + register() { + Services.catMan.addCategoryEntry( + "net-channel-event-sinks", + this._contractID, + this._contractID, + false, + true + ); + }, + + unregister() { + Services.catMan.deleteCategoryEntry( + "net-channel-event-sinks", + this._contractID, + false + ); + }, + + // nsIChannelEventSink implementation + asyncOnChannelRedirect(oldChannel, newChannel, flags, redirectCallback) { + runLater(() => redirectCallback.onRedirectVerifyCallback(Cr.NS_OK)); + try { + HttpObserverManager.onChannelReplaced(oldChannel, newChannel); + } catch (e) { + // we don't wanna throw: it would abort the redirection + } + }, + + // nsIFactory implementation + createInstance(outer, iid) { + if (outer) { + throw Components.Exception("", Cr.NS_ERROR_NO_AGGREGATION); + } + return this.QueryInterface(iid); + }, +}; + +ChannelEventSink.init(); + +// nsIAuthPrompt2 implementation for onAuthRequired +class AuthRequestor { + constructor(channel, httpObserver) { + this.notificationCallbacks = channel.notificationCallbacks; + this.loadGroupCallbacks = + channel.loadGroup && channel.loadGroup.notificationCallbacks; + this.httpObserver = httpObserver; + } + + getInterface(iid) { + if (iid.equals(Ci.nsIAuthPromptProvider) || iid.equals(Ci.nsIAuthPrompt2)) { + return this; + } + try { + return this.notificationCallbacks.getInterface(iid); + } catch (e) {} + throw Components.Exception("", Cr.NS_ERROR_NO_INTERFACE); + } + + _getForwardedInterface(iid) { + try { + return this.notificationCallbacks.getInterface(iid); + } catch (e) { + return this.loadGroupCallbacks.getInterface(iid); + } + } + + // nsIAuthPromptProvider getAuthPrompt + getAuthPrompt(reason, iid) { + // This should never get called without getInterface having been called first. + if (iid.equals(Ci.nsIAuthPrompt2)) { + return this; + } + return this._getForwardedInterface(Ci.nsIAuthPromptProvider).getAuthPrompt( + reason, + iid + ); + } + + // nsIAuthPrompt2 promptAuth + promptAuth(channel, level, authInfo) { + this._getForwardedInterface(Ci.nsIAuthPrompt2).promptAuth( + channel, + level, + authInfo + ); + } + + _getForwardPrompt(data) { + let reason = data.isProxy + ? Ci.nsIAuthPromptProvider.PROMPT_PROXY + : Ci.nsIAuthPromptProvider.PROMPT_NORMAL; + for (let callbacks of [ + this.notificationCallbacks, + this.loadGroupCallbacks, + ]) { + try { + return callbacks + .getInterface(Ci.nsIAuthPromptProvider) + .getAuthPrompt(reason, Ci.nsIAuthPrompt2); + } catch (e) {} + try { + return callbacks.getInterface(Ci.nsIAuthPrompt2); + } catch (e) {} + } + throw Components.Exception("", Cr.NS_ERROR_NO_INTERFACE); + } + + // nsIAuthPrompt2 asyncPromptAuth + asyncPromptAuth(channel, callback, context, level, authInfo) { + let wrapper = ChannelWrapper.get(channel); + + let uri = channel.URI; + let proxyInfo; + let isProxy = !!(authInfo.flags & authInfo.AUTH_PROXY); + if (isProxy && channel instanceof Ci.nsIProxiedChannel) { + proxyInfo = channel.proxyInfo; + } + let data = { + scheme: authInfo.authenticationScheme, + realm: authInfo.realm, + isProxy, + challenger: { + host: proxyInfo ? proxyInfo.host : uri.host, + port: proxyInfo ? proxyInfo.port : uri.port, + }, + }; + + // In the case that no listener provides credentials, we fallback to the + // previously set callback class for authentication. + wrapper.authPromptForward = () => { + try { + let prompt = this._getForwardPrompt(data); + prompt.asyncPromptAuth(channel, callback, context, level, authInfo); + } catch (e) { + Cu.reportError(`webRequest asyncPromptAuth failure ${e}`); + callback.onAuthCancelled(context, false); + } + wrapper.authPromptForward = null; + wrapper.authPromptCallback = null; + }; + wrapper.authPromptCallback = authCredentials => { + // The API allows for canceling the request, providing credentials or + // doing nothing, so we do not provide a way to call onAuthCanceled. + // Canceling the request will result in canceling the authentication. + if ( + authCredentials && + typeof authCredentials.username === "string" && + typeof authCredentials.password === "string" + ) { + authInfo.username = authCredentials.username; + authInfo.password = authCredentials.password; + try { + callback.onAuthAvailable(context, authInfo); + } catch (e) { + Cu.reportError(`webRequest onAuthAvailable failure ${e}`); + } + // At least one addon has responded, so we won't forward to the regular + // prompt handlers. + wrapper.authPromptForward = null; + wrapper.authPromptCallback = null; + } + }; + + this.httpObserver.runChannelListener(wrapper, "onAuthRequired", data); + + return { + QueryInterface: ChromeUtils.generateQI(["nsICancelable"]), + cancel() { + try { + callback.onAuthCancelled(context, false); + } catch (e) { + Cu.reportError(`webRequest onAuthCancelled failure ${e}`); + } + wrapper.authPromptForward = null; + wrapper.authPromptCallback = null; + }, + }; + } +} + +AuthRequestor.prototype.QueryInterface = ChromeUtils.generateQI([ + "nsIInterfaceRequestor", + "nsIAuthPromptProvider", + "nsIAuthPrompt2", +]); + +// Most WebRequest events are implemented via the observer services, but +// a few use custom xpcom interfaces. This class (HttpObserverManager) +// serves two main purposes: +// 1. It abstracts away the names and details of the underlying +// implementation (e.g., onBeforeBeforeRequest is dispatched from +// the http-on-modify-request observable). +// 2. It aggregates multiple listeners so that a single observer or +// handler can serve multiple webRequest listeners. +HttpObserverManager = { + listeners: { + // onBeforeRequest uses http-on-modify observer for HTTP(S). + onBeforeRequest: new Map(), + + // onBeforeSendHeaders and onSendHeaders correspond to the + // http-on-before-connect observer. + onBeforeSendHeaders: new Map(), + onSendHeaders: new Map(), + + // onHeadersReceived corresponds to the http-on-examine-* obserservers. + onHeadersReceived: new Map(), + + // onAuthRequired is handled via the nsIAuthPrompt2 xpcom interface + // which is managed here by AuthRequestor. + onAuthRequired: new Map(), + + // onBeforeRedirect is handled by the nsIChannelEVentSink xpcom interface + // which is managed here by ChannelEventSink. + onBeforeRedirect: new Map(), + + // onResponseStarted, onErrorOccurred, and OnCompleted correspond + // to events dispatched by the ChannelWrapper EventTarget. + onResponseStarted: new Map(), + onErrorOccurred: new Map(), + onCompleted: new Map(), + }, + + openingInitialized: false, + beforeConnectInitialized: false, + examineInitialized: false, + redirectInitialized: false, + activityInitialized: false, + needTracing: false, + hasRedirects: false, + + getWrapper(nativeChannel) { + let wrapper = ChannelWrapper.get(nativeChannel); + if (!wrapper._addedListeners) { + /* eslint-disable mozilla/balanced-listeners */ + if (this.listeners.onErrorOccurred.size) { + wrapper.addEventListener("error", this); + } + if (this.listeners.onResponseStarted.size) { + wrapper.addEventListener("start", this); + } + if (this.listeners.onCompleted.size) { + wrapper.addEventListener("stop", this); + } + /* eslint-enable mozilla/balanced-listeners */ + + wrapper._addedListeners = true; + } + return wrapper; + }, + + get activityDistributor() { + return Cc["@mozilla.org/network/http-activity-distributor;1"].getService( + Ci.nsIHttpActivityDistributor + ); + }, + + // This method is called whenever webRequest listeners are added or removed. + // It reconciles the set of listeners with underlying observers, event + // handlers, etc. by adding new low-level handlers for any newly added + // webRequest listeners and removing those that are no longer needed if + // there are no more listeners for corresponding webRequest events. + addOrRemove() { + let needOpening = this.listeners.onBeforeRequest.size; + let needBeforeConnect = + this.listeners.onBeforeSendHeaders.size || + this.listeners.onSendHeaders.size; + if (needOpening && !this.openingInitialized) { + this.openingInitialized = true; + Services.obs.addObserver(this, "http-on-modify-request"); + } else if (!needOpening && this.openingInitialized) { + this.openingInitialized = false; + Services.obs.removeObserver(this, "http-on-modify-request"); + } + if (needBeforeConnect && !this.beforeConnectInitialized) { + this.beforeConnectInitialized = true; + Services.obs.addObserver(this, "http-on-before-connect"); + } else if (!needBeforeConnect && this.beforeConnectInitialized) { + this.beforeConnectInitialized = false; + Services.obs.removeObserver(this, "http-on-before-connect"); + } + + let haveBlocking = Object.values(this.listeners).some(listeners => + Array.from(listeners.values()).some(listener => listener.blockingAllowed) + ); + + this.needTracing = + this.listeners.onResponseStarted.size || + this.listeners.onErrorOccurred.size || + this.listeners.onCompleted.size || + haveBlocking; + + let needExamine = + this.needTracing || + this.listeners.onHeadersReceived.size || + this.listeners.onAuthRequired.size; + + if (needExamine && !this.examineInitialized) { + this.examineInitialized = true; + Services.obs.addObserver(this, "http-on-examine-response"); + Services.obs.addObserver(this, "http-on-examine-cached-response"); + Services.obs.addObserver(this, "http-on-examine-merged-response"); + } else if (!needExamine && this.examineInitialized) { + this.examineInitialized = false; + Services.obs.removeObserver(this, "http-on-examine-response"); + Services.obs.removeObserver(this, "http-on-examine-cached-response"); + Services.obs.removeObserver(this, "http-on-examine-merged-response"); + } + + // If we have any listeners, we need the channelsink so the channelwrapper is + // updated properly. Otherwise events for channels that are redirected will not + // happen correctly. If we have no listeners, shut it down. + this.hasRedirects = this.listeners.onBeforeRedirect.size > 0; + let needRedirect = + this.hasRedirects || needExamine || needOpening || needBeforeConnect; + if (needRedirect && !this.redirectInitialized) { + this.redirectInitialized = true; + ChannelEventSink.register(); + } else if (!needRedirect && this.redirectInitialized) { + this.redirectInitialized = false; + ChannelEventSink.unregister(); + } + + let needActivity = this.listeners.onErrorOccurred.size; + if (needActivity && !this.activityInitialized) { + this.activityInitialized = true; + this.activityDistributor.addObserver(this); + } else if (!needActivity && this.activityInitialized) { + this.activityInitialized = false; + this.activityDistributor.removeObserver(this); + } + }, + + addListener(kind, callback, opts) { + this.listeners[kind].set(callback, opts); + this.addOrRemove(); + }, + + removeListener(kind, callback) { + this.listeners[kind].delete(callback); + this.addOrRemove(); + }, + + observe(subject, topic, data) { + let channel = this.getWrapper(subject); + switch (topic) { + case "http-on-modify-request": + this.runChannelListener(channel, "onBeforeRequest"); + break; + case "http-on-before-connect": + this.runChannelListener(channel, "onBeforeSendHeaders"); + break; + case "http-on-examine-cached-response": + case "http-on-examine-merged-response": + channel.fromCache = true; + // falls through + case "http-on-examine-response": + this.examine(channel, topic, data); + break; + } + }, + + // We map activity values with tentative error names, e.g. "STATUS_RESOLVING" => "NS_ERROR_NET_ON_RESOLVING". + get activityErrorsMap() { + let prefix = /^(?:ACTIVITY_SUBTYPE_|STATUS_)/; + let map = new Map(); + for (let iface of [nsIHttpActivityObserver, nsISocketTransport]) { + for (let c of Object.keys(iface).filter(name => prefix.test(name))) { + map.set(iface[c], c.replace(prefix, "NS_ERROR_NET_ON_")); + } + } + delete this.activityErrorsMap; + this.activityErrorsMap = map; + return this.activityErrorsMap; + }, + GOOD_LAST_ACTIVITY: nsIHttpActivityObserver.ACTIVITY_SUBTYPE_RESPONSE_HEADER, + observeActivity( + nativeChannel, + activityType, + activitySubtype /* , aTimestamp, aExtraSizeData, aExtraStringData */ + ) { + // Sometimes we get a NullHttpChannel, which implements + // nsIHttpChannel but not nsIChannel. + if (!(nativeChannel instanceof Ci.nsIChannel)) { + return; + } + let channel = this.getWrapper(nativeChannel); + + let lastActivity = channel.lastActivity || 0; + if ( + activitySubtype === + nsIHttpActivityObserver.ACTIVITY_SUBTYPE_RESPONSE_COMPLETE && + lastActivity && + lastActivity !== this.GOOD_LAST_ACTIVITY + ) { + // Make a trip through the event loop to make sure errors have a + // chance to be processed before we fall back to a generic error + // string. + Services.tm.dispatchToMainThread(() => { + channel.errorCheck(); + if (!channel.errorString) { + this.runChannelListener(channel, "onErrorOccurred", { + error: + this.activityErrorsMap.get(lastActivity) || + `NS_ERROR_NET_UNKNOWN_${lastActivity}`, + }); + } + }); + } else if ( + lastActivity !== this.GOOD_LAST_ACTIVITY && + lastActivity !== + nsIHttpActivityObserver.ACTIVITY_SUBTYPE_TRANSACTION_CLOSE + ) { + channel.lastActivity = activitySubtype; + } + }, + + getRequestData(channel, extraData) { + let originAttributes = channel.loadInfo?.originAttributes; + let data = { + requestId: String(channel.id), + url: channel.finalURL, + method: channel.method, + type: channel.type, + fromCache: channel.fromCache, + incognito: originAttributes?.privateBrowsingId > 0, + thirdParty: channel.thirdParty, + + originUrl: channel.originURL || undefined, + documentUrl: channel.documentURL || undefined, + + tabId: this.getBrowserData(channel).tabId, + frameId: channel.frameId, + parentFrameId: channel.parentFrameId, + + frameAncestors: channel.frameAncestors || undefined, + + ip: channel.remoteAddress, + + proxyInfo: channel.proxyInfo, + + serialize: serializeRequestData, + requestSize: channel.requestSize, + responseSize: channel.responseSize, + urlClassification: channel.urlClassification, + }; + + if (originAttributes) { + data.cookieStoreId = getCookieStoreIdForOriginAttributes( + originAttributes + ); + } + + return Object.assign(data, extraData); + }, + + handleEvent(event) { + let channel = event.currentTarget; + switch (event.type) { + case "error": + this.runChannelListener(channel, "onErrorOccurred", { + error: channel.errorString, + }); + break; + case "start": + this.runChannelListener(channel, "onResponseStarted"); + break; + case "stop": + this.runChannelListener(channel, "onCompleted"); + break; + } + }, + + STATUS_TYPES: new Set([ + "onHeadersReceived", + "onAuthRequired", + "onBeforeRedirect", + "onResponseStarted", + "onCompleted", + ]), + FILTER_TYPES: new Set([ + "onBeforeRequest", + "onBeforeSendHeaders", + "onSendHeaders", + "onHeadersReceived", + "onAuthRequired", + "onBeforeRedirect", + ]), + + getBrowserData(wrapper) { + let browserData = wrapper._browserData; + if (!browserData) { + if (wrapper.browserElement) { + browserData = tabTracker.getBrowserData(wrapper.browserElement); + } else { + browserData = { tabId: -1, windowId: -1 }; + } + wrapper._browserData = browserData; + } + return browserData; + }, + + runChannelListener(channel, kind, extraData = null) { + let handlerResults = []; + let requestHeaders; + let responseHeaders; + + try { + if (kind !== "onErrorOccurred" && channel.errorString) { + return; + } + + let registerFilter = this.FILTER_TYPES.has(kind); + let commonData = null; + let requestBody; + this.listeners[kind].forEach((opts, callback) => { + if (opts.filter.tabId !== null || opts.filter.windowId !== null) { + const { tabId, windowId } = this.getBrowserData(channel); + if ( + (opts.filter.tabId !== null && tabId != opts.filter.tabId) || + (opts.filter.windowId !== null && windowId != opts.filter.windowId) + ) { + return; + } + } + if (!channel.matches(opts.filter, opts.policy, extraData)) { + return; + } + + if (!commonData) { + commonData = this.getRequestData(channel, extraData); + if (this.STATUS_TYPES.has(kind)) { + commonData.statusCode = channel.statusCode; + commonData.statusLine = channel.statusLine; + } + } + let data = Object.create(commonData); + + if (registerFilter && opts.blocking && opts.policy) { + data.registerTraceableChannel = (policy, remoteTab) => { + // `channel` is a ChannelWrapper, which contains the actual + // underlying nsIChannel in `channel.channel`. For startup events + // that are held until the extension background page is started, + // it is possible that the underlying channel can be closed and + // cleaned up between the time the event occurred and the time + // we reach this code. + if (channel.channel) { + channel.registerTraceableChannel(policy, remoteTab); + } + }; + } + + if (opts.requestHeaders) { + requestHeaders = requestHeaders || new RequestHeaderChanger(channel); + data.requestHeaders = requestHeaders.toArray(); + } + + if (opts.responseHeaders) { + try { + responseHeaders = + responseHeaders || new ResponseHeaderChanger(channel); + data.responseHeaders = responseHeaders.toArray(); + } catch (e) { + /* headers may not be available on some redirects */ + } + } + + if (opts.requestBody && channel.canModify) { + requestBody = + requestBody || WebRequestUpload.createRequestBody(channel.channel); + data.requestBody = requestBody; + } + + try { + let result = callback(data); + + // isProxy is set during onAuth if the auth request is for a proxy. + // We allow handling proxy auth regardless of canModify. + if ( + (channel.canModify || data.isProxy) && + typeof result === "object" && + opts.blocking + ) { + handlerResults.push({ opts, result }); + } + } catch (e) { + Cu.reportError(e); + } + }); + } catch (e) { + Cu.reportError(e); + } + + return this.applyChanges( + kind, + channel, + handlerResults, + requestHeaders, + responseHeaders + ); + }, + + async applyChanges( + kind, + channel, + handlerResults, + requestHeaders, + responseHeaders + ) { + let shouldResume = !channel.suspended; + let suspenders = []; + + try { + for (let { opts, result } of handlerResults) { + if (isThenable(result)) { + suspenders.push(opts.addonId); + channel.suspend(); + try { + result = await result; + } catch (e) { + let error; + + if (e instanceof Error) { + error = e; + } else if (typeof e === "object" && e.message) { + error = new Error(e.message, e.fileName, e.lineNumber); + } + + Cu.reportError(error); + continue; + } + if (!result || typeof result !== "object") { + continue; + } + } + + if ( + kind === "onAuthRequired" && + result.authCredentials && + channel.authPromptCallback + ) { + channel.authPromptCallback(result.authCredentials); + } + + // We allow proxy auth to cancel or handle authCredentials regardless of + // canModify, but ensure we do nothing else. + if (!channel.canModify) { + continue; + } + + if (result.cancel) { + let text = ""; + if (Services.profiler?.IsActive()) { + text = + `${kind} ${channel.finalURL}` + + ` by ${suspenders.join(", ")} canceled`; + } + channel.resume(text); + channel.cancel( + Cr.NS_ERROR_ABORT, + Ci.nsILoadInfo.BLOCKING_REASON_EXTENSION_WEBREQUEST + ); + let { policy } = opts; + if (policy) { + let properties = channel.channel.QueryInterface( + Ci.nsIWritablePropertyBag + ); + properties.setProperty("cancelledByExtension", policy.id); + } + return; + } + + if (result.redirectUrl) { + try { + let text = ""; + if (Services.profiler?.IsActive()) { + text = + `${kind} ${channel.finalURL}` + + ` by ${suspenders.join(", ")}` + + ` redirected to ${result.redirectUrl}`; + } + channel.resume(text); + channel.redirectTo(Services.io.newURI(result.redirectUrl)); + + // Web Extensions using the WebRequest API are allowed + // to redirect a channel to a data: URI, hence we mark + // the channel to let the redirect blocker know. Please + // note that this marking needs to happen after the + // channel.redirectTo is called because the channel's + // RedirectTo() implementation explicitly drops the flag + // to avoid additional redirects not caused by the + // Web Extension. + channel.loadInfo.allowInsecureRedirectToDataURI = true; + + // To pass CORS checks, we pretend the current request's + // response allows the triggering origin to access. + let origin = channel.getRequestHeader("Origin"); + if (origin) { + channel.setResponseHeader("Access-Control-Allow-Origin", origin); + channel.setResponseHeader( + "Access-Control-Allow-Credentials", + "true" + ); + + // Compute an arbitrary 'Access-Control-Allow-Headers' + // for the internal Redirect + + let allowHeaders = channel + .getRequestHeaders() + .map(header => header.name) + .join(); + channel.setResponseHeader( + "Access-Control-Allow-Headers", + allowHeaders + ); + + channel.setResponseHeader( + "Access-Control-Allow-Methods", + channel.method + ); + } + + return; + } catch (e) { + Cu.reportError(e); + } + } + + if (result.upgradeToSecure && kind === "onBeforeRequest") { + try { + channel.upgradeToSecure(); + } catch (e) { + Cu.reportError(e); + } + } + + if (opts.requestHeaders && result.requestHeaders && requestHeaders) { + requestHeaders.applyChanges(result.requestHeaders, opts); + } + + if (opts.responseHeaders && result.responseHeaders && responseHeaders) { + responseHeaders.applyChanges(result.responseHeaders, opts); + } + } + + // If a listener did not cancel the request or provide credentials, we + // forward the auth request to the base handler. + if (kind === "onAuthRequired" && channel.authPromptForward) { + channel.authPromptForward(); + } + + if (kind === "onBeforeSendHeaders" && this.listeners.onSendHeaders.size) { + this.runChannelListener(channel, "onSendHeaders"); + } else if (kind !== "onErrorOccurred") { + channel.errorCheck(); + } + } catch (e) { + Cu.reportError(e); + } + + // Only resume the channel if it was suspended by this call. + if (shouldResume) { + let text = ""; + if (Services.profiler?.IsActive()) { + text = `${kind} ${channel.finalURL} by ${suspenders.join(", ")}`; + } + channel.resume(text); + } + }, + + shouldHookListener(listener, channel, extraData) { + if (listener.size == 0) { + return false; + } + + for (let opts of listener.values()) { + if (channel.matches(opts.filter, opts.policy, extraData)) { + return true; + } + } + return false; + }, + + examine(channel, topic, data) { + if (this.listeners.onHeadersReceived.size) { + this.runChannelListener(channel, "onHeadersReceived"); + } + + if ( + !channel.hasAuthRequestor && + this.shouldHookListener(this.listeners.onAuthRequired, channel, { + isProxy: true, + }) + ) { + channel.channel.notificationCallbacks = new AuthRequestor( + channel.channel, + this + ); + channel.hasAuthRequestor = true; + } + }, + + onChannelReplaced(oldChannel, newChannel) { + let channel = this.getWrapper(oldChannel); + + // We want originalURI, this will provide a moz-ext rather than jar or file + // uri on redirects. + if (this.hasRedirects) { + this.runChannelListener(channel, "onBeforeRedirect", { + redirectUrl: newChannel.originalURI.spec, + }); + } + channel.channel = newChannel; + }, +}; + +function HttpEvent(internalEvent, options) { + this.internalEvent = internalEvent; + this.options = options; +} + +HttpEvent.prototype = { + addListener(callback, filter = null, options = null, optionsObject = null) { + let opts = parseExtra(options, this.options, optionsObject); + opts.filter = parseFilter(filter); + HttpObserverManager.addListener(this.internalEvent, callback, opts); + }, + + removeListener(callback) { + HttpObserverManager.removeListener(this.internalEvent, callback); + }, +}; + +var onBeforeRequest = new HttpEvent("onBeforeRequest", [ + "blocking", + "requestBody", +]); +var onBeforeSendHeaders = new HttpEvent("onBeforeSendHeaders", [ + "requestHeaders", + "blocking", +]); +var onSendHeaders = new HttpEvent("onSendHeaders", ["requestHeaders"]); +var onHeadersReceived = new HttpEvent("onHeadersReceived", [ + "blocking", + "responseHeaders", +]); +var onAuthRequired = new HttpEvent("onAuthRequired", [ + "blocking", + "responseHeaders", +]); +var onBeforeRedirect = new HttpEvent("onBeforeRedirect", ["responseHeaders"]); +var onResponseStarted = new HttpEvent("onResponseStarted", ["responseHeaders"]); +var onCompleted = new HttpEvent("onCompleted", ["responseHeaders"]); +var onErrorOccurred = new HttpEvent("onErrorOccurred"); + +var WebRequest = { + onBeforeRequest, + onBeforeSendHeaders, + onSendHeaders, + onHeadersReceived, + onAuthRequired, + onBeforeRedirect, + onResponseStarted, + onCompleted, + onErrorOccurred, + + getSecurityInfo: details => { + let channel = ChannelWrapper.getRegisteredChannel( + details.id, + details.policy, + details.remoteTab + ); + if (channel) { + return SecurityInfo.getSecurityInfo(channel.channel, details.options); + } + }, +}; diff --git a/toolkit/components/extensions/webrequest/WebRequestService.cpp b/toolkit/components/extensions/webrequest/WebRequestService.cpp new file mode 100644 index 0000000000..891edf2515 --- /dev/null +++ b/toolkit/components/extensions/webrequest/WebRequestService.cpp @@ -0,0 +1,54 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "WebRequestService.h" + +#include "mozilla/Assertions.h" +#include "mozilla/ClearOnShutdown.h" + +using namespace mozilla; +using namespace mozilla::dom; +using namespace mozilla::extensions; + +static StaticRefPtr<WebRequestService> sWebRequestService; + +/* static */ WebRequestService& WebRequestService::GetSingleton() { + if (!sWebRequestService) { + sWebRequestService = new WebRequestService(); + ClearOnShutdown(&sWebRequestService); + } + return *sWebRequestService; +} + +UniquePtr<WebRequestChannelEntry> WebRequestService::RegisterChannel( + ChannelWrapper* aChannel) { + UniquePtr<ChannelEntry> entry(new ChannelEntry(aChannel)); + + auto key = mChannelEntries.LookupForAdd(entry->mChannelId); + MOZ_DIAGNOSTIC_ASSERT(!key); + key.OrInsert([&entry]() { return entry.get(); }); + + return entry; +} + +already_AddRefed<nsITraceableChannel> WebRequestService::GetTraceableChannel( + uint64_t aChannelId, nsAtom* aAddonId, ContentParent* aContentParent) { + if (auto entry = mChannelEntries.Get(aChannelId)) { + if (entry->mChannel) { + return entry->mChannel->GetTraceableChannel(aAddonId, aContentParent); + } + } + return nullptr; +} + +WebRequestChannelEntry::WebRequestChannelEntry(ChannelWrapper* aChannel) + : mChannelId(aChannel->Id()), mChannel(aChannel) {} + +WebRequestChannelEntry::~WebRequestChannelEntry() { + if (sWebRequestService) { + sWebRequestService->mChannelEntries.Remove(mChannelId); + } +} diff --git a/toolkit/components/extensions/webrequest/WebRequestService.h b/toolkit/components/extensions/webrequest/WebRequestService.h new file mode 100644 index 0000000000..0fec784b8d --- /dev/null +++ b/toolkit/components/extensions/webrequest/WebRequestService.h @@ -0,0 +1,77 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_WebRequestService_h +#define mozilla_WebRequestService_h + +#include "mozilla/LinkedList.h" +#include "mozilla/UniquePtr.h" + +#include "mozilla/extensions/ChannelWrapper.h" +#include "mozilla/extensions/WebExtensionPolicy.h" + +#include "nsHashKeys.h" +#include "nsDataHashtable.h" + +class nsAtom; +class nsIRemoteTab; +class nsITraceableChannel; + +namespace mozilla { +namespace dom { +class BrowserParent; +class ContentParent; +} // namespace dom + +namespace extensions { + +class WebRequestChannelEntry final { + public: + ~WebRequestChannelEntry(); + + private: + friend class WebRequestService; + + explicit WebRequestChannelEntry(ChannelWrapper* aChannel); + + uint64_t mChannelId; + WeakPtr<ChannelWrapper> mChannel; +}; + +class WebRequestService final { + public: + NS_INLINE_DECL_REFCOUNTING(WebRequestService) + + WebRequestService() = default; + + static already_AddRefed<WebRequestService> GetInstance() { + return do_AddRef(&GetSingleton()); + } + + static WebRequestService& GetSingleton(); + + using ChannelEntry = WebRequestChannelEntry; + + UniquePtr<ChannelEntry> RegisterChannel(ChannelWrapper* aChannel); + + void UnregisterTraceableChannel(uint64_t aChannelId); + + already_AddRefed<nsITraceableChannel> GetTraceableChannel( + uint64_t aChannelId, nsAtom* aAddonId, + dom::ContentParent* aContentParent); + + private: + ~WebRequestService() = default; + + friend ChannelEntry; + + nsDataHashtable<nsUint64HashKey, ChannelEntry*> mChannelEntries; +}; + +} // namespace extensions +} // namespace mozilla + +#endif // mozilla_WebRequestService_h diff --git a/toolkit/components/extensions/webrequest/WebRequestUpload.jsm b/toolkit/components/extensions/webrequest/WebRequestUpload.jsm new file mode 100644 index 0000000000..eb8a2bc6b5 --- /dev/null +++ b/toolkit/components/extensions/webrequest/WebRequestUpload.jsm @@ -0,0 +1,552 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +"use strict"; + +const EXPORTED_SYMBOLS = ["WebRequestUpload"]; + +/* exported WebRequestUpload */ + +const { XPCOMUtils } = ChromeUtils.import( + "resource://gre/modules/XPCOMUtils.jsm" +); + +const { ExtensionUtils } = ChromeUtils.import( + "resource://gre/modules/ExtensionUtils.jsm" +); + +const { DefaultMap } = ExtensionUtils; + +XPCOMUtils.defineLazyGlobalGetters(this, ["TextEncoder"]); + +XPCOMUtils.defineLazyServiceGetter( + this, + "mimeHeader", + "@mozilla.org/network/mime-hdrparam;1", + "nsIMIMEHeaderParam" +); + +const BinaryInputStream = Components.Constructor( + "@mozilla.org/binaryinputstream;1", + "nsIBinaryInputStream", + "setInputStream" +); +const ConverterInputStream = Components.Constructor( + "@mozilla.org/intl/converter-input-stream;1", + "nsIConverterInputStream", + "init" +); + +var WebRequestUpload; + +/** + * Parses the given raw header block, and stores the value of each + * lower-cased header name in the resulting map. + */ +class Headers extends Map { + constructor(headerText) { + super(); + + if (headerText) { + this.parseHeaders(headerText); + } + } + + parseHeaders(headerText) { + let lines = headerText.split("\r\n"); + + let lastHeader; + for (let line of lines) { + // The first empty line indicates the end of the header block. + if (line === "") { + return; + } + + // Lines starting with whitespace are appended to the previous + // header. + if (/^\s/.test(line)) { + if (lastHeader) { + let val = this.get(lastHeader); + this.set(lastHeader, `${val}\r\n${line}`); + } + continue; + } + + let match = /^(.*?)\s*:\s+(.*)/.exec(line); + if (match) { + lastHeader = match[1].toLowerCase(); + this.set(lastHeader, match[2]); + } + } + } + + /** + * If the given header exists, and contains the given parameter, + * returns the value of that parameter. + * + * @param {string} name + * The lower-cased header name. + * @param {string} paramName + * The name of the parameter to retrieve, or empty to retrieve + * the first (possibly unnamed) parameter. + * @returns {string | null} + */ + getParam(name, paramName) { + return Headers.getParam(this.get(name), paramName); + } + + /** + * If the given header value is non-null, and contains the given + * parameter, returns the value of that parameter. + * + * @param {string | null} header + * The text of the header from which to retrieve the param. + * @param {string} paramName + * The name of the parameter to retrieve, or empty to retrieve + * the first (possibly unnamed) parameter. + * @returns {string | null} + */ + static getParam(header, paramName) { + if (header) { + // The service expects this to be a raw byte string, so convert to + // UTF-8. + let bytes = new TextEncoder().encode(header); + let binHeader = String.fromCharCode(...bytes); + + return mimeHeader.getParameterHTTP(binHeader, paramName, null, false, {}); + } + + return null; + } +} + +/** + * Creates a new Object with a corresponding property for every + * key-value pair in the given Map. + * + * @param {Map} map + * The map to convert. + * @returns {Object} + */ +function mapToObject(map) { + let result = {}; + for (let [key, value] of map) { + result[key] = value; + } + return result; +} + +/** + * Rewinds the given seekable input stream to its beginning, and catches + * any resulting errors. + * + * @param {nsISeekableStream} stream + * The stream to rewind. + */ +function rewind(stream) { + // Do this outside the try-catch so that we throw if the stream is not + // actually seekable. + stream.QueryInterface(Ci.nsISeekableStream); + + try { + stream.seek(0, 0); + } catch (e) { + // It might be already closed, e.g. because of a previous error. + Cu.reportError(e); + } +} + +/** + * Iterates over all of the sub-streams that make up the given stream, + * or yields the stream itself if it is not a multi-part stream. + * + * @param {nsIIMultiplexInputStream|nsIStreamBufferAccess<nsIMultiplexInputStream>|nsIInputStream} outerStream + * The outer stream over which to iterate. + */ +function* getStreams(outerStream) { + // If this is a multi-part stream, we need to iterate over its sub-streams, + // rather than treating it as a simple input stream. Since it may be wrapped + // in a buffered input stream, unwrap it before we do any checks. + let unbuffered = outerStream; + if (outerStream instanceof Ci.nsIStreamBufferAccess) { + unbuffered = outerStream.unbufferedStream; + } + + if (unbuffered instanceof Ci.nsIMultiplexInputStream) { + let count = unbuffered.count; + for (let i = 0; i < count; i++) { + yield unbuffered.getStream(i); + } + } else { + yield outerStream; + } +} + +/** + * Parses the form data of the given stream as either multipart/form-data or + * x-www-form-urlencoded, and returns a map of its fields. + * + * @param {nsIInputStream} stream + * The input stream from which to parse the form data. + * @param {nsIHttpChannel} channel + * The channel to which the stream belongs. + * @param {boolean} [lenient = false] + * If true, the operation will succeed even if there are UTF-8 + * decoding errors. + * + * @returns {Map<string, Array<string>> | null} + */ +function parseFormData(stream, channel, lenient = false) { + const BUFFER_SIZE = 8192; + + let touchedStreams = new Set(); + let converterStreams = []; + + /** + * Creates a converter input stream from the given raw input stream, + * and adds it to the list of streams to be rewound at the end of + * parsing. + * + * Returns null if the given raw stream cannot be rewound. + * + * @param {nsIInputStream} stream + * The base stream from which to create a converter. + * @returns {ConverterInputStream | null} + */ + function createTextStream(stream) { + if (!(stream instanceof Ci.nsISeekableStream)) { + return null; + } + + touchedStreams.add(stream); + let converterStream = ConverterInputStream( + stream, + "UTF-8", + 0, + lenient ? Ci.nsIConverterInputStream.DEFAULT_REPLACEMENT_CHARACTER : 0 + ); + converterStreams.push(converterStream); + return converterStream; + } + + /** + * Reads a string of no more than the given length from the given text + * stream. + * + * @param {ConverterInputStream} stream + * The stream to read. + * @param {integer} [length = BUFFER_SIZE] + * The maximum length of data to read. + * @returns {string} + */ + function readString(stream, length = BUFFER_SIZE) { + let data = {}; + stream.readString(length, data); + return data.value; + } + + /** + * Iterates over all of the sub-streams of the given (possibly multi-part) + * input stream, and yields a ConverterInputStream for each + * nsIStringInputStream among them. + * + * @param {nsIInputStream|nsIMultiplexInputStream} outerStream + * The multi-part stream over which to iterate. + */ + function* getTextStreams(outerStream) { + for (let stream of getStreams(outerStream)) { + if (stream instanceof Ci.nsIStringInputStream) { + touchedStreams.add(outerStream); + yield createTextStream(stream); + } + } + } + + /** + * Iterates over all of the string streams of the given (possibly + * multi-part) input stream, and yields all of the available data in each as + * chunked strings, each no more than BUFFER_SIZE in length. + * + * @param {nsIInputStream|nsIMultiplexInputStream} outerStream + * The multi-part stream over which to iterate. + */ + function* readAllStrings(outerStream) { + for (let textStream of getTextStreams(outerStream)) { + let str; + while ((str = readString(textStream))) { + yield str; + } + } + } + + /** + * Iterates over the text contents of all of the string streams in the given + * (possibly multi-part) input stream, splits them at occurrences of the + * given boundary string, and yields each part. + * + * @param {nsIInputStream|nsIMultiplexInputStream} stream + * The multi-part stream over which to iterate. + * @param {string} boundary + * The boundary at which to split the parts. + * @param {string} [tail = ""] + * Any initial data to prepend to the start of the stream data. + */ + function* getParts(stream, boundary, tail = "") { + for (let chunk of readAllStrings(stream)) { + chunk = tail + chunk; + + let parts = chunk.split(boundary); + tail = parts.pop(); + + yield* parts; + } + + if (tail) { + yield tail; + } + } + + /** + * Parses the given stream as multipart/form-data and returns a map of its fields. + * + * @param {nsIMultiplexInputStream|nsIInputStream} stream + * The (possibly multi-part) stream to parse. + * @param {string} boundary + * The boundary at which to split the parts. + * @returns {Map<string, Array<string>>} + */ + function parseMultiPart(stream, boundary) { + let formData = new DefaultMap(() => []); + + for (let part of getParts(stream, boundary, "\r\n")) { + if (part === "") { + // The first part will always be empty. + continue; + } + if (part === "--\r\n") { + // This indicates the end of the stream. + break; + } + + let end = part.indexOf("\r\n\r\n"); + + // All valid parts must begin with \r\n, and we can't process form + // fields without any header block. + if (!part.startsWith("\r\n") || end <= 0) { + throw new Error("Invalid MIME stream"); + } + + let content = part.slice(end + 4); + let headerText = part.slice(2, end); + let headers = new Headers(headerText); + + let name = headers.getParam("content-disposition", "name"); + if ( + !name || + headers.getParam("content-disposition", "") !== "form-data" + ) { + throw new Error( + "Invalid MIME stream: No valid Content-Disposition header" + ); + } + + if (headers.has("content-type")) { + // For file upload fields, we return the filename, rather than the + // file data. + let filename = headers.getParam("content-disposition", "filename"); + content = filename || ""; + } + formData.get(name).push(content); + } + + return formData; + } + + /** + * Parses the given stream as x-www-form-urlencoded, and returns a map of its fields. + * + * @param {nsIInputStream} stream + * The stream to parse. + * @returns {Map<string, Array<string>>} + */ + function parseUrlEncoded(stream) { + let formData = new DefaultMap(() => []); + + for (let part of getParts(stream, "&")) { + let [name, value] = part + .replace(/\+/g, " ") + .split("=") + .map(decodeURIComponent); + formData.get(name).push(value); + } + + return formData; + } + + try { + if (stream instanceof Ci.nsIMIMEInputStream && stream.data) { + stream = stream.data; + } + + channel.QueryInterface(Ci.nsIHttpChannel); + let contentType = channel.getRequestHeader("Content-Type"); + + switch (Headers.getParam(contentType, "")) { + case "multipart/form-data": + let boundary = Headers.getParam(contentType, "boundary"); + return parseMultiPart(stream, `\r\n--${boundary}`); + + case "application/x-www-form-urlencoded": + return parseUrlEncoded(stream); + } + } finally { + for (let stream of touchedStreams) { + rewind(stream); + } + for (let converterStream of converterStreams) { + // Release the reference to the underlying input stream, to prevent the + // destructor of nsConverterInputStream from closing the stream, which + // would cause uploads to break. + converterStream.init(null, null, 0, 0); + } + } + + return null; +} + +/** + * Parses the form data of the given stream as either multipart/form-data or + * x-www-form-urlencoded, and returns a map of its fields. + * + * Returns null if the stream is not seekable. + * + * @param {nsIMultiplexInputStream|nsIInputStream} stream + * The (possibly multi-part) stream from which to create the form data. + * @param {nsIChannel} channel + * The channel to which the stream belongs. + * @param {boolean} [lenient = false] + * If true, the operation will succeed even if there are UTF-8 + * decoding errors. + * @returns {Map<string, Array<string>> | null} + */ +function createFormData(stream, channel, lenient) { + if (!(stream instanceof Ci.nsISeekableStream)) { + return null; + } + + try { + let formData = parseFormData(stream, channel, lenient); + if (formData) { + return mapToObject(formData); + } + } catch (e) { + Cu.reportError(e); + } finally { + rewind(stream); + } + return null; +} + +/** + * Iterates over all of the sub-streams of the given (possibly multi-part) + * input stream, and yields an object containing the data for each chunk, up + * to a total of `maxRead` bytes. + * + * @param {nsIMultiplexInputStream|nsIInputStream} outerStream + * The stream for which to return data. + * @param {integer} [maxRead = WebRequestUpload.MAX_RAW_BYTES] + * The maximum total bytes to read. + */ +function* getRawDataChunked( + outerStream, + maxRead = WebRequestUpload.MAX_RAW_BYTES +) { + for (let stream of getStreams(outerStream)) { + // We need to inspect the stream to make sure it's not a file input + // stream. If it's wrapped in a buffered input stream, unwrap it first, + // so we can inspect the inner stream directly. + let unbuffered = stream; + if (stream instanceof Ci.nsIStreamBufferAccess) { + unbuffered = stream.unbufferedStream; + } + + // For file fields, we return an object containing the full path of + // the file, rather than its data. + if ( + unbuffered instanceof Ci.nsIFileInputStream || + unbuffered instanceof Ci.mozIRemoteLazyInputStream + ) { + // But this is not actually supported yet. + yield { file: "<file>" }; + continue; + } + + try { + let binaryStream = BinaryInputStream(stream); + let available; + while ((available = binaryStream.available())) { + let buffer = new ArrayBuffer(Math.min(maxRead, available)); + binaryStream.readArrayBuffer(buffer.byteLength, buffer); + + maxRead -= buffer.byteLength; + + let chunk = { bytes: buffer }; + + if (buffer.byteLength < available) { + chunk.truncated = true; + chunk.originalSize = available; + } + + yield chunk; + + if (maxRead <= 0) { + return; + } + } + } finally { + rewind(stream); + } + } +} + +WebRequestUpload = { + createRequestBody(channel) { + if (!(channel instanceof Ci.nsIUploadChannel) || !channel.uploadStream) { + return null; + } + + if ( + channel instanceof Ci.nsIUploadChannel2 && + channel.uploadStreamHasHeaders + ) { + return { error: "Upload streams with headers are unsupported" }; + } + + try { + let stream = channel.uploadStream; + + let formData = createFormData(stream, channel); + if (formData) { + return { formData }; + } + + // If we failed to parse the stream as form data, return it as a + // sequence of raw data chunks, along with a leniently-parsed form + // data object, which ignores encoding errors. + return { + raw: Array.from(getRawDataChunked(stream)), + lenientFormData: createFormData(stream, channel, true), + }; + } catch (e) { + Cu.reportError(e); + return { error: e.message || String(e) }; + } + }, +}; + +XPCOMUtils.defineLazyPreferenceGetter( + WebRequestUpload, + "MAX_RAW_BYTES", + "webextensions.webRequest.requestBodyMaxRawBytes" +); diff --git a/toolkit/components/extensions/webrequest/moz.build b/toolkit/components/extensions/webrequest/moz.build new file mode 100644 index 0000000000..d80b1333a9 --- /dev/null +++ b/toolkit/components/extensions/webrequest/moz.build @@ -0,0 +1,54 @@ +# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*- +# vim: set filetype=python: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +EXTRA_JS_MODULES += [ + "SecurityInfo.jsm", + "WebRequest.jsm", + "WebRequestUpload.jsm", +] + +UNIFIED_SOURCES += [ + "ChannelWrapper.cpp", + "StreamFilter.cpp", + "StreamFilterChild.cpp", + "StreamFilterEvents.cpp", + "StreamFilterParent.cpp", + "WebRequestService.cpp", +] + +IPDL_SOURCES += [ + "PStreamFilter.ipdl", +] + +EXPORTS.mozilla += [ + "WebRequestService.h", +] + +EXPORTS.mozilla.extensions += [ + "ChannelWrapper.h", + "StreamFilter.h", + "StreamFilterBase.h", + "StreamFilterChild.h", + "StreamFilterEvents.h", + "StreamFilterParent.h", +] + +LOCAL_INCLUDES += [ + "/caps", +] + +include("/ipc/chromium/chromium-config.mozbuild") + +LOCAL_INCLUDES += [ + # For nsHttpChannel.h + "/netwerk/base", + "/netwerk/protocol/http", +] + +FINAL_LIBRARY = "xul" + +with Files("**"): + BUG_COMPONENT = ("WebExtensions", "Request Handling") |