diff options
Diffstat (limited to 'third_party/content_analysis_sdk/demo/client.cc')
-rw-r--r-- | third_party/content_analysis_sdk/demo/client.cc | 411 |
1 files changed, 411 insertions, 0 deletions
diff --git a/third_party/content_analysis_sdk/demo/client.cc b/third_party/content_analysis_sdk/demo/client.cc new file mode 100644 index 0000000000..5e47fca57f --- /dev/null +++ b/third_party/content_analysis_sdk/demo/client.cc @@ -0,0 +1,411 @@ +// Copyright 2022 The Chromium Authors. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <time.h> + +#include <iostream> +#include <memory> +#include <mutex> +#include <sstream> +#include <string> +#include <thread> +#include <vector> + +#include "content_analysis/sdk/analysis_client.h" +#include "demo/atomic_output.h" + +using content_analysis::sdk::Client; +using content_analysis::sdk::ContentAnalysisRequest; +using content_analysis::sdk::ContentAnalysisResponse; +using content_analysis::sdk::ContentAnalysisAcknowledgement; + +// Different paths are used depending on whether this agent should run as a +// use specific agent or not. These values are chosen to match the test +// values in chrome browser. +constexpr char kPathUser[] = "path_user"; +constexpr char kPathSystem[] = "brcm_chrm_cas"; + +// Global app config. +std::string path = kPathSystem; +bool user_specific = false; +bool group = false; +std::unique_ptr<Client> client; + +// Paramters used to build the request. +content_analysis::sdk::AnalysisConnector connector = + content_analysis::sdk::FILE_ATTACHED; +time_t request_token_number = time(nullptr); +std::string request_token; +std::string tag = "dlp"; +bool threaded = false; +std::string digest = "sha256-123456"; +std::string url = "https://upload.example.com"; +std::string email = "me@example.com"; +std::string machine_user = "DOMAIN\\me"; +std::vector<std::string> datas; + +// When grouping, remember the tokens of all requests/responses in order to +// acknowledge them all with the same final action. +// +// This global state. It may be access from multiple thread so must be +// accessed from a critical section. +std::mutex global_mutex; +ContentAnalysisAcknowledgement::FinalAction global_final_action = + ContentAnalysisAcknowledgement::ALLOW; +std::vector<std::string> request_tokens; + +// Command line parameters. +constexpr const char* kArgConnector = "--connector="; +constexpr const char* kArgDigest = "--digest="; +constexpr const char* kArgEmail = "--email="; +constexpr const char* kArgGroup = "--group"; +constexpr const char* kArgMachineUser = "--machine-user="; +constexpr const char* kArgPath = "--path="; +constexpr const char* kArgRequestToken = "--request-token="; +constexpr const char* kArgTag = "--tag="; +constexpr const char* kArgThreaded = "--threaded"; +constexpr const char* kArgUrl = "--url="; +constexpr const char* kArgUserSpecific = "--user"; +constexpr const char* kArgHelp = "--help"; + +bool ParseCommandLine(int argc, char* argv[]) { + for (int i = 1; i < argc; ++i) { + const std::string arg = argv[i]; + if (arg.find(kArgConnector) == 0) { + std::string connector_str = arg.substr(strlen(kArgConnector)); + if (connector_str == "download") { + connector = content_analysis::sdk::FILE_DOWNLOADED; + } else if (connector_str == "attach") { + connector = content_analysis::sdk::FILE_ATTACHED; + } else if (connector_str == "bulk-data-entry") { + connector = content_analysis::sdk::BULK_DATA_ENTRY; + } else if (connector_str == "print") { + connector = content_analysis::sdk::PRINT; + } else if (connector_str == "file-transfer") { + connector = content_analysis::sdk::FILE_TRANSFER; + } else { + std::cout << "[Demo] Incorrect command line arg: " << arg << std::endl; + return false; + } + } else if (arg.find(kArgRequestToken) == 0) { + request_token = arg.substr(strlen(kArgRequestToken)); + } else if (arg.find(kArgTag) == 0) { + tag = arg.substr(strlen(kArgTag)); + } else if (arg.find(kArgThreaded) == 0) { + threaded = true; + } else if (arg.find(kArgDigest) == 0) { + digest = arg.substr(strlen(kArgDigest)); + } else if (arg.find(kArgUrl) == 0) { + url = arg.substr(strlen(kArgUrl)); + } else if (arg.find(kArgMachineUser) == 0) { + machine_user = arg.substr(strlen(kArgMachineUser)); + } else if (arg.find(kArgEmail) == 0) { + email = arg.substr(strlen(kArgEmail)); + } else if (arg.find(kArgPath) == 0) { + path = arg.substr(strlen(kArgPath)); + } else if (arg.find(kArgUserSpecific) == 0) { + // If kArgPath was already used, abort. + if (path != kPathSystem) { + std::cout << std::endl << "ERROR: use --path=<path> after --user"; + return false; + } + path = kPathUser; + user_specific = true; + } else if (arg.find(kArgGroup) == 0) { + group = true; + } else if (arg.find(kArgHelp) == 0) { + return false; + } else { + datas.push_back(arg); + } + } + + return true; +} + +void PrintHelp() { + std::cout + << std::endl << std::endl + << "Usage: client [OPTIONS] [@]content_or_file ..." << std::endl + << "A simple client to send content analysis requests to a running agent." << std::endl + << "Without @ the content to analyze is the argument itself." << std::endl + << "Otherwise the content is read from a file called 'content_or_file'." << std::endl + << "Multiple [@]content_or_file arguments may be specified, each generates one request." << std::endl + << std::endl << "Options:" << std::endl + << kArgConnector << "<connector> : one of 'download', 'attach' (default), 'bulk-data-entry', 'print', or 'file-transfer'" << std::endl + << kArgRequestToken << "<unique-token> : defaults to 'req-<number>' which auto increments" << std::endl + << kArgTag << "<tag> : defaults to 'dlp'" << std::endl + << kArgThreaded << " : handled multiple requests using threads" << std::endl + << kArgUrl << "<url> : defaults to 'https://upload.example.com'" << std::endl + << kArgMachineUser << "<machine-user> : defaults to 'DOMAIN\\me'" << std::endl + << kArgEmail << "<email> : defaults to 'me@example.com'" << std::endl + << kArgPath << " <path> : Used the specified path instead of default. Must come after --user." << std::endl + << kArgUserSpecific << " : Connects to an OS user specific agent" << std::endl + << kArgDigest << "<digest> : defaults to 'sha256-123456'" << std::endl + << kArgGroup << " : Generate the same final action for all requests" << std::endl + << kArgHelp << " : prints this help message" << std::endl; +} + +std::string GenerateRequestToken() { + std::stringstream stm; + stm << "req-" << request_token_number++; + return stm.str(); +} + +ContentAnalysisRequest BuildRequest(const std::string& data) { + std::string filepath; + std::string filename; + if (data[0] == '@') { + filepath = data.substr(1); + filename = filepath.substr(filepath.find_last_of("/\\") + 1); + } + + ContentAnalysisRequest request; + + // Set request to expire 5 minutes into the future. + request.set_expires_at(time(nullptr) + 5 * 60); + request.set_analysis_connector(connector); + request.set_request_token(!request_token.empty() + ? request_token : GenerateRequestToken()); + *request.add_tags() = tag; + + auto request_data = request.mutable_request_data(); + request_data->set_url(url); + request_data->set_email(email); + request_data->set_digest(digest); + if (!filename.empty()) { + request_data->set_filename(filename); + } + + auto client_metadata = request.mutable_client_metadata(); + auto browser = client_metadata->mutable_browser(); + browser->set_machine_user(machine_user); + + if (!filepath.empty()) { + request.set_file_path(filepath); + } else if (!data.empty()) { + request.set_text_content(data); + } else { + std::cout << "[Demo] Specify text content or a file path." << std::endl; + PrintHelp(); + exit(1); + } + + return request; +} + +// Gets the most severe action within the result. +ContentAnalysisResponse::Result::TriggeredRule::Action +GetActionFromResult(const ContentAnalysisResponse::Result& result) { + auto action = + ContentAnalysisResponse::Result::TriggeredRule::ACTION_UNSPECIFIED; + for (auto rule : result.triggered_rules()) { + if (rule.has_action() && rule.action() > action) + action = rule.action(); + } + return action; +} + +// Gets the most severe action within all the the results of a response. +ContentAnalysisResponse::Result::TriggeredRule::Action +GetActionFromResponse(const ContentAnalysisResponse& response) { + auto action = + ContentAnalysisResponse::Result::TriggeredRule::ACTION_UNSPECIFIED; + for (auto result : response.results()) { + auto action2 = GetActionFromResult(result); + if (action2 > action) + action = action2; + } + return action; +} + +void DumpResponse( + std::stringstream& stream, + const ContentAnalysisResponse& response) { + for (auto result : response.results()) { + auto tag = result.has_tag() ? result.tag() : "<no-tag>"; + + auto status = result.has_status() + ? result.status() + : ContentAnalysisResponse::Result::STATUS_UNKNOWN; + std::string status_str; + switch (status) { + case ContentAnalysisResponse::Result::STATUS_UNKNOWN: + status_str = "Unknown"; + break; + case ContentAnalysisResponse::Result::SUCCESS: + status_str = "Success"; + break; + case ContentAnalysisResponse::Result::FAILURE: + status_str = "Failure"; + break; + default: + status_str = "<Uknown>"; + break; + } + + auto action = GetActionFromResult(result); + std::string action_str; + switch (action) { + case ContentAnalysisResponse::Result::TriggeredRule::ACTION_UNSPECIFIED: + action_str = "allowed"; + break; + case ContentAnalysisResponse::Result::TriggeredRule::REPORT_ONLY: + action_str = "reported only"; + break; + case ContentAnalysisResponse::Result::TriggeredRule::WARN: + action_str = "warned"; + break; + case ContentAnalysisResponse::Result::TriggeredRule::BLOCK: + action_str = "blocked"; + break; + } + + time_t now = time(nullptr); + stream << "[Demo] Request " << response.request_token() << " is " << action_str + << " after " << tag + << " analysis, status=" << status_str + << " at " << ctime(&now); + } +} + +ContentAnalysisAcknowledgement BuildAcknowledgement( + const std::string& request_token, + ContentAnalysisAcknowledgement::FinalAction final_action) { + ContentAnalysisAcknowledgement ack; + ack.set_request_token(request_token); + ack.set_status(ContentAnalysisAcknowledgement::SUCCESS); + ack.set_final_action(final_action); + return ack; +} + +void HandleRequest(const ContentAnalysisRequest& request) { + AtomicCout aout; + ContentAnalysisResponse response; + int err = client->Send(request, &response); + if (err != 0) { + aout.stream() << "[Demo] Error sending request " << request.request_token() + << std::endl; + } else if (response.results_size() == 0) { + aout.stream() << "[Demo] Response " << request.request_token() << " is missing a result" + << std::endl; + } else { + DumpResponse(aout.stream(), response); + + auto final_action = ContentAnalysisAcknowledgement::ALLOW; + switch (GetActionFromResponse(response)) { + case ContentAnalysisResponse::Result::TriggeredRule::ACTION_UNSPECIFIED: + break; + case ContentAnalysisResponse::Result::TriggeredRule::REPORT_ONLY: + final_action = ContentAnalysisAcknowledgement::REPORT_ONLY; + break; + case ContentAnalysisResponse::Result::TriggeredRule::WARN: + final_action = ContentAnalysisAcknowledgement::WARN; + break; + case ContentAnalysisResponse::Result::TriggeredRule::BLOCK: + final_action = ContentAnalysisAcknowledgement::BLOCK; + break; + } + + // If grouping, remember the request's token in order to ack the response + // later. + if (group) { + std::unique_lock<std::mutex> lock(global_mutex); + request_tokens.push_back(request.request_token()); + if (final_action > global_final_action) + global_final_action = final_action; + } else { + int err = client->Acknowledge( + BuildAcknowledgement(request.request_token(), final_action)); + if (err != 0) { + aout.stream() << "[Demo] Error sending ack " << request.request_token() + << std::endl; + } + } + } +} + +void ProcessRequest(size_t i) { + auto request = BuildRequest(datas[i]); + + { + AtomicCout aout; + aout.stream() << "[Demo] Sending request " << request.request_token() << std::endl; + } + + HandleRequest(request); +} + +int main(int argc, char* argv[]) { + if (!ParseCommandLine(argc, argv)) { + PrintHelp(); + return 1; + } + + // Each client uses a unique name to identify itself with Google Chrome. + client = Client::Create({path, user_specific}); + if (!client) { + std::cout << "[Demo] Error starting client" << std::endl; + return 1; + }; + + auto info = client->GetAgentInfo(); + std::cout << "Agent pid=" << info.pid + << " path=" << info.binary_path << std::endl; + + if (threaded) { + std::vector<std::unique_ptr<std::thread>> threads; + for (int i = 0; i < datas.size(); ++i) { + AtomicCout aout; + aout.stream() << "Start thread " << i << std::endl; + threads.emplace_back(std::make_unique<std::thread>(ProcessRequest, i)); + } + + // Make sure all threads have terminated. + for (auto& thread : threads) { + thread->join(); + } + } + else { + for (size_t i = 0; i < datas.size(); ++i) { + ProcessRequest(i); + } + } + // It's safe to access global state beyond this point without locking since + // all no more responses will be touching them. + + if (group) { + std::cout << std::endl; + std::cout << "[Demo] Final action for all requests is "; + switch (global_final_action) { + // Google Chrome fails open, so if no action is specified that is the same + // as ALLOW. + case ContentAnalysisAcknowledgement::ACTION_UNSPECIFIED: + case ContentAnalysisAcknowledgement::ALLOW: + std::cout << "allowed"; + break; + case ContentAnalysisAcknowledgement::REPORT_ONLY: + std::cout << "reported only"; + break; + case ContentAnalysisAcknowledgement::WARN: + std::cout << "warned"; + break; + case ContentAnalysisAcknowledgement::BLOCK: + std::cout << "blocked"; + break; + } + std::cout << std::endl << std::endl; + + for (auto token : request_tokens) { + std::cout << "[Demo] Sending group Ack" << std::endl; + int err = client->Acknowledge( + BuildAcknowledgement(token, global_final_action)); + if (err != 0) { + std::cout << "[Demo] Error sending ack for " << token << std::endl; + } + } + } + + return 0; +}; |