diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
commit | 26a029d407be480d791972afb5975cf62c9360a6 (patch) | |
tree | f435a8308119effd964b339f76abb83a57c29483 /third_party/rust/sync15/src/client | |
parent | Initial commit. (diff) | |
download | firefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz firefox-26a029d407be480d791972afb5975cf62c9360a6.zip |
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/sync15/src/client')
-rw-r--r-- | third_party/rust/sync15/src/client/coll_state.rs | 361 | ||||
-rw-r--r-- | third_party/rust/sync15/src/client/coll_update.rs | 121 | ||||
-rw-r--r-- | third_party/rust/sync15/src/client/collection_keys.rs | 61 | ||||
-rw-r--r-- | third_party/rust/sync15/src/client/mod.rs | 39 | ||||
-rw-r--r-- | third_party/rust/sync15/src/client/request.rs | 1199 | ||||
-rw-r--r-- | third_party/rust/sync15/src/client/state.rs | 1089 | ||||
-rw-r--r-- | third_party/rust/sync15/src/client/status.rs | 106 | ||||
-rw-r--r-- | third_party/rust/sync15/src/client/storage_client.rs | 605 | ||||
-rw-r--r-- | third_party/rust/sync15/src/client/sync.rs | 114 | ||||
-rw-r--r-- | third_party/rust/sync15/src/client/sync_multiple.rs | 491 | ||||
-rw-r--r-- | third_party/rust/sync15/src/client/token.rs | 602 | ||||
-rw-r--r-- | third_party/rust/sync15/src/client/util.rs | 102 |
12 files changed, 4890 insertions, 0 deletions
diff --git a/third_party/rust/sync15/src/client/coll_state.rs b/third_party/rust/sync15/src/client/coll_state.rs new file mode 100644 index 0000000000..da50bbee6c --- /dev/null +++ b/third_party/rust/sync15/src/client/coll_state.rs @@ -0,0 +1,361 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::request::InfoConfiguration; +use super::{CollectionKeys, GlobalState}; +use crate::engine::{CollSyncIds, EngineSyncAssociation, SyncEngine}; +use crate::error; +use crate::KeyBundle; +use crate::ServerTimestamp; + +/// Holds state for a collection necessary to perform a sync of it. Lives for the lifetime +/// of a single sync. +#[derive(Debug, Clone)] +pub struct CollState { + // Info about the server configuration/capabilities + pub config: InfoConfiguration, + // from meta/global, used for XIUS when we POST outgoing record based on this state. + pub last_modified: ServerTimestamp, + pub key: KeyBundle, +} + +/// This mini state-machine helps build a CollState +#[derive(Debug)] +pub enum LocalCollState { + /// The state is unknown, with the EngineSyncAssociation the collection + /// reports. + Unknown { assoc: EngineSyncAssociation }, + + /// The engine has been declined. This is a "terminal" state. + Declined, + + /// There's no such collection in meta/global. We could possibly update + /// meta/global, but currently all known collections are there by default, + /// so this is, basically, an error condition. + NoSuchCollection, + + /// Either the global or collection sync ID has changed - we will reset the engine. + SyncIdChanged { ids: CollSyncIds }, + + /// The collection is ready to sync. + Ready { coll_state: CollState }, +} + +pub struct LocalCollStateMachine<'state> { + global_state: &'state GlobalState, + root_key: &'state KeyBundle, +} + +impl<'state> LocalCollStateMachine<'state> { + fn advance( + &self, + from: LocalCollState, + engine: &dyn SyncEngine, + ) -> error::Result<LocalCollState> { + let name = &engine.collection_name().to_string(); + let meta_global = &self.global_state.global; + match from { + LocalCollState::Unknown { assoc } => { + if meta_global.declined.contains(name) { + return Ok(LocalCollState::Declined); + } + match meta_global.engines.get(name) { + Some(engine_meta) => match assoc { + EngineSyncAssociation::Disconnected => Ok(LocalCollState::SyncIdChanged { + ids: CollSyncIds { + global: meta_global.sync_id.clone(), + coll: engine_meta.sync_id.clone(), + }, + }), + EngineSyncAssociation::Connected(ref ids) + if ids.global == meta_global.sync_id + && ids.coll == engine_meta.sync_id => + { + // We are done - build the CollState + let coll_keys = CollectionKeys::from_encrypted_payload( + self.global_state.keys.clone(), + self.global_state.keys_timestamp, + self.root_key, + )?; + let key = coll_keys.key_for_collection(name).clone(); + let name = engine.collection_name(); + let config = self.global_state.config.clone(); + let last_modified = self + .global_state + .collections + .get(name.as_ref()) + .cloned() + .unwrap_or_default(); + let coll_state = CollState { + config, + last_modified, + key, + }; + Ok(LocalCollState::Ready { coll_state }) + } + _ => Ok(LocalCollState::SyncIdChanged { + ids: CollSyncIds { + global: meta_global.sync_id.clone(), + coll: engine_meta.sync_id.clone(), + }, + }), + }, + None => Ok(LocalCollState::NoSuchCollection), + } + } + + LocalCollState::Declined => unreachable!("can't advance from declined"), + + LocalCollState::NoSuchCollection => unreachable!("the collection is unknown"), + + LocalCollState::SyncIdChanged { ids } => { + let assoc = EngineSyncAssociation::Connected(ids); + log::info!("Resetting {} engine", engine.collection_name()); + engine.reset(&assoc)?; + Ok(LocalCollState::Unknown { assoc }) + } + + LocalCollState::Ready { .. } => unreachable!("can't advance from ready"), + } + } + + // A little whimsy - a portmanteau of far and fast + fn run_and_run_as_farst_as_you_can( + &mut self, + engine: &dyn SyncEngine, + ) -> error::Result<Option<CollState>> { + let mut s = LocalCollState::Unknown { + assoc: engine.get_sync_assoc()?, + }; + // This is a simple state machine and should never take more than + // 10 goes around. + let mut count = 0; + loop { + log::trace!("LocalCollState in {:?}", s); + match s { + LocalCollState::Ready { coll_state } => return Ok(Some(coll_state)), + LocalCollState::Declined | LocalCollState::NoSuchCollection => return Ok(None), + _ => { + count += 1; + if count > 10 { + log::warn!("LocalCollStateMachine appears to be looping"); + return Ok(None); + } + // should we have better loop detection? Our limit of 10 + // goes is probably OK for now, but not really ideal. + s = self.advance(s, engine)?; + } + }; + } + } + + pub fn get_state( + engine: &dyn SyncEngine, + global_state: &'state GlobalState, + root_key: &'state KeyBundle, + ) -> error::Result<Option<CollState>> { + let mut gingerbread_man = Self { + global_state, + root_key, + }; + gingerbread_man.run_and_run_as_farst_as_you_can(engine) + } +} + +#[cfg(test)] +mod tests { + use super::super::request::{InfoCollections, InfoConfiguration}; + use super::super::CollectionKeys; + use super::*; + use crate::bso::{IncomingBso, OutgoingBso}; + use crate::engine::CollectionRequest; + use crate::record_types::{MetaGlobalEngine, MetaGlobalRecord}; + use crate::{telemetry, CollectionName}; + use anyhow::Result; + use std::cell::{Cell, RefCell}; + use std::collections::HashMap; + use sync_guid::Guid; + + fn get_global_state(root_key: &KeyBundle) -> GlobalState { + let keys = CollectionKeys::new_random() + .unwrap() + .to_encrypted_payload(root_key) + .unwrap(); + GlobalState { + config: InfoConfiguration::default(), + collections: InfoCollections::new(HashMap::new()), + global: MetaGlobalRecord { + sync_id: "syncIDAAAAAA".into(), + storage_version: 5usize, + engines: vec![( + "bookmarks", + MetaGlobalEngine { + version: 1usize, + sync_id: "syncIDBBBBBB".into(), + }, + )] + .into_iter() + .map(|(key, value)| (key.to_owned(), value)) + .collect(), + declined: vec![], + }, + global_timestamp: ServerTimestamp::default(), + keys, + keys_timestamp: ServerTimestamp::default(), + } + } + + struct TestSyncEngine { + collection_name: &'static str, + assoc: Cell<EngineSyncAssociation>, + num_resets: RefCell<usize>, + } + + impl TestSyncEngine { + fn new(collection_name: &'static str, assoc: EngineSyncAssociation) -> Self { + Self { + collection_name, + assoc: Cell::new(assoc), + num_resets: RefCell::new(0), + } + } + fn get_num_resets(&self) -> usize { + *self.num_resets.borrow() + } + } + + impl SyncEngine for TestSyncEngine { + fn collection_name(&self) -> CollectionName { + self.collection_name.into() + } + + fn stage_incoming( + &self, + _inbound: Vec<IncomingBso>, + _telem: &mut telemetry::Engine, + ) -> Result<()> { + unreachable!("these tests shouldn't call these"); + } + + fn apply( + &self, + _timestamp: ServerTimestamp, + _telem: &mut telemetry::Engine, + ) -> Result<Vec<OutgoingBso>> { + unreachable!("these tests shouldn't call these"); + } + + fn set_uploaded(&self, _new_timestamp: ServerTimestamp, _ids: Vec<Guid>) -> Result<()> { + unreachable!("these tests shouldn't call these"); + } + + fn sync_finished(&self) -> Result<()> { + unreachable!("these tests shouldn't call these"); + } + + fn get_collection_request( + &self, + _server_timestamp: ServerTimestamp, + ) -> Result<Option<CollectionRequest>> { + unreachable!("these tests shouldn't call these"); + } + + fn get_sync_assoc(&self) -> Result<EngineSyncAssociation> { + Ok(self.assoc.replace(EngineSyncAssociation::Disconnected)) + } + + fn reset(&self, new_assoc: &EngineSyncAssociation) -> Result<()> { + self.assoc.replace(new_assoc.clone()); + *self.num_resets.borrow_mut() += 1; + Ok(()) + } + + fn wipe(&self) -> Result<()> { + unreachable!("these tests shouldn't call these"); + } + } + + #[test] + fn test_unknown() { + let root_key = KeyBundle::new_random().expect("should work"); + let gs = get_global_state(&root_key); + let engine = TestSyncEngine::new("unknown", EngineSyncAssociation::Disconnected); + let cs = LocalCollStateMachine::get_state(&engine, &gs, &root_key).expect("should work"); + assert!(cs.is_none(), "unknown collection name can't sync"); + assert_eq!(engine.get_num_resets(), 0); + } + + #[test] + fn test_known_no_state() { + let root_key = KeyBundle::new_random().expect("should work"); + let gs = get_global_state(&root_key); + let engine = TestSyncEngine::new("bookmarks", EngineSyncAssociation::Disconnected); + let cs = LocalCollStateMachine::get_state(&engine, &gs, &root_key).expect("should work"); + assert!(cs.is_some(), "collection can sync"); + assert_eq!( + engine.assoc.replace(EngineSyncAssociation::Disconnected), + EngineSyncAssociation::Connected(CollSyncIds { + global: "syncIDAAAAAA".into(), + coll: "syncIDBBBBBB".into(), + }) + ); + assert_eq!(engine.get_num_resets(), 1); + } + + #[test] + fn test_known_wrong_state() { + let root_key = KeyBundle::new_random().expect("should work"); + let gs = get_global_state(&root_key); + let engine = TestSyncEngine::new( + "bookmarks", + EngineSyncAssociation::Connected(CollSyncIds { + global: "syncIDXXXXXX".into(), + coll: "syncIDYYYYYY".into(), + }), + ); + let cs = LocalCollStateMachine::get_state(&engine, &gs, &root_key).expect("should work"); + assert!(cs.is_some(), "collection can sync"); + assert_eq!( + engine.assoc.replace(EngineSyncAssociation::Disconnected), + EngineSyncAssociation::Connected(CollSyncIds { + global: "syncIDAAAAAA".into(), + coll: "syncIDBBBBBB".into(), + }) + ); + assert_eq!(engine.get_num_resets(), 1); + } + + #[test] + fn test_known_good_state() { + let root_key = KeyBundle::new_random().expect("should work"); + let gs = get_global_state(&root_key); + let engine = TestSyncEngine::new( + "bookmarks", + EngineSyncAssociation::Connected(CollSyncIds { + global: "syncIDAAAAAA".into(), + coll: "syncIDBBBBBB".into(), + }), + ); + let cs = LocalCollStateMachine::get_state(&engine, &gs, &root_key).expect("should work"); + assert!(cs.is_some(), "collection can sync"); + assert_eq!(engine.get_num_resets(), 0); + } + + #[test] + fn test_declined() { + let root_key = KeyBundle::new_random().expect("should work"); + let mut gs = get_global_state(&root_key); + gs.global.declined.push("bookmarks".to_string()); + let engine = TestSyncEngine::new( + "bookmarks", + EngineSyncAssociation::Connected(CollSyncIds { + global: "syncIDAAAAAA".into(), + coll: "syncIDBBBBBB".into(), + }), + ); + let cs = LocalCollStateMachine::get_state(&engine, &gs, &root_key).expect("should work"); + assert!(cs.is_none(), "declined collection can sync"); + assert_eq!(engine.get_num_resets(), 0); + } +} diff --git a/third_party/rust/sync15/src/client/coll_update.rs b/third_party/rust/sync15/src/client/coll_update.rs new file mode 100644 index 0000000000..87430b7727 --- /dev/null +++ b/third_party/rust/sync15/src/client/coll_update.rs @@ -0,0 +1,121 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::{ + request::{NormalResponseHandler, UploadInfo}, + CollState, Sync15ClientResponse, Sync15StorageClient, +}; +use crate::bso::{IncomingBso, OutgoingBso, OutgoingEncryptedBso}; +use crate::engine::CollectionRequest; +use crate::error::{self, Error, Result}; +use crate::{CollectionName, KeyBundle, ServerTimestamp}; + +fn encrypt_outgoing(o: Vec<OutgoingBso>, key: &KeyBundle) -> Result<Vec<OutgoingEncryptedBso>> { + o.into_iter() + .map(|change| change.into_encrypted(key)) + .collect() +} + +pub fn fetch_incoming( + client: &Sync15StorageClient, + state: &CollState, + collection_request: CollectionRequest, +) -> Result<Vec<IncomingBso>> { + let (records, _timestamp) = match client.get_encrypted_records(collection_request)? { + Sync15ClientResponse::Success { + record, + last_modified, + .. + } => (record, last_modified), + other => return Err(other.create_storage_error()), + }; + let mut result = Vec::with_capacity(records.len()); + for record in records { + // if we see a HMAC error, we've made an explicit decision to + // NOT handle it here, but restart the global state machine. + // That should cause us to re-read crypto/keys and things should + // work (although if for some reason crypto/keys was updated but + // not all storage was wiped we are probably screwed.) + result.push(record.into_decrypted(&state.key)?); + } + Ok(result) +} + +pub struct CollectionUpdate<'a> { + client: &'a Sync15StorageClient, + state: &'a CollState, + collection: CollectionName, + xius: ServerTimestamp, + to_update: Vec<OutgoingEncryptedBso>, + fully_atomic: bool, +} + +impl<'a> CollectionUpdate<'a> { + pub fn new( + client: &'a Sync15StorageClient, + state: &'a CollState, + collection: CollectionName, + xius: ServerTimestamp, + records: Vec<OutgoingEncryptedBso>, + fully_atomic: bool, + ) -> CollectionUpdate<'a> { + CollectionUpdate { + client, + state, + collection, + xius, + to_update: records, + fully_atomic, + } + } + + pub fn new_from_changeset( + client: &'a Sync15StorageClient, + state: &'a CollState, + collection: CollectionName, + changeset: Vec<OutgoingBso>, + fully_atomic: bool, + ) -> Result<CollectionUpdate<'a>> { + let to_update = encrypt_outgoing(changeset, &state.key)?; + Ok(CollectionUpdate::new( + client, + state, + collection, + state.last_modified, + to_update, + fully_atomic, + )) + } + + /// Returns a list of the IDs that failed if allowed_dropped_records is true, otherwise + /// returns an empty vec. + pub fn upload(self) -> error::Result<UploadInfo> { + let mut failed = vec![]; + let mut q = self.client.new_post_queue( + &self.collection, + &self.state.config, + self.xius, + NormalResponseHandler::new(!self.fully_atomic), + )?; + + for record in self.to_update.into_iter() { + let enqueued = q.enqueue(&record)?; + if !enqueued && self.fully_atomic { + return Err(Error::RecordTooLargeError); + } + } + + q.flush(true)?; + let mut info = q.completed_upload_info(); + info.failed_ids.append(&mut failed); + if self.fully_atomic { + assert_eq!( + info.failed_ids.len(), + 0, + "Bug: Should have failed by now if we aren't allowing dropped records" + ); + } + Ok(info) + } +} diff --git a/third_party/rust/sync15/src/client/collection_keys.rs b/third_party/rust/sync15/src/client/collection_keys.rs new file mode 100644 index 0000000000..f51894f756 --- /dev/null +++ b/third_party/rust/sync15/src/client/collection_keys.rs @@ -0,0 +1,61 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use crate::error::Result; +use crate::record_types::CryptoKeysRecord; +use crate::{EncryptedPayload, KeyBundle, ServerTimestamp}; +use std::collections::HashMap; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct CollectionKeys { + pub timestamp: ServerTimestamp, + pub default: KeyBundle, + pub collections: HashMap<String, KeyBundle>, +} + +impl CollectionKeys { + pub fn new_random() -> Result<CollectionKeys> { + let default = KeyBundle::new_random()?; + Ok(CollectionKeys { + timestamp: ServerTimestamp(0), + default, + collections: HashMap::new(), + }) + } + + pub fn from_encrypted_payload( + record: EncryptedPayload, + timestamp: ServerTimestamp, + root_key: &KeyBundle, + ) -> Result<CollectionKeys> { + let keys: CryptoKeysRecord = record.decrypt_into(root_key)?; + Ok(CollectionKeys { + timestamp, + default: KeyBundle::from_base64(&keys.default[0], &keys.default[1])?, + collections: keys + .collections + .into_iter() + .map(|kv| Ok((kv.0, KeyBundle::from_base64(&kv.1[0], &kv.1[1])?))) + .collect::<Result<HashMap<String, KeyBundle>>>()?, + }) + } + + pub fn to_encrypted_payload(&self, root_key: &KeyBundle) -> Result<EncryptedPayload> { + let record = CryptoKeysRecord { + id: "keys".into(), + collection: "crypto".into(), + default: self.default.to_b64_array(), + collections: self + .collections + .iter() + .map(|kv| (kv.0.clone(), kv.1.to_b64_array())) + .collect(), + }; + EncryptedPayload::from_cleartext_payload(root_key, &record) + } + + pub fn key_for_collection<'a>(&'a self, collection: &str) -> &'a KeyBundle { + self.collections.get(collection).unwrap_or(&self.default) + } +} diff --git a/third_party/rust/sync15/src/client/mod.rs b/third_party/rust/sync15/src/client/mod.rs new file mode 100644 index 0000000000..c49486bfe6 --- /dev/null +++ b/third_party/rust/sync15/src/client/mod.rs @@ -0,0 +1,39 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +//! A module for everything needed to be a "sync client" - ie, a device which +//! can perform a full sync of any number of collections, including managing +//! the server state. +//! +//! In general, the client is responsible for all communication with the sync server, +//! including ensuring the state is correct, and encrypting/decrypting all records +//! to and from the server. However, the actual syncing of the collections is +//! delegated to an external [crate::engine](Sync Engine). +//! +//! One exception is that the "sync client" owns one sync engine - the +//! [crate::clients_engine], which is managed internally. +mod coll_state; +mod coll_update; +mod collection_keys; +mod request; +mod state; +mod status; +mod storage_client; +mod sync; +mod sync_multiple; +mod token; +mod util; + +pub(crate) use coll_state::{CollState, LocalCollStateMachine}; +pub(crate) use coll_update::{fetch_incoming, CollectionUpdate}; +pub(crate) use collection_keys::CollectionKeys; +pub(crate) use request::InfoConfiguration; +pub(crate) use state::GlobalState; +pub use status::{ServiceStatus, SyncResult}; +pub use storage_client::{ + SetupStorageClient, Sync15ClientResponse, Sync15StorageClient, Sync15StorageClientInit, +}; +pub use sync_multiple::{ + sync_multiple, sync_multiple_with_command_processor, MemoryCachedState, SyncRequestInfo, +}; diff --git a/third_party/rust/sync15/src/client/request.rs b/third_party/rust/sync15/src/client/request.rs new file mode 100644 index 0000000000..c69b630c8d --- /dev/null +++ b/third_party/rust/sync15/src/client/request.rs @@ -0,0 +1,1199 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::storage_client::Sync15ClientResponse; +use crate::bso::OutgoingEncryptedBso; +use crate::error::{self, Error as ErrorKind, Result}; +use crate::ServerTimestamp; +use serde_derive::*; +use std::collections::HashMap; +use std::default::Default; +use std::ops::Deref; +use sync_guid::Guid; +use viaduct::status_codes; + +/// Manages a pair of (byte, count) limits for a PostQueue, such as +/// (max_post_bytes, max_post_records) or (max_total_bytes, max_total_records). +#[derive(Debug, Clone)] +struct LimitTracker { + max_bytes: usize, + max_records: usize, + cur_bytes: usize, + cur_records: usize, +} + +impl LimitTracker { + pub fn new(max_bytes: usize, max_records: usize) -> LimitTracker { + LimitTracker { + max_bytes, + max_records, + cur_bytes: 0, + cur_records: 0, + } + } + + pub fn clear(&mut self) { + self.cur_records = 0; + self.cur_bytes = 0; + } + + pub fn can_add_record(&self, payload_size: usize) -> bool { + // Desktop does the cur_bytes check as exclusive, but we shouldn't see any servers that + // don't have https://github.com/mozilla-services/server-syncstorage/issues/73 + self.cur_records < self.max_records && self.cur_bytes + payload_size <= self.max_bytes + } + + pub fn can_never_add(&self, record_size: usize) -> bool { + record_size >= self.max_bytes + } + + pub fn record_added(&mut self, record_size: usize) { + assert!( + self.can_add_record(record_size), + "LimitTracker::record_added caller must check can_add_record" + ); + self.cur_records += 1; + self.cur_bytes += record_size; + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct InfoConfiguration { + /// The maximum size in bytes of the overall HTTP request body that will be accepted by the + /// server. + #[serde(default = "default_max_request_bytes")] + pub max_request_bytes: usize, + + /// The maximum number of records that can be uploaded to a collection in a single POST request. + #[serde(default = "usize::max_value")] + pub max_post_records: usize, + + /// The maximum combined size in bytes of the record payloads that can be uploaded to a + /// collection in a single POST request. + #[serde(default = "usize::max_value")] + pub max_post_bytes: usize, + + /// The maximum total number of records that can be uploaded to a collection as part of a + /// batched upload. + #[serde(default = "usize::max_value")] + pub max_total_records: usize, + + /// The maximum total combined size in bytes of the record payloads that can be uploaded to a + /// collection as part of a batched upload. + #[serde(default = "usize::max_value")] + pub max_total_bytes: usize, + + /// The maximum size of an individual BSO payload, in bytes. + #[serde(default = "default_max_record_payload_bytes")] + pub max_record_payload_bytes: usize, +} + +// This is annoying but seems to be the only way to do it... +fn default_max_request_bytes() -> usize { + 260 * 1024 +} +fn default_max_record_payload_bytes() -> usize { + 256 * 1024 +} + +impl Default for InfoConfiguration { + #[inline] + fn default() -> InfoConfiguration { + InfoConfiguration { + max_request_bytes: default_max_request_bytes(), + max_record_payload_bytes: default_max_record_payload_bytes(), + max_post_records: usize::max_value(), + max_post_bytes: usize::max_value(), + max_total_records: usize::max_value(), + max_total_bytes: usize::max_value(), + } + } +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct InfoCollections(pub(crate) HashMap<String, ServerTimestamp>); + +impl InfoCollections { + pub fn new(collections: HashMap<String, ServerTimestamp>) -> InfoCollections { + InfoCollections(collections) + } +} + +impl Deref for InfoCollections { + type Target = HashMap<String, ServerTimestamp>; + + fn deref(&self) -> &HashMap<String, ServerTimestamp> { + &self.0 + } +} + +#[derive(Debug, Clone, Deserialize)] +pub struct UploadResult { + batch: Option<String>, + /// Maps record id => why failed + #[serde(default = "HashMap::new")] + pub failed: HashMap<Guid, String>, + /// Vec of ids + #[serde(default = "Vec::new")] + pub success: Vec<Guid>, +} + +pub type PostResponse = Sync15ClientResponse<UploadResult>; + +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub enum BatchState { + Unsupported, + NoBatch, + InBatch(String), +} + +#[derive(Debug)] +pub struct PostQueue<Post, OnResponse> { + poster: Post, + on_response: OnResponse, + post_limits: LimitTracker, + batch_limits: LimitTracker, + max_payload_bytes: usize, + max_request_bytes: usize, + queued: Vec<u8>, + batch: BatchState, + last_modified: ServerTimestamp, +} + +pub trait BatchPoster { + /// Note: Last argument (reference to the batch poster) is provided for the purposes of testing + /// Important: Poster should not report non-success HTTP statuses as errors!! + fn post<P, O>( + &self, + body: Vec<u8>, + xius: ServerTimestamp, + batch: Option<String>, + commit: bool, + queue: &PostQueue<P, O>, + ) -> Result<PostResponse>; +} + +// We don't just use a FnMut here since we want to override it in mocking for RefCell<TestType>, +// which we can't do for FnMut since neither FnMut nor RefCell are defined here. Also, this +// is somewhat better for documentation. +pub trait PostResponseHandler { + fn handle_response(&mut self, r: PostResponse, mid_batch: bool) -> Result<()>; +} + +#[derive(Debug, Clone)] +pub(crate) struct NormalResponseHandler { + pub failed_ids: Vec<Guid>, + pub successful_ids: Vec<Guid>, + pub allow_failed: bool, + pub pending_failed: Vec<Guid>, + pub pending_success: Vec<Guid>, +} + +impl NormalResponseHandler { + pub fn new(allow_failed: bool) -> NormalResponseHandler { + NormalResponseHandler { + failed_ids: vec![], + successful_ids: vec![], + pending_failed: vec![], + pending_success: vec![], + allow_failed, + } + } +} + +impl PostResponseHandler for NormalResponseHandler { + fn handle_response(&mut self, r: PostResponse, mid_batch: bool) -> error::Result<()> { + match r { + Sync15ClientResponse::Success { record, .. } => { + if !record.failed.is_empty() && !self.allow_failed { + return Err(ErrorKind::RecordUploadFailed); + } + for id in record.success.iter() { + self.pending_success.push(id.clone()); + } + for kv in record.failed.iter() { + self.pending_failed.push(kv.0.clone()); + } + if !mid_batch { + self.successful_ids.append(&mut self.pending_success); + self.failed_ids.append(&mut self.pending_failed); + } + Ok(()) + } + _ => Err(r.create_storage_error()), + } + } +} + +impl<Poster, OnResponse> PostQueue<Poster, OnResponse> +where + Poster: BatchPoster, + OnResponse: PostResponseHandler, +{ + pub fn new( + config: &InfoConfiguration, + ts: ServerTimestamp, + poster: Poster, + on_response: OnResponse, + ) -> PostQueue<Poster, OnResponse> { + PostQueue { + poster, + on_response, + last_modified: ts, + post_limits: LimitTracker::new(config.max_post_bytes, config.max_post_records), + batch_limits: LimitTracker::new(config.max_total_bytes, config.max_total_records), + batch: BatchState::NoBatch, + max_payload_bytes: config.max_record_payload_bytes, + max_request_bytes: config.max_request_bytes, + queued: Vec::new(), + } + } + + #[inline] + fn in_batch(&self) -> bool { + !matches!(&self.batch, BatchState::Unsupported | BatchState::NoBatch) + } + + pub fn enqueue(&mut self, record: &OutgoingEncryptedBso) -> Result<bool> { + let payload_length = record.serialized_payload_len(); + + if self.post_limits.can_never_add(payload_length) + || self.batch_limits.can_never_add(payload_length) + || payload_length >= self.max_payload_bytes + { + log::warn!( + "Single record too large to submit to server ({} b)", + payload_length + ); + return Ok(false); + } + + // Write directly into `queued` but undo if necessary (the vast majority of the time + // it won't be necessary). If we hit a problem we need to undo that, but the only error + // case we have to worry about right now is in flush() + let item_start = self.queued.len(); + + // This is conservative but can't hurt. + self.queued.reserve(payload_length + 2); + + // Either the first character in an array, or a comma separating + // it from the previous item. + let c = if self.queued.is_empty() { b'[' } else { b',' }; + self.queued.push(c); + + // This unwrap is fine, since serde_json's failure case is HashMaps that have non-object + // keys, which is impossible. If you decide to change this part, you *need* to call + // `self.queued.truncate(item_start)` here in the failure case! + serde_json::to_writer(&mut self.queued, &record).unwrap(); + + let item_end = self.queued.len(); + + debug_assert!( + item_end >= payload_length, + "EncryptedPayload::serialized_len is bugged" + ); + + // The + 1 is only relevant for the final record, which will have a trailing ']'. + let item_len = item_end - item_start + 1; + + if item_len >= self.max_request_bytes { + self.queued.truncate(item_start); + log::warn!( + "Single record too large to submit to server ({} b)", + item_len + ); + return Ok(false); + } + + let can_post_record = self.post_limits.can_add_record(payload_length); + let can_batch_record = self.batch_limits.can_add_record(payload_length); + let can_send_record = self.queued.len() < self.max_request_bytes; + + if !can_post_record || !can_send_record || !can_batch_record { + log::debug!( + "PostQueue flushing! (can_post = {}, can_send = {}, can_batch = {})", + can_post_record, + can_send_record, + can_batch_record + ); + // "unwrite" the record. + self.queued.truncate(item_start); + // Flush whatever we have queued. + self.flush(!can_batch_record)?; + // And write it again. + let c = if self.queued.is_empty() { b'[' } else { b',' }; + self.queued.push(c); + serde_json::to_writer(&mut self.queued, &record).unwrap(); + } + + self.post_limits.record_added(payload_length); + self.batch_limits.record_added(payload_length); + + Ok(true) + } + + pub fn flush(&mut self, want_commit: bool) -> Result<()> { + if self.queued.is_empty() { + assert!( + !self.in_batch(), + "Bug: Somehow we're in a batch but have no queued records" + ); + // Nothing to do! + return Ok(()); + } + + self.queued.push(b']'); + let batch_id = match &self.batch { + // Not the first post and we know we have no batch semantics. + BatchState::Unsupported => None, + // First commit in possible batch + BatchState::NoBatch => Some("true".into()), + // In a batch and we have a batch id. + BatchState::InBatch(ref s) => Some(s.clone()), + }; + + log::info!( + "Posting {} records of {} bytes", + self.post_limits.cur_records, + self.queued.len() + ); + + let is_commit = want_commit && batch_id.is_some(); + // Weird syntax for calling a function object that is a property. + let resp_or_error = self.poster.post( + self.queued.clone(), + self.last_modified, + batch_id, + is_commit, + self, + ); + + self.queued.truncate(0); + + if want_commit || self.batch == BatchState::Unsupported { + self.batch_limits.clear(); + } + self.post_limits.clear(); + + let resp = resp_or_error?; + + let (status, last_modified, record) = match resp { + Sync15ClientResponse::Success { + status, + last_modified, + ref record, + .. + } => (status, last_modified, record), + _ => { + self.on_response.handle_response(resp, !want_commit)?; + // on_response() should always fail! + unreachable!(); + } + }; + + if want_commit || self.batch == BatchState::Unsupported { + self.last_modified = last_modified; + } + + if want_commit { + log::debug!("Committed batch {:?}", self.batch); + self.batch = BatchState::NoBatch; + self.on_response.handle_response(resp, false)?; + return Ok(()); + } + + if status != status_codes::ACCEPTED { + if self.in_batch() { + return Err(ErrorKind::ServerBatchProblem( + "Server responded non-202 success code while a batch was in progress", + )); + } + self.last_modified = last_modified; + self.batch = BatchState::Unsupported; + self.batch_limits.clear(); + self.on_response.handle_response(resp, false)?; + return Ok(()); + } + + let batch_id = record + .batch + .as_ref() + .ok_or({ + ErrorKind::ServerBatchProblem("Invalid server response: 202 without a batch ID") + })? + .clone(); + + match &self.batch { + BatchState::Unsupported => { + log::warn!("Server changed its mind about supporting batching mid-batch..."); + } + + BatchState::InBatch(ref cur_id) => { + if cur_id != &batch_id { + return Err(ErrorKind::ServerBatchProblem( + "Invalid server response: 202 without a batch ID", + )); + } + } + _ => {} + } + + // Can't change this in match arms without NLL + self.batch = BatchState::InBatch(batch_id); + self.last_modified = last_modified; + + self.on_response.handle_response(resp, true)?; + + Ok(()) + } +} + +#[derive(Clone)] +pub struct UploadInfo { + pub successful_ids: Vec<Guid>, + pub failed_ids: Vec<Guid>, + pub modified_timestamp: ServerTimestamp, +} + +impl<Poster> PostQueue<Poster, NormalResponseHandler> { + // TODO: should take by move + pub fn completed_upload_info(&mut self) -> UploadInfo { + let mut result = UploadInfo { + successful_ids: Vec::with_capacity(self.on_response.successful_ids.len()), + failed_ids: Vec::with_capacity( + self.on_response.failed_ids.len() + + self.on_response.pending_failed.len() + + self.on_response.pending_success.len(), + ), + modified_timestamp: self.last_modified, + }; + + result + .successful_ids + .append(&mut self.on_response.successful_ids); + + result.failed_ids.append(&mut self.on_response.failed_ids); + result + .failed_ids + .append(&mut self.on_response.pending_failed); + result + .failed_ids + .append(&mut self.on_response.pending_success); + + result + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::bso::{IncomingEncryptedBso, OutgoingEncryptedBso, OutgoingEnvelope}; + use crate::EncryptedPayload; + use lazy_static::lazy_static; + use std::cell::RefCell; + use std::collections::VecDeque; + use std::rc::Rc; + + #[derive(Debug, Clone)] + struct PostedData { + body: String, + _xius: ServerTimestamp, + batch: Option<String>, + commit: bool, + payload_bytes: usize, + records: usize, + } + + impl PostedData { + fn records_as_json(&self) -> Vec<serde_json::Value> { + let values = + serde_json::from_str::<serde_json::Value>(&self.body).expect("Posted invalid json"); + // Check that they actually deserialize as what we want + let records_or_err = + serde_json::from_value::<Vec<IncomingEncryptedBso>>(values.clone()); + records_or_err.expect("Failed to deserialize data"); + serde_json::from_value(values).unwrap() + } + } + + #[derive(Debug, Clone)] + struct BatchInfo { + id: Option<String>, + posts: Vec<PostedData>, + bytes: usize, + records: usize, + } + + #[derive(Debug, Clone)] + struct TestPoster { + all_posts: Vec<PostedData>, + responses: VecDeque<PostResponse>, + batches: Vec<BatchInfo>, + cur_batch: Option<BatchInfo>, + cfg: InfoConfiguration, + } + + type TestPosterRef = Rc<RefCell<TestPoster>>; + impl TestPoster { + pub fn new<T>(cfg: &InfoConfiguration, responses: T) -> TestPosterRef + where + T: Into<VecDeque<PostResponse>>, + { + Rc::new(RefCell::new(TestPoster { + all_posts: vec![], + responses: responses.into(), + batches: vec![], + cur_batch: None, + cfg: cfg.clone(), + })) + } + // Adds &mut + fn do_post<T, O>( + &mut self, + body: &[u8], + xius: ServerTimestamp, + batch: Option<String>, + commit: bool, + queue: &PostQueue<T, O>, + ) -> Sync15ClientResponse<UploadResult> { + let mut post = PostedData { + body: String::from_utf8(body.into()).expect("Posted invalid utf8..."), + batch: batch.clone(), + _xius: xius, + commit, + payload_bytes: 0, + records: 0, + }; + + assert!(body.len() <= self.cfg.max_request_bytes); + + let (num_records, record_payload_bytes) = { + let recs = post.records_as_json(); + assert!(recs.len() <= self.cfg.max_post_records); + assert!(recs.len() <= self.cfg.max_total_records); + let payload_bytes: usize = recs + .iter() + .map(|r| { + let len = r["payload"] + .as_str() + .expect("Non string payload property") + .len(); + assert!(len <= self.cfg.max_record_payload_bytes); + len + }) + .sum(); + assert!(payload_bytes <= self.cfg.max_post_bytes); + assert!(payload_bytes <= self.cfg.max_total_bytes); + + assert_eq!(queue.post_limits.cur_bytes, payload_bytes); + assert_eq!(queue.post_limits.cur_records, recs.len()); + (recs.len(), payload_bytes) + }; + post.payload_bytes = record_payload_bytes; + post.records = num_records; + + self.all_posts.push(post.clone()); + let response = self.responses.pop_front().unwrap(); + + let record = match response { + Sync15ClientResponse::Success { ref record, .. } => record, + _ => { + panic!("only success codes are used in this test"); + } + }; + + if self.cur_batch.is_none() { + assert!( + batch.is_none() || batch == Some("true".into()), + "We shouldn't be in a batch now" + ); + self.cur_batch = Some(BatchInfo { + id: record.batch.clone(), + posts: vec![], + records: 0, + bytes: 0, + }); + } else { + assert_eq!( + batch, + self.cur_batch.as_ref().unwrap().id, + "We're in a batch but got the wrong batch id" + ); + } + + { + let batch = self.cur_batch.as_mut().unwrap(); + batch.posts.push(post); + batch.records += num_records; + batch.bytes += record_payload_bytes; + + assert!(batch.bytes <= self.cfg.max_total_bytes); + assert!(batch.records <= self.cfg.max_total_records); + + assert_eq!(batch.records, queue.batch_limits.cur_records); + assert_eq!(batch.bytes, queue.batch_limits.cur_bytes); + } + + if commit || record.batch.is_none() { + let batch = self.cur_batch.take().unwrap(); + self.batches.push(batch); + } + + response + } + + fn do_handle_response(&mut self, _: PostResponse, mid_batch: bool) { + assert_eq!(mid_batch, self.cur_batch.is_some()); + } + } + impl BatchPoster for TestPosterRef { + fn post<T, O>( + &self, + body: Vec<u8>, + xius: ServerTimestamp, + batch: Option<String>, + commit: bool, + queue: &PostQueue<T, O>, + ) -> Result<PostResponse> { + Ok(self.borrow_mut().do_post(&body, xius, batch, commit, queue)) + } + } + + impl PostResponseHandler for TestPosterRef { + fn handle_response(&mut self, r: PostResponse, mid_batch: bool) -> Result<()> { + self.borrow_mut().do_handle_response(r, mid_batch); + Ok(()) + } + } + + type MockedPostQueue = PostQueue<TestPosterRef, TestPosterRef>; + + fn pq_test_setup( + cfg: InfoConfiguration, + lm: i64, + resps: Vec<PostResponse>, + ) -> (MockedPostQueue, TestPosterRef) { + let tester = TestPoster::new(&cfg, resps); + let pq = PostQueue::new(&cfg, ServerTimestamp(lm), tester.clone(), tester.clone()); + (pq, tester) + } + + fn fake_response<'a, T: Into<Option<&'a str>>>(status: u16, lm: i64, batch: T) -> PostResponse { + assert!(status_codes::is_success_code(status)); + Sync15ClientResponse::Success { + status, + last_modified: ServerTimestamp(lm), + record: UploadResult { + batch: batch.into().map(Into::into), + failed: HashMap::new(), + success: vec![], + }, + route: "test/path".into(), + } + } + + lazy_static! { + // ~40b + static ref PAYLOAD_OVERHEAD: usize = { + let payload = EncryptedPayload { + iv: "".into(), + hmac: "".into(), + ciphertext: "".into() + }; + serde_json::to_string(&payload).unwrap().len() + }; + // ~80b + static ref TOTAL_RECORD_OVERHEAD: usize = { + let val = serde_json::to_value(OutgoingEncryptedBso::new(OutgoingEnvelope { + id: "".into(), + sortindex: None, + ttl: None, + }, + EncryptedPayload { + iv: "".into(), + hmac: "".into(), + ciphertext: "".into() + }, + )).unwrap(); + serde_json::to_string(&val).unwrap().len() + }; + // There's some subtlety in how we calulate this having to do with the fact that + // the quotes in the payload are escaped but the escape chars count to the request len + // and *not* to the payload len (the payload len check happens after json parsing the + // top level object). + static ref NON_PAYLOAD_OVERHEAD: usize = { + *TOTAL_RECORD_OVERHEAD - *PAYLOAD_OVERHEAD + }; + } + + // Actual record size (for max_request_len) will be larger by some amount + fn make_record(payload_size: usize) -> OutgoingEncryptedBso { + assert!(payload_size > *PAYLOAD_OVERHEAD); + let ciphertext_len = payload_size - *PAYLOAD_OVERHEAD; + OutgoingEncryptedBso::new( + OutgoingEnvelope { + id: "".into(), + sortindex: None, + ttl: None, + }, + EncryptedPayload { + iv: "".into(), + hmac: "".into(), + ciphertext: "x".repeat(ciphertext_len), + }, + ) + } + + fn request_bytes_for_payloads(payloads: &[usize]) -> usize { + 1 + payloads + .iter() + .map(|&size| size + 1 + *NON_PAYLOAD_OVERHEAD) + .sum::<usize>() + } + + #[test] + fn test_pq_basic() { + let cfg = InfoConfiguration { + max_request_bytes: 1000, + max_record_payload_bytes: 1000, + ..InfoConfiguration::default() + }; + let time = 11_111_111_000; + let (mut pq, tester) = pq_test_setup( + cfg, + time, + vec![fake_response(status_codes::OK, time + 100_000, None)], + ); + + pq.enqueue(&make_record(100)).unwrap(); + pq.flush(true).unwrap(); + + let t = tester.borrow(); + assert!(t.cur_batch.is_none()); + assert_eq!(t.all_posts.len(), 1); + assert_eq!(t.batches.len(), 1); + assert_eq!(t.batches[0].posts.len(), 1); + assert_eq!(t.batches[0].records, 1); + assert_eq!(t.batches[0].bytes, 100); + assert_eq!( + t.batches[0].posts[0].body.len(), + request_bytes_for_payloads(&[100]) + ); + } + + #[test] + fn test_pq_max_request_bytes_no_batch() { + let cfg = InfoConfiguration { + max_request_bytes: 250, + ..InfoConfiguration::default() + }; + let time = 11_111_111_000; + let (mut pq, tester) = pq_test_setup( + cfg, + time, + vec![ + fake_response(status_codes::OK, time + 100_000, None), + fake_response(status_codes::OK, time + 200_000, None), + ], + ); + + // Note that the total record overhead is around 85 bytes + let payload_size = 100 - *NON_PAYLOAD_OVERHEAD; + pq.enqueue(&make_record(payload_size)).unwrap(); // total size == 102; [r] + pq.enqueue(&make_record(payload_size)).unwrap(); // total size == 203; [r,r] + pq.enqueue(&make_record(payload_size)).unwrap(); // too big, 2nd post. + pq.flush(true).unwrap(); + + let t = tester.borrow(); + assert!(t.cur_batch.is_none()); + assert_eq!(t.all_posts.len(), 2); + assert_eq!(t.batches.len(), 2); + assert_eq!(t.batches[0].posts.len(), 1); + assert_eq!(t.batches[0].records, 2); + assert_eq!(t.batches[0].bytes, payload_size * 2); + assert_eq!(t.batches[0].posts[0].batch, Some("true".into())); + assert_eq!( + t.batches[0].posts[0].body.len(), + request_bytes_for_payloads(&[payload_size, payload_size]) + ); + + assert_eq!(t.batches[1].posts.len(), 1); + assert_eq!(t.batches[1].records, 1); + assert_eq!(t.batches[1].bytes, payload_size); + // We know at this point that the server does not support batching. + assert_eq!(t.batches[1].posts[0].batch, None); + assert!(!t.batches[1].posts[0].commit); + assert_eq!( + t.batches[1].posts[0].body.len(), + request_bytes_for_payloads(&[payload_size]) + ); + } + + #[test] + fn test_pq_max_record_payload_bytes_no_batch() { + let cfg = InfoConfiguration { + max_record_payload_bytes: 150, + max_request_bytes: 350, + ..InfoConfiguration::default() + }; + let time = 11_111_111_000; + let (mut pq, tester) = pq_test_setup( + cfg, + time, + vec![ + fake_response(status_codes::OK, time + 100_000, None), + fake_response(status_codes::OK, time + 200_000, None), + ], + ); + + // Note that the total record overhead is around 85 bytes + let payload_size = 100 - *NON_PAYLOAD_OVERHEAD; + pq.enqueue(&make_record(payload_size)).unwrap(); // total size == 102; [r] + let enqueued = pq.enqueue(&make_record(151)).unwrap(); // still 102 + assert!(!enqueued, "Should not have fit"); + pq.enqueue(&make_record(payload_size)).unwrap(); + pq.flush(true).unwrap(); + + let t = tester.borrow(); + assert!(t.cur_batch.is_none()); + assert_eq!(t.all_posts.len(), 1); + assert_eq!(t.batches.len(), 1); + assert_eq!(t.batches[0].posts.len(), 1); + assert_eq!(t.batches[0].records, 2); + assert_eq!(t.batches[0].bytes, payload_size * 2); + assert_eq!( + t.batches[0].posts[0].body.len(), + request_bytes_for_payloads(&[payload_size, payload_size]) + ); + } + + #[test] + fn test_pq_single_batch() { + let cfg = InfoConfiguration::default(); + let time = 11_111_111_000; + let (mut pq, tester) = pq_test_setup( + cfg, + time, + vec![fake_response( + status_codes::ACCEPTED, + time + 100_000, + Some("1234"), + )], + ); + + let payload_size = 100 - *NON_PAYLOAD_OVERHEAD; + pq.enqueue(&make_record(payload_size)).unwrap(); + pq.enqueue(&make_record(payload_size)).unwrap(); + pq.enqueue(&make_record(payload_size)).unwrap(); + pq.flush(true).unwrap(); + + let t = tester.borrow(); + assert!(t.cur_batch.is_none()); + assert_eq!(t.all_posts.len(), 1); + assert_eq!(t.batches.len(), 1); + assert_eq!(t.batches[0].id.as_ref().unwrap(), "1234"); + assert_eq!(t.batches[0].posts.len(), 1); + assert_eq!(t.batches[0].records, 3); + assert_eq!(t.batches[0].bytes, payload_size * 3); + assert!(t.batches[0].posts[0].commit); + assert_eq!( + t.batches[0].posts[0].body.len(), + request_bytes_for_payloads(&[payload_size, payload_size, payload_size]) + ); + } + + #[test] + fn test_pq_multi_post_batch_bytes() { + let cfg = InfoConfiguration { + max_post_bytes: 200, + ..InfoConfiguration::default() + }; + let time = 11_111_111_000; + let (mut pq, tester) = pq_test_setup( + cfg, + time, + vec![ + fake_response(status_codes::ACCEPTED, time, Some("1234")), + fake_response(status_codes::ACCEPTED, time + 100_000, Some("1234")), + ], + ); + + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + // POST + pq.enqueue(&make_record(100)).unwrap(); + pq.flush(true).unwrap(); // COMMIT + + let t = tester.borrow(); + assert!(t.cur_batch.is_none()); + assert_eq!(t.all_posts.len(), 2); + assert_eq!(t.batches.len(), 1); + assert_eq!(t.batches[0].posts.len(), 2); + assert_eq!(t.batches[0].records, 3); + assert_eq!(t.batches[0].bytes, 300); + + assert_eq!(t.batches[0].posts[0].batch.as_ref().unwrap(), "true"); + assert_eq!(t.batches[0].posts[0].records, 2); + assert_eq!(t.batches[0].posts[0].payload_bytes, 200); + assert!(!t.batches[0].posts[0].commit); + assert_eq!( + t.batches[0].posts[0].body.len(), + request_bytes_for_payloads(&[100, 100]) + ); + + assert_eq!(t.batches[0].posts[1].batch.as_ref().unwrap(), "1234"); + assert_eq!(t.batches[0].posts[1].records, 1); + assert_eq!(t.batches[0].posts[1].payload_bytes, 100); + assert!(t.batches[0].posts[1].commit); + assert_eq!( + t.batches[0].posts[1].body.len(), + request_bytes_for_payloads(&[100]) + ); + } + + #[test] + fn test_pq_multi_post_batch_records() { + let cfg = InfoConfiguration { + max_post_records: 3, + ..InfoConfiguration::default() + }; + let time = 11_111_111_000; + let (mut pq, tester) = pq_test_setup( + cfg, + time, + vec![ + fake_response(status_codes::ACCEPTED, time, Some("1234")), + fake_response(status_codes::ACCEPTED, time, Some("1234")), + fake_response(status_codes::ACCEPTED, time + 100_000, Some("1234")), + ], + ); + + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + // POST + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + // POST + pq.enqueue(&make_record(100)).unwrap(); + pq.flush(true).unwrap(); // COMMIT + + let t = tester.borrow(); + assert!(t.cur_batch.is_none()); + assert_eq!(t.all_posts.len(), 3); + assert_eq!(t.batches.len(), 1); + assert_eq!(t.batches[0].posts.len(), 3); + assert_eq!(t.batches[0].records, 7); + assert_eq!(t.batches[0].bytes, 700); + + assert_eq!(t.batches[0].posts[0].batch.as_ref().unwrap(), "true"); + assert_eq!(t.batches[0].posts[0].records, 3); + assert_eq!(t.batches[0].posts[0].payload_bytes, 300); + assert!(!t.batches[0].posts[0].commit); + assert_eq!( + t.batches[0].posts[0].body.len(), + request_bytes_for_payloads(&[100, 100, 100]) + ); + + assert_eq!(t.batches[0].posts[1].batch.as_ref().unwrap(), "1234"); + assert_eq!(t.batches[0].posts[1].records, 3); + assert_eq!(t.batches[0].posts[1].payload_bytes, 300); + assert!(!t.batches[0].posts[1].commit); + assert_eq!( + t.batches[0].posts[1].body.len(), + request_bytes_for_payloads(&[100, 100, 100]) + ); + + assert_eq!(t.batches[0].posts[2].batch.as_ref().unwrap(), "1234"); + assert_eq!(t.batches[0].posts[2].records, 1); + assert_eq!(t.batches[0].posts[2].payload_bytes, 100); + assert!(t.batches[0].posts[2].commit); + assert_eq!( + t.batches[0].posts[2].body.len(), + request_bytes_for_payloads(&[100]) + ); + } + + #[test] + #[allow(clippy::cognitive_complexity)] + fn test_pq_multi_post_multi_batch_records() { + let cfg = InfoConfiguration { + max_post_records: 3, + max_total_records: 5, + ..InfoConfiguration::default() + }; + let time = 11_111_111_000; + let (mut pq, tester) = pq_test_setup( + cfg, + time, + vec![ + fake_response(status_codes::ACCEPTED, time, Some("1234")), + fake_response(status_codes::ACCEPTED, time + 100_000, Some("1234")), + fake_response(status_codes::ACCEPTED, time + 100_000, Some("abcd")), + fake_response(status_codes::ACCEPTED, time + 200_000, Some("abcd")), + ], + ); + + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + // POST + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + // POST + COMMIT + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + // POST + pq.enqueue(&make_record(100)).unwrap(); + pq.flush(true).unwrap(); // COMMIT + + let t = tester.borrow(); + assert!(t.cur_batch.is_none()); + assert_eq!(t.all_posts.len(), 4); + assert_eq!(t.batches.len(), 2); + assert_eq!(t.batches[0].posts.len(), 2); + assert_eq!(t.batches[1].posts.len(), 2); + + assert_eq!(t.batches[0].records, 5); + assert_eq!(t.batches[1].records, 4); + + assert_eq!(t.batches[0].bytes, 500); + assert_eq!(t.batches[1].bytes, 400); + + assert_eq!(t.batches[0].posts[0].batch.as_ref().unwrap(), "true"); + assert_eq!(t.batches[0].posts[0].records, 3); + assert_eq!(t.batches[0].posts[0].payload_bytes, 300); + assert!(!t.batches[0].posts[0].commit); + assert_eq!( + t.batches[0].posts[0].body.len(), + request_bytes_for_payloads(&[100, 100, 100]) + ); + + assert_eq!(t.batches[0].posts[1].batch.as_ref().unwrap(), "1234"); + assert_eq!(t.batches[0].posts[1].records, 2); + assert_eq!(t.batches[0].posts[1].payload_bytes, 200); + assert!(t.batches[0].posts[1].commit); + assert_eq!( + t.batches[0].posts[1].body.len(), + request_bytes_for_payloads(&[100, 100]) + ); + + assert_eq!(t.batches[1].posts[0].batch.as_ref().unwrap(), "true"); + assert_eq!(t.batches[1].posts[0].records, 3); + assert_eq!(t.batches[1].posts[0].payload_bytes, 300); + assert!(!t.batches[1].posts[0].commit); + assert_eq!( + t.batches[1].posts[0].body.len(), + request_bytes_for_payloads(&[100, 100, 100]) + ); + + assert_eq!(t.batches[1].posts[1].batch.as_ref().unwrap(), "abcd"); + assert_eq!(t.batches[1].posts[1].records, 1); + assert_eq!(t.batches[1].posts[1].payload_bytes, 100); + assert!(t.batches[1].posts[1].commit); + assert_eq!( + t.batches[1].posts[1].body.len(), + request_bytes_for_payloads(&[100]) + ); + } + + #[test] + #[allow(clippy::cognitive_complexity)] + fn test_pq_multi_post_multi_batch_bytes() { + let cfg = InfoConfiguration { + max_post_bytes: 300, + max_total_bytes: 500, + ..InfoConfiguration::default() + }; + let time = 11_111_111_000; + let (mut pq, tester) = pq_test_setup( + cfg, + time, + vec![ + fake_response(status_codes::ACCEPTED, time, Some("1234")), + fake_response(status_codes::ACCEPTED, time + 100_000, Some("1234")), // should commit + fake_response(status_codes::ACCEPTED, time + 100_000, Some("abcd")), + fake_response(status_codes::ACCEPTED, time + 200_000, Some("abcd")), // should commit + ], + ); + + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + assert_eq!(pq.last_modified.0, time); + // POST + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + // POST + COMMIT + pq.enqueue(&make_record(100)).unwrap(); + assert_eq!(pq.last_modified.0, time + 100_000); + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + + // POST + pq.enqueue(&make_record(100)).unwrap(); + assert_eq!(pq.last_modified.0, time + 100_000); + pq.flush(true).unwrap(); // COMMIT + + assert_eq!(pq.last_modified.0, time + 200_000); + + let t = tester.borrow(); + assert!(t.cur_batch.is_none()); + assert_eq!(t.all_posts.len(), 4); + assert_eq!(t.batches.len(), 2); + assert_eq!(t.batches[0].posts.len(), 2); + assert_eq!(t.batches[1].posts.len(), 2); + + assert_eq!(t.batches[0].records, 5); + assert_eq!(t.batches[1].records, 4); + + assert_eq!(t.batches[0].bytes, 500); + assert_eq!(t.batches[1].bytes, 400); + + assert_eq!(t.batches[0].posts[0].batch.as_ref().unwrap(), "true"); + assert_eq!(t.batches[0].posts[0].records, 3); + assert_eq!(t.batches[0].posts[0].payload_bytes, 300); + assert!(!t.batches[0].posts[0].commit); + assert_eq!( + t.batches[0].posts[0].body.len(), + request_bytes_for_payloads(&[100, 100, 100]) + ); + + assert_eq!(t.batches[0].posts[1].batch.as_ref().unwrap(), "1234"); + assert_eq!(t.batches[0].posts[1].records, 2); + assert_eq!(t.batches[0].posts[1].payload_bytes, 200); + assert!(t.batches[0].posts[1].commit); + assert_eq!( + t.batches[0].posts[1].body.len(), + request_bytes_for_payloads(&[100, 100]) + ); + + assert_eq!(t.batches[1].posts[0].batch.as_ref().unwrap(), "true"); + assert_eq!(t.batches[1].posts[0].records, 3); + assert_eq!(t.batches[1].posts[0].payload_bytes, 300); + assert!(!t.batches[1].posts[0].commit); + assert_eq!( + t.batches[1].posts[0].body.len(), + request_bytes_for_payloads(&[100, 100, 100]) + ); + + assert_eq!(t.batches[1].posts[1].batch.as_ref().unwrap(), "abcd"); + assert_eq!(t.batches[1].posts[1].records, 1); + assert_eq!(t.batches[1].posts[1].payload_bytes, 100); + assert!(t.batches[1].posts[1].commit); + assert_eq!( + t.batches[1].posts[1].body.len(), + request_bytes_for_payloads(&[100]) + ); + } + + // TODO: Test + // + // - error cases!!! We don't test our handling of server errors at all! + // - mixed bytes/record limits + // + // A lot of these have good examples in test_postqueue.js on deskftop sync +} diff --git a/third_party/rust/sync15/src/client/state.rs b/third_party/rust/sync15/src/client/state.rs new file mode 100644 index 0000000000..78e9a6a718 --- /dev/null +++ b/third_party/rust/sync15/src/client/state.rs @@ -0,0 +1,1089 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use std::collections::{HashMap, HashSet}; + +use super::request::{InfoCollections, InfoConfiguration}; +use super::storage_client::{SetupStorageClient, Sync15ClientResponse}; +use super::CollectionKeys; +use crate::bso::OutgoingEncryptedBso; +use crate::error::{self, Error as ErrorKind, ErrorResponse}; +use crate::record_types::{MetaGlobalEngine, MetaGlobalRecord}; +use crate::EncryptedPayload; +use crate::{Guid, KeyBundle, ServerTimestamp}; +use interrupt_support::Interruptee; +use serde_derive::*; + +use self::SetupState::*; + +const STORAGE_VERSION: usize = 5; + +/// Maps names to storage versions for engines to include in a fresh +/// `meta/global` record. We include engines that we don't implement +/// because they'll be disabled on other clients if we omit them +/// (bug 1479929). +const DEFAULT_ENGINES: &[(&str, usize)] = &[ + ("passwords", 1), + ("clients", 1), + ("addons", 1), + ("addresses", 1), + ("bookmarks", 2), + ("creditcards", 1), + ("forms", 1), + ("history", 1), + ("prefs", 2), + ("tabs", 1), +]; + +// Declined engines to include in a fresh `meta/global` record. +const DEFAULT_DECLINED: &[&str] = &[]; + +/// State that we require the app to persist to storage for us. +/// It's a little unfortunate we need this, because it's only tracking +/// "declined engines", and even then, only needed in practice when there's +/// no meta/global so we need to create one. It's extra unfortunate because we +/// want to move away from "globally declined" engines anyway, moving towards +/// allowing engines to be enabled or disabled per client rather than globally. +/// +/// Apps are expected to treat this as opaque, so we support serializing it. +/// Note that this structure is *not* used to *change* the declined engines +/// list - that will be done in the future, but the API exposed for that +/// purpose will also take a mutable PersistedGlobalState. +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "schema_version")] +pub enum PersistedGlobalState { + /// V1 was when we persisted the entire GlobalState, keys and all! + + /// V2 is just tracking the globally declined list. + /// None means "I've no idea" and theoretically should only happen on the + /// very first sync for an app. + V2 { declined: Option<Vec<String>> }, +} + +impl Default for PersistedGlobalState { + #[inline] + fn default() -> PersistedGlobalState { + PersistedGlobalState::V2 { declined: None } + } +} + +#[derive(Debug, Default, Clone, PartialEq)] +pub(crate) struct EngineChangesNeeded { + pub local_resets: HashSet<String>, + pub remote_wipes: HashSet<String>, +} + +#[derive(Debug, Default, Clone, PartialEq)] +struct RemoteEngineState { + info_collections: HashSet<String>, + declined: HashSet<String>, +} + +#[derive(Debug, Default, Clone, PartialEq)] +struct EngineStateInput { + local_declined: HashSet<String>, + remote: Option<RemoteEngineState>, + user_changes: HashMap<String, bool>, +} + +#[derive(Debug, Default, Clone, PartialEq)] +struct EngineStateOutput { + // The new declined. + declined: HashSet<String>, + // Which engines need resets or wipes. + changes_needed: EngineChangesNeeded, +} + +fn compute_engine_states(input: EngineStateInput) -> EngineStateOutput { + use super::util::*; + log::debug!("compute_engine_states: input {:?}", input); + let (must_enable, must_disable) = partition_by_value(&input.user_changes); + let have_remote = input.remote.is_some(); + let RemoteEngineState { + info_collections, + declined: remote_declined, + } = input.remote.clone().unwrap_or_default(); + + let both_declined_and_remote = set_intersection(&info_collections, &remote_declined); + if !both_declined_and_remote.is_empty() { + // Should we wipe these too? + log::warn!( + "Remote state contains engines which are in both info/collections and meta/global's declined: {:?}", + both_declined_and_remote, + ); + } + + let most_recent_declined_list = if have_remote { + &remote_declined + } else { + &input.local_declined + }; + + let result_declined = set_difference( + &set_union(most_recent_declined_list, &must_disable), + &must_enable, + ); + + let output = EngineStateOutput { + changes_needed: EngineChangesNeeded { + // Anything now declined which wasn't in our declined list before gets a reset. + local_resets: set_difference(&result_declined, &input.local_declined), + // Anything remote that we just declined gets a wipe. In the future + // we might want to consider wiping things in both remote declined + // and info/collections, but we'll let other clients pick up their + // own mess for now. + remote_wipes: set_intersection(&info_collections, &must_disable), + }, + declined: result_declined, + }; + // No PII here and this helps debug problems. + log::debug!("compute_engine_states: output {:?}", output); + output +} + +impl PersistedGlobalState { + fn set_declined(&mut self, new_declined: Vec<String>) { + match self { + Self::V2 { ref mut declined } => *declined = Some(new_declined), + } + } + pub(crate) fn get_declined(&self) -> &[String] { + match self { + Self::V2 { declined: Some(d) } => d, + Self::V2 { declined: None } => &[], + } + } +} + +/// Holds global Sync state, including server upload limits, the +/// last-fetched collection modified times, `meta/global` record, and +/// an encrypted copy of the crypto/keys resource (avoids keeping them +/// in memory longer than necessary; avoids key mismatches by ensuring the same KeyBundle +/// is used for both the keys and encrypted payloads.) +#[derive(Debug, Clone)] +pub struct GlobalState { + pub config: InfoConfiguration, + pub collections: InfoCollections, + pub global: MetaGlobalRecord, + pub global_timestamp: ServerTimestamp, + pub keys: EncryptedPayload, + pub keys_timestamp: ServerTimestamp, +} + +/// Creates a fresh `meta/global` record, using the default engine selections, +/// and declined engines from our PersistedGlobalState. +fn new_global(pgs: &PersistedGlobalState) -> MetaGlobalRecord { + let sync_id = Guid::random(); + let mut engines: HashMap<String, _> = HashMap::new(); + for (name, version) in DEFAULT_ENGINES.iter() { + let sync_id = Guid::random(); + engines.insert( + (*name).to_string(), + MetaGlobalEngine { + version: *version, + sync_id, + }, + ); + } + // We only need our PersistedGlobalState to fill out a new meta/global - if + // we previously saw a meta/global then we would have updated it with what + // it was at the time. + let declined = match pgs { + PersistedGlobalState::V2 { declined: Some(d) } => d.clone(), + _ => DEFAULT_DECLINED.iter().map(ToString::to_string).collect(), + }; + + MetaGlobalRecord { + sync_id, + storage_version: STORAGE_VERSION, + engines, + declined, + } +} + +fn fixup_meta_global(global: &mut MetaGlobalRecord) -> bool { + let mut changed_any = false; + for &(name, version) in DEFAULT_ENGINES.iter() { + let had_engine = global.engines.contains_key(name); + let should_have_engine = !global.declined.iter().any(|c| c == name); + if had_engine != should_have_engine { + if should_have_engine { + log::debug!("SyncID for engine {:?} was missing", name); + global.engines.insert( + name.to_string(), + MetaGlobalEngine { + version, + sync_id: Guid::random(), + }, + ); + } else { + log::debug!("SyncID for engine {:?} was present, but shouldn't be", name); + global.engines.remove(name); + } + changed_any = true; + } + } + changed_any +} + +pub struct SetupStateMachine<'a> { + client: &'a dyn SetupStorageClient, + root_key: &'a KeyBundle, + pgs: &'a mut PersistedGlobalState, + // `allowed_states` is designed so that we can arrange for the concept of + // a "fast" sync - so we decline to advance if we need to setup from scratch. + // The idea is that if we need to sync before going to sleep we should do + // it as fast as possible. However, in practice this isn't going to do + // what we expect - a "fast sync" that finds lots to do is almost certainly + // going to take longer than a "full sync" that finds nothing to do. + // We should almost certainly remove this and instead allow for a "time + // budget", after which we get interrupted. Later... + allowed_states: Vec<&'static str>, + sequence: Vec<&'static str>, + engine_updates: Option<&'a HashMap<String, bool>>, + interruptee: &'a dyn Interruptee, + pub(crate) changes_needed: Option<EngineChangesNeeded>, +} + +impl<'a> SetupStateMachine<'a> { + /// Creates a state machine for a "classic" Sync 1.5 client that supports + /// all states, including uploading a fresh `meta/global` and `crypto/keys` + /// after a node reassignment. + pub fn for_full_sync( + client: &'a dyn SetupStorageClient, + root_key: &'a KeyBundle, + pgs: &'a mut PersistedGlobalState, + engine_updates: Option<&'a HashMap<String, bool>>, + interruptee: &'a dyn Interruptee, + ) -> SetupStateMachine<'a> { + SetupStateMachine::with_allowed_states( + client, + root_key, + pgs, + interruptee, + engine_updates, + vec![ + "Initial", + "InitialWithConfig", + "InitialWithInfo", + "InitialWithMetaGlobal", + "Ready", + "FreshStartRequired", + "WithPreviousState", + ], + ) + } + + fn with_allowed_states( + client: &'a dyn SetupStorageClient, + root_key: &'a KeyBundle, + pgs: &'a mut PersistedGlobalState, + interruptee: &'a dyn Interruptee, + engine_updates: Option<&'a HashMap<String, bool>>, + allowed_states: Vec<&'static str>, + ) -> SetupStateMachine<'a> { + SetupStateMachine { + client, + root_key, + pgs, + sequence: Vec::new(), + allowed_states, + engine_updates, + interruptee, + changes_needed: None, + } + } + + fn advance(&mut self, from: SetupState) -> error::Result<SetupState> { + match from { + // Fetch `info/configuration` with current server limits, and + // `info/collections` with collection last modified times. + Initial => { + let config = match self.client.fetch_info_configuration()? { + Sync15ClientResponse::Success { record, .. } => record, + Sync15ClientResponse::Error(ErrorResponse::NotFound { .. }) => { + InfoConfiguration::default() + } + other => return Err(other.create_storage_error()), + }; + Ok(InitialWithConfig { config }) + } + + // XXX - we could consider combining these Initial* states, because we don't + // attempt to support filling in "missing" global state - *any* 404 in them + // means `FreshStart`. + // IOW, in all cases, they either `Err()`, move to `FreshStartRequired`, or + // advance to a specific next state. + InitialWithConfig { config } => { + match self.client.fetch_info_collections()? { + Sync15ClientResponse::Success { + record: collections, + .. + } => Ok(InitialWithInfo { + config, + collections, + }), + // If the server doesn't have a `crypto/keys`, start over + // and reupload our `meta/global` and `crypto/keys`. + Sync15ClientResponse::Error(ErrorResponse::NotFound { .. }) => { + Ok(FreshStartRequired { config }) + } + other => Err(other.create_storage_error()), + } + } + + InitialWithInfo { + config, + collections, + } => { + match self.client.fetch_meta_global()? { + Sync15ClientResponse::Success { + record: mut global, + last_modified: mut global_timestamp, + .. + } => { + // If the server has a newer storage version, we can't + // sync until our client is updated. + if global.storage_version > STORAGE_VERSION { + return Err(ErrorKind::ClientUpgradeRequired); + } + + // If the server has an older storage version, wipe and + // reupload. + if global.storage_version < STORAGE_VERSION { + Ok(FreshStartRequired { config }) + } else { + log::info!("Have info/collections and meta/global. Computing new engine states"); + let initial_global_declined: HashSet<String> = + global.declined.iter().cloned().collect(); + let result = compute_engine_states(EngineStateInput { + local_declined: self.pgs.get_declined().iter().cloned().collect(), + user_changes: self.engine_updates.cloned().unwrap_or_default(), + remote: Some(RemoteEngineState { + declined: initial_global_declined.clone(), + info_collections: collections.keys().cloned().collect(), + }), + }); + // Persist the new declined. + self.pgs + .set_declined(result.declined.iter().cloned().collect()); + // If the declined engines differ from remote, fix that. + let fixed_declined = if result.declined != initial_global_declined { + global.declined = result.declined.iter().cloned().collect(); + log::info!( + "Uploading new declined {:?} to meta/global with timestamp {:?}", + global.declined, + global_timestamp, + ); + true + } else { + false + }; + // If there are missing syncIds, we need to fix those as well + let fixed_ids = if fixup_meta_global(&mut global) { + log::info!( + "Uploading corrected meta/global with timestamp {:?}", + global_timestamp, + ); + true + } else { + false + }; + + if fixed_declined || fixed_ids { + global_timestamp = + self.client.put_meta_global(global_timestamp, &global)?; + log::debug!("new global_timestamp: {:?}", global_timestamp); + } + // Update the set of changes needed. + if self.changes_needed.is_some() { + // Should never happen (we prevent state machine + // loops elsewhere) but if it did, the info is stale + // anyway. + log::warn!("Already have a set of changes needed, Overwriting..."); + } + self.changes_needed = Some(result.changes_needed); + Ok(InitialWithMetaGlobal { + config, + collections, + global, + global_timestamp, + }) + } + } + Sync15ClientResponse::Error(ErrorResponse::NotFound { .. }) => { + Ok(FreshStartRequired { config }) + } + other => Err(other.create_storage_error()), + } + } + + InitialWithMetaGlobal { + config, + collections, + global, + global_timestamp, + } => { + // Now try and get keys etc - if we fresh-start we'll re-use declined. + match self.client.fetch_crypto_keys()? { + Sync15ClientResponse::Success { + record, + last_modified, + .. + } => { + // Note that collection/keys is itself a bso, so the + // json body also carries the timestamp. If they aren't + // identical something has screwed up and we should die. + assert_eq!(last_modified, record.envelope.modified); + let state = GlobalState { + config, + collections, + global, + global_timestamp, + keys: record.payload, + keys_timestamp: last_modified, + }; + Ok(Ready { state }) + } + // If the server doesn't have a `crypto/keys`, start over + // and reupload our `meta/global` and `crypto/keys`. + Sync15ClientResponse::Error(ErrorResponse::NotFound { .. }) => { + Ok(FreshStartRequired { config }) + } + other => Err(other.create_storage_error()), + } + } + + // We've got old state that's likely to be OK. + // We keep things simple here - if there's evidence of a new/missing + // meta/global or new/missing keys we just restart from scratch. + WithPreviousState { old_state } => match self.client.fetch_info_collections()? { + Sync15ClientResponse::Success { + record: collections, + .. + } => Ok( + if self.engine_updates.is_none() + && is_same_timestamp(old_state.global_timestamp, &collections, "meta") + && is_same_timestamp(old_state.keys_timestamp, &collections, "crypto") + { + Ready { + state: GlobalState { + collections, + ..old_state + }, + } + } else { + InitialWithConfig { + config: old_state.config, + } + }, + ), + _ => Ok(InitialWithConfig { + config: old_state.config, + }), + }, + + Ready { state } => Ok(Ready { state }), + + FreshStartRequired { config } => { + // Wipe the server. + log::info!("Fresh start: wiping remote"); + self.client.wipe_all_remote()?; + + // Upload a fresh `meta/global`... + log::info!("Uploading meta/global"); + let computed = compute_engine_states(EngineStateInput { + local_declined: self.pgs.get_declined().iter().cloned().collect(), + user_changes: self.engine_updates.cloned().unwrap_or_default(), + remote: None, + }); + self.pgs + .set_declined(computed.declined.iter().cloned().collect()); + + self.changes_needed = Some(computed.changes_needed); + + let new_global = new_global(self.pgs); + + self.client + .put_meta_global(ServerTimestamp::default(), &new_global)?; + + // ...And a fresh `crypto/keys`. + let new_keys = CollectionKeys::new_random()?.to_encrypted_payload(self.root_key)?; + let bso = OutgoingEncryptedBso::new(Guid::new("keys").into(), new_keys); + self.client + .put_crypto_keys(ServerTimestamp::default(), &bso)?; + + // TODO(lina): Can we pass along server timestamps from the PUTs + // above, and avoid re-fetching the `m/g` and `c/k` we just + // uploaded? + // OTOH(mark): restarting the state machine keeps life simple and rare. + Ok(InitialWithConfig { config }) + } + } + } + + /// Runs through the state machine to the ready state. + pub fn run_to_ready(&mut self, state: Option<GlobalState>) -> error::Result<GlobalState> { + let mut s = match state { + Some(old_state) => WithPreviousState { old_state }, + None => Initial, + }; + loop { + self.interruptee.err_if_interrupted()?; + let label = &s.label(); + log::trace!("global state: {:?}", label); + match s { + Ready { state } => { + self.sequence.push(label); + return Ok(state); + } + // If we already started over once before, we're likely in a + // cycle, and should try again later. Intermediate states + // aren't a problem, just the initial ones. + FreshStartRequired { .. } | WithPreviousState { .. } | Initial => { + if self.sequence.contains(label) { + // Is this really the correct error? + return Err(ErrorKind::SetupRace); + } + } + _ => { + if !self.allowed_states.contains(label) { + return Err(ErrorKind::SetupRequired); + } + } + }; + self.sequence.push(label); + s = self.advance(s)?; + } + } +} + +/// States in the remote setup process. +/// TODO(lina): Add link once #56 is merged. +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +enum SetupState { + // These "Initial" states are only ever used when starting from scratch. + Initial, + InitialWithConfig { + config: InfoConfiguration, + }, + InitialWithInfo { + config: InfoConfiguration, + collections: InfoCollections, + }, + InitialWithMetaGlobal { + config: InfoConfiguration, + collections: InfoCollections, + global: MetaGlobalRecord, + global_timestamp: ServerTimestamp, + }, + WithPreviousState { + old_state: GlobalState, + }, + Ready { + state: GlobalState, + }, + FreshStartRequired { + config: InfoConfiguration, + }, +} + +impl SetupState { + fn label(&self) -> &'static str { + match self { + Initial { .. } => "Initial", + InitialWithConfig { .. } => "InitialWithConfig", + InitialWithInfo { .. } => "InitialWithInfo", + InitialWithMetaGlobal { .. } => "InitialWithMetaGlobal", + Ready { .. } => "Ready", + WithPreviousState { .. } => "WithPreviousState", + FreshStartRequired { .. } => "FreshStartRequired", + } + } +} + +/// Whether we should skip fetching an item. Used when we already have timestamps +/// and want to check if we should reuse our existing state. The state's fairly +/// cheap to recreate and very bad to use if it is wrong, so we insist on the +/// *exact* timestamp matching and not a simple "later than" check. +fn is_same_timestamp(local: ServerTimestamp, collections: &InfoCollections, key: &str) -> bool { + collections.get(key).map_or(false, |ts| local == *ts) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::bso::{IncomingEncryptedBso, IncomingEnvelope}; + use interrupt_support::NeverInterrupts; + + struct InMemoryClient { + info_configuration: error::Result<Sync15ClientResponse<InfoConfiguration>>, + info_collections: error::Result<Sync15ClientResponse<InfoCollections>>, + meta_global: error::Result<Sync15ClientResponse<MetaGlobalRecord>>, + crypto_keys: error::Result<Sync15ClientResponse<IncomingEncryptedBso>>, + } + + impl SetupStorageClient for InMemoryClient { + fn fetch_info_configuration( + &self, + ) -> error::Result<Sync15ClientResponse<InfoConfiguration>> { + match &self.info_configuration { + Ok(client_response) => Ok(client_response.clone()), + Err(_) => Ok(Sync15ClientResponse::Error(ErrorResponse::ServerError { + status: 500, + route: "test/path".into(), + })), + } + } + + fn fetch_info_collections(&self) -> error::Result<Sync15ClientResponse<InfoCollections>> { + match &self.info_collections { + Ok(collections) => Ok(collections.clone()), + Err(_) => Ok(Sync15ClientResponse::Error(ErrorResponse::ServerError { + status: 500, + route: "test/path".into(), + })), + } + } + + fn fetch_meta_global(&self) -> error::Result<Sync15ClientResponse<MetaGlobalRecord>> { + match &self.meta_global { + Ok(global) => Ok(global.clone()), + // TODO(lina): Special handling for 404s, we want to ensure we + // handle missing keys and other server errors correctly. + Err(_) => Ok(Sync15ClientResponse::Error(ErrorResponse::ServerError { + status: 500, + route: "test/path".into(), + })), + } + } + + fn put_meta_global( + &self, + xius: ServerTimestamp, + global: &MetaGlobalRecord, + ) -> error::Result<ServerTimestamp> { + // Ensure that the meta/global record we uploaded is "fixed up" + assert!(DEFAULT_ENGINES + .iter() + .filter(|e| e.0 != "logins") + .all(|&(k, _v)| global.engines.contains_key(k))); + assert!(!global.engines.contains_key("logins")); + assert_eq!(global.declined, vec!["logins".to_string()]); + // return a different timestamp. + Ok(ServerTimestamp(xius.0 + 1)) + } + + fn fetch_crypto_keys(&self) -> error::Result<Sync15ClientResponse<IncomingEncryptedBso>> { + match &self.crypto_keys { + Ok(Sync15ClientResponse::Success { + status, + record, + last_modified, + route, + }) => Ok(Sync15ClientResponse::Success { + status: *status, + record: IncomingEncryptedBso::new( + record.envelope.clone(), + record.payload.clone(), + ), + last_modified: *last_modified, + route: route.clone(), + }), + // TODO(lina): Same as above, for 404s. + _ => Ok(Sync15ClientResponse::Error(ErrorResponse::ServerError { + status: 500, + route: "test/path".into(), + })), + } + } + + fn put_crypto_keys( + &self, + xius: ServerTimestamp, + _keys: &OutgoingEncryptedBso, + ) -> error::Result<()> { + assert_eq!(xius, ServerTimestamp(888_800)); + Err(ErrorKind::StorageHttpError(ErrorResponse::ServerError { + status: 500, + route: "crypto/keys".to_string(), + })) + } + + fn wipe_all_remote(&self) -> error::Result<()> { + Ok(()) + } + } + + #[allow(clippy::unnecessary_wraps)] + fn mocked_success_ts<T>(t: T, ts: i64) -> error::Result<Sync15ClientResponse<T>> { + Ok(Sync15ClientResponse::Success { + status: 200, + record: t, + last_modified: ServerTimestamp(ts), + route: "test/path".into(), + }) + } + + fn mocked_success<T>(t: T) -> error::Result<Sync15ClientResponse<T>> { + mocked_success_ts(t, 0) + } + + fn mocked_success_keys( + keys: CollectionKeys, + root_key: &KeyBundle, + ) -> error::Result<Sync15ClientResponse<IncomingEncryptedBso>> { + let timestamp = keys.timestamp; + let payload = keys.to_encrypted_payload(root_key).unwrap(); + let bso = IncomingEncryptedBso::new( + IncomingEnvelope { + id: Guid::new("keys"), + modified: timestamp, + sortindex: None, + ttl: None, + }, + payload, + ); + Ok(Sync15ClientResponse::Success { + status: 200, + record: bso, + last_modified: timestamp, + route: "test/path".into(), + }) + } + + #[test] + fn test_state_machine_ready_from_empty() { + let _ = env_logger::try_init(); + let root_key = KeyBundle::new_random().unwrap(); + let keys = CollectionKeys { + timestamp: ServerTimestamp(123_400), + default: KeyBundle::new_random().unwrap(), + collections: HashMap::new(), + }; + let mg = MetaGlobalRecord { + sync_id: "syncIDAAAAAA".into(), + storage_version: 5usize, + engines: vec![( + "bookmarks", + MetaGlobalEngine { + version: 1usize, + sync_id: "syncIDBBBBBB".into(), + }, + )] + .into_iter() + .map(|(key, value)| (key.to_owned(), value)) + .collect(), + // We ensure that the record we upload doesn't have a logins record. + declined: vec!["logins".to_string()], + }; + let client = InMemoryClient { + info_configuration: mocked_success(InfoConfiguration::default()), + info_collections: mocked_success(InfoCollections::new( + vec![("meta", 123_456), ("crypto", 145_000)] + .into_iter() + .map(|(key, value)| (key.to_owned(), ServerTimestamp(value))) + .collect(), + )), + meta_global: mocked_success_ts(mg, 999_000), + crypto_keys: mocked_success_keys(keys, &root_key), + }; + let mut pgs = PersistedGlobalState::V2 { declined: None }; + + let mut state_machine = + SetupStateMachine::for_full_sync(&client, &root_key, &mut pgs, None, &NeverInterrupts); + assert!( + state_machine.run_to_ready(None).is_ok(), + "Should drive state machine to ready" + ); + assert_eq!( + state_machine.sequence, + vec![ + "Initial", + "InitialWithConfig", + "InitialWithInfo", + "InitialWithMetaGlobal", + "Ready", + ], + "Should cycle through all states" + ); + } + + #[test] + fn test_from_previous_state_declined() { + let _ = env_logger::try_init(); + // The state-machine sequence where we didn't use the previous state + // (ie, where the state machine restarted) + let sm_seq_restarted = vec![ + "WithPreviousState", + "InitialWithConfig", + "InitialWithInfo", + "InitialWithMetaGlobal", + "Ready", + ]; + // The state-machine sequence where we used the previous state. + let sm_seq_used_previous = vec!["WithPreviousState", "Ready"]; + + // do the actual test. + fn do_test( + client: &dyn SetupStorageClient, + root_key: &KeyBundle, + pgs: &mut PersistedGlobalState, + engine_updates: Option<&HashMap<String, bool>>, + old_state: GlobalState, + expected_states: &[&str], + ) { + let mut state_machine = SetupStateMachine::for_full_sync( + client, + root_key, + pgs, + engine_updates, + &NeverInterrupts, + ); + assert!( + state_machine.run_to_ready(Some(old_state)).is_ok(), + "Should drive state machine to ready" + ); + assert_eq!(state_machine.sequence, expected_states); + } + + // and all the complicated setup... + let ts_metaglobal = 123_456; + let ts_keys = 145_000; + let root_key = KeyBundle::new_random().unwrap(); + let keys = CollectionKeys { + timestamp: ServerTimestamp(ts_keys + 1), + default: KeyBundle::new_random().unwrap(), + collections: HashMap::new(), + }; + let mg = MetaGlobalRecord { + sync_id: "syncIDAAAAAA".into(), + storage_version: 5usize, + engines: vec![( + "bookmarks", + MetaGlobalEngine { + version: 1usize, + sync_id: "syncIDBBBBBB".into(), + }, + )] + .into_iter() + .map(|(key, value)| (key.to_owned(), value)) + .collect(), + // We ensure that the record we upload doesn't have a logins record. + declined: vec!["logins".to_string()], + }; + let collections = InfoCollections::new( + vec![("meta", ts_metaglobal), ("crypto", ts_keys)] + .into_iter() + .map(|(key, value)| (key.to_owned(), ServerTimestamp(value))) + .collect(), + ); + let client = InMemoryClient { + info_configuration: mocked_success(InfoConfiguration::default()), + info_collections: mocked_success(collections.clone()), + meta_global: mocked_success_ts(mg.clone(), ts_metaglobal), + crypto_keys: mocked_success_keys(keys.clone(), &root_key), + }; + + // First a test where the "previous" global state is OK to reuse. + { + let mut pgs = PersistedGlobalState::V2 { declined: None }; + // A "previous" global state. + let old_state = GlobalState { + config: InfoConfiguration::default(), + collections: collections.clone(), + global: mg.clone(), + global_timestamp: ServerTimestamp(ts_metaglobal), + keys: keys + .to_encrypted_payload(&root_key) + .expect("should always work in this test"), + keys_timestamp: ServerTimestamp(ts_keys), + }; + do_test( + &client, + &root_key, + &mut pgs, + None, + old_state, + &sm_seq_used_previous, + ); + } + + // Now where the meta/global record on the server is later. + { + let mut pgs = PersistedGlobalState::V2 { declined: None }; + // A "previous" global state. + let old_state = GlobalState { + config: InfoConfiguration::default(), + collections: collections.clone(), + global: mg.clone(), + global_timestamp: ServerTimestamp(999_999), + keys: keys + .to_encrypted_payload(&root_key) + .expect("should always work in this test"), + keys_timestamp: ServerTimestamp(ts_keys), + }; + do_test( + &client, + &root_key, + &mut pgs, + None, + old_state, + &sm_seq_restarted, + ); + } + + // Where keys on the server is later. + { + let mut pgs = PersistedGlobalState::V2 { declined: None }; + // A "previous" global state. + let old_state = GlobalState { + config: InfoConfiguration::default(), + collections: collections.clone(), + global: mg.clone(), + global_timestamp: ServerTimestamp(ts_metaglobal), + keys: keys + .to_encrypted_payload(&root_key) + .expect("should always work in this test"), + keys_timestamp: ServerTimestamp(999_999), + }; + do_test( + &client, + &root_key, + &mut pgs, + None, + old_state, + &sm_seq_restarted, + ); + } + + // Where there are engine-state changes. + { + let mut pgs = PersistedGlobalState::V2 { declined: None }; + // A "previous" global state. + let old_state = GlobalState { + config: InfoConfiguration::default(), + collections, + global: mg, + global_timestamp: ServerTimestamp(ts_metaglobal), + keys: keys + .to_encrypted_payload(&root_key) + .expect("should always work in this test"), + keys_timestamp: ServerTimestamp(ts_keys), + }; + let mut engine_updates = HashMap::<String, bool>::new(); + engine_updates.insert("logins".to_string(), false); + do_test( + &client, + &root_key, + &mut pgs, + Some(&engine_updates), + old_state, + &sm_seq_restarted, + ); + let declined = match pgs { + PersistedGlobalState::V2 { declined: d } => d, + }; + // and check we now consider logins as declined. + assert_eq!(declined, Some(vec!["logins".to_string()])); + } + } + + fn string_set(s: &[&str]) -> HashSet<String> { + s.iter().map(ToString::to_string).collect() + } + fn string_map<T: Clone>(s: &[(&str, T)]) -> HashMap<String, T> { + s.iter().map(|v| (v.0.to_string(), v.1.clone())).collect() + } + #[test] + fn test_engine_states() { + assert_eq!( + compute_engine_states(EngineStateInput { + local_declined: string_set(&["foo", "bar"]), + remote: None, + user_changes: Default::default(), + }), + EngineStateOutput { + declined: string_set(&["foo", "bar"]), + // No wipes, no resets + changes_needed: Default::default(), + } + ); + assert_eq!( + compute_engine_states(EngineStateInput { + local_declined: string_set(&["foo", "bar"]), + remote: Some(RemoteEngineState { + declined: string_set(&["foo"]), + info_collections: string_set(&["bar"]) + }), + user_changes: Default::default(), + }), + EngineStateOutput { + // Now we have `foo`. + declined: string_set(&["foo"]), + // No wipes, no resets, should just be a local update. + changes_needed: Default::default(), + } + ); + assert_eq!( + compute_engine_states(EngineStateInput { + local_declined: string_set(&["foo", "bar"]), + remote: Some(RemoteEngineState { + declined: string_set(&["foo", "bar", "quux"]), + info_collections: string_set(&[]) + }), + user_changes: Default::default(), + }), + EngineStateOutput { + // Now we have `foo`. + declined: string_set(&["foo", "bar", "quux"]), + changes_needed: EngineChangesNeeded { + // Should reset `quux`. + local_resets: string_set(&["quux"]), + // No wipes, though. + remote_wipes: string_set(&[]), + } + } + ); + assert_eq!( + compute_engine_states(EngineStateInput { + local_declined: string_set(&["bar", "baz"]), + remote: Some(RemoteEngineState { + declined: string_set(&["bar", "baz",]), + info_collections: string_set(&["quux"]) + }), + // Change a declined engine to undeclined. + user_changes: string_map(&[("bar", true)]), + }), + EngineStateOutput { + declined: string_set(&["baz"]), + // No wipes, just undecline it. + changes_needed: Default::default() + } + ); + assert_eq!( + compute_engine_states(EngineStateInput { + local_declined: string_set(&["bar", "baz"]), + remote: Some(RemoteEngineState { + declined: string_set(&["bar", "baz"]), + info_collections: string_set(&["foo"]) + }), + // Change an engine which exists remotely to declined. + user_changes: string_map(&[("foo", false)]), + }), + EngineStateOutput { + declined: string_set(&["baz", "bar", "foo"]), + // No wipes, just undecline it. + changes_needed: EngineChangesNeeded { + // Should reset our local foo + local_resets: string_set(&["foo"]), + // And wipe the server. + remote_wipes: string_set(&["foo"]), + } + } + ); + } +} diff --git a/third_party/rust/sync15/src/client/status.rs b/third_party/rust/sync15/src/client/status.rs new file mode 100644 index 0000000000..407efeec12 --- /dev/null +++ b/third_party/rust/sync15/src/client/status.rs @@ -0,0 +1,106 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use crate::error::{Error, ErrorResponse}; +use crate::telemetry::SyncTelemetryPing; +use std::collections::HashMap; +use std::time::{Duration, SystemTime}; + +/// The general status of sync - should probably be moved to the "sync manager" +/// once we have one! +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ServiceStatus { + /// Everything is fine. + Ok, + /// Some general network issue. + NetworkError, + /// Some apparent issue with the servers. + ServiceError, + /// Some external FxA action needs to be taken. + AuthenticationError, + /// We declined to do anything for backoff or rate-limiting reasons. + BackedOff, + /// We were interrupted. + Interrupted, + /// Something else - you need to check the logs for more details. May + /// or may not be transient, we really don't know. + OtherError, +} + +impl ServiceStatus { + // This is a bit naive and probably will not survive in this form in the + // SyncManager - eg, we'll want to handle backoff etc. + pub fn from_err(err: &Error) -> ServiceStatus { + match err { + // HTTP based errors. + Error::TokenserverHttpError(status) => { + // bit of a shame the tokenserver is different to storage... + if *status == 401 { + ServiceStatus::AuthenticationError + } else { + ServiceStatus::ServiceError + } + } + // BackoffError is also from the tokenserver. + Error::BackoffError(_) => ServiceStatus::ServiceError, + Error::StorageHttpError(ref e) => match e { + ErrorResponse::Unauthorized { .. } => ServiceStatus::AuthenticationError, + _ => ServiceStatus::ServiceError, + }, + + // Network errors. + Error::RequestError(_) | Error::UnexpectedStatus(_) | Error::HawkError(_) => { + ServiceStatus::NetworkError + } + + Error::Interrupted(_) => ServiceStatus::Interrupted, + _ => ServiceStatus::OtherError, + } + } +} + +/// The result of a sync request. This too is from the "sync manager", but only +/// has a fraction of the things it will have when we actually build that. +#[derive(Debug)] +pub struct SyncResult { + /// The general health. + pub service_status: ServiceStatus, + + /// The set of declined engines, if we know them. + pub declined: Option<Vec<String>>, + + /// The result of the sync. + pub result: Result<(), Error>, + + /// The result for each engine. + /// Note that we expect the `String` to be replaced with an enum later. + pub engine_results: HashMap<String, Result<(), Error>>, + + pub telemetry: SyncTelemetryPing, + + pub next_sync_after: Option<std::time::SystemTime>, +} + +// If `r` has a BackoffError, then returns the later backoff value. +fn advance_backoff(cur_best: SystemTime, r: &Result<(), Error>) -> SystemTime { + if let Err(e) = r { + if let Some(time) = e.get_backoff() { + return std::cmp::max(time, cur_best); + } + } + cur_best +} + +impl SyncResult { + pub(crate) fn set_sync_after(&mut self, backoff_duration: Duration) { + let now = SystemTime::now(); + let toplevel = advance_backoff(now + backoff_duration, &self.result); + let sync_after = self.engine_results.values().fold(toplevel, advance_backoff); + if sync_after <= now { + self.next_sync_after = None; + } else { + self.next_sync_after = Some(sync_after); + } + } +} diff --git a/third_party/rust/sync15/src/client/storage_client.rs b/third_party/rust/sync15/src/client/storage_client.rs new file mode 100644 index 0000000000..7b8f650d24 --- /dev/null +++ b/third_party/rust/sync15/src/client/storage_client.rs @@ -0,0 +1,605 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::request::{ + BatchPoster, InfoCollections, InfoConfiguration, PostQueue, PostResponse, PostResponseHandler, +}; +use super::token; +use crate::bso::{IncomingBso, IncomingEncryptedBso, OutgoingBso, OutgoingEncryptedBso}; +use crate::engine::{CollectionPost, CollectionRequest}; +use crate::error::{self, Error, ErrorResponse}; +use crate::record_types::MetaGlobalRecord; +use crate::{CollectionName, Guid, ServerTimestamp}; +use serde_json::Value; +use std::str::FromStr; +use std::sync::atomic::{AtomicU32, Ordering}; +use url::Url; +use viaduct::{ + header_names::{self, AUTHORIZATION}, + Method, Request, Response, +}; + +/// A response from a GET request on a Sync15StorageClient, encapsulating all +/// the variants users of this client needs to care about. +#[derive(Debug, Clone)] +pub enum Sync15ClientResponse<T> { + Success { + status: u16, + record: T, + last_modified: ServerTimestamp, + route: String, + }, + Error(ErrorResponse), +} + +fn parse_seconds(seconds_str: &str) -> Option<u32> { + let secs = seconds_str.parse::<f64>().ok()?.ceil(); + // Note: u32 doesn't impl TryFrom<f64> :( + if !secs.is_finite() || secs < 0.0 || secs > f64::from(u32::max_value()) { + log::warn!("invalid backoff value: {}", secs); + None + } else { + Some(secs as u32) + } +} + +impl<T> Sync15ClientResponse<T> { + pub fn from_response(resp: Response, backoff_listener: &BackoffListener) -> error::Result<Self> + where + for<'a> T: serde::de::Deserialize<'a>, + { + let route: String = resp.url.path().into(); + // Android seems to respect retry_after even on success requests, so we + // will too if it's present. This also lets us handle both backoff-like + // properties in the same place. + let retry_after = resp + .headers + .get(header_names::RETRY_AFTER) + .and_then(parse_seconds); + + let backoff = resp + .headers + .get(header_names::X_WEAVE_BACKOFF) + .and_then(parse_seconds); + + if let Some(b) = backoff { + backoff_listener.note_backoff(b); + } + if let Some(ra) = retry_after { + backoff_listener.note_retry_after(ra); + } + + Ok(if resp.is_success() { + let record: T = resp.json()?; + let last_modified = resp + .headers + .get(header_names::X_LAST_MODIFIED) + .and_then(|s| ServerTimestamp::from_str(s).ok()) + .ok_or(Error::MissingServerTimestamp)?; + log::info!( + "Successful request to \"{}\", incoming x-last-modified={:?}", + route, + last_modified + ); + + Sync15ClientResponse::Success { + status: resp.status, + record, + last_modified, + route, + } + } else { + let status = resp.status; + log::info!("Request \"{}\" was an error (status={})", route, status); + match status { + 404 => Sync15ClientResponse::Error(ErrorResponse::NotFound { route }), + 401 => Sync15ClientResponse::Error(ErrorResponse::Unauthorized { route }), + 412 => Sync15ClientResponse::Error(ErrorResponse::PreconditionFailed { route }), + 500..=600 => { + Sync15ClientResponse::Error(ErrorResponse::ServerError { route, status }) + } + _ => Sync15ClientResponse::Error(ErrorResponse::RequestFailed { route, status }), + } + }) + } + + pub fn create_storage_error(self) -> Error { + let inner = match self { + Sync15ClientResponse::Success { status, route, .. } => { + // This should never happen as callers are expected to have + // already special-cased this response, so warn if it does. + // (or maybe we could panic?) + log::warn!("Converting success response into an error"); + ErrorResponse::RequestFailed { status, route } + } + Sync15ClientResponse::Error(e) => e, + }; + Error::StorageHttpError(inner) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Sync15StorageClientInit { + pub key_id: String, + pub access_token: String, + pub tokenserver_url: Url, +} + +/// A trait containing the methods required to run through the setup state +/// machine. This is factored out into a separate trait to make mocking +/// easier. +pub trait SetupStorageClient { + fn fetch_info_configuration(&self) -> error::Result<Sync15ClientResponse<InfoConfiguration>>; + fn fetch_info_collections(&self) -> error::Result<Sync15ClientResponse<InfoCollections>>; + fn fetch_meta_global(&self) -> error::Result<Sync15ClientResponse<MetaGlobalRecord>>; + fn fetch_crypto_keys(&self) -> error::Result<Sync15ClientResponse<IncomingEncryptedBso>>; + + fn put_meta_global( + &self, + xius: ServerTimestamp, + global: &MetaGlobalRecord, + ) -> error::Result<ServerTimestamp>; + fn put_crypto_keys( + &self, + xius: ServerTimestamp, + keys: &OutgoingEncryptedBso, + ) -> error::Result<()>; + fn wipe_all_remote(&self) -> error::Result<()>; +} + +#[derive(Debug, Default)] +pub struct BackoffState { + pub backoff_secs: AtomicU32, + pub retry_after_secs: AtomicU32, +} + +pub(crate) type BackoffListener = std::sync::Arc<BackoffState>; + +pub(crate) fn new_backoff_listener() -> BackoffListener { + std::sync::Arc::new(BackoffState::default()) +} + +impl BackoffState { + pub fn note_backoff(&self, noted: u32) { + super::util::atomic_update_max(&self.backoff_secs, noted) + } + + pub fn note_retry_after(&self, noted: u32) { + super::util::atomic_update_max(&self.retry_after_secs, noted) + } + + pub fn get_backoff_secs(&self) -> u32 { + self.backoff_secs.load(Ordering::SeqCst) + } + + pub fn get_retry_after_secs(&self) -> u32 { + self.retry_after_secs.load(Ordering::SeqCst) + } + + pub fn get_required_wait(&self, ignore_soft_backoff: bool) -> Option<std::time::Duration> { + let bo = self.get_backoff_secs(); + let ra = self.get_retry_after_secs(); + let secs = u64::from(if ignore_soft_backoff { ra } else { bo.max(ra) }); + if secs > 0 { + Some(std::time::Duration::from_secs(secs)) + } else { + None + } + } + + pub fn reset(&self) { + self.backoff_secs.store(0, Ordering::SeqCst); + self.retry_after_secs.store(0, Ordering::SeqCst); + } +} + +// meta/global is a clear-text Bso (ie, there's a String `payload` which has a MetaGlobalRecord) +// We don't use the 'content' helpers here because we want json errors to be fatal here +// (ie, we don't need tombstones and can't just skip a malformed record) +type IncMetaGlobalBso = IncomingBso; +type OutMetaGlobalBso = OutgoingBso; + +#[derive(Debug)] +pub struct Sync15StorageClient { + tsc: token::TokenProvider, + pub(crate) backoff: BackoffListener, +} + +impl SetupStorageClient for Sync15StorageClient { + fn fetch_info_configuration(&self) -> error::Result<Sync15ClientResponse<InfoConfiguration>> { + self.relative_storage_request(Method::Get, "info/configuration") + } + + fn fetch_info_collections(&self) -> error::Result<Sync15ClientResponse<InfoCollections>> { + self.relative_storage_request(Method::Get, "info/collections") + } + + fn fetch_meta_global(&self) -> error::Result<Sync15ClientResponse<MetaGlobalRecord>> { + let got: Sync15ClientResponse<IncMetaGlobalBso> = + self.relative_storage_request(Method::Get, "storage/meta/global")?; + Ok(match got { + Sync15ClientResponse::Success { + record, + last_modified, + route, + status, + } => { + log::debug!( + "Got meta global with modified = {}; last-modified = {}", + record.envelope.modified, + last_modified + ); + Sync15ClientResponse::Success { + record: serde_json::from_str(&record.payload)?, + last_modified, + route, + status, + } + } + Sync15ClientResponse::Error(e) => Sync15ClientResponse::Error(e), + }) + } + + fn fetch_crypto_keys(&self) -> error::Result<Sync15ClientResponse<IncomingEncryptedBso>> { + self.relative_storage_request(Method::Get, "storage/crypto/keys") + } + + fn put_meta_global( + &self, + xius: ServerTimestamp, + global: &MetaGlobalRecord, + ) -> error::Result<ServerTimestamp> { + let bso = OutMetaGlobalBso::new(Guid::new("global").into(), global)?; + self.put("storage/meta/global", xius, &bso) + } + + fn put_crypto_keys( + &self, + xius: ServerTimestamp, + keys: &OutgoingEncryptedBso, + ) -> error::Result<()> { + self.put("storage/crypto/keys", xius, keys)?; + Ok(()) + } + + fn wipe_all_remote(&self) -> error::Result<()> { + let s = self.tsc.api_endpoint()?; + let url = Url::parse(&s)?; + + let req = self.build_request(Method::Delete, url)?; + match self.exec_request::<Value>(req, false) { + Ok(Sync15ClientResponse::Error(ErrorResponse::NotFound { .. })) + | Ok(Sync15ClientResponse::Success { .. }) => Ok(()), + Ok(resp) => Err(resp.create_storage_error()), + Err(e) => Err(e), + } + } +} + +impl Sync15StorageClient { + pub fn new(init_params: Sync15StorageClientInit) -> error::Result<Sync15StorageClient> { + rc_crypto::ensure_initialized(); + let tsc = token::TokenProvider::new( + init_params.tokenserver_url, + init_params.access_token, + init_params.key_id, + )?; + Ok(Sync15StorageClient { + tsc, + backoff: new_backoff_listener(), + }) + } + + pub fn get_encrypted_records( + &self, + collection_request: CollectionRequest, + ) -> error::Result<Sync15ClientResponse<Vec<IncomingEncryptedBso>>> { + self.collection_request(Method::Get, collection_request) + } + + #[inline] + fn authorized(&self, req: Request) -> error::Result<Request> { + let hawk_header_value = self.tsc.authorization(&req)?; + Ok(req.header(AUTHORIZATION, hawk_header_value)?) + } + + // TODO: probably want a builder-like API to do collection requests (e.g. something + // that occupies roughly the same conceptual role as the Collection class in desktop) + fn build_request(&self, method: Method, url: Url) -> error::Result<Request> { + self.authorized(Request::new(method, url).header(header_names::ACCEPT, "application/json")?) + } + + fn relative_storage_request<P, T>( + &self, + method: Method, + relative_path: P, + ) -> error::Result<Sync15ClientResponse<T>> + where + P: AsRef<str>, + for<'a> T: serde::de::Deserialize<'a>, + { + let s = self.tsc.api_endpoint()? + "/"; + let url = Url::parse(&s)?.join(relative_path.as_ref())?; + self.exec_request(self.build_request(method, url)?, false) + } + + fn exec_request<T>( + &self, + req: Request, + require_success: bool, + ) -> error::Result<Sync15ClientResponse<T>> + where + for<'a> T: serde::de::Deserialize<'a>, + { + log::trace!( + "request: {} {} ({:?})", + req.method, + req.url.path(), + req.url.query() + ); + let resp = req.send()?; + + let result = Sync15ClientResponse::from_response(resp, &self.backoff)?; + match result { + Sync15ClientResponse::Success { .. } => Ok(result), + _ => { + if require_success { + Err(result.create_storage_error()) + } else { + Ok(result) + } + } + } + } + + fn collection_request<T>( + &self, + method: Method, + r: CollectionRequest, + ) -> error::Result<Sync15ClientResponse<T>> + where + for<'a> T: serde::de::Deserialize<'a>, + { + let url = build_collection_request_url(Url::parse(&self.tsc.api_endpoint()?)?, r)?; + self.exec_request(self.build_request(method, url)?, false) + } + + pub fn new_post_queue<'a, F: PostResponseHandler>( + &'a self, + coll: &'a CollectionName, + config: &InfoConfiguration, + ts: ServerTimestamp, + on_response: F, + ) -> error::Result<PostQueue<PostWrapper<'a>, F>> { + let pw = PostWrapper { client: self, coll }; + Ok(PostQueue::new(config, ts, pw, on_response)) + } + + fn put<P, B>( + &self, + relative_path: P, + xius: ServerTimestamp, + body: &B, + ) -> error::Result<ServerTimestamp> + where + P: AsRef<str>, + B: serde::ser::Serialize, + { + let s = self.tsc.api_endpoint()? + "/"; + let url = Url::parse(&s)?.join(relative_path.as_ref())?; + + let req = self + .build_request(Method::Put, url)? + .json(body) + .header(header_names::X_IF_UNMODIFIED_SINCE, format!("{}", xius))?; + + let resp = self.exec_request::<Value>(req, true)?; + // Note: we pass `true` for require_success, so this panic never happens. + if let Sync15ClientResponse::Success { last_modified, .. } = resp { + Ok(last_modified) + } else { + unreachable!("Error returned exec_request when `require_success` was true"); + } + } + + pub fn hashed_uid(&self) -> error::Result<String> { + self.tsc.hashed_uid() + } + + pub(crate) fn wipe_remote_engine(&self, engine: &str) -> error::Result<()> { + let s = self.tsc.api_endpoint()? + "/"; + let url = Url::parse(&s)?.join(&format!("storage/{}", engine))?; + log::debug!("Wiping: {:?}", url); + let req = self.build_request(Method::Delete, url)?; + match self.exec_request::<Value>(req, false) { + Ok(Sync15ClientResponse::Error(ErrorResponse::NotFound { .. })) + | Ok(Sync15ClientResponse::Success { .. }) => Ok(()), + Ok(resp) => Err(resp.create_storage_error()), + Err(e) => Err(e), + } + } +} + +pub struct PostWrapper<'a> { + client: &'a Sync15StorageClient, + coll: &'a CollectionName, +} + +impl<'a> BatchPoster for PostWrapper<'a> { + fn post<T, O>( + &self, + bytes: Vec<u8>, + xius: ServerTimestamp, + batch: Option<String>, + commit: bool, + _: &PostQueue<T, O>, + ) -> error::Result<PostResponse> { + let r = CollectionPost::new(self.coll.clone()) + .batch(batch) + .commit(commit); + let url = build_collection_post_url(Url::parse(&self.client.tsc.api_endpoint()?)?, r)?; + + let req = self + .client + .build_request(Method::Post, url)? + .header(header_names::CONTENT_TYPE, "application/json")? + .header(header_names::X_IF_UNMODIFIED_SINCE, format!("{}", xius))? + .body(bytes); + self.client.exec_request(req, false) + } +} + +fn build_collection_url(mut base_url: Url, collection: CollectionName) -> error::Result<Url> { + base_url + .path_segments_mut() + .map_err(|_| Error::UnacceptableUrl("Storage server URL is not a base".to_string()))? + .extend(&["storage", &collection]); + + // This is strange but just accessing query_pairs_mut makes you have + // a trailing question mark on your url. I don't think anything bad + // would happen here, but I don't know, and also, it looks dumb so + // I'd rather not have it. + if base_url.query() == Some("") { + base_url.set_query(None); + } + Ok(base_url) +} + +fn build_collection_request_url(mut base_url: Url, r: CollectionRequest) -> error::Result<Url> { + let mut pairs = base_url.query_pairs_mut(); + if r.full { + pairs.append_pair("full", "1"); + } + if let Some(ids) = &r.ids { + // Most ids are 12 characters, and we comma separate them, so 13. + let mut buf = String::with_capacity(ids.len() * 13); + for (i, id) in ids.iter().enumerate() { + if i > 0 { + buf.push(','); + } + buf.push_str(id.as_str()); + } + pairs.append_pair("ids", &buf); + } + if let Some(ts) = r.older { + pairs.append_pair("older", &ts.to_string()); + } + if let Some(ts) = r.newer { + pairs.append_pair("newer", &ts.to_string()); + } + if let Some(l) = r.limit { + pairs.append_pair("sort", l.order.as_str()); + pairs.append_pair("limit", &l.num.to_string()); + } + pairs.finish(); + drop(pairs); + build_collection_url(base_url, r.collection) +} + +#[cfg(feature = "sync-client")] +fn build_collection_post_url(mut base_url: Url, r: CollectionPost) -> error::Result<Url> { + let mut pairs = base_url.query_pairs_mut(); + if let Some(batch) = &r.batch { + pairs.append_pair("batch", batch); + } + if r.commit { + pairs.append_pair("commit", "true"); + } + pairs.finish(); + drop(pairs); + build_collection_url(base_url, r.collection) +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_send() { + fn ensure_send<T: Send>() {} + // Compile will fail if not send. + ensure_send::<Sync15StorageClient>(); + } + + #[test] + fn test_parse_seconds() { + assert_eq!(parse_seconds("1"), Some(1)); + assert_eq!(parse_seconds("1.4"), Some(2)); + assert_eq!(parse_seconds("1.5"), Some(2)); + assert_eq!(parse_seconds("3600.0"), Some(3600)); + assert_eq!(parse_seconds("3600"), Some(3600)); + assert_eq!(parse_seconds("-1"), None); + assert_eq!(parse_seconds("inf"), None); + assert_eq!(parse_seconds("-inf"), None); + assert_eq!(parse_seconds("one-thousand"), None); + assert_eq!(parse_seconds("4294967295"), Some(4294967295)); + assert_eq!(parse_seconds("4294967296"), None); + } + + #[test] + fn test_query_building() { + use crate::engine::RequestOrder; + let base = Url::parse("https://example.com/sync").unwrap(); + + let empty = + build_collection_request_url(base.clone(), CollectionRequest::new("foo".into())) + .unwrap(); + assert_eq!(empty.as_str(), "https://example.com/sync/storage/foo"); + + let idreq = build_collection_request_url( + base.clone(), + CollectionRequest::new("wutang".into()) + .full() + .ids(&["rza", "gza"]), + ) + .unwrap(); + assert_eq!( + idreq.as_str(), + "https://example.com/sync/storage/wutang?full=1&ids=rza%2Cgza" + ); + + let complex = build_collection_request_url( + base, + CollectionRequest::new("specific".into()) + .full() + .limit(10, RequestOrder::Oldest) + .older_than(ServerTimestamp(9_876_540)) + .newer_than(ServerTimestamp(1_234_560)), + ) + .unwrap(); + assert_eq!(complex.as_str(), + "https://example.com/sync/storage/specific?full=1&older=9876.54&newer=1234.56&sort=oldest&limit=10"); + } + + #[cfg(feature = "sync-client")] + #[test] + fn test_post_query_building() { + let base = Url::parse("https://example.com/sync").unwrap(); + + let empty = + build_collection_post_url(base.clone(), CollectionPost::new("foo".into())).unwrap(); + assert_eq!(empty.as_str(), "https://example.com/sync/storage/foo"); + let batch_start = build_collection_post_url( + base.clone(), + CollectionPost::new("bar".into()) + .batch(Some("true".into())) + .commit(false), + ) + .unwrap(); + assert_eq!( + batch_start.as_str(), + "https://example.com/sync/storage/bar?batch=true" + ); + let batch_commit = build_collection_post_url( + base, + CollectionPost::new("asdf".into()) + .batch(Some("1234abc".into())) + .commit(true), + ) + .unwrap(); + assert_eq!( + batch_commit.as_str(), + "https://example.com/sync/storage/asdf?batch=1234abc&commit=true" + ); + } +} diff --git a/third_party/rust/sync15/src/client/sync.rs b/third_party/rust/sync15/src/client/sync.rs new file mode 100644 index 0000000000..b25220032f --- /dev/null +++ b/third_party/rust/sync15/src/client/sync.rs @@ -0,0 +1,114 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::{CollectionUpdate, GlobalState, LocalCollStateMachine, Sync15StorageClient}; +use crate::clients_engine; +use crate::engine::SyncEngine; +use crate::error::Error; +use crate::telemetry; +use crate::KeyBundle; +use interrupt_support::Interruptee; + +#[allow(clippy::too_many_arguments)] +pub fn synchronize_with_clients_engine( + client: &Sync15StorageClient, + global_state: &GlobalState, + root_sync_key: &KeyBundle, + clients: Option<&clients_engine::Engine<'_>>, + engine: &dyn SyncEngine, + fully_atomic: bool, + telem_engine: &mut telemetry::Engine, + interruptee: &dyn Interruptee, +) -> Result<(), Error> { + let collection = engine.collection_name(); + log::info!("Syncing collection {}", collection); + + // our global state machine is ready - get the collection machine going. + let coll_state = match LocalCollStateMachine::get_state(engine, global_state, root_sync_key)? { + Some(coll_state) => coll_state, + None => { + // XXX - this is either "error" or "declined". + log::warn!( + "can't setup for the {} collection - hopefully it works later", + collection + ); + return Ok(()); + } + }; + + if let Some(clients) = clients { + engine.prepare_for_sync(&|| clients.get_client_data())?; + } + interruptee.err_if_interrupted()?; + // We assume an "engine" manages exactly one "collection" with the engine's name. + match engine.get_collection_request(coll_state.last_modified)? { + None => { + log::info!("skipping incoming for {} - not needed.", collection); + } + Some(collection_request) => { + // Ideally we would "batch" incoming records (eg, fetch just 1000 at a time) + // and ask the engine to "stage" them as they come in - but currently we just read + // them all in one request. + + // Doing this batching will involve specifying a "limit=" param and + // "x-if-unmodified-since" for each request, looking for an + // "X-Weave-Next-Offset header in the response and using that in subsequent + // requests. + // See https://mozilla-services.readthedocs.io/en/latest/storage/apis-1.5.html#syncstorage-paging + // + // But even if we had that, we need to deal with a 412 response on a subsequent batch, + // so we can't know if we've staged *every* record for that timestamp; the next + // sync must use an earlier one. + // + // For this reason, an engine can't really trust a server timestamp until the + // very end when we know we've staged them all. + let incoming = super::fetch_incoming(client, &coll_state, collection_request)?; + log::info!("Downloaded {} remote changes", incoming.len()); + engine.stage_incoming(incoming, telem_engine)?; + interruptee.err_if_interrupted()?; + } + }; + + // Should consider adding a new `fetch_outgoing()` and having `apply()` only apply. + // It *might* even make sense to only call `apply()` when something was staged, + // but that's not clear - see the discussion at + // https://github.com/mozilla/application-services/pull/5441/files/f36274f455a6299f10e7ce56b167882c369aa806#r1189267540 + log::info!("Applying changes"); + let outgoing = engine.apply(coll_state.last_modified, telem_engine)?; + interruptee.err_if_interrupted()?; + + // XXX - this upload strategy is buggy due to batching. With enough records, we will commit + // 2 batches on the server. If the second fails, we get an Err<> here, so can't tell the + // engine about the successful server batch commit. + // Most stuff below should be called per-batch rather than at the successful end of all + // batches, but that's not trivial. + log::info!("Uploading {} outgoing changes", outgoing.len()); + let upload_info = CollectionUpdate::new_from_changeset( + client, + &coll_state, + collection, + outgoing, + fully_atomic, + )? + .upload()?; + log::info!( + "Upload success ({} records success, {} records failed)", + upload_info.successful_ids.len(), + upload_info.failed_ids.len() + ); + + let mut telem_outgoing = telemetry::EngineOutgoing::new(); + telem_outgoing.sent(upload_info.successful_ids.len() + upload_info.failed_ids.len()); + telem_outgoing.failed(upload_info.failed_ids.len()); + telem_engine.outgoing(telem_outgoing); + + engine.set_uploaded(upload_info.modified_timestamp, upload_info.successful_ids)?; + + // The above should all be per-batch :( + + engine.sync_finished()?; + + log::info!("Sync finished!"); + Ok(()) +} diff --git a/third_party/rust/sync15/src/client/sync_multiple.rs b/third_party/rust/sync15/src/client/sync_multiple.rs new file mode 100644 index 0000000000..22727dc8a6 --- /dev/null +++ b/third_party/rust/sync15/src/client/sync_multiple.rs @@ -0,0 +1,491 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// This helps you perform a sync of multiple engines and helps you manage +// global and local state between syncs. + +use super::state::{EngineChangesNeeded, GlobalState, PersistedGlobalState, SetupStateMachine}; +use super::status::{ServiceStatus, SyncResult}; +use super::storage_client::{BackoffListener, Sync15StorageClient, Sync15StorageClientInit}; +use crate::clients_engine::{self, CommandProcessor, CLIENTS_TTL_REFRESH}; +use crate::engine::{EngineSyncAssociation, SyncEngine}; +use crate::error::Error; +use crate::telemetry; +use crate::KeyBundle; +use interrupt_support::Interruptee; +use std::collections::HashMap; +use std::result; +use std::time::{Duration, SystemTime}; + +/// Info about the client to use. We reuse the client unless +/// we discover the client_init has changed, in which case we re-create one. +#[derive(Debug)] +struct ClientInfo { + // the client_init used to create `client`. + client_init: Sync15StorageClientInit, + // the client (our tokenserver state machine state, and our http library's state) + client: Sync15StorageClient, +} + +impl ClientInfo { + fn new(ci: &Sync15StorageClientInit) -> Result<Self, Error> { + Ok(Self { + client_init: ci.clone(), + client: Sync15StorageClient::new(ci.clone())?, + }) + } +} + +/// Info we want callers to engine *in memory* for us so that subsequent +/// syncs are faster. This should never be persisted to storage as it holds +/// sensitive information, such as the sync decryption keys. +#[derive(Debug, Default)] +pub struct MemoryCachedState { + last_client_info: Option<ClientInfo>, + last_global_state: Option<GlobalState>, + // These are just engined in memory, as persisting an invalid value far in the + // future has the potential to break sync for good. + next_sync_after: Option<SystemTime>, + next_client_refresh_after: Option<SystemTime>, +} + +impl MemoryCachedState { + // Called we notice the cached state is stale. + pub fn clear_sensitive_info(&mut self) { + self.last_client_info = None; + self.last_global_state = None; + // Leave the backoff time, as there's no reason to think it's not still + // true. + } + pub fn get_next_sync_after(&self) -> Option<SystemTime> { + self.next_sync_after + } + pub fn should_refresh_client(&self) -> bool { + match self.next_client_refresh_after { + Some(t) => SystemTime::now() > t, + None => true, + } + } + pub fn note_client_refresh(&mut self) { + self.next_client_refresh_after = + Some(SystemTime::now() + Duration::from_secs(CLIENTS_TTL_REFRESH)); + } +} + +/// Sync multiple engines +/// * `engines` - The engines to sync +/// * `persisted_global_state` - The global state to use, or None if never +/// before provided. At the end of the sync, and even when the sync fails, +/// the value in this cell should be persisted to permanent storage and +/// provided next time the sync is called. +/// * `last_client_info` - The client state to use, or None if never before +/// provided. At the end of the sync, the value should be persisted +/// *in memory only* - it should not be persisted to disk. +/// * `storage_init` - Information about how the sync http client should be +/// configured. +/// * `root_sync_key` - The KeyBundle used for encryption. +/// +/// Returns a map, keyed by name and holding an error value - if any engine +/// fails, the sync will continue on to other engines, but the error will be +/// places in this map. The absence of a name in the map implies the engine +/// succeeded. +pub fn sync_multiple( + engines: &[&dyn SyncEngine], + persisted_global_state: &mut Option<String>, + mem_cached_state: &mut MemoryCachedState, + storage_init: &Sync15StorageClientInit, + root_sync_key: &KeyBundle, + interruptee: &dyn Interruptee, + req_info: Option<SyncRequestInfo<'_>>, +) -> SyncResult { + sync_multiple_with_command_processor( + None, + engines, + persisted_global_state, + mem_cached_state, + storage_init, + root_sync_key, + interruptee, + req_info, + ) +} + +/// Like `sync_multiple`, but specifies an optional command processor to handle +/// commands from the clients collection. This function is called by the sync +/// manager, which provides its own processor. +#[allow(clippy::too_many_arguments)] +pub fn sync_multiple_with_command_processor( + command_processor: Option<&dyn CommandProcessor>, + engines: &[&dyn SyncEngine], + persisted_global_state: &mut Option<String>, + mem_cached_state: &mut MemoryCachedState, + storage_init: &Sync15StorageClientInit, + root_sync_key: &KeyBundle, + interruptee: &dyn Interruptee, + req_info: Option<SyncRequestInfo<'_>>, +) -> SyncResult { + log::info!("Syncing {} engines", engines.len()); + let mut sync_result = SyncResult { + service_status: ServiceStatus::OtherError, + result: Ok(()), + declined: None, + next_sync_after: None, + engine_results: HashMap::with_capacity(engines.len()), + telemetry: telemetry::SyncTelemetryPing::new(), + }; + let backoff = super::storage_client::new_backoff_listener(); + let req_info = req_info.unwrap_or_default(); + let driver = SyncMultipleDriver { + command_processor, + engines, + storage_init, + interruptee, + engines_to_state_change: req_info.engines_to_state_change, + backoff: backoff.clone(), + root_sync_key, + result: &mut sync_result, + persisted_global_state, + mem_cached_state, + saw_auth_error: false, + ignore_soft_backoff: req_info.is_user_action, + }; + match driver.sync() { + Ok(()) => { + log::debug!( + "sync was successful, final status={:?}", + sync_result.service_status + ); + } + Err(e) => { + log::warn!( + "sync failed: {}, final status={:?}", + e, + sync_result.service_status, + ); + sync_result.result = Err(e); + } + } + // Respect `backoff` value when computing the next sync time even if we were + // ignoring it during the sync + sync_result.set_sync_after(backoff.get_required_wait(false).unwrap_or_default()); + mem_cached_state.next_sync_after = sync_result.next_sync_after; + log::trace!("Sync result: {:?}", sync_result); + sync_result +} + +/// This is essentially a bag of information that the sync manager knows, but +/// otherwise we won't. It should probably be rethought if it gains many more +/// fields. +#[derive(Debug, Default)] +pub struct SyncRequestInfo<'a> { + pub engines_to_state_change: Option<&'a HashMap<String, bool>>, + pub is_user_action: bool, +} + +// The sync multiple driver +struct SyncMultipleDriver<'info, 'res, 'pgs, 'mcs> { + command_processor: Option<&'info dyn CommandProcessor>, + engines: &'info [&'info dyn SyncEngine], + storage_init: &'info Sync15StorageClientInit, + root_sync_key: &'info KeyBundle, + interruptee: &'info dyn Interruptee, + backoff: BackoffListener, + engines_to_state_change: Option<&'info HashMap<String, bool>>, + result: &'res mut SyncResult, + persisted_global_state: &'pgs mut Option<String>, + mem_cached_state: &'mcs mut MemoryCachedState, + ignore_soft_backoff: bool, + saw_auth_error: bool, +} + +impl<'info, 'res, 'pgs, 'mcs> SyncMultipleDriver<'info, 'res, 'pgs, 'mcs> { + /// The actual worker for sync_multiple. + fn sync(mut self) -> result::Result<(), Error> { + log::info!("Loading/initializing persisted state"); + let mut pgs = self.prepare_persisted_state(); + + log::info!("Preparing client info"); + let client_info = self.prepare_client_info()?; + + if self.was_interrupted() { + return Ok(()); + } + + log::info!("Entering sync state machine"); + // Advance the state machine to the point where it can perform a full + // sync. This may involve uploading meta/global, crypto/keys etc. + let mut global_state = self.run_state_machine(&client_info, &mut pgs)?; + + if self.was_interrupted() { + return Ok(()); + } + + // Set the service status to OK here - we may adjust it based on an individual + // engine failing. + self.result.service_status = ServiceStatus::Ok; + + let clients_engine = if let Some(command_processor) = self.command_processor { + log::info!("Synchronizing clients engine"); + let should_refresh = self.mem_cached_state.should_refresh_client(); + let mut engine = clients_engine::Engine::new(command_processor, self.interruptee); + if let Err(e) = engine.sync( + &client_info.client, + &global_state, + self.root_sync_key, + should_refresh, + ) { + // Record telemetry with the error just in case... + let mut telem_sync = telemetry::SyncTelemetry::new(); + let mut telem_engine = telemetry::Engine::new("clients"); + telem_engine.failure(&e); + telem_sync.engine(telem_engine); + self.result.service_status = ServiceStatus::from_err(&e); + + // ...And bail, because a clients engine sync failure is fatal. + return Err(e); + } + // We don't record telemetry for successful clients engine + // syncs, since we only keep client records in memory, we + // expect the counts to be the same most times, and a + // failure aborts the entire sync. + if self.was_interrupted() { + return Ok(()); + } + self.mem_cached_state.note_client_refresh(); + Some(engine) + } else { + None + }; + + log::info!("Synchronizing engines"); + + let telem_sync = + self.sync_engines(&client_info, &mut global_state, clients_engine.as_ref()); + self.result.telemetry.sync(telem_sync); + + log::info!("Finished syncing engines."); + + if !self.saw_auth_error { + log::trace!("Updating persisted global state"); + self.mem_cached_state.last_client_info = Some(client_info); + self.mem_cached_state.last_global_state = Some(global_state); + } + + Ok(()) + } + + fn was_interrupted(&mut self) -> bool { + if self.interruptee.was_interrupted() { + log::info!("Interrupted, bailing out"); + self.result.service_status = ServiceStatus::Interrupted; + true + } else { + false + } + } + + fn sync_engines( + &mut self, + client_info: &ClientInfo, + global_state: &mut GlobalState, + clients: Option<&clients_engine::Engine<'_>>, + ) -> telemetry::SyncTelemetry { + let mut telem_sync = telemetry::SyncTelemetry::new(); + for engine in self.engines { + let name = engine.collection_name(); + if self + .backoff + .get_required_wait(self.ignore_soft_backoff) + .is_some() + { + log::warn!("Got backoff, bailing out of sync early"); + break; + } + if global_state.global.declined.iter().any(|e| e == &*name) { + log::info!("The {} engine is declined. Skipping", name); + continue; + } + log::info!("Syncing {} engine!", name); + + let mut telem_engine = telemetry::Engine::new(&*name); + let result = super::sync::synchronize_with_clients_engine( + &client_info.client, + global_state, + self.root_sync_key, + clients, + *engine, + true, + &mut telem_engine, + self.interruptee, + ); + + match result { + Ok(()) => log::info!("Sync of {} was successful!", name), + Err(ref e) => { + log::warn!("Sync of {} failed! {:?}", name, e); + let this_status = ServiceStatus::from_err(e); + // The only error which forces us to discard our state is an + // auth error. + self.saw_auth_error = + self.saw_auth_error || this_status == ServiceStatus::AuthenticationError; + telem_engine.failure(e); + // If the failure from the engine looks like anything other than + // a "engine error" we don't bother trying the others. + if this_status != ServiceStatus::OtherError { + telem_sync.engine(telem_engine); + self.result.engine_results.insert(name.into(), result); + self.result.service_status = this_status; + break; + } + } + } + telem_sync.engine(telem_engine); + self.result.engine_results.insert(name.into(), result); + if self.was_interrupted() { + break; + } + } + telem_sync + } + + fn run_state_machine( + &mut self, + client_info: &ClientInfo, + pgs: &mut PersistedGlobalState, + ) -> result::Result<GlobalState, Error> { + let last_state = self.mem_cached_state.last_global_state.take(); + + let mut state_machine = SetupStateMachine::for_full_sync( + &client_info.client, + self.root_sync_key, + pgs, + self.engines_to_state_change, + self.interruptee, + ); + + log::info!("Advancing state machine to ready (full)"); + let res = state_machine.run_to_ready(last_state); + // Grab this now even though we don't need it until later to avoid a + // lifetime issue + let changes = state_machine.changes_needed.take(); + // The state machine might have updated our persisted_global_state, so + // update the caller's repr of it. + *self.persisted_global_state = Some(serde_json::to_string(&pgs)?); + + // Now that we've gone through the state machine, engine the declined list in + // the sync_result + self.result.declined = Some(pgs.get_declined().to_vec()); + log::debug!( + "Declined engines list after state machine set to: {:?}", + self.result.declined, + ); + + if let Some(c) = changes { + self.wipe_or_reset_engines(c, &client_info.client)?; + } + let state = match res { + Err(e) => { + self.result.service_status = ServiceStatus::from_err(&e); + return Err(e); + } + Ok(state) => state, + }; + self.result.telemetry.uid(client_info.client.hashed_uid()?); + // As for client_info, put None back now so we start from scratch on error. + self.mem_cached_state.last_global_state = None; + Ok(state) + } + + fn wipe_or_reset_engines( + &mut self, + changes: EngineChangesNeeded, + client: &Sync15StorageClient, + ) -> result::Result<(), Error> { + if changes.local_resets.is_empty() && changes.remote_wipes.is_empty() { + return Ok(()); + } + for e in &changes.remote_wipes { + log::info!("Engine {:?} just got disabled locally, wiping server", e); + client.wipe_remote_engine(e)?; + } + + for s in self.engines { + let name = s.collection_name(); + if changes.local_resets.contains(&*name) { + log::info!("Resetting engine {}, as it was declined remotely", name); + s.reset(&EngineSyncAssociation::Disconnected)?; + } + } + + Ok(()) + } + + fn prepare_client_info(&mut self) -> result::Result<ClientInfo, Error> { + let mut client_info = match self.mem_cached_state.last_client_info.take() { + Some(client_info) => { + // if our storage_init has changed it probably means the user has + // changed, courtesy of the 'kid' in the structure. Thus, we can't + // reuse the client or the memory cached state. We do keep the disk + // state as currently that's only the declined list. + if client_info.client_init != *self.storage_init { + log::info!("Discarding all state as the account might have changed"); + *self.mem_cached_state = MemoryCachedState::default(); + ClientInfo::new(self.storage_init)? + } else { + log::debug!("Reusing memory-cached client_info"); + // we can reuse it (which should be the common path) + client_info + } + } + None => { + log::debug!("mem_cached_state was stale or missing, need setup"); + // We almost certainly have no other state here, but to be safe, we + // throw away any memory state we do have. + self.mem_cached_state.clear_sensitive_info(); + ClientInfo::new(self.storage_init)? + } + }; + // Ensure we use the correct listener here rather than on all the branches + // above, since it seems less error prone. + client_info.client.backoff = self.backoff.clone(); + Ok(client_info) + } + + fn prepare_persisted_state(&mut self) -> PersistedGlobalState { + // Note that any failure to use a persisted state means we also decline + // to use our memory cached state, so that we fully rebuild that + // persisted state for next time. + match self.persisted_global_state { + Some(persisted_string) if !persisted_string.is_empty() => { + match serde_json::from_str::<PersistedGlobalState>(persisted_string) { + Ok(state) => { + log::trace!("Read persisted state: {:?}", state); + // Note that we don't set `result.declined` from the + // data in state - it remains None, which explicitly + // indicates "we don't have updated info". + state + } + _ => { + // Don't log the error since it might contain sensitive + // info (although currently it only contains the declined engines list) + error_support::report_error!( + "sync15-prepare-persisted-state", + "Failed to parse PersistedGlobalState from JSON! Falling back to default" + ); + *self.mem_cached_state = MemoryCachedState::default(); + PersistedGlobalState::default() + } + } + } + _ => { + log::info!( + "The application didn't give us persisted state - \ + this is only expected on the very first run for a given user." + ); + *self.mem_cached_state = MemoryCachedState::default(); + PersistedGlobalState::default() + } + } + } +} diff --git a/third_party/rust/sync15/src/client/token.rs b/third_party/rust/sync15/src/client/token.rs new file mode 100644 index 0000000000..b416c0c12a --- /dev/null +++ b/third_party/rust/sync15/src/client/token.rs @@ -0,0 +1,602 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use crate::error::{self, Error as ErrorKind, Result}; +use crate::ServerTimestamp; +use rc_crypto::hawk; +use serde_derive::*; +use std::borrow::{Borrow, Cow}; +use std::cell::RefCell; +use std::fmt; +use std::time::{Duration, SystemTime}; +use url::Url; +use viaduct::{header_names, Request}; + +const RETRY_AFTER_DEFAULT_MS: u64 = 10000; + +// The TokenserverToken is the token as received directly from the token server +// and deserialized from JSON. +#[derive(Deserialize, Clone, PartialEq, Eq)] +struct TokenserverToken { + id: String, + key: String, + api_endpoint: String, + uid: u64, + duration: u64, + hashed_fxa_uid: String, +} + +impl std::fmt::Debug for TokenserverToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TokenserverToken") + .field("api_endpoint", &self.api_endpoint) + .field("uid", &self.uid) + .field("duration", &self.duration) + .field("hashed_fxa_uid", &self.hashed_fxa_uid) + .finish() + } +} + +// The struct returned by the TokenFetcher - the token itself and the +// server timestamp. +struct TokenFetchResult { + token: TokenserverToken, + server_timestamp: ServerTimestamp, +} + +// The trait for fetching tokens - we'll provide a "real" implementation but +// tests will re-implement it. +trait TokenFetcher { + fn fetch_token(&self) -> crate::Result<TokenFetchResult>; + // We allow the trait to tell us what the time is so tests can get funky. + fn now(&self) -> SystemTime; +} + +// Our "real" token fetcher, implementing the TokenFetcher trait, which hits +// the token server +#[derive(Debug)] +struct TokenServerFetcher { + // The stuff needed to fetch a token. + server_url: Url, + access_token: String, + key_id: String, +} + +fn fixup_server_url(mut url: Url) -> url::Url { + // The given `url` is the end-point as returned by .well-known/fxa-client-configuration, + // or as directly specified by self-hosters. As a result, it may or may not have + // the sync 1.5 suffix of "/1.0/sync/1.5", so add it on here if it does not. + if url.as_str().ends_with("1.0/sync/1.5") { + // ok! + } else if url.as_str().ends_with("1.0/sync/1.5/") { + // Shouldn't ever be Err() here, but the result is `Result<PathSegmentsMut, ()>` + // and I don't want to unwrap or add a new error type just for PathSegmentsMut failing. + if let Ok(mut path) = url.path_segments_mut() { + path.pop(); + } + } else { + // We deliberately don't use `.join()` here in order to preserve all path components. + // For example, "http://example.com/token" should produce "http://example.com/token/1.0/sync/1.5" + // but using `.join()` would produce "http://example.com/1.0/sync/1.5". + if let Ok(mut path) = url.path_segments_mut() { + path.pop_if_empty(); + path.extend(&["1.0", "sync", "1.5"]); + } + }; + url +} + +impl TokenServerFetcher { + fn new(base_url: Url, access_token: String, key_id: String) -> TokenServerFetcher { + TokenServerFetcher { + server_url: fixup_server_url(base_url), + access_token, + key_id, + } + } +} + +impl TokenFetcher for TokenServerFetcher { + fn fetch_token(&self) -> Result<TokenFetchResult> { + log::debug!("Fetching token from {}", self.server_url); + let resp = Request::get(self.server_url.clone()) + .header( + header_names::AUTHORIZATION, + format!("Bearer {}", self.access_token), + )? + .header(header_names::X_KEYID, self.key_id.clone())? + .send()?; + + if !resp.is_success() { + log::warn!("Non-success status when fetching token: {}", resp.status); + // TODO: the body should be JSON and contain a status parameter we might need? + log::trace!(" Response body {}", resp.text()); + // XXX - shouldn't we "chain" these errors - ie, a BackoffError could + // have a TokenserverHttpError as its cause? + if let Some(res) = resp.headers.get_as::<f64, _>(header_names::RETRY_AFTER) { + let ms = res + .ok() + .map_or(RETRY_AFTER_DEFAULT_MS, |f| (f * 1000f64) as u64); + let when = self.now() + Duration::from_millis(ms); + return Err(ErrorKind::BackoffError(when)); + } + let status = resp.status; + return Err(ErrorKind::TokenserverHttpError(status)); + } + + let token: TokenserverToken = resp.json()?; + let server_timestamp = resp + .headers + .try_get::<ServerTimestamp, _>(header_names::X_TIMESTAMP) + .ok_or(ErrorKind::MissingServerTimestamp)?; + Ok(TokenFetchResult { + token, + server_timestamp, + }) + } + + fn now(&self) -> SystemTime { + SystemTime::now() + } +} + +// The context stored by our TokenProvider when it has a TokenState::Token +// state. +struct TokenContext { + token: TokenserverToken, + credentials: hawk::Credentials, + server_timestamp: ServerTimestamp, + valid_until: SystemTime, +} + +// hawk::Credentials doesn't implement debug -_- +impl fmt::Debug for TokenContext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> ::std::result::Result<(), fmt::Error> { + f.debug_struct("TokenContext") + .field("token", &self.token) + .field("credentials", &"(omitted)") + .field("server_timestamp", &self.server_timestamp) + .field("valid_until", &self.valid_until) + .finish() + } +} + +impl TokenContext { + fn new( + token: TokenserverToken, + credentials: hawk::Credentials, + server_timestamp: ServerTimestamp, + valid_until: SystemTime, + ) -> Self { + Self { + token, + credentials, + server_timestamp, + valid_until, + } + } + + fn is_valid(&self, now: SystemTime) -> bool { + // We could consider making the duration a little shorter - if it + // only has 1 second validity there seems a reasonable chance it will + // have expired by the time it gets presented to the remote that wants + // it. + // Either way though, we will eventually need to handle a token being + // rejected as a non-fatal error and recover, so maybe we don't care? + now < self.valid_until + } + + fn authorization(&self, req: &Request) -> Result<String> { + let url = &req.url; + + let path_and_query = match url.query() { + None => Cow::from(url.path()), + Some(qs) => Cow::from(format!("{}?{}", url.path(), qs)), + }; + + let host = url + .host_str() + .ok_or_else(|| ErrorKind::UnacceptableUrl("Storage URL has no host".into()))?; + + // Known defaults exist for https? (among others), so this should be impossible + let port = url.port_or_known_default().ok_or_else(|| { + ErrorKind::UnacceptableUrl( + "Storage URL has no port and no default port is known for the protocol".into(), + ) + })?; + + let header = + hawk::RequestBuilder::new(req.method.as_str(), host, port, path_and_query.borrow()) + .request() + .make_header(&self.credentials)?; + + Ok(format!("Hawk {}", header)) + } +} + +// The state our TokenProvider holds to reflect the state of the token. +#[derive(Debug)] +enum TokenState { + // We've never fetched a token. + NoToken, + // Have a token and last we checked it remained valid. + Token(TokenContext), + // We failed to fetch a token. First elt is the error, second elt is + // the api_endpoint we had before we failed to fetch a new token (or + // None if the very first attempt at fetching a token failed) + Failed(Option<error::Error>, Option<String>), + // Previously failed and told to back-off for SystemTime duration. Second + // elt is the api_endpoint we had before we hit the backoff error. + // XXX - should we roll Backoff and Failed together? + Backoff(SystemTime, Option<String>), + // api_endpoint changed - we are never going to get a token nor move out + // of this state. + NodeReassigned, +} + +/// The generic TokenProvider implementation - long lived and fetches tokens +/// on demand (eg, when first needed, or when an existing one expires.) +#[derive(Debug)] +struct TokenProviderImpl<TF: TokenFetcher> { + fetcher: TF, + // Our token state (ie, whether we have a token, and if not, why not) + current_state: RefCell<TokenState>, +} + +impl<TF: TokenFetcher> TokenProviderImpl<TF> { + fn new(fetcher: TF) -> Self { + // We check this at the real entrypoint of the application, but tests + // can/do bypass that, so check this here too. + rc_crypto::ensure_initialized(); + TokenProviderImpl { + fetcher, + current_state: RefCell::new(TokenState::NoToken), + } + } + + // Uses our fetcher to grab a new token and if successfull, derives other + // info from that token into a usable TokenContext. + fn fetch_context(&self) -> Result<TokenContext> { + let result = self.fetcher.fetch_token()?; + let token = result.token; + let valid_until = SystemTime::now() + Duration::from_secs(token.duration); + + let credentials = hawk::Credentials { + id: token.id.clone(), + key: hawk::Key::new(token.key.as_bytes(), hawk::SHA256)?, + }; + + Ok(TokenContext::new( + token, + credentials, + result.server_timestamp, + valid_until, + )) + } + + // Attempt to fetch a new token and return a new state reflecting that + // operation. If it worked a TokenState will be returned, but errors may + // cause other states. + fn fetch_token(&self, previous_endpoint: Option<&str>) -> TokenState { + match self.fetch_context() { + Ok(tc) => { + // We got a new token - check that the endpoint is the same + // as a previous endpoint we saw (if any) + match previous_endpoint { + Some(prev) => { + if prev == tc.token.api_endpoint { + TokenState::Token(tc) + } else { + log::warn!( + "api_endpoint changed from {} to {}", + prev, + tc.token.api_endpoint + ); + TokenState::NodeReassigned + } + } + None => { + // Never had an api_endpoint in the past, so this is OK. + TokenState::Token(tc) + } + } + } + Err(e) => { + // Early to avoid nll issues... + if let ErrorKind::BackoffError(be) = e { + return TokenState::Backoff(be, previous_endpoint.map(ToString::to_string)); + } + TokenState::Failed(Some(e), previous_endpoint.map(ToString::to_string)) + } + } + } + + // Given the state we are currently in, return a new current state. + // Returns None if the current state should be used (eg, if we are + // holding a token that remains valid) or Some() if the state has changed + // (which may have changed to a state with a token or an error state) + fn advance_state(&self, state: &TokenState) -> Option<TokenState> { + match state { + TokenState::NoToken => Some(self.fetch_token(None)), + TokenState::Failed(_, existing_endpoint) => { + Some(self.fetch_token(existing_endpoint.as_ref().map(String::as_str))) + } + TokenState::Token(existing_context) => { + if existing_context.is_valid(self.fetcher.now()) { + None + } else { + Some(self.fetch_token(Some(existing_context.token.api_endpoint.as_str()))) + } + } + TokenState::Backoff(ref until, ref existing_endpoint) => { + if let Ok(remaining) = until.duration_since(self.fetcher.now()) { + log::debug!("enforcing existing backoff - {:?} remains", remaining); + None + } else { + // backoff period is over + Some(self.fetch_token(existing_endpoint.as_ref().map(String::as_str))) + } + } + TokenState::NodeReassigned => { + // We never leave this state. + None + } + } + } + + fn with_token<T, F>(&self, func: F) -> Result<T> + where + F: FnOnce(&TokenContext) -> Result<T>, + { + // first get a mutable ref to our existing state, advance to the + // state we will use, then re-stash that state for next time. + let state: &mut TokenState = &mut self.current_state.borrow_mut(); + if let Some(new_state) = self.advance_state(state) { + *state = new_state; + } + + // Now re-fetch the state we should use for this call - if it's + // anything other than TokenState::Token we will fail. + match state { + TokenState::NoToken => { + // it should be impossible to get here. + panic!("Can't be in NoToken state after advancing"); + } + TokenState::Token(ref token_context) => { + // make the call. + func(token_context) + } + TokenState::Failed(e, _) => { + // We swap the error out of the state enum and return it. + Err(e.take().unwrap()) + } + TokenState::NodeReassigned => { + // this is unrecoverable. + Err(ErrorKind::StorageResetError) + } + TokenState::Backoff(ref remaining, _) => Err(ErrorKind::BackoffError(*remaining)), + } + } + + fn hashed_uid(&self) -> Result<String> { + self.with_token(|ctx| Ok(ctx.token.hashed_fxa_uid.clone())) + } + + fn authorization(&self, req: &Request) -> Result<String> { + self.with_token(|ctx| ctx.authorization(req)) + } + + fn api_endpoint(&self) -> Result<String> { + self.with_token(|ctx| Ok(ctx.token.api_endpoint.clone())) + } + // TODO: we probably want a "drop_token/context" type method so that when + // using a token with some validity fails the caller can force a new one + // (in which case the new token request will probably fail with a 401) +} + +// The public concrete object exposed by this module +#[derive(Debug)] +pub struct TokenProvider { + imp: TokenProviderImpl<TokenServerFetcher>, +} + +impl TokenProvider { + pub fn new(url: Url, access_token: String, key_id: String) -> Result<Self> { + let fetcher = TokenServerFetcher::new(url, access_token, key_id); + Ok(Self { + imp: TokenProviderImpl::new(fetcher), + }) + } + + pub fn hashed_uid(&self) -> Result<String> { + self.imp.hashed_uid() + } + + pub fn authorization(&self, req: &Request) -> Result<String> { + self.imp.authorization(req) + } + + pub fn api_endpoint(&self) -> Result<String> { + self.imp.api_endpoint() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::cell::Cell; + + struct TestFetcher<FF, FN> + where + FF: Fn() -> Result<TokenFetchResult>, + FN: Fn() -> SystemTime, + { + fetch: FF, + now: FN, + } + impl<FF, FN> TokenFetcher for TestFetcher<FF, FN> + where + FF: Fn() -> Result<TokenFetchResult>, + FN: Fn() -> SystemTime, + { + fn fetch_token(&self) -> Result<TokenFetchResult> { + (self.fetch)() + } + fn now(&self) -> SystemTime { + (self.now)() + } + } + + fn make_tsc<FF, FN>(fetch: FF, now: FN) -> TokenProviderImpl<TestFetcher<FF, FN>> + where + FF: Fn() -> Result<TokenFetchResult>, + FN: Fn() -> SystemTime, + { + let fetcher: TestFetcher<FF, FN> = TestFetcher { fetch, now }; + TokenProviderImpl::new(fetcher) + } + + #[test] + fn test_endpoint() { + // Use a cell to avoid the closure having a mutable ref to this scope. + let counter: Cell<u32> = Cell::new(0); + let fetch = || { + counter.set(counter.get() + 1); + Ok(TokenFetchResult { + token: TokenserverToken { + id: "id".to_string(), + key: "key".to_string(), + api_endpoint: "api_endpoint".to_string(), + uid: 1, + duration: 1000, + hashed_fxa_uid: "hash".to_string(), + }, + server_timestamp: ServerTimestamp(0i64), + }) + }; + + let tsc = make_tsc(fetch, SystemTime::now); + + let e = tsc.api_endpoint().expect("should work"); + assert_eq!(e, "api_endpoint".to_string()); + assert_eq!(counter.get(), 1); + + let e2 = tsc.api_endpoint().expect("should work"); + assert_eq!(e2, "api_endpoint".to_string()); + // should not have re-fetched. + assert_eq!(counter.get(), 1); + } + + #[test] + fn test_backoff() { + let counter: Cell<u32> = Cell::new(0); + let fetch = || { + counter.set(counter.get() + 1); + let when = SystemTime::now() + Duration::from_millis(10000); + Err(ErrorKind::BackoffError(when)) + }; + let now: Cell<SystemTime> = Cell::new(SystemTime::now()); + let tsc = make_tsc(fetch, || now.get()); + + tsc.api_endpoint().expect_err("should bail"); + // XXX - check error type. + assert_eq!(counter.get(), 1); + // try and get another token - should not re-fetch as backoff is still + // in progress. + tsc.api_endpoint().expect_err("should bail"); + assert_eq!(counter.get(), 1); + + // Advance the clock. + now.set(now.get() + Duration::new(20, 0)); + + // Our token fetch mock is still returning a backoff error, so we + // still fail, but should have re-hit the fetch function. + tsc.api_endpoint().expect_err("should bail"); + assert_eq!(counter.get(), 2); + } + + #[test] + fn test_validity() { + let counter: Cell<u32> = Cell::new(0); + let fetch = || { + counter.set(counter.get() + 1); + Ok(TokenFetchResult { + token: TokenserverToken { + id: "id".to_string(), + key: "key".to_string(), + api_endpoint: "api_endpoint".to_string(), + uid: 1, + duration: 10, + hashed_fxa_uid: "hash".to_string(), + }, + server_timestamp: ServerTimestamp(0i64), + }) + }; + let now: Cell<SystemTime> = Cell::new(SystemTime::now()); + let tsc = make_tsc(fetch, || now.get()); + + tsc.api_endpoint().expect("should get a valid token"); + assert_eq!(counter.get(), 1); + + // try and get another token - should not re-fetch as the old one + // remains valid. + tsc.api_endpoint().expect("should reuse existing token"); + assert_eq!(counter.get(), 1); + + // Advance the clock. + now.set(now.get() + Duration::new(20, 0)); + + // We should discard our token and fetch a new one. + tsc.api_endpoint().expect("should re-fetch"); + assert_eq!(counter.get(), 2); + } + + #[test] + fn test_server_url() { + assert_eq!( + fixup_server_url( + Url::parse("https://token.services.mozilla.com/1.0/sync/1.5").unwrap() + ) + .as_str(), + "https://token.services.mozilla.com/1.0/sync/1.5" + ); + assert_eq!( + fixup_server_url( + Url::parse("https://token.services.mozilla.com/1.0/sync/1.5/").unwrap() + ) + .as_str(), + "https://token.services.mozilla.com/1.0/sync/1.5" + ); + assert_eq!( + fixup_server_url(Url::parse("https://token.services.mozilla.com").unwrap()).as_str(), + "https://token.services.mozilla.com/1.0/sync/1.5" + ); + assert_eq!( + fixup_server_url(Url::parse("https://token.services.mozilla.com/").unwrap()).as_str(), + "https://token.services.mozilla.com/1.0/sync/1.5" + ); + assert_eq!( + fixup_server_url( + Url::parse("https://selfhosted.example.com/token/1.0/sync/1.5").unwrap() + ) + .as_str(), + "https://selfhosted.example.com/token/1.0/sync/1.5" + ); + assert_eq!( + fixup_server_url( + Url::parse("https://selfhosted.example.com/token/1.0/sync/1.5/").unwrap() + ) + .as_str(), + "https://selfhosted.example.com/token/1.0/sync/1.5" + ); + assert_eq!( + fixup_server_url(Url::parse("https://selfhosted.example.com/token/").unwrap()).as_str(), + "https://selfhosted.example.com/token/1.0/sync/1.5" + ); + assert_eq!( + fixup_server_url(Url::parse("https://selfhosted.example.com/token").unwrap()).as_str(), + "https://selfhosted.example.com/token/1.0/sync/1.5" + ); + } +} diff --git a/third_party/rust/sync15/src/client/util.rs b/third_party/rust/sync15/src/client/util.rs new file mode 100644 index 0000000000..01fff77afa --- /dev/null +++ b/third_party/rust/sync15/src/client/util.rs @@ -0,0 +1,102 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use std::collections::{HashMap, HashSet}; +use std::sync::atomic::{AtomicU32, Ordering}; + +/// Finds the maximum of the current value and the argument `val`, and sets the +/// new value to the result. +/// +/// Note: `AtomicFoo::fetch_max` is unstable, and can't really be implemented as +/// a single atomic operation from outside the stdlib ;-; +pub(crate) fn atomic_update_max(v: &AtomicU32, new: u32) { + // For loads (and the compare_exchange_weak second ordering argument) this + // is too strong, we could probably get away with Acquire (or maybe Relaxed + // because we don't need the result?). In either case, this fn isn't called + // from a hot spot so whatever. + let mut cur = v.load(Ordering::SeqCst); + while cur < new { + // we're already handling the failure case so there's no reason not to + // use _weak here. + match v.compare_exchange_weak(cur, new, Ordering::SeqCst, Ordering::SeqCst) { + Ok(_) => { + // Success. + break; + } + Err(new_cur) => { + // Interrupted, keep trying. + cur = new_cur + } + } + } +} + +// Slight wrappers around the builtin methods for doing this. +pub(crate) fn set_union(a: &HashSet<String>, b: &HashSet<String>) -> HashSet<String> { + a.union(b).cloned().collect() +} + +pub(crate) fn set_difference(a: &HashSet<String>, b: &HashSet<String>) -> HashSet<String> { + a.difference(b).cloned().collect() +} + +pub(crate) fn set_intersection(a: &HashSet<String>, b: &HashSet<String>) -> HashSet<String> { + a.intersection(b).cloned().collect() +} + +pub(crate) fn partition_by_value(v: &HashMap<String, bool>) -> (HashSet<String>, HashSet<String>) { + let mut true_: HashSet<String> = HashSet::new(); + let mut false_: HashSet<String> = HashSet::new(); + for (s, val) in v { + if *val { + true_.insert(s.clone()); + } else { + false_.insert(s.clone()); + } + } + (true_, false_) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_set_ops() { + fn hash_set(s: &[&str]) -> HashSet<String> { + s.iter() + .copied() + .map(ToOwned::to_owned) + .collect::<HashSet<_>>() + } + + assert_eq!( + set_union(&hash_set(&["a", "b", "c"]), &hash_set(&["b", "d"])), + hash_set(&["a", "b", "c", "d"]), + ); + + assert_eq!( + set_difference(&hash_set(&["a", "b", "c"]), &hash_set(&["b", "d"])), + hash_set(&["a", "c"]), + ); + assert_eq!( + set_intersection(&hash_set(&["a", "b", "c"]), &hash_set(&["b", "d"])), + hash_set(&["b"]), + ); + let m: HashMap<String, bool> = [ + ("foo", true), + ("bar", true), + ("baz", false), + ("quux", false), + ] + .iter() + .copied() + .map(|(a, b)| (a.to_owned(), b)) + .collect(); + assert_eq!( + partition_by_value(&m), + (hash_set(&["foo", "bar"]), hash_set(&["baz", "quux"])), + ); + } +} |