diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /dom/media/webrtc/transport/mdns_service/src | |
parent | Initial commit. (diff) | |
download | firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'dom/media/webrtc/transport/mdns_service/src')
-rw-r--r-- | dom/media/webrtc/transport/mdns_service/src/lib.rs | 843 |
1 files changed, 843 insertions, 0 deletions
diff --git a/dom/media/webrtc/transport/mdns_service/src/lib.rs b/dom/media/webrtc/transport/mdns_service/src/lib.rs new file mode 100644 index 0000000000..fbd07b45f8 --- /dev/null +++ b/dom/media/webrtc/transport/mdns_service/src/lib.rs @@ -0,0 +1,843 @@ +/* -*- Mode: rust; rust-indent-offset: 2 -*- */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ +use byteorder::{BigEndian, WriteBytesExt}; +use socket2::{Domain, Socket, Type}; +use std::collections::HashMap; +use std::collections::LinkedList; +use std::ffi::{c_void, CStr, CString}; +use std::io; +use std::net; +use std::os::raw::c_char; +use std::sync::mpsc::channel; +use std::thread; +use std::time; +use uuid::Uuid; + +#[macro_use] +extern crate log; + +struct Callback { + data: *const c_void, + resolved: unsafe extern "C" fn(*const c_void, *const c_char, *const c_char), + timedout: unsafe extern "C" fn(*const c_void, *const c_char), +} + +unsafe impl Send for Callback {} + +fn hostname_resolved(callback: &Callback, hostname: &str, addr: &str) { + if let Ok(hostname) = CString::new(hostname) { + if let Ok(addr) = CString::new(addr) { + unsafe { + (callback.resolved)(callback.data, hostname.as_ptr(), addr.as_ptr()); + } + } + } +} + +fn hostname_timedout(callback: &Callback, hostname: &str) { + if let Ok(hostname) = CString::new(hostname) { + unsafe { + (callback.timedout)(callback.data, hostname.as_ptr()); + } + } +} + +// This code is derived from code for creating questions in the dns-parser +// crate. It would be nice to upstream this, or something similar. +fn create_answer(id: u16, answers: &[(String, &[u8])]) -> Result<Vec<u8>, io::Error> { + let mut buf = Vec::with_capacity(512); + let head = dns_parser::Header { + id, + query: false, + opcode: dns_parser::Opcode::StandardQuery, + authoritative: true, + truncated: false, + recursion_desired: false, + recursion_available: false, + authenticated_data: false, + checking_disabled: false, + response_code: dns_parser::ResponseCode::NoError, + questions: 0, + answers: answers.len() as u16, + nameservers: 0, + additional: 0, + }; + + buf.extend([0u8; 12].iter()); + head.write(&mut buf[..12]); + + for (name, addr) in answers { + for part in name.split('.') { + if part.len() > 62 { + return Err(io::Error::new( + io::ErrorKind::Other, + "Name part length too long", + )); + } + let ln = part.len() as u8; + buf.push(ln); + buf.extend(part.as_bytes()); + } + buf.push(0); + + if addr.len() == 4 { + buf.write_u16::<BigEndian>(dns_parser::Type::A as u16)?; + } else { + buf.write_u16::<BigEndian>(dns_parser::Type::AAAA as u16)?; + } + // set cache flush bit + buf.write_u16::<BigEndian>(dns_parser::Class::IN as u16 | (0x1 << 15))?; + buf.write_u32::<BigEndian>(120)?; + buf.write_u16::<BigEndian>(addr.len() as u16)?; + buf.extend(*addr); + } + + Ok(buf) +} + +fn create_query(id: u16, queries: &[String]) -> Result<Vec<u8>, io::Error> { + let mut buf = Vec::with_capacity(512); + let head = dns_parser::Header { + id, + query: true, + opcode: dns_parser::Opcode::StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: false, + recursion_available: false, + authenticated_data: false, + checking_disabled: false, + response_code: dns_parser::ResponseCode::NoError, + questions: queries.len() as u16, + answers: 0, + nameservers: 0, + additional: 0, + }; + + buf.extend([0u8; 12].iter()); + head.write(&mut buf[..12]); + + for name in queries { + for part in name.split('.') { + assert!(part.len() < 63); + let ln = part.len() as u8; + buf.push(ln); + buf.extend(part.as_bytes()); + } + buf.push(0); + + buf.write_u16::<BigEndian>(dns_parser::QueryType::A as u16)?; + buf.write_u16::<BigEndian>(dns_parser::QueryClass::IN as u16)?; + } + + Ok(buf) +} + +fn handle_queries( + socket: &std::net::UdpSocket, + mdns_addr: &std::net::SocketAddr, + pending_queries: &mut HashMap<String, Query>, + unsent_queries: &mut LinkedList<Query>, +) { + if pending_queries.len() < 50 { + let mut queries: Vec<Query> = Vec::new(); + while queries.len() < 5 && !unsent_queries.is_empty() { + if let Some(query) = unsent_queries.pop_front() { + if !pending_queries.contains_key(&query.hostname) { + queries.push(query); + } + } + } + if !queries.is_empty() { + let query_hostnames: Vec<String> = + queries.iter().map(|q| q.hostname.to_string()).collect(); + + if let Ok(buf) = create_query(0, &query_hostnames) { + match socket.send_to(&buf, &mdns_addr) { + Ok(_) => { + for query in queries { + pending_queries.insert(query.hostname.to_string(), query); + } + } + Err(err) => { + warn!("Sending mDNS query failed: {}", err); + if err.kind() != io::ErrorKind::PermissionDenied { + for query in queries { + unsent_queries.push_back(query); + } + } else { + for query in queries { + hostname_timedout(&query.callback, &query.hostname); + } + } + } + } + } + } + } + + let now = time::Instant::now(); + let expired: Vec<String> = pending_queries + .iter() + .filter(|(_, query)| now.duration_since(query.timestamp).as_secs() >= 3) + .map(|(hostname, _)| hostname.to_string()) + .collect(); + for hostname in expired { + if let Some(mut query) = pending_queries.remove(&hostname) { + query.attempts += 1; + if query.attempts < 3 { + query.timestamp = now; + unsent_queries.push_back(query); + } else { + hostname_timedout(&query.callback, &hostname); + } + } + } +} + +fn handle_mdns_socket( + socket: &std::net::UdpSocket, + mdns_addr: &std::net::SocketAddr, + mut buffer: &mut [u8], + hosts: &mut HashMap<String, Vec<u8>>, + pending_queries: &mut HashMap<String, Query>, +) -> bool { + // Record a simple marker to see how often this is called. + gecko_profiler::add_untyped_marker( + "handle_mdns_socket", + gecko_profiler::gecko_profiler_category!(Network), + Default::default(), + ); + + match socket.recv_from(&mut buffer) { + Ok((amt, _)) => { + if amt > 0 { + let buffer = &buffer[0..amt]; + match dns_parser::Packet::parse(&buffer) { + Ok(parsed) => { + let mut answers: Vec<(String, &[u8])> = Vec::new(); + + // If a packet contains both both questions and + // answers, the questions should be ignored. + if parsed.answers.is_empty() { + parsed + .questions + .iter() + .filter(|question| question.qtype == dns_parser::QueryType::A) + .for_each(|question| { + let qname = question.qname.to_string(); + trace!("mDNS question: {} {:?}", qname, question.qtype); + if let Some(octets) = hosts.get(&qname) { + trace!("Sending mDNS answer for {}: {:?}", qname, octets); + answers.push((qname, &octets)); + } + }); + } + for answer in parsed.answers { + let hostname = answer.name.to_string(); + match pending_queries.get(&hostname) { + Some(query) => { + match answer.data { + dns_parser::RData::A(dns_parser::rdata::a::Record( + addr, + )) => { + let addr = addr.to_string(); + trace!("mDNS response: {} {}", hostname, addr); + hostname_resolved(&query.callback, &hostname, &addr); + } + dns_parser::RData::AAAA( + dns_parser::rdata::aaaa::Record(addr), + ) => { + let addr = addr.to_string(); + trace!("mDNS response: {} {}", hostname, addr); + hostname_resolved(&query.callback, &hostname, &addr); + } + _ => {} + } + pending_queries.remove(&hostname); + } + None => { + continue; + } + } + } + // TODO: If we did not answer every query in this + // question, we should wait for a random amount of time + // so as to not collide with someone else responding to + // this query. + if !answers.is_empty() { + if let Ok(buf) = create_answer(parsed.header.id, &answers) { + if let Err(err) = socket.send_to(&buf, &mdns_addr) { + warn!("Sending mDNS answer failed: {}", err); + } + } + } + } + Err(err) => { + warn!("Could not parse mDNS packet: {}", err); + } + } + } + } + Err(err) => { + if err.kind() != io::ErrorKind::Interrupted + && err.kind() != io::ErrorKind::TimedOut + && err.kind() != io::ErrorKind::WouldBlock + { + error!("Socket error: {}", err); + return false; + } + } + } + + true +} + +fn validate_hostname(hostname: &str) -> bool { + match hostname.find(".local") { + Some(index) => match hostname.get(0..index) { + Some(uuid) => match uuid.get(0..36) { + Some(initial) => match Uuid::parse_str(initial) { + Ok(_) => { + // Oddly enough, Safari does not generate valid UUIDs, + // the last part sometimes contains more than 12 digits. + match uuid.get(36..) { + Some(trailing) => { + for c in trailing.chars() { + if !c.is_ascii_hexdigit() { + return false; + } + } + true + } + None => true, + } + } + Err(_) => false, + }, + None => false, + }, + None => false, + }, + None => false, + } +} + +enum ServiceControl { + Register { + hostname: String, + address: String, + }, + Query { + callback: Callback, + hostname: String, + }, + Unregister { + hostname: String, + }, + Stop, +} + +struct Query { + hostname: String, + callback: Callback, + timestamp: time::Instant, + attempts: i32, +} + +impl Query { + fn new(hostname: &str, callback: Callback) -> Query { + Query { + hostname: hostname.to_string(), + callback, + timestamp: time::Instant::now(), + attempts: 0, + } + } +} + +pub struct MDNSService { + handle: Option<std::thread::JoinHandle<()>>, + sender: Option<std::sync::mpsc::Sender<ServiceControl>>, +} + +impl MDNSService { + fn register_hostname(&mut self, hostname: &str, address: &str) { + if let Some(sender) = &self.sender { + if let Err(err) = sender.send(ServiceControl::Register { + hostname: hostname.to_string(), + address: address.to_string(), + }) { + warn!( + "Could not send register hostname {} message: {}", + hostname, err + ); + } + } + } + + fn query_hostname(&mut self, callback: Callback, hostname: &str) { + if let Some(sender) = &self.sender { + if let Err(err) = sender.send(ServiceControl::Query { + callback, + hostname: hostname.to_string(), + }) { + warn!( + "Could not send query hostname {} message: {}", + hostname, err + ); + } + } + } + + fn unregister_hostname(&mut self, hostname: &str) { + if let Some(sender) = &self.sender { + if let Err(err) = sender.send(ServiceControl::Unregister { + hostname: hostname.to_string(), + }) { + warn!( + "Could not send unregister hostname {} message: {}", + hostname, err + ); + } + } + } + + fn start(&mut self, addrs: Vec<std::net::Ipv4Addr>) -> io::Result<()> { + let (sender, receiver) = channel(); + self.sender = Some(sender); + + let mdns_addr = std::net::Ipv4Addr::new(224, 0, 0, 251); + let port = 5353; + + let socket = Socket::new(Domain::IPV4, Type::DGRAM, None)?; + socket.set_reuse_address(true)?; + + #[cfg(not(target_os = "windows"))] + socket.set_reuse_port(true)?; + socket.bind(&socket2::SockAddr::from(std::net::SocketAddr::from(( + [0, 0, 0, 0], + port, + ))))?; + + let socket = std::net::UdpSocket::from(socket); + socket.set_multicast_loop_v4(true)?; + socket.set_read_timeout(Some(time::Duration::from_millis(1)))?; + socket.set_write_timeout(Some(time::Duration::from_millis(1)))?; + for addr in addrs { + if let Err(err) = socket.join_multicast_v4(&mdns_addr, &addr) { + warn!( + "Could not join multicast group on interface: {:?}: {}", + addr, err + ); + } + } + + let thread_name = "mdns_service"; + let builder = thread::Builder::new().name(thread_name.into()); + self.handle = Some(builder.spawn(move || { + gecko_profiler::register_thread(thread_name); + let mdns_addr = std::net::SocketAddr::from(([224, 0, 0, 251], port)); + let mut buffer: [u8; 9_000] = [0; 9_000]; + let mut hosts = HashMap::new(); + let mut unsent_queries = LinkedList::new(); + let mut pending_queries = HashMap::new(); + loop { + match receiver.try_recv() { + Ok(msg) => match msg { + ServiceControl::Register { hostname, address } => { + if !validate_hostname(&hostname) { + warn!("Not registering invalid hostname: {}", hostname); + continue; + } + trace!("Registering {} for: {}", hostname, address); + match address.parse().and_then(|ip| { + Ok(match ip { + net::IpAddr::V4(ip) => ip.octets().to_vec(), + net::IpAddr::V6(ip) => ip.octets().to_vec(), + }) + }) { + Ok(octets) => { + let mut v = Vec::new(); + v.extend(octets); + hosts.insert(hostname, v); + } + Err(err) => { + warn!( + "Could not parse address for {}: {}: {}", + hostname, address, err + ); + } + } + } + ServiceControl::Query { callback, hostname } => { + trace!("Querying {}", hostname); + if !validate_hostname(&hostname) { + warn!("Not sending mDNS query for invalid hostname: {}", hostname); + continue; + } + unsent_queries.push_back(Query::new(&hostname, callback)); + } + ServiceControl::Unregister { hostname } => { + trace!("Unregistering {}", hostname); + hosts.remove(&hostname); + } + ServiceControl::Stop => { + trace!("Stopping"); + break; + } + }, + Err(std::sync::mpsc::TryRecvError::Disconnected) => { + break; + } + Err(std::sync::mpsc::TryRecvError::Empty) => {} + } + + handle_queries( + &socket, + &mdns_addr, + &mut pending_queries, + &mut unsent_queries, + ); + + if !handle_mdns_socket( + &socket, + &mdns_addr, + &mut buffer, + &mut hosts, + &mut pending_queries, + ) { + break; + } + } + gecko_profiler::unregister_thread(); + })?); + + Ok(()) + } + + fn stop(self) { + if let Some(sender) = self.sender { + if let Err(err) = sender.send(ServiceControl::Stop) { + warn!("Could not stop mDNS Service: {}", err); + } + if let Some(handle) = self.handle { + if handle.join().is_err() { + error!("Error on thread join"); + } + } + } + } + + fn new() -> MDNSService { + MDNSService { + handle: None, + sender: None, + } + } +} + +/// # Safety +/// +/// This function must only be called with a valid MDNSService pointer. +/// This hostname and address arguments must be zero terminated strings. +#[no_mangle] +pub unsafe extern "C" fn mdns_service_register_hostname( + serv: *mut MDNSService, + hostname: *const c_char, + address: *const c_char, +) { + assert!(!serv.is_null()); + assert!(!hostname.is_null()); + assert!(!address.is_null()); + let hostname = CStr::from_ptr(hostname).to_string_lossy(); + let address = CStr::from_ptr(address).to_string_lossy(); + (*serv).register_hostname(&hostname, &address); +} + +/// # Safety +/// +/// This ifaddrs argument must be a zero terminated string. +#[no_mangle] +pub unsafe extern "C" fn mdns_service_start(ifaddrs: *const c_char) -> *mut MDNSService { + assert!(!ifaddrs.is_null()); + let mut r = Box::new(MDNSService::new()); + let ifaddrs = CStr::from_ptr(ifaddrs).to_string_lossy(); + let addrs: Vec<std::net::Ipv4Addr> = + ifaddrs.split(';').filter_map(|x| x.parse().ok()).collect(); + + if addrs.is_empty() { + warn!("Could not parse interface addresses from: {}", ifaddrs); + } else if let Err(err) = r.start(addrs) { + warn!("Could not start mDNS Service: {}", err); + } + + Box::into_raw(r) +} + +/// # Safety +/// +/// This function must only be called with a valid MDNSService pointer. +#[no_mangle] +pub unsafe extern "C" fn mdns_service_stop(serv: *mut MDNSService) { + assert!(!serv.is_null()); + let boxed = Box::from_raw(serv); + boxed.stop(); +} + +/// # Safety +/// +/// This function must only be called with a valid MDNSService pointer. +/// The data argument will be passed back into the resolved and timedout +/// functions. The object it points to must not be freed until the MDNSService +/// has stopped. +#[no_mangle] +pub unsafe extern "C" fn mdns_service_query_hostname( + serv: *mut MDNSService, + data: *const c_void, + resolved: unsafe extern "C" fn(*const c_void, *const c_char, *const c_char), + timedout: unsafe extern "C" fn(*const c_void, *const c_char), + hostname: *const c_char, +) { + assert!(!serv.is_null()); + assert!(!data.is_null()); + assert!(!hostname.is_null()); + let hostname = CStr::from_ptr(hostname).to_string_lossy(); + let callback = Callback { + data, + resolved, + timedout, + }; + (*serv).query_hostname(callback, &hostname); +} + +/// # Safety +/// +/// This function must only be called with a valid MDNSService pointer. +/// This function should only be called once per hostname. +#[no_mangle] +pub unsafe extern "C" fn mdns_service_unregister_hostname( + serv: *mut MDNSService, + hostname: *const c_char, +) { + assert!(!serv.is_null()); + assert!(!hostname.is_null()); + let hostname = CStr::from_ptr(hostname).to_string_lossy(); + (*serv).unregister_hostname(&hostname); +} + +#[cfg(test)] +mod tests { + use crate::create_query; + use crate::validate_hostname; + use crate::Callback; + use crate::MDNSService; + use socket2::{Domain, Socket, Type}; + use std::collections::HashSet; + use std::ffi::c_void; + use std::io; + use std::iter::FromIterator; + use std::os::raw::c_char; + use std::thread; + use std::time; + use uuid::Uuid; + + #[no_mangle] + pub unsafe extern "C" fn mdns_service_resolved( + _: *const c_void, + _: *const c_char, + _: *const c_char, + ) -> () { + } + + #[no_mangle] + pub unsafe extern "C" fn mdns_service_timedout(_: *const c_void, _: *const c_char) -> () {} + + fn listen_until(addr: &std::net::Ipv4Addr, stop: u64) -> thread::JoinHandle<Vec<String>> { + let port = 5353; + + let socket = Socket::new(Domain::IPV4, Type::DGRAM, None).unwrap(); + socket.set_reuse_address(true).unwrap(); + + #[cfg(not(target_os = "windows"))] + socket.set_reuse_port(true).unwrap(); + socket + .bind(&socket2::SockAddr::from(std::net::SocketAddr::from(( + [0, 0, 0, 0], + port, + )))) + .unwrap(); + + let socket = std::net::UdpSocket::from(socket); + socket.set_multicast_loop_v4(true).unwrap(); + socket + .set_read_timeout(Some(time::Duration::from_millis(10))) + .unwrap(); + socket + .set_write_timeout(Some(time::Duration::from_millis(10))) + .unwrap(); + socket + .join_multicast_v4(&std::net::Ipv4Addr::new(224, 0, 0, 251), &addr) + .unwrap(); + + let mut buffer: [u8; 9_000] = [0; 9_000]; + thread::spawn(move || { + let start = time::Instant::now(); + let mut questions = Vec::new(); + while time::Instant::now().duration_since(start).as_secs() < stop { + match socket.recv_from(&mut buffer) { + Ok((amt, _)) => { + if amt > 0 { + let buffer = &buffer[0..amt]; + match dns_parser::Packet::parse(&buffer) { + Ok(parsed) => { + parsed + .questions + .iter() + .filter(|question| { + question.qtype == dns_parser::QueryType::A + }) + .for_each(|question| { + let qname = question.qname.to_string(); + questions.push(qname); + }); + } + Err(err) => { + warn!("Could not parse mDNS packet: {}", err); + } + } + } + } + Err(err) => { + if err.kind() != io::ErrorKind::WouldBlock + && err.kind() != io::ErrorKind::TimedOut + { + error!("Socket error: {}", err); + break; + } + } + } + } + questions + }) + } + + #[test] + fn test_validate_hostname() { + assert_eq!( + validate_hostname("e17f08d4-689a-4df6-ba31-35bb9f041100.local"), + true + ); + assert_eq!( + validate_hostname("62240723-ae6d-4f6a-99b8-94a233e3f84a2.local"), + true + ); + assert_eq!( + validate_hostname("62240723-ae6d-4f6a-99b8.94e3f84a2.local"), + false + ); + assert_eq!(validate_hostname("hi there"), false); + } + + #[test] + fn start_stop() { + let mut service = MDNSService::new(); + let addr = "127.0.0.1".parse().unwrap(); + service.start(vec![addr]).unwrap(); + service.stop(); + } + + #[test] + fn simple_query() { + let mut service = MDNSService::new(); + let addr = "127.0.0.1".parse().unwrap(); + let handle = listen_until(&addr, 1); + + service.start(vec![addr]).unwrap(); + + let callback = Callback { + data: 0 as *const c_void, + resolved: mdns_service_resolved, + timedout: mdns_service_timedout, + }; + let hostname = Uuid::new_v4().as_hyphenated().to_string() + ".local"; + service.query_hostname(callback, &hostname); + service.stop(); + let questions = handle.join().unwrap(); + assert!(questions.contains(&hostname)); + } + + #[test] + fn rate_limited_query() { + let mut service = MDNSService::new(); + let addr = "127.0.0.1".parse().unwrap(); + let handle = listen_until(&addr, 1); + + service.start(vec![addr]).unwrap(); + + let mut hostnames = HashSet::new(); + for _ in 0..100 { + let callback = Callback { + data: 0 as *const c_void, + resolved: mdns_service_resolved, + timedout: mdns_service_timedout, + }; + let hostname = Uuid::new_v4().as_hyphenated().to_string() + ".local"; + service.query_hostname(callback, &hostname); + hostnames.insert(hostname); + } + service.stop(); + let questions = HashSet::from_iter(handle.join().unwrap().iter().map(|x| x.to_string())); + let intersection: HashSet<&String> = questions.intersection(&hostnames).collect(); + assert_eq!(intersection.len(), 50); + } + + #[test] + fn repeat_failed_query() { + let mut service = MDNSService::new(); + let addr = "127.0.0.1".parse().unwrap(); + let handle = listen_until(&addr, 4); + + service.start(vec![addr]).unwrap(); + + let hostname = Uuid::new_v4().as_hyphenated().to_string() + ".local"; + let callback = Callback { + data: 0 as *const c_void, + resolved: mdns_service_resolved, + timedout: mdns_service_timedout, + }; + service.query_hostname(callback, &hostname); + thread::sleep(time::Duration::from_secs(4)); + service.stop(); + + let questions: Vec<String> = handle + .join() + .unwrap() + .iter() + .filter(|x| *x == &hostname) + .map(|x| x.to_string()) + .collect(); + assert_eq!(questions.len(), 2); + } + + #[test] + fn multiple_queries_in_a_single_packet() { + let mut hostnames: Vec<String> = Vec::new(); + for _ in 0..100 { + let hostname = Uuid::new_v4().as_hyphenated().to_string() + ".local"; + hostnames.push(hostname); + } + + match create_query(42, &hostnames) { + Ok(q) => match dns_parser::Packet::parse(&q) { + Ok(parsed) => { + assert_eq!(parsed.questions.len(), 100); + } + Err(_) => assert!(false), + }, + Err(_) => assert!(false), + } + } +} |