diff options
Diffstat (limited to '')
-rw-r--r-- | testing/geckodriver/src/main.rs | 549 |
1 files changed, 549 insertions, 0 deletions
diff --git a/testing/geckodriver/src/main.rs b/testing/geckodriver/src/main.rs new file mode 100644 index 0000000000..64df65a0d0 --- /dev/null +++ b/testing/geckodriver/src/main.rs @@ -0,0 +1,549 @@ +#![forbid(unsafe_code)] + +extern crate chrono; +#[macro_use] +extern crate clap; +#[macro_use] +extern crate lazy_static; +extern crate hyper; +extern crate marionette as marionette_rs; +extern crate mozdevice; +extern crate mozprofile; +extern crate mozrunner; +extern crate mozversion; +extern crate regex; +extern crate serde; +#[macro_use] +extern crate serde_derive; +extern crate serde_json; +extern crate serde_yaml; +extern crate tempfile; +extern crate url; +extern crate uuid; +extern crate webdriver; +extern crate zip; + +#[macro_use] +extern crate log; + +use std::env; +use std::fmt; +use std::io; +use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; +use std::path::PathBuf; +use std::result; +use std::str::FromStr; + +use clap::{AppSettings, Arg, Command}; + +macro_rules! try_opt { + ($expr:expr, $err_type:expr, $err_msg:expr) => {{ + match $expr { + Some(x) => x, + None => return Err(WebDriverError::new($err_type, $err_msg)), + } + }}; +} + +mod android; +mod browser; +mod build; +mod capabilities; +mod command; +mod logging; +mod marionette; +mod prefs; + +#[cfg(test)] +pub mod test; + +use crate::command::extension_routes; +use crate::logging::Level; +use crate::marionette::{MarionetteHandler, MarionetteSettings}; +use mozdevice::AndroidStorageInput; +use url::{Host, Url}; + +const EXIT_SUCCESS: i32 = 0; +const EXIT_USAGE: i32 = 64; +const EXIT_UNAVAILABLE: i32 = 69; + +enum FatalError { + Parsing(clap::Error), + Usage(String), + Server(io::Error), +} + +impl FatalError { + fn exit_code(&self) -> i32 { + use FatalError::*; + match *self { + Parsing(_) | Usage(_) => EXIT_USAGE, + Server(_) => EXIT_UNAVAILABLE, + } + } + + fn help_included(&self) -> bool { + matches!(*self, FatalError::Parsing(_)) + } +} + +impl From<clap::Error> for FatalError { + fn from(err: clap::Error) -> FatalError { + FatalError::Parsing(err) + } +} + +impl From<io::Error> for FatalError { + fn from(err: io::Error) -> FatalError { + FatalError::Server(err) + } +} + +// harmonise error message from clap to avoid duplicate "error:" prefix +impl fmt::Display for FatalError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use FatalError::*; + let s = match *self { + Parsing(ref err) => err.to_string(), + Usage(ref s) => format!("error: {}", s), + Server(ref err) => format!("error: {}", err), + }; + write!(f, "{}", s) + } +} + +macro_rules! usage { + ($msg:expr) => { + return Err(FatalError::Usage($msg.to_string())) + }; + + ($fmt:expr, $($arg:tt)+) => { + return Err(FatalError::Usage(format!($fmt, $($arg)+))) + }; +} + +type ProgramResult<T> = result::Result<T, FatalError>; + +#[allow(clippy::large_enum_variant)] +enum Operation { + Help, + Version, + Server { + log_level: Option<Level>, + log_truncate: bool, + address: SocketAddr, + allow_hosts: Vec<Host>, + allow_origins: Vec<Url>, + settings: MarionetteSettings, + deprecated_storage_arg: bool, + }, +} + +/// Get a socket address from the provided host and port +/// +/// # Arguments +/// * `webdriver_host` - The hostname on which the server will listen +/// * `webdriver_port` - The port on which the server will listen +/// +/// When the host and port resolve to multiple addresses, prefer +/// IPv4 addresses vs IPv6. +fn server_address(webdriver_host: &str, webdriver_port: u16) -> ProgramResult<SocketAddr> { + let mut socket_addrs = match format!("{}:{}", webdriver_host, webdriver_port).to_socket_addrs() + { + Ok(addrs) => addrs.collect::<Vec<_>>(), + Err(e) => usage!("{}: {}:{}", e, webdriver_host, webdriver_port), + }; + if socket_addrs.is_empty() { + usage!( + "Unable to resolve host: {}:{}", + webdriver_host, + webdriver_port + ) + } + // Prefer ipv4 address + socket_addrs.sort_by(|a, b| { + let a_val = i32::from(!a.ip().is_ipv4()); + let b_val = i32::from(!b.ip().is_ipv4()); + a_val.partial_cmp(&b_val).expect("Comparison failed") + }); + Ok(socket_addrs.remove(0)) +} + +/// Parse a given string into a Host +fn parse_hostname(webdriver_host: &str) -> Result<Host, url::ParseError> { + let host_str = if let Ok(ip_addr) = IpAddr::from_str(webdriver_host) { + // In this case we have an IP address as the host + if ip_addr.is_ipv6() { + // Convert to quoted form + format!("[{}]", &webdriver_host) + } else { + webdriver_host.into() + } + } else { + webdriver_host.into() + }; + + Host::parse(&host_str) +} + +/// Get a list of default hostnames to allow +/// +/// This only covers domain names, not IP addresses, since IP adresses +/// are always accepted. +fn get_default_allowed_hosts(ip: IpAddr) -> Vec<Result<Host, url::ParseError>> { + let localhost_is_loopback = ("localhost".to_string(), 80) + .to_socket_addrs() + .map(|addr_iter| { + addr_iter + .map(|addr| addr.ip()) + .filter(|ip| ip.is_loopback()) + }) + .iter() + .len() + > 0; + if ip.is_loopback() && localhost_is_loopback { + vec![Host::parse("localhost")] + } else { + vec![] + } +} + +fn get_allowed_hosts( + host: Host, + allow_hosts: Option<clap::Values>, +) -> Result<Vec<Host>, url::ParseError> { + allow_hosts + .map(|hosts| hosts.map(Host::parse).collect::<Vec<_>>()) + .unwrap_or_else(|| match host { + Host::Domain(_) => { + vec![Ok(host.clone())] + } + Host::Ipv4(ip) => get_default_allowed_hosts(IpAddr::V4(ip)), + Host::Ipv6(ip) => get_default_allowed_hosts(IpAddr::V6(ip)), + }) + .into_iter() + .collect::<Result<Vec<Host>, url::ParseError>>() +} + +fn get_allowed_origins(allow_origins: Option<clap::Values>) -> Result<Vec<Url>, url::ParseError> { + allow_origins + .map(|origins| { + origins + .map(Url::parse) + .collect::<Result<Vec<Url>, url::ParseError>>() + }) + .unwrap_or_else(|| Ok(vec![])) +} + +fn parse_args(cmd: &mut Command) -> ProgramResult<Operation> { + let args = cmd.try_get_matches_from_mut(env::args())?; + + if args.is_present("help") { + return Ok(Operation::Help); + } else if args.is_present("version") { + return Ok(Operation::Version); + } + + let log_level = if args.is_present("log_level") { + Level::from_str(args.value_of("log_level").unwrap()).ok() + } else { + Some(match args.occurrences_of("verbosity") { + 0 => Level::Info, + 1 => Level::Debug, + _ => Level::Trace, + }) + }; + + let webdriver_host = args.value_of("webdriver_host").unwrap(); + let webdriver_port = { + let s = args.value_of("webdriver_port").unwrap(); + match u16::from_str(s) { + Ok(n) => n, + Err(e) => usage!("invalid --port: {}: {}", e, s), + } + }; + + let android_storage = args + .value_of_t::<AndroidStorageInput>("android_storage") + .unwrap_or(AndroidStorageInput::Auto); + + let binary = args.value_of("binary").map(PathBuf::from); + + let profile_root = args.value_of("profile_root").map(PathBuf::from); + + // Try to create a temporary directory on startup to check that the directory exists and is writable + { + let tmp_dir = if let Some(ref tmp_root) = profile_root { + tempfile::tempdir_in(tmp_root) + } else { + tempfile::tempdir() + }; + if tmp_dir.is_err() { + usage!("Unable to write to temporary directory; consider --profile-root with a writeable directory") + } + } + + let marionette_host = args.value_of("marionette_host").unwrap(); + let marionette_port = match args.value_of("marionette_port") { + Some(s) => match u16::from_str(s) { + Ok(n) => Some(n), + Err(e) => usage!("invalid --marionette-port: {}", e), + }, + None => None, + }; + + // For Android the port on the device must be the same as the one on the + // host. For now default to 9222, which is the default for --remote-debugging-port. + let websocket_port = match args.value_of("websocket_port") { + Some(s) => match u16::from_str(s) { + Ok(n) => n, + Err(e) => usage!("invalid --websocket-port: {}", e), + }, + None => 9222, + }; + + let host = match parse_hostname(webdriver_host) { + Ok(name) => name, + Err(e) => usage!("invalid --host {}: {}", webdriver_host, e), + }; + + let allow_hosts = match get_allowed_hosts(host, args.values_of("allow_hosts")) { + Ok(hosts) => hosts, + Err(e) => usage!("invalid --allow-hosts {}", e), + }; + + let allow_origins = match get_allowed_origins(args.values_of("allow_origins")) { + Ok(origins) => origins, + Err(e) => usage!("invalid --allow-origins {}", e), + }; + + let address = server_address(webdriver_host, webdriver_port)?; + + let settings = MarionetteSettings { + binary, + profile_root, + connect_existing: args.is_present("connect_existing"), + host: marionette_host.into(), + port: marionette_port, + websocket_port, + allow_hosts: allow_hosts.clone(), + allow_origins: allow_origins.clone(), + jsdebugger: args.is_present("jsdebugger"), + android_storage, + }; + Ok(Operation::Server { + log_level, + log_truncate: !args.is_present("log_no_truncate"), + allow_hosts, + allow_origins, + address, + settings, + deprecated_storage_arg: args.is_present("android_storage"), + }) +} + +fn inner_main(cmd: &mut Command) -> ProgramResult<()> { + match parse_args(cmd)? { + Operation::Help => print_help(cmd), + Operation::Version => print_version(), + + Operation::Server { + log_level, + log_truncate, + address, + allow_hosts, + allow_origins, + settings, + deprecated_storage_arg, + } => { + if let Some(ref level) = log_level { + logging::init_with_level(*level, log_truncate).unwrap(); + } else { + logging::init(log_truncate).unwrap(); + } + + if deprecated_storage_arg { + warn!("--android-storage argument is deprecated and will be removed soon."); + }; + + let handler = MarionetteHandler::new(settings); + let listening = webdriver::server::start( + address, + allow_hosts, + allow_origins, + handler, + extension_routes(), + )?; + info!("Listening on {}", listening.socket); + } + } + + Ok(()) +} + +fn main() { + use std::process::exit; + + let mut cmd = make_command(); + + // use std::process:Termination when it graduates + exit(match inner_main(&mut cmd) { + Ok(_) => EXIT_SUCCESS, + + Err(e) => { + eprintln!("{}: {}", get_program_name(), e); + if !e.help_included() { + print_help(&mut cmd); + } + + e.exit_code() + } + }); +} + +fn make_command<'a>() -> Command<'a> { + Command::new(format!("geckodriver {}", build::build_info())) + .setting(AppSettings::NoAutoHelp) + .setting(AppSettings::NoAutoVersion) + .about("WebDriver implementation for Firefox") + .arg( + Arg::new("webdriver_host") + .long("host") + .takes_value(true) + .value_name("HOST") + .default_value("127.0.0.1") + .help("Host IP to use for WebDriver server"), + ) + .arg( + Arg::new("webdriver_port") + .short('p') + .long("port") + .takes_value(true) + .value_name("PORT") + .default_value("4444") + .help("Port to use for WebDriver server"), + ) + .arg( + Arg::new("binary") + .short('b') + .long("binary") + .takes_value(true) + .value_name("BINARY") + .help("Path to the Firefox binary"), + ) + .arg( + Arg::new("marionette_host") + .long("marionette-host") + .takes_value(true) + .value_name("HOST") + .default_value("127.0.0.1") + .help("Host to use to connect to Gecko"), + ) + .arg( + Arg::new("marionette_port") + .long("marionette-port") + .takes_value(true) + .value_name("PORT") + .help("Port to use to connect to Gecko [default: system-allocated port]"), + ) + .arg( + Arg::new("websocket_port") + .long("websocket-port") + .takes_value(true) + .value_name("PORT") + .conflicts_with("connect_existing") + .help("Port to use to connect to WebDriver BiDi [default: 9222]"), + ) + .arg( + Arg::new("connect_existing") + .long("connect-existing") + .requires("marionette_port") + .help("Connect to an existing Firefox instance"), + ) + .arg( + Arg::new("jsdebugger") + .long("jsdebugger") + .help("Attach browser toolbox debugger for Firefox"), + ) + .arg( + Arg::new("verbosity") + .multiple_occurrences(true) + .conflicts_with("log_level") + .short('v') + .help("Log level verbosity (-v for debug and -vv for trace level)"), + ) + .arg( + Arg::new("log_level") + .long("log") + .takes_value(true) + .value_name("LEVEL") + .possible_values(["fatal", "error", "warn", "info", "config", "debug", "trace"]) + .help("Set Gecko log level"), + ) + .arg( + Arg::new("log_no_truncate") + .long("log-no-truncate") + .help("Disable truncation of long log lines"), + ) + .arg( + Arg::new("help") + .short('h') + .long("help") + .help("Prints this message"), + ) + .arg( + Arg::new("version") + .short('V') + .long("version") + .help("Prints version and copying information"), + ) + .arg( + Arg::new("profile_root") + .long("profile-root") + .takes_value(true) + .value_name("PROFILE_ROOT") + .help("Directory in which to create profiles. Defaults to the system temporary directory."), + ) + .arg( + Arg::new("android_storage") + .long("android-storage") + .possible_values(["auto", "app", "internal", "sdcard"]) + .value_name("ANDROID_STORAGE") + .help("Selects storage location to be used for test data (deprecated)."), + ) + .arg( + Arg::new("allow_hosts") + .long("allow-hosts") + .takes_value(true) + .multiple_values(true) + .value_name("ALLOW_HOSTS") + .help("List of hostnames to allow. By default the value of --host is allowed, and in addition if that's a well known local address, other variations on well known local addresses are allowed. If --allow-hosts is provided only exactly those hosts are allowed."), + ) + .arg( + Arg::new("allow_origins") + .long("allow-origins") + .takes_value(true) + .multiple_values(true) + .value_name("ALLOW_ORIGINS") + .help("List of request origins to allow. These must be formatted as scheme://host:port. By default any request with an origin header is rejected. If --allow-origins is provided then only exactly those origins are allowed."), + ) +} + +fn get_program_name() -> String { + env::args().next().unwrap() +} + +fn print_help(cmd: &mut Command) { + cmd.print_help().ok(); + println!(); +} + +fn print_version() { + println!("geckodriver {}", build::build_info()); + println!(); + println!("The source code of this program is available from"); + println!("testing/geckodriver in https://hg.mozilla.org/mozilla-central."); + println!(); + println!("This program is subject to the terms of the Mozilla Public License 2.0."); + println!("You can obtain a copy of the license at https://mozilla.org/MPL/2.0/."); +} |