summaryrefslogtreecommitdiffstats
path: root/testing/geckodriver/marionette/src/message.rs
diff options
context:
space:
mode:
Diffstat (limited to 'testing/geckodriver/marionette/src/message.rs')
-rw-r--r--testing/geckodriver/marionette/src/message.rs336
1 files changed, 336 insertions, 0 deletions
diff --git a/testing/geckodriver/marionette/src/message.rs b/testing/geckodriver/marionette/src/message.rs
new file mode 100644
index 0000000000..704d52f67b
--- /dev/null
+++ b/testing/geckodriver/marionette/src/message.rs
@@ -0,0 +1,336 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use serde::de::{self, SeqAccess, Unexpected, Visitor};
+use serde::{Deserialize, Deserializer, Serialize, Serializer};
+use serde_json::{Map, Value};
+use serde_repr::{Deserialize_repr, Serialize_repr};
+use std::fmt;
+
+use crate::error::MarionetteError;
+use crate::marionette;
+use crate::result::MarionetteResult;
+use crate::webdriver;
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum Command {
+ WebDriver(webdriver::Command),
+ Marionette(marionette::Command),
+}
+
+impl Command {
+ pub fn name(&self) -> String {
+ let (command_name, _) = self.first_entry();
+ command_name
+ }
+
+ fn params(&self) -> Value {
+ let (_, params) = self.first_entry();
+ params
+ }
+
+ fn first_entry(&self) -> (String, serde_json::Value) {
+ match serde_json::to_value(self).unwrap() {
+ Value::String(cmd) => (cmd, Value::Object(Map::new())),
+ Value::Object(items) => {
+ let mut iter = items.iter();
+ let (cmd, params) = iter.next().unwrap();
+ (cmd.to_string(), params.clone())
+ }
+ _ => unreachable!(),
+ }
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize_repr, Deserialize_repr)]
+#[repr(u8)]
+enum MessageDirection {
+ Incoming = 0,
+ Outgoing = 1,
+}
+
+pub type MessageId = u32;
+
+#[derive(Debug, Clone, PartialEq)]
+pub struct Request(pub MessageId, pub Command);
+
+impl Request {
+ pub fn id(&self) -> MessageId {
+ self.0
+ }
+
+ pub fn command(&self) -> &Command {
+ &self.1
+ }
+
+ pub fn params(&self) -> Value {
+ self.command().params()
+ }
+}
+
+impl Serialize for Request {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: Serializer,
+ {
+ (
+ MessageDirection::Incoming,
+ self.id(),
+ self.command().name(),
+ self.params(),
+ )
+ .serialize(serializer)
+ }
+}
+
+#[derive(Debug, PartialEq)]
+pub enum Response {
+ Result {
+ id: MessageId,
+ result: MarionetteResult,
+ },
+ Error {
+ id: MessageId,
+ error: MarionetteError,
+ },
+}
+
+impl Serialize for Response {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: Serializer,
+ {
+ match self {
+ Response::Result { id, result } => {
+ (MessageDirection::Outgoing, id, Value::Null, &result).serialize(serializer)
+ }
+ Response::Error { id, error } => {
+ (MessageDirection::Outgoing, id, &error, Value::Null).serialize(serializer)
+ }
+ }
+ }
+}
+
+#[derive(Debug, PartialEq, Serialize)]
+#[serde(untagged)]
+pub enum Message {
+ Incoming(Request),
+ Outgoing(Response),
+}
+
+struct MessageVisitor;
+
+impl<'de> Visitor<'de> for MessageVisitor {
+ type Value = Message;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("four-element array")
+ }
+
+ fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
+ let direction = seq
+ .next_element::<MessageDirection>()?
+ .ok_or_else(|| de::Error::invalid_length(0, &self))?;
+ let id: MessageId = seq
+ .next_element()?
+ .ok_or_else(|| de::Error::invalid_length(1, &self))?;
+
+ let msg = match direction {
+ MessageDirection::Incoming => {
+ let name: String = seq
+ .next_element()?
+ .ok_or_else(|| de::Error::invalid_length(2, &self))?;
+ let params: Value = seq
+ .next_element()?
+ .ok_or_else(|| de::Error::invalid_length(3, &self))?;
+
+ let command = match params {
+ Value::Object(ref items) if !items.is_empty() => {
+ let command_to_params = {
+ let mut m = Map::new();
+ m.insert(name, params);
+ Value::Object(m)
+ };
+ serde_json::from_value(command_to_params).map_err(de::Error::custom)
+ }
+ Value::Object(_) | Value::Null => {
+ serde_json::from_value(Value::String(name)).map_err(de::Error::custom)
+ }
+ x => Err(de::Error::custom(format!("unknown params type: {}", x))),
+ }?;
+ Message::Incoming(Request(id, command))
+ }
+
+ MessageDirection::Outgoing => {
+ let maybe_error: Option<MarionetteError> = seq
+ .next_element()?
+ .ok_or_else(|| de::Error::invalid_length(2, &self))?;
+
+ let response = if let Some(error) = maybe_error {
+ seq.next_element::<Value>()?
+ .ok_or_else(|| de::Error::invalid_length(3, &self))?
+ .as_null()
+ .ok_or_else(|| de::Error::invalid_type(Unexpected::Unit, &self))?;
+ Response::Error { id, error }
+ } else {
+ let result: MarionetteResult = seq
+ .next_element()?
+ .ok_or_else(|| de::Error::invalid_length(3, &self))?;
+ Response::Result { id, result }
+ };
+
+ Message::Outgoing(response)
+ }
+ };
+
+ Ok(msg)
+ }
+}
+
+impl<'de> Deserialize<'de> for Message {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ deserializer.deserialize_seq(MessageVisitor)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use serde_json::json;
+
+ use super::*;
+
+ use crate::common::*;
+ use crate::error::{ErrorKind, MarionetteError};
+ use crate::test::assert_ser_de;
+
+ #[test]
+ fn test_incoming() {
+ let json =
+ json!([0, 42, "WebDriver:FindElement", {"using": "css selector", "value": "value"}]);
+ let find_element = webdriver::Command::FindElement(webdriver::Locator {
+ using: webdriver::Selector::Css,
+ value: "value".into(),
+ });
+ let req = Request(42, Command::WebDriver(find_element));
+ let msg = Message::Incoming(req);
+ assert_ser_de(&msg, json);
+ }
+
+ #[test]
+ fn test_incoming_empty_params() {
+ let json = json!([0, 42, "WebDriver:GetTimeouts", {}]);
+ let req = Request(42, Command::WebDriver(webdriver::Command::GetTimeouts));
+ let msg = Message::Incoming(req);
+ assert_ser_de(&msg, json);
+ }
+
+ #[test]
+ fn test_incoming_common_params() {
+ let json = json!([0, 42, "Marionette:AcceptConnections", {"value": false}]);
+ let params = BoolValue::new(false);
+ let req = Request(
+ 42,
+ Command::Marionette(marionette::Command::AcceptConnections(params)),
+ );
+ let msg = Message::Incoming(req);
+ assert_ser_de(&msg, json);
+ }
+
+ #[test]
+ fn test_incoming_params_derived() {
+ assert!(serde_json::from_value::<Message>(
+ json!([0,42,"WebDriver:FindElement",{"using":"foo","value":"foo"}])
+ )
+ .is_err());
+ assert!(serde_json::from_value::<Message>(
+ json!([0,42,"Marionette:AcceptConnections",{"value":"foo"}])
+ )
+ .is_err());
+ }
+
+ #[test]
+ fn test_incoming_no_params() {
+ assert!(serde_json::from_value::<Message>(
+ json!([0,42,"WebDriver:GetTimeouts",{"value":true}])
+ )
+ .is_err());
+ assert!(serde_json::from_value::<Message>(
+ json!([0,42,"Marionette:Context",{"value":"foo"}])
+ )
+ .is_err());
+ assert!(serde_json::from_value::<Message>(
+ json!([0,42,"Marionette:GetScreenOrientation",{"value":true}])
+ )
+ .is_err());
+ }
+
+ #[test]
+ fn test_outgoing_result() {
+ let json = json!([1, 42, null, { "value": null }]);
+ let result = MarionetteResult::Null;
+ let msg = Message::Outgoing(Response::Result { id: 42, result });
+
+ assert_ser_de(&msg, json);
+ }
+
+ #[test]
+ fn test_outgoing_error() {
+ let json =
+ json!([1, 42, {"error": "no such element", "message": "", "stacktrace": ""}, null]);
+ let error = MarionetteError {
+ kind: ErrorKind::NoSuchElement,
+ message: "".into(),
+ stack: "".into(),
+ };
+ let msg = Message::Outgoing(Response::Error { id: 42, error });
+
+ assert_ser_de(&msg, json);
+ }
+
+ #[test]
+ fn test_invalid_type() {
+ assert!(
+ serde_json::from_value::<Message>(json!([2, 42, "WebDriver:GetTimeouts", {}])).is_err()
+ );
+ assert!(serde_json::from_value::<Message>(json!([3, 42, "no such element", {}])).is_err());
+ }
+
+ #[test]
+ fn test_missing_fields() {
+ // all fields are required
+ assert!(
+ serde_json::from_value::<Message>(json!([2, 42, "WebDriver:GetTimeouts"])).is_err()
+ );
+ assert!(serde_json::from_value::<Message>(json!([2, 42])).is_err());
+ assert!(serde_json::from_value::<Message>(json!([2])).is_err());
+ assert!(serde_json::from_value::<Message>(json!([])).is_err());
+ }
+
+ #[test]
+ fn test_unknown_command() {
+ assert!(serde_json::from_value::<Message>(json!([0, 42, "hooba", {}])).is_err());
+ }
+
+ #[test]
+ fn test_unknown_error() {
+ assert!(serde_json::from_value::<Message>(json!([1, 42, "flooba", {}])).is_err());
+ }
+
+ #[test]
+ fn test_message_id_bounds() {
+ let overflow = i64::from(std::u32::MAX) + 1;
+ let underflow = -1;
+
+ fn get_timeouts(message_id: i64) -> Value {
+ json!([0, message_id, "WebDriver:GetTimeouts", {}])
+ }
+
+ assert!(serde_json::from_value::<Message>(get_timeouts(overflow)).is_err());
+ assert!(serde_json::from_value::<Message>(get_timeouts(underflow)).is_err());
+ }
+}