/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ use serde::de::{self, SeqAccess, Unexpected, Visitor}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_json::{Map, Value}; use serde_repr::{Deserialize_repr, Serialize_repr}; use std::fmt; use crate::error::MarionetteError; use crate::marionette; use crate::result::MarionetteResult; use crate::webdriver; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[serde(untagged)] pub enum Command { WebDriver(webdriver::Command), Marionette(marionette::Command), } impl Command { pub fn name(&self) -> String { let (command_name, _) = self.first_entry(); command_name } fn params(&self) -> Value { let (_, params) = self.first_entry(); params } fn first_entry(&self) -> (String, serde_json::Value) { match serde_json::to_value(self).unwrap() { Value::String(cmd) => (cmd, Value::Object(Map::new())), Value::Object(items) => { let mut iter = items.iter(); let (cmd, params) = iter.next().unwrap(); (cmd.to_string(), params.clone()) } _ => unreachable!(), } } } #[derive(Clone, Debug, PartialEq, Serialize_repr, Deserialize_repr)] #[repr(u8)] enum MessageDirection { Incoming = 0, Outgoing = 1, } pub type MessageId = u32; #[derive(Debug, Clone, PartialEq)] pub struct Request(pub MessageId, pub Command); impl Request { pub fn id(&self) -> MessageId { self.0 } pub fn command(&self) -> &Command { &self.1 } pub fn params(&self) -> Value { self.command().params() } } impl Serialize for Request { fn serialize(&self, serializer: S) -> Result where S: Serializer, { ( MessageDirection::Incoming, self.id(), self.command().name(), self.params(), ) .serialize(serializer) } } #[derive(Debug, PartialEq)] pub enum Response { Result { id: MessageId, result: MarionetteResult, }, Error { id: MessageId, error: MarionetteError, }, } impl Serialize for Response { fn serialize(&self, serializer: S) -> Result where S: Serializer, { match self { Response::Result { id, result } => { (MessageDirection::Outgoing, id, Value::Null, &result).serialize(serializer) } Response::Error { id, error } => { (MessageDirection::Outgoing, id, &error, Value::Null).serialize(serializer) } } } } #[derive(Debug, PartialEq, Serialize)] #[serde(untagged)] pub enum Message { Incoming(Request), Outgoing(Response), } struct MessageVisitor; impl<'de> Visitor<'de> for MessageVisitor { type Value = Message; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("four-element array") } fn visit_seq>(self, mut seq: A) -> Result { let direction = seq .next_element::()? .ok_or_else(|| de::Error::invalid_length(0, &self))?; let id: MessageId = seq .next_element()? .ok_or_else(|| de::Error::invalid_length(1, &self))?; let msg = match direction { MessageDirection::Incoming => { let name: String = seq .next_element()? .ok_or_else(|| de::Error::invalid_length(2, &self))?; let params: Value = seq .next_element()? .ok_or_else(|| de::Error::invalid_length(3, &self))?; let command = match params { Value::Object(ref items) if !items.is_empty() => { let command_to_params = { let mut m = Map::new(); m.insert(name, params); Value::Object(m) }; serde_json::from_value(command_to_params).map_err(de::Error::custom) } Value::Object(_) | Value::Null => { serde_json::from_value(Value::String(name)).map_err(de::Error::custom) } x => Err(de::Error::custom(format!("unknown params type: {}", x))), }?; Message::Incoming(Request(id, command)) } MessageDirection::Outgoing => { let maybe_error: Option = seq .next_element()? .ok_or_else(|| de::Error::invalid_length(2, &self))?; let response = if let Some(error) = maybe_error { seq.next_element::()? .ok_or_else(|| de::Error::invalid_length(3, &self))? .as_null() .ok_or_else(|| de::Error::invalid_type(Unexpected::Unit, &self))?; Response::Error { id, error } } else { let result: MarionetteResult = seq .next_element()? .ok_or_else(|| de::Error::invalid_length(3, &self))?; Response::Result { id, result } }; Message::Outgoing(response) } }; Ok(msg) } } impl<'de> Deserialize<'de> for Message { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { deserializer.deserialize_seq(MessageVisitor) } } #[cfg(test)] mod tests { use serde_json::json; use super::*; use crate::common::*; use crate::error::{ErrorKind, MarionetteError}; use crate::test::assert_ser_de; #[test] fn test_incoming() { let json = json!([0, 42, "WebDriver:FindElement", {"using": "css selector", "value": "value"}]); let find_element = webdriver::Command::FindElement(webdriver::Locator { using: webdriver::Selector::Css, value: "value".into(), }); let req = Request(42, Command::WebDriver(find_element)); let msg = Message::Incoming(req); assert_ser_de(&msg, json); } #[test] fn test_incoming_empty_params() { let json = json!([0, 42, "WebDriver:GetTimeouts", {}]); let req = Request(42, Command::WebDriver(webdriver::Command::GetTimeouts)); let msg = Message::Incoming(req); assert_ser_de(&msg, json); } #[test] fn test_incoming_common_params() { let json = json!([0, 42, "Marionette:AcceptConnections", {"value": false}]); let params = BoolValue::new(false); let req = Request( 42, Command::Marionette(marionette::Command::AcceptConnections(params)), ); let msg = Message::Incoming(req); assert_ser_de(&msg, json); } #[test] fn test_incoming_params_derived() { assert!(serde_json::from_value::( json!([0,42,"WebDriver:FindElement",{"using":"foo","value":"foo"}]) ) .is_err()); assert!(serde_json::from_value::( json!([0,42,"Marionette:AcceptConnections",{"value":"foo"}]) ) .is_err()); } #[test] fn test_incoming_no_params() { assert!(serde_json::from_value::( json!([0,42,"WebDriver:GetTimeouts",{"value":true}]) ) .is_err()); assert!(serde_json::from_value::( json!([0,42,"Marionette:Context",{"value":"foo"}]) ) .is_err()); assert!(serde_json::from_value::( json!([0,42,"Marionette:GetScreenOrientation",{"value":true}]) ) .is_err()); } #[test] fn test_outgoing_result() { let json = json!([1, 42, null, { "value": null }]); let result = MarionetteResult::Null; let msg = Message::Outgoing(Response::Result { id: 42, result }); assert_ser_de(&msg, json); } #[test] fn test_outgoing_error() { let json = json!([1, 42, {"error": "no such element", "message": "", "stacktrace": ""}, null]); let error = MarionetteError { kind: ErrorKind::NoSuchElement, message: "".into(), stack: "".into(), }; let msg = Message::Outgoing(Response::Error { id: 42, error }); assert_ser_de(&msg, json); } #[test] fn test_invalid_type() { assert!( serde_json::from_value::(json!([2, 42, "WebDriver:GetTimeouts", {}])).is_err() ); assert!(serde_json::from_value::(json!([3, 42, "no such element", {}])).is_err()); } #[test] fn test_missing_fields() { // all fields are required assert!( serde_json::from_value::(json!([2, 42, "WebDriver:GetTimeouts"])).is_err() ); assert!(serde_json::from_value::(json!([2, 42])).is_err()); assert!(serde_json::from_value::(json!([2])).is_err()); assert!(serde_json::from_value::(json!([])).is_err()); } #[test] fn test_unknown_command() { assert!(serde_json::from_value::(json!([0, 42, "hooba", {}])).is_err()); } #[test] fn test_unknown_error() { assert!(serde_json::from_value::(json!([1, 42, "flooba", {}])).is_err()); } #[test] fn test_message_id_bounds() { let overflow = i64::from(std::u32::MAX) + 1; let underflow = -1; fn get_timeouts(message_id: i64) -> Value { json!([0, message_id, "WebDriver:GetTimeouts", {}]) } assert!(serde_json::from_value::(get_timeouts(overflow)).is_err()); assert!(serde_json::from_value::(get_timeouts(underflow)).is_err()); } }