summaryrefslogtreecommitdiffstats
path: root/toolkit/components/url-classifier/tests/gtest/TestFindFullHash.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'toolkit/components/url-classifier/tests/gtest/TestFindFullHash.cpp')
-rw-r--r--toolkit/components/url-classifier/tests/gtest/TestFindFullHash.cpp216
1 files changed, 216 insertions, 0 deletions
diff --git a/toolkit/components/url-classifier/tests/gtest/TestFindFullHash.cpp b/toolkit/components/url-classifier/tests/gtest/TestFindFullHash.cpp
new file mode 100644
index 0000000000..f0f5785017
--- /dev/null
+++ b/toolkit/components/url-classifier/tests/gtest/TestFindFullHash.cpp
@@ -0,0 +1,216 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "mozilla/Base64.h"
+#include "nsUrlClassifierUtils.h"
+#include "safebrowsing.pb.h"
+
+#include "Common.h"
+
+using namespace mozilla;
+
+template <size_t N>
+static void ToBase64EncodedStringArray(nsCString (&aInput)[N],
+ nsTArray<nsCString>& aEncodedArray) {
+ for (size_t i = 0; i < N; i++) {
+ nsCString encoded;
+ nsresult rv = mozilla::Base64Encode(aInput[i], encoded);
+ NS_ENSURE_SUCCESS_VOID(rv);
+ aEncodedArray.AppendElement(std::move(encoded));
+ }
+}
+
+TEST(UrlClassifierFindFullHash, Request)
+{
+ nsUrlClassifierUtils* urlUtil = nsUrlClassifierUtils::GetInstance();
+
+ nsTArray<nsCString> listNames;
+ listNames.AppendElement("moztest-phish-proto");
+ listNames.AppendElement("moztest-unwanted-proto");
+
+ nsCString listStates[] = {nsCString("sta\x00te1", 7),
+ nsCString("sta\x00te2", 7)};
+ nsTArray<nsCString> listStateArray;
+ ToBase64EncodedStringArray(listStates, listStateArray);
+
+ nsCString prefixes[] = {nsCString("\x00\x00\x00\x01", 4),
+ nsCString("\x00\x00\x00\x00\x01", 5),
+ nsCString("\x00\xFF\x00\x01", 4),
+ nsCString("\x00\xFF\x00\x01\x11\x23\xAA\xBC", 8),
+ nsCString("\x00\x00\x00\x01\x00\x01\x98", 7)};
+ nsTArray<nsCString> prefixArray;
+ ToBase64EncodedStringArray(prefixes, prefixArray);
+
+ nsCString requestBase64;
+ nsresult rv;
+ rv = urlUtil->MakeFindFullHashRequestV4(listNames, listStateArray,
+ prefixArray, requestBase64);
+ ASSERT_NS_SUCCEEDED(rv);
+
+ // Base64 URL decode first.
+ FallibleTArray<uint8_t> requestBinary;
+ rv = Base64URLDecode(requestBase64, Base64URLDecodePaddingPolicy::Require,
+ requestBinary);
+ ASSERT_NS_SUCCEEDED(rv);
+
+ // Parse the FindFullHash binary and compare with the expected values.
+ FindFullHashesRequest r;
+ ASSERT_TRUE(r.ParseFromArray(&requestBinary[0], requestBinary.Length()));
+
+ // Compare client states.
+ ASSERT_EQ(r.client_states_size(), (int)ArrayLength(listStates));
+ for (int i = 0; i < r.client_states_size(); i++) {
+ auto s = r.client_states(i);
+ ASSERT_TRUE(listStates[i].Equals(nsCString(s.c_str(), s.size())));
+ }
+
+ auto threatInfo = r.threat_info();
+
+ // Compare threat types.
+ ASSERT_EQ(threatInfo.threat_types_size(), (int)ArrayLength(listStates));
+ for (int i = 0; i < threatInfo.threat_types_size(); i++) {
+ uint32_t expectedThreatType;
+ rv =
+ urlUtil->ConvertListNameToThreatType(listNames[i], &expectedThreatType);
+ ASSERT_NS_SUCCEEDED(rv);
+ ASSERT_EQ(threatInfo.threat_types(i), (int)expectedThreatType);
+ }
+
+ // Compare prefixes.
+ ASSERT_EQ(threatInfo.threat_entries_size(), (int)ArrayLength(prefixes));
+ for (int i = 0; i < threatInfo.threat_entries_size(); i++) {
+ auto p = threatInfo.threat_entries(i).hash();
+ ASSERT_TRUE(prefixes[i].Equals(nsCString(p.c_str(), p.size())));
+ }
+}
+
+/////////////////////////////////////////////////////////////
+// Following is to test parsing the gethash response.
+
+namespace {
+
+// safebrowsing::Duration manipulation.
+struct MyDuration {
+ uint32_t mSecs;
+ uint32_t mNanos;
+};
+void PopulateDuration(Duration& aDest, const MyDuration& aSrc) {
+ aDest.set_seconds(aSrc.mSecs);
+ aDest.set_nanos(aSrc.mNanos);
+}
+
+// The expected match data.
+static MyDuration EXPECTED_MIN_WAIT_DURATION = {12, 10};
+static MyDuration EXPECTED_NEG_CACHE_DURATION = {120, 9};
+static const struct ExpectedMatch {
+ nsCString mCompleteHash;
+ ThreatType mThreatType;
+ MyDuration mPerHashCacheDuration;
+} EXPECTED_MATCH[] = {
+ {nsCString("01234567890123456789012345678901"),
+ SOCIAL_ENGINEERING_PUBLIC,
+ {8, 500}},
+ {nsCString("12345678901234567890123456789012"),
+ SOCIAL_ENGINEERING_PUBLIC,
+ {7, 100}},
+ {nsCString("23456789012345678901234567890123"),
+ SOCIAL_ENGINEERING_PUBLIC,
+ {1, 20}},
+};
+
+class MyParseCallback final : public nsIUrlClassifierParseFindFullHashCallback {
+ public:
+ NS_DECL_ISUPPORTS
+
+ explicit MyParseCallback(uint32_t& aCallbackCount)
+ : mCallbackCount(aCallbackCount) {}
+
+ NS_IMETHOD
+ OnCompleteHashFound(const nsACString& aCompleteHash,
+ const nsACString& aTableNames,
+ uint32_t aPerHashCacheDuration) override {
+ Verify(aCompleteHash, aTableNames, aPerHashCacheDuration);
+
+ return NS_OK;
+ }
+
+ NS_IMETHOD
+ OnResponseParsed(uint32_t aMinWaitDuration,
+ uint32_t aNegCacheDuration) override {
+ VerifyDuration(aMinWaitDuration / 1000, EXPECTED_MIN_WAIT_DURATION);
+ VerifyDuration(aNegCacheDuration, EXPECTED_NEG_CACHE_DURATION);
+
+ return NS_OK;
+ }
+
+ private:
+ void Verify(const nsACString& aCompleteHash, const nsACString& aTableNames,
+ uint32_t aPerHashCacheDuration) {
+ auto expected = EXPECTED_MATCH[mCallbackCount];
+
+ ASSERT_TRUE(aCompleteHash.Equals(expected.mCompleteHash));
+
+ // Verify aTableNames
+ nsUrlClassifierUtils* urlUtil = nsUrlClassifierUtils::GetInstance();
+
+ nsCString tableNames;
+ nsresult rv =
+ urlUtil->ConvertThreatTypeToListNames(expected.mThreatType, tableNames);
+ ASSERT_NS_SUCCEEDED(rv);
+ ASSERT_TRUE(aTableNames.Equals(tableNames));
+
+ VerifyDuration(aPerHashCacheDuration, expected.mPerHashCacheDuration);
+
+ mCallbackCount++;
+ }
+
+ void VerifyDuration(uint32_t aToVerify, const MyDuration& aExpected) {
+ ASSERT_TRUE(aToVerify == aExpected.mSecs);
+ }
+
+ ~MyParseCallback() = default;
+
+ uint32_t& mCallbackCount;
+};
+
+NS_IMPL_ISUPPORTS(MyParseCallback, nsIUrlClassifierParseFindFullHashCallback)
+
+} // end of unnamed namespace.
+
+TEST(UrlClassifierFindFullHash, ParseRequest)
+{
+ // Build response.
+ FindFullHashesResponse r;
+
+ // Init response-wise durations.
+ auto minWaitDuration = r.mutable_minimum_wait_duration();
+ PopulateDuration(*minWaitDuration, EXPECTED_MIN_WAIT_DURATION);
+ auto negCacheDuration = r.mutable_negative_cache_duration();
+ PopulateDuration(*negCacheDuration, EXPECTED_NEG_CACHE_DURATION);
+
+ // Init matches.
+ for (uint32_t i = 0; i < ArrayLength(EXPECTED_MATCH); i++) {
+ auto expected = EXPECTED_MATCH[i];
+ auto match = r.mutable_matches()->Add();
+ match->set_threat_type(expected.mThreatType);
+ match->mutable_threat()->set_hash(expected.mCompleteHash.BeginReading(),
+ expected.mCompleteHash.Length());
+ auto perHashCacheDuration = match->mutable_cache_duration();
+ PopulateDuration(*perHashCacheDuration, expected.mPerHashCacheDuration);
+ }
+ std::string s;
+ r.SerializeToString(&s);
+
+ uint32_t callbackCount = 0;
+ nsCOMPtr<nsIUrlClassifierParseFindFullHashCallback> callback =
+ new MyParseCallback(callbackCount);
+
+ nsUrlClassifierUtils* urlUtil = nsUrlClassifierUtils::GetInstance();
+ nsresult rv = urlUtil->ParseFindFullHashResponseV4(
+ nsCString(s.c_str(), s.size()), callback);
+ NS_ENSURE_SUCCESS_VOID(rv);
+
+ ASSERT_EQ(callbackCount, ArrayLength(EXPECTED_MATCH));
+}