summaryrefslogtreecommitdiffstats
path: root/testing/webdriver/src/server.rs
diff options
context:
space:
mode:
Diffstat (limited to 'testing/webdriver/src/server.rs')
-rw-r--r--testing/webdriver/src/server.rs691
1 files changed, 691 insertions, 0 deletions
diff --git a/testing/webdriver/src/server.rs b/testing/webdriver/src/server.rs
new file mode 100644
index 0000000000..3aa55c690e
--- /dev/null
+++ b/testing/webdriver/src/server.rs
@@ -0,0 +1,691 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use crate::command::{WebDriverCommand, WebDriverMessage};
+use crate::error::{ErrorStatus, WebDriverError, WebDriverResult};
+use crate::httpapi::{
+ standard_routes, Route, VoidWebDriverExtensionRoute, WebDriverExtensionRoute,
+};
+use crate::response::{CloseWindowResponse, WebDriverResponse};
+use crate::Parameters;
+use bytes::Bytes;
+use http::{self, Method, StatusCode};
+use std::marker::PhantomData;
+use std::net::{SocketAddr, TcpListener as StdTcpListener};
+use std::sync::mpsc::{channel, Receiver, Sender};
+use std::sync::{Arc, Mutex};
+use std::thread;
+use tokio::net::TcpListener;
+use tokio_stream::wrappers::TcpListenerStream;
+use url::{Host, Url};
+use warp::{self, Buf, Filter, Rejection};
+
+// Silence warning about Quit being unused for now.
+#[allow(dead_code)]
+enum DispatchMessage<U: WebDriverExtensionRoute> {
+ HandleWebDriver(
+ WebDriverMessage<U>,
+ Sender<WebDriverResult<WebDriverResponse>>,
+ ),
+ Quit,
+}
+
+#[derive(Clone, Debug, PartialEq)]
+/// Representation of whether we managed to successfully send a DeleteSession message
+/// and read the response during session teardown.
+pub enum SessionTeardownKind {
+ /// A DeleteSession message has been sent and the response handled.
+ Deleted,
+ /// No DeleteSession message has been sent, or the response was not received.
+ NotDeleted,
+}
+
+#[derive(Clone, Debug, PartialEq)]
+pub struct Session {
+ pub id: String,
+}
+
+impl Session {
+ fn new(id: String) -> Session {
+ Session { id }
+ }
+}
+
+pub trait WebDriverHandler<U: WebDriverExtensionRoute = VoidWebDriverExtensionRoute>: Send {
+ fn handle_command(
+ &mut self,
+ session: &Option<Session>,
+ msg: WebDriverMessage<U>,
+ ) -> WebDriverResult<WebDriverResponse>;
+ fn teardown_session(&mut self, kind: SessionTeardownKind);
+}
+
+#[derive(Debug)]
+struct Dispatcher<T: WebDriverHandler<U>, U: WebDriverExtensionRoute> {
+ handler: T,
+ session: Option<Session>,
+ extension_type: PhantomData<U>,
+}
+
+impl<T: WebDriverHandler<U>, U: WebDriverExtensionRoute> Dispatcher<T, U> {
+ fn new(handler: T) -> Dispatcher<T, U> {
+ Dispatcher {
+ handler,
+ session: None,
+ extension_type: PhantomData,
+ }
+ }
+
+ fn run(&mut self, msg_chan: &Receiver<DispatchMessage<U>>) {
+ loop {
+ match msg_chan.recv() {
+ Ok(DispatchMessage::HandleWebDriver(msg, resp_chan)) => {
+ let resp = match self.check_session(&msg) {
+ Ok(_) => self.handler.handle_command(&self.session, msg),
+ Err(e) => Err(e),
+ };
+
+ match resp {
+ Ok(WebDriverResponse::NewSession(ref new_session)) => {
+ self.session = Some(Session::new(new_session.session_id.clone()));
+ }
+ Ok(WebDriverResponse::CloseWindow(CloseWindowResponse(ref handles))) => {
+ if handles.is_empty() {
+ debug!("Last window was closed, deleting session");
+ // The teardown_session implementation is responsible for actually
+ // sending the DeleteSession message in this case
+ self.teardown_session(SessionTeardownKind::NotDeleted);
+ }
+ }
+ Ok(WebDriverResponse::DeleteSession) => {
+ self.teardown_session(SessionTeardownKind::Deleted);
+ }
+ Err(ref x) if x.delete_session => {
+ // This includes the case where we failed during session creation
+ self.teardown_session(SessionTeardownKind::NotDeleted)
+ }
+ _ => {}
+ }
+
+ if resp_chan.send(resp).is_err() {
+ error!("Sending response to the main thread failed");
+ };
+ }
+ Ok(DispatchMessage::Quit) => break,
+ Err(e) => panic!("Error receiving message in handler: {:?}", e),
+ }
+ }
+ }
+
+ fn teardown_session(&mut self, kind: SessionTeardownKind) {
+ debug!("Teardown session");
+ let final_kind = match kind {
+ SessionTeardownKind::NotDeleted if self.session.is_some() => {
+ let delete_session = WebDriverMessage {
+ session_id: Some(
+ self.session
+ .as_ref()
+ .expect("Failed to get session")
+ .id
+ .clone(),
+ ),
+ command: WebDriverCommand::DeleteSession,
+ };
+ match self.handler.handle_command(&self.session, delete_session) {
+ Ok(_) => SessionTeardownKind::Deleted,
+ Err(_) => SessionTeardownKind::NotDeleted,
+ }
+ }
+ _ => kind,
+ };
+ self.handler.teardown_session(final_kind);
+ self.session = None;
+ }
+
+ fn check_session(&self, msg: &WebDriverMessage<U>) -> WebDriverResult<()> {
+ match msg.session_id {
+ Some(ref msg_session_id) => match self.session {
+ Some(ref existing_session) => {
+ if existing_session.id != *msg_session_id {
+ Err(WebDriverError::new(
+ ErrorStatus::InvalidSessionId,
+ format!("Got unexpected session id {}", msg_session_id),
+ ))
+ } else {
+ Ok(())
+ }
+ }
+ None => Ok(()),
+ },
+ None => {
+ match self.session {
+ Some(_) => {
+ match msg.command {
+ WebDriverCommand::Status => Ok(()),
+ WebDriverCommand::NewSession(_) => Err(WebDriverError::new(
+ ErrorStatus::SessionNotCreated,
+ "Session is already started",
+ )),
+ _ => {
+ //This should be impossible
+ error!("Got a message with no session id");
+ Err(WebDriverError::new(
+ ErrorStatus::UnknownError,
+ "Got a command with no session?!",
+ ))
+ }
+ }
+ }
+ None => match msg.command {
+ WebDriverCommand::NewSession(_) => Ok(()),
+ WebDriverCommand::Status => Ok(()),
+ _ => Err(WebDriverError::new(
+ ErrorStatus::InvalidSessionId,
+ "Tried to run a command before creating a session",
+ )),
+ },
+ }
+ }
+ }
+ }
+}
+
+pub struct Listener {
+ guard: Option<thread::JoinHandle<()>>,
+ pub socket: SocketAddr,
+}
+
+impl Drop for Listener {
+ fn drop(&mut self) {
+ let _ = self.guard.take().map(|j| j.join());
+ }
+}
+
+pub fn start<T, U>(
+ mut address: SocketAddr,
+ allow_hosts: Vec<Host>,
+ allow_origins: Vec<Url>,
+ handler: T,
+ extension_routes: Vec<(Method, &'static str, U)>,
+) -> ::std::io::Result<Listener>
+where
+ T: 'static + WebDriverHandler<U>,
+ U: 'static + WebDriverExtensionRoute + Send + Sync,
+{
+ let listener = StdTcpListener::bind(address)?;
+ listener.set_nonblocking(true)?;
+ let addr = listener.local_addr()?;
+ if address.port() == 0 {
+ // If we passed in 0 as the port number the OS will assign an unused port;
+ // we want to update the address to the actual used port
+ address.set_port(addr.port())
+ }
+ let (msg_send, msg_recv) = channel();
+
+ let builder = thread::Builder::new().name("webdriver server".to_string());
+ let handle = builder.spawn(move || {
+ let rt = tokio::runtime::Builder::new_current_thread()
+ .enable_io()
+ .build()
+ .unwrap();
+ let listener = rt.block_on(async { TcpListener::from_std(listener).unwrap() });
+ let wroutes = build_warp_routes(
+ address,
+ allow_hosts,
+ allow_origins,
+ &extension_routes,
+ msg_send.clone(),
+ );
+ let fut = warp::serve(wroutes).run_incoming(TcpListenerStream::new(listener));
+ rt.block_on(fut);
+ })?;
+
+ let builder = thread::Builder::new().name("webdriver dispatcher".to_string());
+ builder.spawn(move || {
+ let mut dispatcher = Dispatcher::new(handler);
+ dispatcher.run(&msg_recv);
+ })?;
+
+ Ok(Listener {
+ guard: Some(handle),
+ socket: addr,
+ })
+}
+
+fn build_warp_routes<U: 'static + WebDriverExtensionRoute + Send + Sync>(
+ address: SocketAddr,
+ allow_hosts: Vec<Host>,
+ allow_origins: Vec<Url>,
+ ext_routes: &[(Method, &'static str, U)],
+ chan: Sender<DispatchMessage<U>>,
+) -> impl Filter<Extract = (impl warp::Reply,), Error = Rejection> + Clone {
+ let chan = Arc::new(Mutex::new(chan));
+ let mut std_routes = standard_routes::<U>();
+ let (method, path, res) = std_routes.pop().unwrap();
+ let mut wroutes = build_route(
+ address,
+ allow_hosts.clone(),
+ allow_origins.clone(),
+ method,
+ path,
+ res,
+ chan.clone(),
+ );
+ for (method, path, res) in std_routes {
+ wroutes = wroutes
+ .or(build_route(
+ address,
+ allow_hosts.clone(),
+ allow_origins.clone(),
+ method,
+ path,
+ res.clone(),
+ chan.clone(),
+ ))
+ .unify()
+ .boxed()
+ }
+ for (method, path, res) in ext_routes {
+ wroutes = wroutes
+ .or(build_route(
+ address,
+ allow_hosts.clone(),
+ allow_origins.clone(),
+ method.clone(),
+ path,
+ Route::Extension(res.clone()),
+ chan.clone(),
+ ))
+ .unify()
+ .boxed()
+ }
+ wroutes
+}
+
+fn is_host_allowed(server_address: &SocketAddr, allow_hosts: &[Host], host_header: &str) -> bool {
+ // Validate that the Host header value has a hostname in allow_hosts and
+ // the port matches the server configuration
+ let header_host_url = match Url::parse(&format!("http://{}", &host_header)) {
+ Ok(x) => x,
+ Err(_) => {
+ return false;
+ }
+ };
+
+ let host = match header_host_url.host() {
+ Some(host) => host.to_owned(),
+ None => {
+ // This shouldn't be possible since http URL always have a
+ // host, but conservatively return false here, which will cause
+ // an error response
+ return false;
+ }
+ };
+ let port = match header_host_url.port_or_known_default() {
+ Some(port) => port,
+ None => {
+ // This shouldn't be possible since http URL always have a
+ // default port, but conservatively return false here, which will cause
+ // an error response
+ return false;
+ }
+ };
+
+ let host_matches = match host {
+ Host::Domain(_) => allow_hosts.contains(&host),
+ Host::Ipv4(_) | Host::Ipv6(_) => true,
+ };
+ let port_matches = server_address.port() == port;
+ host_matches && port_matches
+}
+
+fn is_origin_allowed(allow_origins: &[Url], origin_url: Url) -> bool {
+ // Validate that the Origin header value is in allow_origins
+ allow_origins.contains(&origin_url)
+}
+
+fn build_route<U: 'static + WebDriverExtensionRoute + Send + Sync>(
+ server_address: SocketAddr,
+ allow_hosts: Vec<Host>,
+ allow_origins: Vec<Url>,
+ method: Method,
+ path: &'static str,
+ route: Route<U>,
+ chan: Arc<Mutex<Sender<DispatchMessage<U>>>>,
+) -> warp::filters::BoxedFilter<(impl warp::Reply,)> {
+ // Create an empty filter based on the provided method and append an empty hashmap to it. The
+ // hashmap will be used to store path parameters.
+ let mut subroute = match method {
+ Method::GET => warp::get().boxed(),
+ Method::POST => warp::post().boxed(),
+ Method::DELETE => warp::delete().boxed(),
+ Method::OPTIONS => warp::options().boxed(),
+ Method::PUT => warp::put().boxed(),
+ _ => panic!("Unsupported method"),
+ }
+ .or(warp::head())
+ .unify()
+ .map(Parameters::new)
+ .boxed();
+
+ // For each part of the path, if it's a normal part, just append it to the current filter,
+ // otherwise if it's a parameter (a named enclosed in { }), we take that parameter and insert
+ // it into the hashmap created earlier.
+ for part in path.split('/') {
+ if part.is_empty() {
+ continue;
+ } else if part.starts_with('{') {
+ assert!(part.ends_with('}'));
+
+ subroute = subroute
+ .and(warp::path::param())
+ .map(move |mut params: Parameters, param: String| {
+ let name = &part[1..part.len() - 1];
+ params.insert(name.to_string(), param);
+ params
+ })
+ .boxed();
+ } else {
+ subroute = subroute.and(warp::path(part)).boxed();
+ }
+ }
+
+ // Finally, tell warp that the path is complete
+ subroute
+ .and(warp::path::end())
+ .and(warp::path::full())
+ .and(warp::method())
+ .and(warp::header::optional::<String>("origin"))
+ .and(warp::header::optional::<String>("host"))
+ .and(warp::header::optional::<String>("content-type"))
+ .and(warp::body::bytes())
+ .map(
+ move |params,
+ full_path: warp::path::FullPath,
+ method,
+ origin_header: Option<String>,
+ host_header: Option<String>,
+ content_type_header: Option<String>,
+ body: Bytes| {
+ if method == Method::HEAD {
+ return warp::reply::with_status("".into(), StatusCode::OK);
+ }
+ if let Some(host) = host_header {
+ if !is_host_allowed(&server_address, &allow_hosts, &host) {
+ warn!(
+ "Rejected request with Host header {}, allowed values are [{}]",
+ host,
+ allow_hosts
+ .iter()
+ .map(|x| format!("{}:{}", x, server_address.port()))
+ .collect::<Vec<_>>()
+ .join(",")
+ );
+ let err = WebDriverError::new(
+ ErrorStatus::UnknownError,
+ format!("Invalid Host header {}", host),
+ );
+ return warp::reply::with_status(
+ serde_json::to_string(&err).unwrap(),
+ StatusCode::INTERNAL_SERVER_ERROR,
+ );
+ };
+ } else {
+ warn!("Rejected request with missing Host header");
+ let err = WebDriverError::new(
+ ErrorStatus::UnknownError,
+ "Missing Host header".to_string(),
+ );
+ return warp::reply::with_status(
+ serde_json::to_string(&err).unwrap(),
+ StatusCode::INTERNAL_SERVER_ERROR,
+ );
+ }
+ if let Some(origin) = origin_header {
+ let make_err = || {
+ warn!(
+ "Rejected request with Origin header {}, allowed values are [{}]",
+ origin,
+ allow_origins
+ .iter()
+ .map(|x| x.to_string())
+ .collect::<Vec<_>>()
+ .join(",")
+ );
+ WebDriverError::new(
+ ErrorStatus::UnknownError,
+ format!("Invalid Origin header {}", origin),
+ )
+ };
+ let origin_url = match Url::parse(&origin) {
+ Ok(url) => url,
+ Err(_) => {
+ return warp::reply::with_status(
+ serde_json::to_string(&make_err()).unwrap(),
+ StatusCode::INTERNAL_SERVER_ERROR,
+ );
+ }
+ };
+ if !is_origin_allowed(&allow_origins, origin_url) {
+ return warp::reply::with_status(
+ serde_json::to_string(&make_err()).unwrap(),
+ StatusCode::INTERNAL_SERVER_ERROR,
+ );
+ }
+ }
+ if method == Method::POST {
+ // Disallow CORS-safelisted request headers
+ // c.f. https://fetch.spec.whatwg.org/#cors-safelisted-request-header
+ let content_type = content_type_header
+ .as_ref()
+ .map(|x| x.find(';').and_then(|idx| x.get(0..idx)).unwrap_or(x))
+ .map(|x| x.trim())
+ .map(|x| x.to_lowercase());
+ match content_type.as_ref().map(|x| x.as_ref()) {
+ Some("application/x-www-form-urlencoded")
+ | Some("multipart/form-data")
+ | Some("text/plain") => {
+ warn!(
+ "Rejected POST request with disallowed content type {}",
+ content_type.unwrap_or_else(|| "".into())
+ );
+ let err = WebDriverError::new(
+ ErrorStatus::UnknownError,
+ "Invalid Content-Type",
+ );
+ return warp::reply::with_status(
+ serde_json::to_string(&err).unwrap(),
+ StatusCode::INTERNAL_SERVER_ERROR,
+ );
+ }
+ Some(_) | None => {}
+ }
+ }
+ let body = String::from_utf8(body.chunk().to_vec());
+ if body.is_err() {
+ let err = WebDriverError::new(
+ ErrorStatus::UnknownError,
+ "Request body wasn't valid UTF-8",
+ );
+ return warp::reply::with_status(
+ serde_json::to_string(&err).unwrap(),
+ StatusCode::INTERNAL_SERVER_ERROR,
+ );
+ }
+ let body = body.unwrap();
+
+ debug!("-> {} {} {}", method, full_path.as_str(), body);
+ let msg_result = WebDriverMessage::from_http(
+ route.clone(),
+ &params,
+ &body,
+ method == Method::POST,
+ );
+
+ let (status, resp_body) = match msg_result {
+ Ok(message) => {
+ let (send_res, recv_res) = channel();
+ match chan.lock() {
+ Ok(ref c) => {
+ let res =
+ c.send(DispatchMessage::HandleWebDriver(message, send_res));
+ match res {
+ Ok(x) => x,
+ Err(e) => panic!("Error: {:?}", e),
+ }
+ }
+ Err(e) => panic!("Error reading response: {:?}", e),
+ }
+
+ match recv_res.recv() {
+ Ok(data) => match data {
+ Ok(response) => {
+ (StatusCode::OK, serde_json::to_string(&response).unwrap())
+ }
+ Err(e) => (e.http_status(), serde_json::to_string(&e).unwrap()),
+ },
+ Err(e) => panic!("Error reading response: {:?}", e),
+ }
+ }
+ Err(e) => (e.http_status(), serde_json::to_string(&e).unwrap()),
+ };
+
+ debug!("<- {} {}", status, resp_body);
+ warp::reply::with_status(resp_body, status)
+ },
+ )
+ .with(warp::reply::with::header(
+ http::header::CONTENT_TYPE,
+ "application/json; charset=utf-8",
+ ))
+ .with(warp::reply::with::header(
+ http::header::CACHE_CONTROL,
+ "no-cache",
+ ))
+ .boxed()
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::net::IpAddr;
+ use std::str::FromStr;
+
+ #[test]
+ fn test_host_allowed() {
+ let addr_80 = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
+ let addr_8000 = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 8000);
+ let addr_v6_80 = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 80);
+ let addr_v6_8000 = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 8000);
+
+ // We match the host ip address to the server, so we can only use hosts that actually resolve
+ let localhost_host = Host::Domain("localhost".to_string());
+ let test_host = Host::Domain("example.test".to_string());
+ let subdomain_localhost_host = Host::Domain("subdomain.localhost".to_string());
+
+ assert!(is_host_allowed(
+ &addr_80,
+ &[localhost_host.clone()],
+ "localhost:80"
+ ));
+ assert!(is_host_allowed(
+ &addr_80,
+ &[test_host.clone()],
+ "example.test:80"
+ ));
+ assert!(is_host_allowed(
+ &addr_80,
+ &[test_host.clone(), localhost_host.clone()],
+ "example.test"
+ ));
+ assert!(is_host_allowed(
+ &addr_80,
+ &[subdomain_localhost_host.clone()],
+ "subdomain.localhost"
+ ));
+
+ // ip address cases
+ assert!(is_host_allowed(&addr_80, &[], "127.0.0.1:80"));
+ assert!(is_host_allowed(&addr_v6_80, &[], "127.0.0.1"));
+ assert!(is_host_allowed(&addr_80, &[], "[::1]"));
+ assert!(is_host_allowed(&addr_8000, &[], "127.0.0.1:8000"));
+ assert!(is_host_allowed(
+ &addr_80,
+ &[subdomain_localhost_host.clone()],
+ "[::1]"
+ ));
+ assert!(is_host_allowed(
+ &addr_v6_8000,
+ &[subdomain_localhost_host.clone()],
+ "[::1]:8000"
+ ));
+
+ // Mismatch cases
+
+ assert!(!is_host_allowed(&addr_80, &[test_host], "localhost"));
+
+ assert!(!is_host_allowed(&addr_80, &[], "localhost:80"));
+
+ // Port mismatch cases
+
+ assert!(!is_host_allowed(
+ &addr_80,
+ &[localhost_host.clone()],
+ "localhost:8000"
+ ));
+ assert!(!is_host_allowed(
+ &addr_8000,
+ &[localhost_host.clone()],
+ "localhost"
+ ));
+ assert!(!is_host_allowed(
+ &addr_v6_8000,
+ &[localhost_host.clone()],
+ "[::1]"
+ ));
+ }
+
+ #[test]
+ fn test_origin_allowed() {
+ assert!(is_origin_allowed(
+ &[Url::parse("http://localhost").unwrap()],
+ Url::parse("http://localhost").unwrap()
+ ));
+ assert!(is_origin_allowed(
+ &[Url::parse("http://localhost").unwrap()],
+ Url::parse("http://localhost:80").unwrap()
+ ));
+ assert!(is_origin_allowed(
+ &[
+ Url::parse("https://test.example").unwrap(),
+ Url::parse("http://localhost").unwrap()
+ ],
+ Url::parse("http://localhost").unwrap()
+ ));
+ assert!(is_origin_allowed(
+ &[
+ Url::parse("https://test.example").unwrap(),
+ Url::parse("http://localhost").unwrap()
+ ],
+ Url::parse("https://test.example:443").unwrap()
+ ));
+ // Mismatch cases
+ assert!(!is_origin_allowed(
+ &[],
+ Url::parse("http://localhost").unwrap()
+ ));
+ assert!(!is_origin_allowed(
+ &[Url::parse("http://localhost").unwrap()],
+ Url::parse("http://localhost:8000").unwrap()
+ ));
+ assert!(!is_origin_allowed(
+ &[Url::parse("https://localhost").unwrap()],
+ Url::parse("http://localhost").unwrap()
+ ));
+ assert!(!is_origin_allowed(
+ &[Url::parse("https://example.test").unwrap()],
+ Url::parse("http://subdomain.example.test").unwrap()
+ ));
+ }
+}