summaryrefslogtreecommitdiffstats
path: root/third_party/rust/sync15/src/client
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/sync15/src/client')
-rw-r--r--third_party/rust/sync15/src/client/coll_state.rs354
-rw-r--r--third_party/rust/sync15/src/client/coll_update.rs137
-rw-r--r--third_party/rust/sync15/src/client/collection_keys.rs61
-rw-r--r--third_party/rust/sync15/src/client/mod.rs39
-rw-r--r--third_party/rust/sync15/src/client/request.rs1199
-rw-r--r--third_party/rust/sync15/src/client/state.rs1089
-rw-r--r--third_party/rust/sync15/src/client/status.rs106
-rw-r--r--third_party/rust/sync15/src/client/storage_client.rs587
-rw-r--r--third_party/rust/sync15/src/client/sync.rs105
-rw-r--r--third_party/rust/sync15/src/client/sync_multiple.rs493
-rw-r--r--third_party/rust/sync15/src/client/token.rs602
-rw-r--r--third_party/rust/sync15/src/client/util.rs102
12 files changed, 4874 insertions, 0 deletions
diff --git a/third_party/rust/sync15/src/client/coll_state.rs b/third_party/rust/sync15/src/client/coll_state.rs
new file mode 100644
index 0000000000..df8be5f5b5
--- /dev/null
+++ b/third_party/rust/sync15/src/client/coll_state.rs
@@ -0,0 +1,354 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use super::request::InfoConfiguration;
+use super::{CollectionKeys, GlobalState};
+use crate::engine::{CollSyncIds, EngineSyncAssociation, SyncEngine};
+use crate::error;
+use crate::KeyBundle;
+use crate::ServerTimestamp;
+
+/// Holds state for a collection. In general, only the CollState is
+/// needed to sync a collection (but a valid GlobalState is needed to obtain
+/// a CollState)
+#[derive(Debug, Clone)]
+pub struct CollState {
+ pub config: InfoConfiguration,
+ // initially from meta/global, updated after an xius POST/PUT.
+ pub last_modified: ServerTimestamp,
+ pub key: KeyBundle,
+}
+
+#[derive(Debug)]
+pub enum LocalCollState {
+ /// The state is unknown, with the EngineSyncAssociation the collection
+ /// reports.
+ Unknown { assoc: EngineSyncAssociation },
+
+ /// The engine has been declined. This is a "terminal" state.
+ Declined,
+
+ /// There's no such collection in meta/global. We could possibly update
+ /// meta/global, but currently all known collections are there by default,
+ /// so this is, basically, an error condition.
+ NoSuchCollection,
+
+ /// Either the global or collection sync ID has changed - we will reset the engine.
+ SyncIdChanged { ids: CollSyncIds },
+
+ /// The collection is ready to sync.
+ Ready { key: KeyBundle },
+}
+
+pub struct LocalCollStateMachine<'state> {
+ global_state: &'state GlobalState,
+ root_key: &'state KeyBundle,
+}
+
+impl<'state> LocalCollStateMachine<'state> {
+ fn advance(
+ &self,
+ from: LocalCollState,
+ engine: &dyn SyncEngine,
+ ) -> error::Result<LocalCollState> {
+ let name = &engine.collection_name().to_string();
+ let meta_global = &self.global_state.global;
+ match from {
+ LocalCollState::Unknown { assoc } => {
+ if meta_global.declined.contains(name) {
+ return Ok(LocalCollState::Declined);
+ }
+ match meta_global.engines.get(name) {
+ Some(engine_meta) => match assoc {
+ EngineSyncAssociation::Disconnected => Ok(LocalCollState::SyncIdChanged {
+ ids: CollSyncIds {
+ global: meta_global.sync_id.clone(),
+ coll: engine_meta.sync_id.clone(),
+ },
+ }),
+ EngineSyncAssociation::Connected(ref ids)
+ if ids.global == meta_global.sync_id
+ && ids.coll == engine_meta.sync_id =>
+ {
+ let coll_keys = CollectionKeys::from_encrypted_payload(
+ self.global_state.keys.clone(),
+ self.global_state.keys_timestamp,
+ self.root_key,
+ )?;
+ Ok(LocalCollState::Ready {
+ key: coll_keys.key_for_collection(name).clone(),
+ })
+ }
+ _ => Ok(LocalCollState::SyncIdChanged {
+ ids: CollSyncIds {
+ global: meta_global.sync_id.clone(),
+ coll: engine_meta.sync_id.clone(),
+ },
+ }),
+ },
+ None => Ok(LocalCollState::NoSuchCollection),
+ }
+ }
+
+ LocalCollState::Declined => unreachable!("can't advance from declined"),
+
+ LocalCollState::NoSuchCollection => unreachable!("the collection is unknown"),
+
+ LocalCollState::SyncIdChanged { ids } => {
+ let assoc = EngineSyncAssociation::Connected(ids);
+ log::info!("Resetting {} engine", engine.collection_name());
+ engine.reset(&assoc)?;
+ Ok(LocalCollState::Unknown { assoc })
+ }
+
+ LocalCollState::Ready { .. } => unreachable!("can't advance from ready"),
+ }
+ }
+
+ // A little whimsy - a portmanteau of far and fast
+ fn run_and_run_as_farst_as_you_can(
+ &mut self,
+ engine: &dyn SyncEngine,
+ ) -> error::Result<Option<CollState>> {
+ let mut s = LocalCollState::Unknown {
+ assoc: engine.get_sync_assoc()?,
+ };
+ // This is a simple state machine and should never take more than
+ // 10 goes around.
+ let mut count = 0;
+ loop {
+ log::trace!("LocalCollState in {:?}", s);
+ match s {
+ LocalCollState::Ready { key } => {
+ let name = engine.collection_name();
+ let config = self.global_state.config.clone();
+ let last_modified = self
+ .global_state
+ .collections
+ .get(name.as_ref())
+ .cloned()
+ .unwrap_or_default();
+ return Ok(Some(CollState {
+ config,
+ last_modified,
+ key,
+ }));
+ }
+ LocalCollState::Declined | LocalCollState::NoSuchCollection => return Ok(None),
+
+ _ => {
+ count += 1;
+ if count > 10 {
+ log::warn!("LocalCollStateMachine appears to be looping");
+ return Ok(None);
+ }
+ // should we have better loop detection? Our limit of 10
+ // goes is probably OK for now, but not really ideal.
+ s = self.advance(s, engine)?;
+ }
+ };
+ }
+ }
+
+ pub fn get_state(
+ engine: &dyn SyncEngine,
+ global_state: &'state GlobalState,
+ root_key: &'state KeyBundle,
+ ) -> error::Result<Option<CollState>> {
+ let mut gingerbread_man = Self {
+ global_state,
+ root_key,
+ };
+ gingerbread_man.run_and_run_as_farst_as_you_can(engine)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::super::request::{InfoCollections, InfoConfiguration};
+ use super::super::CollectionKeys;
+ use super::*;
+ use crate::engine::CollectionRequest;
+ use crate::engine::{IncomingChangeset, OutgoingChangeset};
+ use crate::record_types::{MetaGlobalEngine, MetaGlobalRecord};
+ use crate::telemetry;
+ use anyhow::Result;
+ use std::cell::{Cell, RefCell};
+ use std::collections::HashMap;
+ use sync_guid::Guid;
+
+ fn get_global_state(root_key: &KeyBundle) -> GlobalState {
+ let keys = CollectionKeys::new_random()
+ .unwrap()
+ .to_encrypted_payload(root_key)
+ .unwrap();
+ GlobalState {
+ config: InfoConfiguration::default(),
+ collections: InfoCollections::new(HashMap::new()),
+ global: MetaGlobalRecord {
+ sync_id: "syncIDAAAAAA".into(),
+ storage_version: 5usize,
+ engines: vec![(
+ "bookmarks",
+ MetaGlobalEngine {
+ version: 1usize,
+ sync_id: "syncIDBBBBBB".into(),
+ },
+ )]
+ .into_iter()
+ .map(|(key, value)| (key.to_owned(), value))
+ .collect(),
+ declined: vec![],
+ },
+ global_timestamp: ServerTimestamp::default(),
+ keys,
+ keys_timestamp: ServerTimestamp::default(),
+ }
+ }
+
+ struct TestSyncEngine {
+ collection_name: &'static str,
+ assoc: Cell<EngineSyncAssociation>,
+ num_resets: RefCell<usize>,
+ }
+
+ impl TestSyncEngine {
+ fn new(collection_name: &'static str, assoc: EngineSyncAssociation) -> Self {
+ Self {
+ collection_name,
+ assoc: Cell::new(assoc),
+ num_resets: RefCell::new(0),
+ }
+ }
+ fn get_num_resets(&self) -> usize {
+ *self.num_resets.borrow()
+ }
+ }
+
+ impl SyncEngine for TestSyncEngine {
+ fn collection_name(&self) -> std::borrow::Cow<'static, str> {
+ self.collection_name.into()
+ }
+
+ fn apply_incoming(
+ &self,
+ _inbound: Vec<IncomingChangeset>,
+ _telem: &mut telemetry::Engine,
+ ) -> Result<OutgoingChangeset> {
+ unreachable!("these tests shouldn't call these");
+ }
+
+ fn sync_finished(
+ &self,
+ _new_timestamp: ServerTimestamp,
+ _records_synced: Vec<Guid>,
+ ) -> Result<()> {
+ unreachable!("these tests shouldn't call these");
+ }
+
+ fn get_collection_requests(
+ &self,
+ _server_timestamp: ServerTimestamp,
+ ) -> Result<Vec<CollectionRequest>> {
+ unreachable!("these tests shouldn't call these");
+ }
+
+ fn get_sync_assoc(&self) -> Result<EngineSyncAssociation> {
+ Ok(self.assoc.replace(EngineSyncAssociation::Disconnected))
+ }
+
+ fn reset(&self, new_assoc: &EngineSyncAssociation) -> Result<()> {
+ self.assoc.replace(new_assoc.clone());
+ *self.num_resets.borrow_mut() += 1;
+ Ok(())
+ }
+
+ fn wipe(&self) -> Result<()> {
+ unreachable!("these tests shouldn't call these");
+ }
+ }
+
+ #[test]
+ fn test_unknown() {
+ let root_key = KeyBundle::new_random().expect("should work");
+ let gs = get_global_state(&root_key);
+ let engine = TestSyncEngine::new("unknown", EngineSyncAssociation::Disconnected);
+ let cs = LocalCollStateMachine::get_state(&engine, &gs, &root_key).expect("should work");
+ assert!(cs.is_none(), "unknown collection name can't sync");
+ assert_eq!(engine.get_num_resets(), 0);
+ }
+
+ #[test]
+ fn test_known_no_state() {
+ let root_key = KeyBundle::new_random().expect("should work");
+ let gs = get_global_state(&root_key);
+ let engine = TestSyncEngine::new("bookmarks", EngineSyncAssociation::Disconnected);
+ let cs = LocalCollStateMachine::get_state(&engine, &gs, &root_key).expect("should work");
+ assert!(cs.is_some(), "collection can sync");
+ assert_eq!(
+ engine.assoc.replace(EngineSyncAssociation::Disconnected),
+ EngineSyncAssociation::Connected(CollSyncIds {
+ global: "syncIDAAAAAA".into(),
+ coll: "syncIDBBBBBB".into(),
+ })
+ );
+ assert_eq!(engine.get_num_resets(), 1);
+ }
+
+ #[test]
+ fn test_known_wrong_state() {
+ let root_key = KeyBundle::new_random().expect("should work");
+ let gs = get_global_state(&root_key);
+ let engine = TestSyncEngine::new(
+ "bookmarks",
+ EngineSyncAssociation::Connected(CollSyncIds {
+ global: "syncIDXXXXXX".into(),
+ coll: "syncIDYYYYYY".into(),
+ }),
+ );
+ let cs = LocalCollStateMachine::get_state(&engine, &gs, &root_key).expect("should work");
+ assert!(cs.is_some(), "collection can sync");
+ assert_eq!(
+ engine.assoc.replace(EngineSyncAssociation::Disconnected),
+ EngineSyncAssociation::Connected(CollSyncIds {
+ global: "syncIDAAAAAA".into(),
+ coll: "syncIDBBBBBB".into(),
+ })
+ );
+ assert_eq!(engine.get_num_resets(), 1);
+ }
+
+ #[test]
+ fn test_known_good_state() {
+ let root_key = KeyBundle::new_random().expect("should work");
+ let gs = get_global_state(&root_key);
+ let engine = TestSyncEngine::new(
+ "bookmarks",
+ EngineSyncAssociation::Connected(CollSyncIds {
+ global: "syncIDAAAAAA".into(),
+ coll: "syncIDBBBBBB".into(),
+ }),
+ );
+ let cs = LocalCollStateMachine::get_state(&engine, &gs, &root_key).expect("should work");
+ assert!(cs.is_some(), "collection can sync");
+ assert_eq!(engine.get_num_resets(), 0);
+ }
+
+ #[test]
+ fn test_declined() {
+ let root_key = KeyBundle::new_random().expect("should work");
+ let mut gs = get_global_state(&root_key);
+ gs.global.declined.push("bookmarks".to_string());
+ let engine = TestSyncEngine::new(
+ "bookmarks",
+ EngineSyncAssociation::Connected(CollSyncIds {
+ global: "syncIDAAAAAA".into(),
+ coll: "syncIDBBBBBB".into(),
+ }),
+ );
+ let cs = LocalCollStateMachine::get_state(&engine, &gs, &root_key).expect("should work");
+ assert!(cs.is_none(), "declined collection can sync");
+ assert_eq!(engine.get_num_resets(), 0);
+ }
+}
diff --git a/third_party/rust/sync15/src/client/coll_update.rs b/third_party/rust/sync15/src/client/coll_update.rs
new file mode 100644
index 0000000000..baa551c9a4
--- /dev/null
+++ b/third_party/rust/sync15/src/client/coll_update.rs
@@ -0,0 +1,137 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use super::{
+ request::{NormalResponseHandler, UploadInfo},
+ CollState, Sync15ClientResponse, Sync15StorageClient,
+};
+use crate::bso::OutgoingEncryptedBso;
+use crate::engine::{CollectionRequest, IncomingChangeset, OutgoingChangeset};
+use crate::error::{self, Error, ErrorResponse, Result};
+use crate::{KeyBundle, ServerTimestamp};
+use std::borrow::Cow;
+
+pub fn encrypt_outgoing(
+ o: OutgoingChangeset,
+ key: &KeyBundle,
+) -> Result<Vec<OutgoingEncryptedBso>> {
+ o.changes
+ .into_iter()
+ .map(|change| change.into_encrypted(key))
+ .collect()
+}
+
+pub fn fetch_incoming(
+ client: &Sync15StorageClient,
+ state: &mut CollState,
+ collection_request: &CollectionRequest,
+) -> Result<IncomingChangeset> {
+ let collection = collection_request.collection.clone();
+ let (records, timestamp) = match client.get_encrypted_records(collection_request)? {
+ Sync15ClientResponse::Success {
+ record,
+ last_modified,
+ ..
+ } => (record, last_modified),
+ other => return Err(other.create_storage_error()),
+ };
+ // xxx - duplication below of `timestamp` smells wrong
+ state.last_modified = timestamp;
+ let mut result = IncomingChangeset::new(collection, timestamp);
+ result.changes.reserve(records.len());
+ for record in records {
+ // if we see a HMAC error, we've made an explicit decision to
+ // NOT handle it here, but restart the global state machine.
+ // That should cause us to re-read crypto/keys and things should
+ // work (although if for some reason crypto/keys was updated but
+ // not all storage was wiped we are probably screwed.)
+ result.changes.push(record.into_decrypted(&state.key)?);
+ }
+ Ok(result)
+}
+
+pub struct CollectionUpdate<'a> {
+ client: &'a Sync15StorageClient,
+ state: &'a CollState,
+ collection: Cow<'static, str>,
+ xius: ServerTimestamp,
+ to_update: Vec<OutgoingEncryptedBso>,
+ fully_atomic: bool,
+}
+
+impl<'a> CollectionUpdate<'a> {
+ pub fn new(
+ client: &'a Sync15StorageClient,
+ state: &'a CollState,
+ collection: Cow<'static, str>,
+ xius: ServerTimestamp,
+ records: Vec<OutgoingEncryptedBso>,
+ fully_atomic: bool,
+ ) -> CollectionUpdate<'a> {
+ CollectionUpdate {
+ client,
+ state,
+ collection,
+ xius,
+ to_update: records,
+ fully_atomic,
+ }
+ }
+
+ pub fn new_from_changeset(
+ client: &'a Sync15StorageClient,
+ state: &'a CollState,
+ changeset: OutgoingChangeset,
+ fully_atomic: bool,
+ ) -> Result<CollectionUpdate<'a>> {
+ let collection = changeset.collection.clone();
+ let xius = changeset.timestamp;
+ if xius < state.last_modified {
+ // We know we are going to fail the XIUS check...
+ return Err(Error::StorageHttpError(ErrorResponse::PreconditionFailed {
+ route: collection.into_owned(),
+ }));
+ }
+ let to_update = encrypt_outgoing(changeset, &state.key)?;
+ Ok(CollectionUpdate::new(
+ client,
+ state,
+ collection,
+ xius,
+ to_update,
+ fully_atomic,
+ ))
+ }
+
+ /// Returns a list of the IDs that failed if allowed_dropped_records is true, otherwise
+ /// returns an empty vec.
+ pub fn upload(self) -> error::Result<UploadInfo> {
+ let mut failed = vec![];
+ let mut q = self.client.new_post_queue(
+ &self.collection,
+ &self.state.config,
+ self.xius,
+ NormalResponseHandler::new(!self.fully_atomic),
+ )?;
+
+ for record in self.to_update.into_iter() {
+ let enqueued = q.enqueue(&record)?;
+ if !enqueued && self.fully_atomic {
+ return Err(Error::RecordTooLargeError);
+ }
+ }
+
+ q.flush(true)?;
+ let mut info = q.completed_upload_info();
+ info.failed_ids.append(&mut failed);
+ if self.fully_atomic {
+ assert_eq!(
+ info.failed_ids.len(),
+ 0,
+ "Bug: Should have failed by now if we aren't allowing dropped records"
+ );
+ }
+ Ok(info)
+ }
+}
diff --git a/third_party/rust/sync15/src/client/collection_keys.rs b/third_party/rust/sync15/src/client/collection_keys.rs
new file mode 100644
index 0000000000..f51894f756
--- /dev/null
+++ b/third_party/rust/sync15/src/client/collection_keys.rs
@@ -0,0 +1,61 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use crate::error::Result;
+use crate::record_types::CryptoKeysRecord;
+use crate::{EncryptedPayload, KeyBundle, ServerTimestamp};
+use std::collections::HashMap;
+
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct CollectionKeys {
+ pub timestamp: ServerTimestamp,
+ pub default: KeyBundle,
+ pub collections: HashMap<String, KeyBundle>,
+}
+
+impl CollectionKeys {
+ pub fn new_random() -> Result<CollectionKeys> {
+ let default = KeyBundle::new_random()?;
+ Ok(CollectionKeys {
+ timestamp: ServerTimestamp(0),
+ default,
+ collections: HashMap::new(),
+ })
+ }
+
+ pub fn from_encrypted_payload(
+ record: EncryptedPayload,
+ timestamp: ServerTimestamp,
+ root_key: &KeyBundle,
+ ) -> Result<CollectionKeys> {
+ let keys: CryptoKeysRecord = record.decrypt_into(root_key)?;
+ Ok(CollectionKeys {
+ timestamp,
+ default: KeyBundle::from_base64(&keys.default[0], &keys.default[1])?,
+ collections: keys
+ .collections
+ .into_iter()
+ .map(|kv| Ok((kv.0, KeyBundle::from_base64(&kv.1[0], &kv.1[1])?)))
+ .collect::<Result<HashMap<String, KeyBundle>>>()?,
+ })
+ }
+
+ pub fn to_encrypted_payload(&self, root_key: &KeyBundle) -> Result<EncryptedPayload> {
+ let record = CryptoKeysRecord {
+ id: "keys".into(),
+ collection: "crypto".into(),
+ default: self.default.to_b64_array(),
+ collections: self
+ .collections
+ .iter()
+ .map(|kv| (kv.0.clone(), kv.1.to_b64_array()))
+ .collect(),
+ };
+ EncryptedPayload::from_cleartext_payload(root_key, &record)
+ }
+
+ pub fn key_for_collection<'a>(&'a self, collection: &str) -> &'a KeyBundle {
+ self.collections.get(collection).unwrap_or(&self.default)
+ }
+}
diff --git a/third_party/rust/sync15/src/client/mod.rs b/third_party/rust/sync15/src/client/mod.rs
new file mode 100644
index 0000000000..84fe3678de
--- /dev/null
+++ b/third_party/rust/sync15/src/client/mod.rs
@@ -0,0 +1,39 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+//! A module for everything needed to be a "sync client" - ie, a device which
+//! can perform a full sync of any number of collections, including managing
+//! the server state.
+//!
+//! In general, the client is responsible for all communication with the sync server,
+//! including ensuring the state is correct, and encrypting/decrypting all records
+//! to and from the server. However, the actual syncing of the collections is
+//! delegated to an external [crate::engine](Sync Engine).
+//!
+//! One exception is that the "sync client" owns one sync engine - the
+//! [crate::clients_engine], which is managed internally.
+mod coll_state;
+mod coll_update;
+mod collection_keys;
+mod request;
+mod state;
+mod status;
+mod storage_client;
+mod sync;
+mod sync_multiple;
+mod token;
+mod util;
+
+pub(crate) use coll_state::CollState;
+pub(crate) use coll_update::{fetch_incoming, CollectionUpdate};
+pub(crate) use collection_keys::CollectionKeys;
+pub(crate) use request::InfoConfiguration;
+pub(crate) use state::GlobalState;
+pub use status::{ServiceStatus, SyncResult};
+pub use storage_client::{
+ SetupStorageClient, Sync15ClientResponse, Sync15StorageClient, Sync15StorageClientInit,
+};
+pub use sync_multiple::{
+ sync_multiple, sync_multiple_with_command_processor, MemoryCachedState, SyncRequestInfo,
+};
diff --git a/third_party/rust/sync15/src/client/request.rs b/third_party/rust/sync15/src/client/request.rs
new file mode 100644
index 0000000000..c69b630c8d
--- /dev/null
+++ b/third_party/rust/sync15/src/client/request.rs
@@ -0,0 +1,1199 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use super::storage_client::Sync15ClientResponse;
+use crate::bso::OutgoingEncryptedBso;
+use crate::error::{self, Error as ErrorKind, Result};
+use crate::ServerTimestamp;
+use serde_derive::*;
+use std::collections::HashMap;
+use std::default::Default;
+use std::ops::Deref;
+use sync_guid::Guid;
+use viaduct::status_codes;
+
+/// Manages a pair of (byte, count) limits for a PostQueue, such as
+/// (max_post_bytes, max_post_records) or (max_total_bytes, max_total_records).
+#[derive(Debug, Clone)]
+struct LimitTracker {
+ max_bytes: usize,
+ max_records: usize,
+ cur_bytes: usize,
+ cur_records: usize,
+}
+
+impl LimitTracker {
+ pub fn new(max_bytes: usize, max_records: usize) -> LimitTracker {
+ LimitTracker {
+ max_bytes,
+ max_records,
+ cur_bytes: 0,
+ cur_records: 0,
+ }
+ }
+
+ pub fn clear(&mut self) {
+ self.cur_records = 0;
+ self.cur_bytes = 0;
+ }
+
+ pub fn can_add_record(&self, payload_size: usize) -> bool {
+ // Desktop does the cur_bytes check as exclusive, but we shouldn't see any servers that
+ // don't have https://github.com/mozilla-services/server-syncstorage/issues/73
+ self.cur_records < self.max_records && self.cur_bytes + payload_size <= self.max_bytes
+ }
+
+ pub fn can_never_add(&self, record_size: usize) -> bool {
+ record_size >= self.max_bytes
+ }
+
+ pub fn record_added(&mut self, record_size: usize) {
+ assert!(
+ self.can_add_record(record_size),
+ "LimitTracker::record_added caller must check can_add_record"
+ );
+ self.cur_records += 1;
+ self.cur_bytes += record_size;
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Clone)]
+pub struct InfoConfiguration {
+ /// The maximum size in bytes of the overall HTTP request body that will be accepted by the
+ /// server.
+ #[serde(default = "default_max_request_bytes")]
+ pub max_request_bytes: usize,
+
+ /// The maximum number of records that can be uploaded to a collection in a single POST request.
+ #[serde(default = "usize::max_value")]
+ pub max_post_records: usize,
+
+ /// The maximum combined size in bytes of the record payloads that can be uploaded to a
+ /// collection in a single POST request.
+ #[serde(default = "usize::max_value")]
+ pub max_post_bytes: usize,
+
+ /// The maximum total number of records that can be uploaded to a collection as part of a
+ /// batched upload.
+ #[serde(default = "usize::max_value")]
+ pub max_total_records: usize,
+
+ /// The maximum total combined size in bytes of the record payloads that can be uploaded to a
+ /// collection as part of a batched upload.
+ #[serde(default = "usize::max_value")]
+ pub max_total_bytes: usize,
+
+ /// The maximum size of an individual BSO payload, in bytes.
+ #[serde(default = "default_max_record_payload_bytes")]
+ pub max_record_payload_bytes: usize,
+}
+
+// This is annoying but seems to be the only way to do it...
+fn default_max_request_bytes() -> usize {
+ 260 * 1024
+}
+fn default_max_record_payload_bytes() -> usize {
+ 256 * 1024
+}
+
+impl Default for InfoConfiguration {
+ #[inline]
+ fn default() -> InfoConfiguration {
+ InfoConfiguration {
+ max_request_bytes: default_max_request_bytes(),
+ max_record_payload_bytes: default_max_record_payload_bytes(),
+ max_post_records: usize::max_value(),
+ max_post_bytes: usize::max_value(),
+ max_total_records: usize::max_value(),
+ max_total_bytes: usize::max_value(),
+ }
+ }
+}
+
+#[derive(Clone, Debug, Default, Deserialize, Serialize)]
+pub struct InfoCollections(pub(crate) HashMap<String, ServerTimestamp>);
+
+impl InfoCollections {
+ pub fn new(collections: HashMap<String, ServerTimestamp>) -> InfoCollections {
+ InfoCollections(collections)
+ }
+}
+
+impl Deref for InfoCollections {
+ type Target = HashMap<String, ServerTimestamp>;
+
+ fn deref(&self) -> &HashMap<String, ServerTimestamp> {
+ &self.0
+ }
+}
+
+#[derive(Debug, Clone, Deserialize)]
+pub struct UploadResult {
+ batch: Option<String>,
+ /// Maps record id => why failed
+ #[serde(default = "HashMap::new")]
+ pub failed: HashMap<Guid, String>,
+ /// Vec of ids
+ #[serde(default = "Vec::new")]
+ pub success: Vec<Guid>,
+}
+
+pub type PostResponse = Sync15ClientResponse<UploadResult>;
+
+#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
+pub enum BatchState {
+ Unsupported,
+ NoBatch,
+ InBatch(String),
+}
+
+#[derive(Debug)]
+pub struct PostQueue<Post, OnResponse> {
+ poster: Post,
+ on_response: OnResponse,
+ post_limits: LimitTracker,
+ batch_limits: LimitTracker,
+ max_payload_bytes: usize,
+ max_request_bytes: usize,
+ queued: Vec<u8>,
+ batch: BatchState,
+ last_modified: ServerTimestamp,
+}
+
+pub trait BatchPoster {
+ /// Note: Last argument (reference to the batch poster) is provided for the purposes of testing
+ /// Important: Poster should not report non-success HTTP statuses as errors!!
+ fn post<P, O>(
+ &self,
+ body: Vec<u8>,
+ xius: ServerTimestamp,
+ batch: Option<String>,
+ commit: bool,
+ queue: &PostQueue<P, O>,
+ ) -> Result<PostResponse>;
+}
+
+// We don't just use a FnMut here since we want to override it in mocking for RefCell<TestType>,
+// which we can't do for FnMut since neither FnMut nor RefCell are defined here. Also, this
+// is somewhat better for documentation.
+pub trait PostResponseHandler {
+ fn handle_response(&mut self, r: PostResponse, mid_batch: bool) -> Result<()>;
+}
+
+#[derive(Debug, Clone)]
+pub(crate) struct NormalResponseHandler {
+ pub failed_ids: Vec<Guid>,
+ pub successful_ids: Vec<Guid>,
+ pub allow_failed: bool,
+ pub pending_failed: Vec<Guid>,
+ pub pending_success: Vec<Guid>,
+}
+
+impl NormalResponseHandler {
+ pub fn new(allow_failed: bool) -> NormalResponseHandler {
+ NormalResponseHandler {
+ failed_ids: vec![],
+ successful_ids: vec![],
+ pending_failed: vec![],
+ pending_success: vec![],
+ allow_failed,
+ }
+ }
+}
+
+impl PostResponseHandler for NormalResponseHandler {
+ fn handle_response(&mut self, r: PostResponse, mid_batch: bool) -> error::Result<()> {
+ match r {
+ Sync15ClientResponse::Success { record, .. } => {
+ if !record.failed.is_empty() && !self.allow_failed {
+ return Err(ErrorKind::RecordUploadFailed);
+ }
+ for id in record.success.iter() {
+ self.pending_success.push(id.clone());
+ }
+ for kv in record.failed.iter() {
+ self.pending_failed.push(kv.0.clone());
+ }
+ if !mid_batch {
+ self.successful_ids.append(&mut self.pending_success);
+ self.failed_ids.append(&mut self.pending_failed);
+ }
+ Ok(())
+ }
+ _ => Err(r.create_storage_error()),
+ }
+ }
+}
+
+impl<Poster, OnResponse> PostQueue<Poster, OnResponse>
+where
+ Poster: BatchPoster,
+ OnResponse: PostResponseHandler,
+{
+ pub fn new(
+ config: &InfoConfiguration,
+ ts: ServerTimestamp,
+ poster: Poster,
+ on_response: OnResponse,
+ ) -> PostQueue<Poster, OnResponse> {
+ PostQueue {
+ poster,
+ on_response,
+ last_modified: ts,
+ post_limits: LimitTracker::new(config.max_post_bytes, config.max_post_records),
+ batch_limits: LimitTracker::new(config.max_total_bytes, config.max_total_records),
+ batch: BatchState::NoBatch,
+ max_payload_bytes: config.max_record_payload_bytes,
+ max_request_bytes: config.max_request_bytes,
+ queued: Vec::new(),
+ }
+ }
+
+ #[inline]
+ fn in_batch(&self) -> bool {
+ !matches!(&self.batch, BatchState::Unsupported | BatchState::NoBatch)
+ }
+
+ pub fn enqueue(&mut self, record: &OutgoingEncryptedBso) -> Result<bool> {
+ let payload_length = record.serialized_payload_len();
+
+ if self.post_limits.can_never_add(payload_length)
+ || self.batch_limits.can_never_add(payload_length)
+ || payload_length >= self.max_payload_bytes
+ {
+ log::warn!(
+ "Single record too large to submit to server ({} b)",
+ payload_length
+ );
+ return Ok(false);
+ }
+
+ // Write directly into `queued` but undo if necessary (the vast majority of the time
+ // it won't be necessary). If we hit a problem we need to undo that, but the only error
+ // case we have to worry about right now is in flush()
+ let item_start = self.queued.len();
+
+ // This is conservative but can't hurt.
+ self.queued.reserve(payload_length + 2);
+
+ // Either the first character in an array, or a comma separating
+ // it from the previous item.
+ let c = if self.queued.is_empty() { b'[' } else { b',' };
+ self.queued.push(c);
+
+ // This unwrap is fine, since serde_json's failure case is HashMaps that have non-object
+ // keys, which is impossible. If you decide to change this part, you *need* to call
+ // `self.queued.truncate(item_start)` here in the failure case!
+ serde_json::to_writer(&mut self.queued, &record).unwrap();
+
+ let item_end = self.queued.len();
+
+ debug_assert!(
+ item_end >= payload_length,
+ "EncryptedPayload::serialized_len is bugged"
+ );
+
+ // The + 1 is only relevant for the final record, which will have a trailing ']'.
+ let item_len = item_end - item_start + 1;
+
+ if item_len >= self.max_request_bytes {
+ self.queued.truncate(item_start);
+ log::warn!(
+ "Single record too large to submit to server ({} b)",
+ item_len
+ );
+ return Ok(false);
+ }
+
+ let can_post_record = self.post_limits.can_add_record(payload_length);
+ let can_batch_record = self.batch_limits.can_add_record(payload_length);
+ let can_send_record = self.queued.len() < self.max_request_bytes;
+
+ if !can_post_record || !can_send_record || !can_batch_record {
+ log::debug!(
+ "PostQueue flushing! (can_post = {}, can_send = {}, can_batch = {})",
+ can_post_record,
+ can_send_record,
+ can_batch_record
+ );
+ // "unwrite" the record.
+ self.queued.truncate(item_start);
+ // Flush whatever we have queued.
+ self.flush(!can_batch_record)?;
+ // And write it again.
+ let c = if self.queued.is_empty() { b'[' } else { b',' };
+ self.queued.push(c);
+ serde_json::to_writer(&mut self.queued, &record).unwrap();
+ }
+
+ self.post_limits.record_added(payload_length);
+ self.batch_limits.record_added(payload_length);
+
+ Ok(true)
+ }
+
+ pub fn flush(&mut self, want_commit: bool) -> Result<()> {
+ if self.queued.is_empty() {
+ assert!(
+ !self.in_batch(),
+ "Bug: Somehow we're in a batch but have no queued records"
+ );
+ // Nothing to do!
+ return Ok(());
+ }
+
+ self.queued.push(b']');
+ let batch_id = match &self.batch {
+ // Not the first post and we know we have no batch semantics.
+ BatchState::Unsupported => None,
+ // First commit in possible batch
+ BatchState::NoBatch => Some("true".into()),
+ // In a batch and we have a batch id.
+ BatchState::InBatch(ref s) => Some(s.clone()),
+ };
+
+ log::info!(
+ "Posting {} records of {} bytes",
+ self.post_limits.cur_records,
+ self.queued.len()
+ );
+
+ let is_commit = want_commit && batch_id.is_some();
+ // Weird syntax for calling a function object that is a property.
+ let resp_or_error = self.poster.post(
+ self.queued.clone(),
+ self.last_modified,
+ batch_id,
+ is_commit,
+ self,
+ );
+
+ self.queued.truncate(0);
+
+ if want_commit || self.batch == BatchState::Unsupported {
+ self.batch_limits.clear();
+ }
+ self.post_limits.clear();
+
+ let resp = resp_or_error?;
+
+ let (status, last_modified, record) = match resp {
+ Sync15ClientResponse::Success {
+ status,
+ last_modified,
+ ref record,
+ ..
+ } => (status, last_modified, record),
+ _ => {
+ self.on_response.handle_response(resp, !want_commit)?;
+ // on_response() should always fail!
+ unreachable!();
+ }
+ };
+
+ if want_commit || self.batch == BatchState::Unsupported {
+ self.last_modified = last_modified;
+ }
+
+ if want_commit {
+ log::debug!("Committed batch {:?}", self.batch);
+ self.batch = BatchState::NoBatch;
+ self.on_response.handle_response(resp, false)?;
+ return Ok(());
+ }
+
+ if status != status_codes::ACCEPTED {
+ if self.in_batch() {
+ return Err(ErrorKind::ServerBatchProblem(
+ "Server responded non-202 success code while a batch was in progress",
+ ));
+ }
+ self.last_modified = last_modified;
+ self.batch = BatchState::Unsupported;
+ self.batch_limits.clear();
+ self.on_response.handle_response(resp, false)?;
+ return Ok(());
+ }
+
+ let batch_id = record
+ .batch
+ .as_ref()
+ .ok_or({
+ ErrorKind::ServerBatchProblem("Invalid server response: 202 without a batch ID")
+ })?
+ .clone();
+
+ match &self.batch {
+ BatchState::Unsupported => {
+ log::warn!("Server changed its mind about supporting batching mid-batch...");
+ }
+
+ BatchState::InBatch(ref cur_id) => {
+ if cur_id != &batch_id {
+ return Err(ErrorKind::ServerBatchProblem(
+ "Invalid server response: 202 without a batch ID",
+ ));
+ }
+ }
+ _ => {}
+ }
+
+ // Can't change this in match arms without NLL
+ self.batch = BatchState::InBatch(batch_id);
+ self.last_modified = last_modified;
+
+ self.on_response.handle_response(resp, true)?;
+
+ Ok(())
+ }
+}
+
+#[derive(Clone)]
+pub struct UploadInfo {
+ pub successful_ids: Vec<Guid>,
+ pub failed_ids: Vec<Guid>,
+ pub modified_timestamp: ServerTimestamp,
+}
+
+impl<Poster> PostQueue<Poster, NormalResponseHandler> {
+ // TODO: should take by move
+ pub fn completed_upload_info(&mut self) -> UploadInfo {
+ let mut result = UploadInfo {
+ successful_ids: Vec::with_capacity(self.on_response.successful_ids.len()),
+ failed_ids: Vec::with_capacity(
+ self.on_response.failed_ids.len()
+ + self.on_response.pending_failed.len()
+ + self.on_response.pending_success.len(),
+ ),
+ modified_timestamp: self.last_modified,
+ };
+
+ result
+ .successful_ids
+ .append(&mut self.on_response.successful_ids);
+
+ result.failed_ids.append(&mut self.on_response.failed_ids);
+ result
+ .failed_ids
+ .append(&mut self.on_response.pending_failed);
+ result
+ .failed_ids
+ .append(&mut self.on_response.pending_success);
+
+ result
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ use crate::bso::{IncomingEncryptedBso, OutgoingEncryptedBso, OutgoingEnvelope};
+ use crate::EncryptedPayload;
+ use lazy_static::lazy_static;
+ use std::cell::RefCell;
+ use std::collections::VecDeque;
+ use std::rc::Rc;
+
+ #[derive(Debug, Clone)]
+ struct PostedData {
+ body: String,
+ _xius: ServerTimestamp,
+ batch: Option<String>,
+ commit: bool,
+ payload_bytes: usize,
+ records: usize,
+ }
+
+ impl PostedData {
+ fn records_as_json(&self) -> Vec<serde_json::Value> {
+ let values =
+ serde_json::from_str::<serde_json::Value>(&self.body).expect("Posted invalid json");
+ // Check that they actually deserialize as what we want
+ let records_or_err =
+ serde_json::from_value::<Vec<IncomingEncryptedBso>>(values.clone());
+ records_or_err.expect("Failed to deserialize data");
+ serde_json::from_value(values).unwrap()
+ }
+ }
+
+ #[derive(Debug, Clone)]
+ struct BatchInfo {
+ id: Option<String>,
+ posts: Vec<PostedData>,
+ bytes: usize,
+ records: usize,
+ }
+
+ #[derive(Debug, Clone)]
+ struct TestPoster {
+ all_posts: Vec<PostedData>,
+ responses: VecDeque<PostResponse>,
+ batches: Vec<BatchInfo>,
+ cur_batch: Option<BatchInfo>,
+ cfg: InfoConfiguration,
+ }
+
+ type TestPosterRef = Rc<RefCell<TestPoster>>;
+ impl TestPoster {
+ pub fn new<T>(cfg: &InfoConfiguration, responses: T) -> TestPosterRef
+ where
+ T: Into<VecDeque<PostResponse>>,
+ {
+ Rc::new(RefCell::new(TestPoster {
+ all_posts: vec![],
+ responses: responses.into(),
+ batches: vec![],
+ cur_batch: None,
+ cfg: cfg.clone(),
+ }))
+ }
+ // Adds &mut
+ fn do_post<T, O>(
+ &mut self,
+ body: &[u8],
+ xius: ServerTimestamp,
+ batch: Option<String>,
+ commit: bool,
+ queue: &PostQueue<T, O>,
+ ) -> Sync15ClientResponse<UploadResult> {
+ let mut post = PostedData {
+ body: String::from_utf8(body.into()).expect("Posted invalid utf8..."),
+ batch: batch.clone(),
+ _xius: xius,
+ commit,
+ payload_bytes: 0,
+ records: 0,
+ };
+
+ assert!(body.len() <= self.cfg.max_request_bytes);
+
+ let (num_records, record_payload_bytes) = {
+ let recs = post.records_as_json();
+ assert!(recs.len() <= self.cfg.max_post_records);
+ assert!(recs.len() <= self.cfg.max_total_records);
+ let payload_bytes: usize = recs
+ .iter()
+ .map(|r| {
+ let len = r["payload"]
+ .as_str()
+ .expect("Non string payload property")
+ .len();
+ assert!(len <= self.cfg.max_record_payload_bytes);
+ len
+ })
+ .sum();
+ assert!(payload_bytes <= self.cfg.max_post_bytes);
+ assert!(payload_bytes <= self.cfg.max_total_bytes);
+
+ assert_eq!(queue.post_limits.cur_bytes, payload_bytes);
+ assert_eq!(queue.post_limits.cur_records, recs.len());
+ (recs.len(), payload_bytes)
+ };
+ post.payload_bytes = record_payload_bytes;
+ post.records = num_records;
+
+ self.all_posts.push(post.clone());
+ let response = self.responses.pop_front().unwrap();
+
+ let record = match response {
+ Sync15ClientResponse::Success { ref record, .. } => record,
+ _ => {
+ panic!("only success codes are used in this test");
+ }
+ };
+
+ if self.cur_batch.is_none() {
+ assert!(
+ batch.is_none() || batch == Some("true".into()),
+ "We shouldn't be in a batch now"
+ );
+ self.cur_batch = Some(BatchInfo {
+ id: record.batch.clone(),
+ posts: vec![],
+ records: 0,
+ bytes: 0,
+ });
+ } else {
+ assert_eq!(
+ batch,
+ self.cur_batch.as_ref().unwrap().id,
+ "We're in a batch but got the wrong batch id"
+ );
+ }
+
+ {
+ let batch = self.cur_batch.as_mut().unwrap();
+ batch.posts.push(post);
+ batch.records += num_records;
+ batch.bytes += record_payload_bytes;
+
+ assert!(batch.bytes <= self.cfg.max_total_bytes);
+ assert!(batch.records <= self.cfg.max_total_records);
+
+ assert_eq!(batch.records, queue.batch_limits.cur_records);
+ assert_eq!(batch.bytes, queue.batch_limits.cur_bytes);
+ }
+
+ if commit || record.batch.is_none() {
+ let batch = self.cur_batch.take().unwrap();
+ self.batches.push(batch);
+ }
+
+ response
+ }
+
+ fn do_handle_response(&mut self, _: PostResponse, mid_batch: bool) {
+ assert_eq!(mid_batch, self.cur_batch.is_some());
+ }
+ }
+ impl BatchPoster for TestPosterRef {
+ fn post<T, O>(
+ &self,
+ body: Vec<u8>,
+ xius: ServerTimestamp,
+ batch: Option<String>,
+ commit: bool,
+ queue: &PostQueue<T, O>,
+ ) -> Result<PostResponse> {
+ Ok(self.borrow_mut().do_post(&body, xius, batch, commit, queue))
+ }
+ }
+
+ impl PostResponseHandler for TestPosterRef {
+ fn handle_response(&mut self, r: PostResponse, mid_batch: bool) -> Result<()> {
+ self.borrow_mut().do_handle_response(r, mid_batch);
+ Ok(())
+ }
+ }
+
+ type MockedPostQueue = PostQueue<TestPosterRef, TestPosterRef>;
+
+ fn pq_test_setup(
+ cfg: InfoConfiguration,
+ lm: i64,
+ resps: Vec<PostResponse>,
+ ) -> (MockedPostQueue, TestPosterRef) {
+ let tester = TestPoster::new(&cfg, resps);
+ let pq = PostQueue::new(&cfg, ServerTimestamp(lm), tester.clone(), tester.clone());
+ (pq, tester)
+ }
+
+ fn fake_response<'a, T: Into<Option<&'a str>>>(status: u16, lm: i64, batch: T) -> PostResponse {
+ assert!(status_codes::is_success_code(status));
+ Sync15ClientResponse::Success {
+ status,
+ last_modified: ServerTimestamp(lm),
+ record: UploadResult {
+ batch: batch.into().map(Into::into),
+ failed: HashMap::new(),
+ success: vec![],
+ },
+ route: "test/path".into(),
+ }
+ }
+
+ lazy_static! {
+ // ~40b
+ static ref PAYLOAD_OVERHEAD: usize = {
+ let payload = EncryptedPayload {
+ iv: "".into(),
+ hmac: "".into(),
+ ciphertext: "".into()
+ };
+ serde_json::to_string(&payload).unwrap().len()
+ };
+ // ~80b
+ static ref TOTAL_RECORD_OVERHEAD: usize = {
+ let val = serde_json::to_value(OutgoingEncryptedBso::new(OutgoingEnvelope {
+ id: "".into(),
+ sortindex: None,
+ ttl: None,
+ },
+ EncryptedPayload {
+ iv: "".into(),
+ hmac: "".into(),
+ ciphertext: "".into()
+ },
+ )).unwrap();
+ serde_json::to_string(&val).unwrap().len()
+ };
+ // There's some subtlety in how we calulate this having to do with the fact that
+ // the quotes in the payload are escaped but the escape chars count to the request len
+ // and *not* to the payload len (the payload len check happens after json parsing the
+ // top level object).
+ static ref NON_PAYLOAD_OVERHEAD: usize = {
+ *TOTAL_RECORD_OVERHEAD - *PAYLOAD_OVERHEAD
+ };
+ }
+
+ // Actual record size (for max_request_len) will be larger by some amount
+ fn make_record(payload_size: usize) -> OutgoingEncryptedBso {
+ assert!(payload_size > *PAYLOAD_OVERHEAD);
+ let ciphertext_len = payload_size - *PAYLOAD_OVERHEAD;
+ OutgoingEncryptedBso::new(
+ OutgoingEnvelope {
+ id: "".into(),
+ sortindex: None,
+ ttl: None,
+ },
+ EncryptedPayload {
+ iv: "".into(),
+ hmac: "".into(),
+ ciphertext: "x".repeat(ciphertext_len),
+ },
+ )
+ }
+
+ fn request_bytes_for_payloads(payloads: &[usize]) -> usize {
+ 1 + payloads
+ .iter()
+ .map(|&size| size + 1 + *NON_PAYLOAD_OVERHEAD)
+ .sum::<usize>()
+ }
+
+ #[test]
+ fn test_pq_basic() {
+ let cfg = InfoConfiguration {
+ max_request_bytes: 1000,
+ max_record_payload_bytes: 1000,
+ ..InfoConfiguration::default()
+ };
+ let time = 11_111_111_000;
+ let (mut pq, tester) = pq_test_setup(
+ cfg,
+ time,
+ vec![fake_response(status_codes::OK, time + 100_000, None)],
+ );
+
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.flush(true).unwrap();
+
+ let t = tester.borrow();
+ assert!(t.cur_batch.is_none());
+ assert_eq!(t.all_posts.len(), 1);
+ assert_eq!(t.batches.len(), 1);
+ assert_eq!(t.batches[0].posts.len(), 1);
+ assert_eq!(t.batches[0].records, 1);
+ assert_eq!(t.batches[0].bytes, 100);
+ assert_eq!(
+ t.batches[0].posts[0].body.len(),
+ request_bytes_for_payloads(&[100])
+ );
+ }
+
+ #[test]
+ fn test_pq_max_request_bytes_no_batch() {
+ let cfg = InfoConfiguration {
+ max_request_bytes: 250,
+ ..InfoConfiguration::default()
+ };
+ let time = 11_111_111_000;
+ let (mut pq, tester) = pq_test_setup(
+ cfg,
+ time,
+ vec![
+ fake_response(status_codes::OK, time + 100_000, None),
+ fake_response(status_codes::OK, time + 200_000, None),
+ ],
+ );
+
+ // Note that the total record overhead is around 85 bytes
+ let payload_size = 100 - *NON_PAYLOAD_OVERHEAD;
+ pq.enqueue(&make_record(payload_size)).unwrap(); // total size == 102; [r]
+ pq.enqueue(&make_record(payload_size)).unwrap(); // total size == 203; [r,r]
+ pq.enqueue(&make_record(payload_size)).unwrap(); // too big, 2nd post.
+ pq.flush(true).unwrap();
+
+ let t = tester.borrow();
+ assert!(t.cur_batch.is_none());
+ assert_eq!(t.all_posts.len(), 2);
+ assert_eq!(t.batches.len(), 2);
+ assert_eq!(t.batches[0].posts.len(), 1);
+ assert_eq!(t.batches[0].records, 2);
+ assert_eq!(t.batches[0].bytes, payload_size * 2);
+ assert_eq!(t.batches[0].posts[0].batch, Some("true".into()));
+ assert_eq!(
+ t.batches[0].posts[0].body.len(),
+ request_bytes_for_payloads(&[payload_size, payload_size])
+ );
+
+ assert_eq!(t.batches[1].posts.len(), 1);
+ assert_eq!(t.batches[1].records, 1);
+ assert_eq!(t.batches[1].bytes, payload_size);
+ // We know at this point that the server does not support batching.
+ assert_eq!(t.batches[1].posts[0].batch, None);
+ assert!(!t.batches[1].posts[0].commit);
+ assert_eq!(
+ t.batches[1].posts[0].body.len(),
+ request_bytes_for_payloads(&[payload_size])
+ );
+ }
+
+ #[test]
+ fn test_pq_max_record_payload_bytes_no_batch() {
+ let cfg = InfoConfiguration {
+ max_record_payload_bytes: 150,
+ max_request_bytes: 350,
+ ..InfoConfiguration::default()
+ };
+ let time = 11_111_111_000;
+ let (mut pq, tester) = pq_test_setup(
+ cfg,
+ time,
+ vec![
+ fake_response(status_codes::OK, time + 100_000, None),
+ fake_response(status_codes::OK, time + 200_000, None),
+ ],
+ );
+
+ // Note that the total record overhead is around 85 bytes
+ let payload_size = 100 - *NON_PAYLOAD_OVERHEAD;
+ pq.enqueue(&make_record(payload_size)).unwrap(); // total size == 102; [r]
+ let enqueued = pq.enqueue(&make_record(151)).unwrap(); // still 102
+ assert!(!enqueued, "Should not have fit");
+ pq.enqueue(&make_record(payload_size)).unwrap();
+ pq.flush(true).unwrap();
+
+ let t = tester.borrow();
+ assert!(t.cur_batch.is_none());
+ assert_eq!(t.all_posts.len(), 1);
+ assert_eq!(t.batches.len(), 1);
+ assert_eq!(t.batches[0].posts.len(), 1);
+ assert_eq!(t.batches[0].records, 2);
+ assert_eq!(t.batches[0].bytes, payload_size * 2);
+ assert_eq!(
+ t.batches[0].posts[0].body.len(),
+ request_bytes_for_payloads(&[payload_size, payload_size])
+ );
+ }
+
+ #[test]
+ fn test_pq_single_batch() {
+ let cfg = InfoConfiguration::default();
+ let time = 11_111_111_000;
+ let (mut pq, tester) = pq_test_setup(
+ cfg,
+ time,
+ vec![fake_response(
+ status_codes::ACCEPTED,
+ time + 100_000,
+ Some("1234"),
+ )],
+ );
+
+ let payload_size = 100 - *NON_PAYLOAD_OVERHEAD;
+ pq.enqueue(&make_record(payload_size)).unwrap();
+ pq.enqueue(&make_record(payload_size)).unwrap();
+ pq.enqueue(&make_record(payload_size)).unwrap();
+ pq.flush(true).unwrap();
+
+ let t = tester.borrow();
+ assert!(t.cur_batch.is_none());
+ assert_eq!(t.all_posts.len(), 1);
+ assert_eq!(t.batches.len(), 1);
+ assert_eq!(t.batches[0].id.as_ref().unwrap(), "1234");
+ assert_eq!(t.batches[0].posts.len(), 1);
+ assert_eq!(t.batches[0].records, 3);
+ assert_eq!(t.batches[0].bytes, payload_size * 3);
+ assert!(t.batches[0].posts[0].commit);
+ assert_eq!(
+ t.batches[0].posts[0].body.len(),
+ request_bytes_for_payloads(&[payload_size, payload_size, payload_size])
+ );
+ }
+
+ #[test]
+ fn test_pq_multi_post_batch_bytes() {
+ let cfg = InfoConfiguration {
+ max_post_bytes: 200,
+ ..InfoConfiguration::default()
+ };
+ let time = 11_111_111_000;
+ let (mut pq, tester) = pq_test_setup(
+ cfg,
+ time,
+ vec![
+ fake_response(status_codes::ACCEPTED, time, Some("1234")),
+ fake_response(status_codes::ACCEPTED, time + 100_000, Some("1234")),
+ ],
+ );
+
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.enqueue(&make_record(100)).unwrap();
+ // POST
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.flush(true).unwrap(); // COMMIT
+
+ let t = tester.borrow();
+ assert!(t.cur_batch.is_none());
+ assert_eq!(t.all_posts.len(), 2);
+ assert_eq!(t.batches.len(), 1);
+ assert_eq!(t.batches[0].posts.len(), 2);
+ assert_eq!(t.batches[0].records, 3);
+ assert_eq!(t.batches[0].bytes, 300);
+
+ assert_eq!(t.batches[0].posts[0].batch.as_ref().unwrap(), "true");
+ assert_eq!(t.batches[0].posts[0].records, 2);
+ assert_eq!(t.batches[0].posts[0].payload_bytes, 200);
+ assert!(!t.batches[0].posts[0].commit);
+ assert_eq!(
+ t.batches[0].posts[0].body.len(),
+ request_bytes_for_payloads(&[100, 100])
+ );
+
+ assert_eq!(t.batches[0].posts[1].batch.as_ref().unwrap(), "1234");
+ assert_eq!(t.batches[0].posts[1].records, 1);
+ assert_eq!(t.batches[0].posts[1].payload_bytes, 100);
+ assert!(t.batches[0].posts[1].commit);
+ assert_eq!(
+ t.batches[0].posts[1].body.len(),
+ request_bytes_for_payloads(&[100])
+ );
+ }
+
+ #[test]
+ fn test_pq_multi_post_batch_records() {
+ let cfg = InfoConfiguration {
+ max_post_records: 3,
+ ..InfoConfiguration::default()
+ };
+ let time = 11_111_111_000;
+ let (mut pq, tester) = pq_test_setup(
+ cfg,
+ time,
+ vec![
+ fake_response(status_codes::ACCEPTED, time, Some("1234")),
+ fake_response(status_codes::ACCEPTED, time, Some("1234")),
+ fake_response(status_codes::ACCEPTED, time + 100_000, Some("1234")),
+ ],
+ );
+
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.enqueue(&make_record(100)).unwrap();
+ // POST
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.enqueue(&make_record(100)).unwrap();
+ // POST
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.flush(true).unwrap(); // COMMIT
+
+ let t = tester.borrow();
+ assert!(t.cur_batch.is_none());
+ assert_eq!(t.all_posts.len(), 3);
+ assert_eq!(t.batches.len(), 1);
+ assert_eq!(t.batches[0].posts.len(), 3);
+ assert_eq!(t.batches[0].records, 7);
+ assert_eq!(t.batches[0].bytes, 700);
+
+ assert_eq!(t.batches[0].posts[0].batch.as_ref().unwrap(), "true");
+ assert_eq!(t.batches[0].posts[0].records, 3);
+ assert_eq!(t.batches[0].posts[0].payload_bytes, 300);
+ assert!(!t.batches[0].posts[0].commit);
+ assert_eq!(
+ t.batches[0].posts[0].body.len(),
+ request_bytes_for_payloads(&[100, 100, 100])
+ );
+
+ assert_eq!(t.batches[0].posts[1].batch.as_ref().unwrap(), "1234");
+ assert_eq!(t.batches[0].posts[1].records, 3);
+ assert_eq!(t.batches[0].posts[1].payload_bytes, 300);
+ assert!(!t.batches[0].posts[1].commit);
+ assert_eq!(
+ t.batches[0].posts[1].body.len(),
+ request_bytes_for_payloads(&[100, 100, 100])
+ );
+
+ assert_eq!(t.batches[0].posts[2].batch.as_ref().unwrap(), "1234");
+ assert_eq!(t.batches[0].posts[2].records, 1);
+ assert_eq!(t.batches[0].posts[2].payload_bytes, 100);
+ assert!(t.batches[0].posts[2].commit);
+ assert_eq!(
+ t.batches[0].posts[2].body.len(),
+ request_bytes_for_payloads(&[100])
+ );
+ }
+
+ #[test]
+ #[allow(clippy::cognitive_complexity)]
+ fn test_pq_multi_post_multi_batch_records() {
+ let cfg = InfoConfiguration {
+ max_post_records: 3,
+ max_total_records: 5,
+ ..InfoConfiguration::default()
+ };
+ let time = 11_111_111_000;
+ let (mut pq, tester) = pq_test_setup(
+ cfg,
+ time,
+ vec![
+ fake_response(status_codes::ACCEPTED, time, Some("1234")),
+ fake_response(status_codes::ACCEPTED, time + 100_000, Some("1234")),
+ fake_response(status_codes::ACCEPTED, time + 100_000, Some("abcd")),
+ fake_response(status_codes::ACCEPTED, time + 200_000, Some("abcd")),
+ ],
+ );
+
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.enqueue(&make_record(100)).unwrap();
+ // POST
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.enqueue(&make_record(100)).unwrap();
+ // POST + COMMIT
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.enqueue(&make_record(100)).unwrap();
+ // POST
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.flush(true).unwrap(); // COMMIT
+
+ let t = tester.borrow();
+ assert!(t.cur_batch.is_none());
+ assert_eq!(t.all_posts.len(), 4);
+ assert_eq!(t.batches.len(), 2);
+ assert_eq!(t.batches[0].posts.len(), 2);
+ assert_eq!(t.batches[1].posts.len(), 2);
+
+ assert_eq!(t.batches[0].records, 5);
+ assert_eq!(t.batches[1].records, 4);
+
+ assert_eq!(t.batches[0].bytes, 500);
+ assert_eq!(t.batches[1].bytes, 400);
+
+ assert_eq!(t.batches[0].posts[0].batch.as_ref().unwrap(), "true");
+ assert_eq!(t.batches[0].posts[0].records, 3);
+ assert_eq!(t.batches[0].posts[0].payload_bytes, 300);
+ assert!(!t.batches[0].posts[0].commit);
+ assert_eq!(
+ t.batches[0].posts[0].body.len(),
+ request_bytes_for_payloads(&[100, 100, 100])
+ );
+
+ assert_eq!(t.batches[0].posts[1].batch.as_ref().unwrap(), "1234");
+ assert_eq!(t.batches[0].posts[1].records, 2);
+ assert_eq!(t.batches[0].posts[1].payload_bytes, 200);
+ assert!(t.batches[0].posts[1].commit);
+ assert_eq!(
+ t.batches[0].posts[1].body.len(),
+ request_bytes_for_payloads(&[100, 100])
+ );
+
+ assert_eq!(t.batches[1].posts[0].batch.as_ref().unwrap(), "true");
+ assert_eq!(t.batches[1].posts[0].records, 3);
+ assert_eq!(t.batches[1].posts[0].payload_bytes, 300);
+ assert!(!t.batches[1].posts[0].commit);
+ assert_eq!(
+ t.batches[1].posts[0].body.len(),
+ request_bytes_for_payloads(&[100, 100, 100])
+ );
+
+ assert_eq!(t.batches[1].posts[1].batch.as_ref().unwrap(), "abcd");
+ assert_eq!(t.batches[1].posts[1].records, 1);
+ assert_eq!(t.batches[1].posts[1].payload_bytes, 100);
+ assert!(t.batches[1].posts[1].commit);
+ assert_eq!(
+ t.batches[1].posts[1].body.len(),
+ request_bytes_for_payloads(&[100])
+ );
+ }
+
+ #[test]
+ #[allow(clippy::cognitive_complexity)]
+ fn test_pq_multi_post_multi_batch_bytes() {
+ let cfg = InfoConfiguration {
+ max_post_bytes: 300,
+ max_total_bytes: 500,
+ ..InfoConfiguration::default()
+ };
+ let time = 11_111_111_000;
+ let (mut pq, tester) = pq_test_setup(
+ cfg,
+ time,
+ vec![
+ fake_response(status_codes::ACCEPTED, time, Some("1234")),
+ fake_response(status_codes::ACCEPTED, time + 100_000, Some("1234")), // should commit
+ fake_response(status_codes::ACCEPTED, time + 100_000, Some("abcd")),
+ fake_response(status_codes::ACCEPTED, time + 200_000, Some("abcd")), // should commit
+ ],
+ );
+
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.enqueue(&make_record(100)).unwrap();
+ assert_eq!(pq.last_modified.0, time);
+ // POST
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.enqueue(&make_record(100)).unwrap();
+ // POST + COMMIT
+ pq.enqueue(&make_record(100)).unwrap();
+ assert_eq!(pq.last_modified.0, time + 100_000);
+ pq.enqueue(&make_record(100)).unwrap();
+ pq.enqueue(&make_record(100)).unwrap();
+
+ // POST
+ pq.enqueue(&make_record(100)).unwrap();
+ assert_eq!(pq.last_modified.0, time + 100_000);
+ pq.flush(true).unwrap(); // COMMIT
+
+ assert_eq!(pq.last_modified.0, time + 200_000);
+
+ let t = tester.borrow();
+ assert!(t.cur_batch.is_none());
+ assert_eq!(t.all_posts.len(), 4);
+ assert_eq!(t.batches.len(), 2);
+ assert_eq!(t.batches[0].posts.len(), 2);
+ assert_eq!(t.batches[1].posts.len(), 2);
+
+ assert_eq!(t.batches[0].records, 5);
+ assert_eq!(t.batches[1].records, 4);
+
+ assert_eq!(t.batches[0].bytes, 500);
+ assert_eq!(t.batches[1].bytes, 400);
+
+ assert_eq!(t.batches[0].posts[0].batch.as_ref().unwrap(), "true");
+ assert_eq!(t.batches[0].posts[0].records, 3);
+ assert_eq!(t.batches[0].posts[0].payload_bytes, 300);
+ assert!(!t.batches[0].posts[0].commit);
+ assert_eq!(
+ t.batches[0].posts[0].body.len(),
+ request_bytes_for_payloads(&[100, 100, 100])
+ );
+
+ assert_eq!(t.batches[0].posts[1].batch.as_ref().unwrap(), "1234");
+ assert_eq!(t.batches[0].posts[1].records, 2);
+ assert_eq!(t.batches[0].posts[1].payload_bytes, 200);
+ assert!(t.batches[0].posts[1].commit);
+ assert_eq!(
+ t.batches[0].posts[1].body.len(),
+ request_bytes_for_payloads(&[100, 100])
+ );
+
+ assert_eq!(t.batches[1].posts[0].batch.as_ref().unwrap(), "true");
+ assert_eq!(t.batches[1].posts[0].records, 3);
+ assert_eq!(t.batches[1].posts[0].payload_bytes, 300);
+ assert!(!t.batches[1].posts[0].commit);
+ assert_eq!(
+ t.batches[1].posts[0].body.len(),
+ request_bytes_for_payloads(&[100, 100, 100])
+ );
+
+ assert_eq!(t.batches[1].posts[1].batch.as_ref().unwrap(), "abcd");
+ assert_eq!(t.batches[1].posts[1].records, 1);
+ assert_eq!(t.batches[1].posts[1].payload_bytes, 100);
+ assert!(t.batches[1].posts[1].commit);
+ assert_eq!(
+ t.batches[1].posts[1].body.len(),
+ request_bytes_for_payloads(&[100])
+ );
+ }
+
+ // TODO: Test
+ //
+ // - error cases!!! We don't test our handling of server errors at all!
+ // - mixed bytes/record limits
+ //
+ // A lot of these have good examples in test_postqueue.js on deskftop sync
+}
diff --git a/third_party/rust/sync15/src/client/state.rs b/third_party/rust/sync15/src/client/state.rs
new file mode 100644
index 0000000000..78e9a6a718
--- /dev/null
+++ b/third_party/rust/sync15/src/client/state.rs
@@ -0,0 +1,1089 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use std::collections::{HashMap, HashSet};
+
+use super::request::{InfoCollections, InfoConfiguration};
+use super::storage_client::{SetupStorageClient, Sync15ClientResponse};
+use super::CollectionKeys;
+use crate::bso::OutgoingEncryptedBso;
+use crate::error::{self, Error as ErrorKind, ErrorResponse};
+use crate::record_types::{MetaGlobalEngine, MetaGlobalRecord};
+use crate::EncryptedPayload;
+use crate::{Guid, KeyBundle, ServerTimestamp};
+use interrupt_support::Interruptee;
+use serde_derive::*;
+
+use self::SetupState::*;
+
+const STORAGE_VERSION: usize = 5;
+
+/// Maps names to storage versions for engines to include in a fresh
+/// `meta/global` record. We include engines that we don't implement
+/// because they'll be disabled on other clients if we omit them
+/// (bug 1479929).
+const DEFAULT_ENGINES: &[(&str, usize)] = &[
+ ("passwords", 1),
+ ("clients", 1),
+ ("addons", 1),
+ ("addresses", 1),
+ ("bookmarks", 2),
+ ("creditcards", 1),
+ ("forms", 1),
+ ("history", 1),
+ ("prefs", 2),
+ ("tabs", 1),
+];
+
+// Declined engines to include in a fresh `meta/global` record.
+const DEFAULT_DECLINED: &[&str] = &[];
+
+/// State that we require the app to persist to storage for us.
+/// It's a little unfortunate we need this, because it's only tracking
+/// "declined engines", and even then, only needed in practice when there's
+/// no meta/global so we need to create one. It's extra unfortunate because we
+/// want to move away from "globally declined" engines anyway, moving towards
+/// allowing engines to be enabled or disabled per client rather than globally.
+///
+/// Apps are expected to treat this as opaque, so we support serializing it.
+/// Note that this structure is *not* used to *change* the declined engines
+/// list - that will be done in the future, but the API exposed for that
+/// purpose will also take a mutable PersistedGlobalState.
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "schema_version")]
+pub enum PersistedGlobalState {
+ /// V1 was when we persisted the entire GlobalState, keys and all!
+
+ /// V2 is just tracking the globally declined list.
+ /// None means "I've no idea" and theoretically should only happen on the
+ /// very first sync for an app.
+ V2 { declined: Option<Vec<String>> },
+}
+
+impl Default for PersistedGlobalState {
+ #[inline]
+ fn default() -> PersistedGlobalState {
+ PersistedGlobalState::V2 { declined: None }
+ }
+}
+
+#[derive(Debug, Default, Clone, PartialEq)]
+pub(crate) struct EngineChangesNeeded {
+ pub local_resets: HashSet<String>,
+ pub remote_wipes: HashSet<String>,
+}
+
+#[derive(Debug, Default, Clone, PartialEq)]
+struct RemoteEngineState {
+ info_collections: HashSet<String>,
+ declined: HashSet<String>,
+}
+
+#[derive(Debug, Default, Clone, PartialEq)]
+struct EngineStateInput {
+ local_declined: HashSet<String>,
+ remote: Option<RemoteEngineState>,
+ user_changes: HashMap<String, bool>,
+}
+
+#[derive(Debug, Default, Clone, PartialEq)]
+struct EngineStateOutput {
+ // The new declined.
+ declined: HashSet<String>,
+ // Which engines need resets or wipes.
+ changes_needed: EngineChangesNeeded,
+}
+
+fn compute_engine_states(input: EngineStateInput) -> EngineStateOutput {
+ use super::util::*;
+ log::debug!("compute_engine_states: input {:?}", input);
+ let (must_enable, must_disable) = partition_by_value(&input.user_changes);
+ let have_remote = input.remote.is_some();
+ let RemoteEngineState {
+ info_collections,
+ declined: remote_declined,
+ } = input.remote.clone().unwrap_or_default();
+
+ let both_declined_and_remote = set_intersection(&info_collections, &remote_declined);
+ if !both_declined_and_remote.is_empty() {
+ // Should we wipe these too?
+ log::warn!(
+ "Remote state contains engines which are in both info/collections and meta/global's declined: {:?}",
+ both_declined_and_remote,
+ );
+ }
+
+ let most_recent_declined_list = if have_remote {
+ &remote_declined
+ } else {
+ &input.local_declined
+ };
+
+ let result_declined = set_difference(
+ &set_union(most_recent_declined_list, &must_disable),
+ &must_enable,
+ );
+
+ let output = EngineStateOutput {
+ changes_needed: EngineChangesNeeded {
+ // Anything now declined which wasn't in our declined list before gets a reset.
+ local_resets: set_difference(&result_declined, &input.local_declined),
+ // Anything remote that we just declined gets a wipe. In the future
+ // we might want to consider wiping things in both remote declined
+ // and info/collections, but we'll let other clients pick up their
+ // own mess for now.
+ remote_wipes: set_intersection(&info_collections, &must_disable),
+ },
+ declined: result_declined,
+ };
+ // No PII here and this helps debug problems.
+ log::debug!("compute_engine_states: output {:?}", output);
+ output
+}
+
+impl PersistedGlobalState {
+ fn set_declined(&mut self, new_declined: Vec<String>) {
+ match self {
+ Self::V2 { ref mut declined } => *declined = Some(new_declined),
+ }
+ }
+ pub(crate) fn get_declined(&self) -> &[String] {
+ match self {
+ Self::V2 { declined: Some(d) } => d,
+ Self::V2 { declined: None } => &[],
+ }
+ }
+}
+
+/// Holds global Sync state, including server upload limits, the
+/// last-fetched collection modified times, `meta/global` record, and
+/// an encrypted copy of the crypto/keys resource (avoids keeping them
+/// in memory longer than necessary; avoids key mismatches by ensuring the same KeyBundle
+/// is used for both the keys and encrypted payloads.)
+#[derive(Debug, Clone)]
+pub struct GlobalState {
+ pub config: InfoConfiguration,
+ pub collections: InfoCollections,
+ pub global: MetaGlobalRecord,
+ pub global_timestamp: ServerTimestamp,
+ pub keys: EncryptedPayload,
+ pub keys_timestamp: ServerTimestamp,
+}
+
+/// Creates a fresh `meta/global` record, using the default engine selections,
+/// and declined engines from our PersistedGlobalState.
+fn new_global(pgs: &PersistedGlobalState) -> MetaGlobalRecord {
+ let sync_id = Guid::random();
+ let mut engines: HashMap<String, _> = HashMap::new();
+ for (name, version) in DEFAULT_ENGINES.iter() {
+ let sync_id = Guid::random();
+ engines.insert(
+ (*name).to_string(),
+ MetaGlobalEngine {
+ version: *version,
+ sync_id,
+ },
+ );
+ }
+ // We only need our PersistedGlobalState to fill out a new meta/global - if
+ // we previously saw a meta/global then we would have updated it with what
+ // it was at the time.
+ let declined = match pgs {
+ PersistedGlobalState::V2 { declined: Some(d) } => d.clone(),
+ _ => DEFAULT_DECLINED.iter().map(ToString::to_string).collect(),
+ };
+
+ MetaGlobalRecord {
+ sync_id,
+ storage_version: STORAGE_VERSION,
+ engines,
+ declined,
+ }
+}
+
+fn fixup_meta_global(global: &mut MetaGlobalRecord) -> bool {
+ let mut changed_any = false;
+ for &(name, version) in DEFAULT_ENGINES.iter() {
+ let had_engine = global.engines.contains_key(name);
+ let should_have_engine = !global.declined.iter().any(|c| c == name);
+ if had_engine != should_have_engine {
+ if should_have_engine {
+ log::debug!("SyncID for engine {:?} was missing", name);
+ global.engines.insert(
+ name.to_string(),
+ MetaGlobalEngine {
+ version,
+ sync_id: Guid::random(),
+ },
+ );
+ } else {
+ log::debug!("SyncID for engine {:?} was present, but shouldn't be", name);
+ global.engines.remove(name);
+ }
+ changed_any = true;
+ }
+ }
+ changed_any
+}
+
+pub struct SetupStateMachine<'a> {
+ client: &'a dyn SetupStorageClient,
+ root_key: &'a KeyBundle,
+ pgs: &'a mut PersistedGlobalState,
+ // `allowed_states` is designed so that we can arrange for the concept of
+ // a "fast" sync - so we decline to advance if we need to setup from scratch.
+ // The idea is that if we need to sync before going to sleep we should do
+ // it as fast as possible. However, in practice this isn't going to do
+ // what we expect - a "fast sync" that finds lots to do is almost certainly
+ // going to take longer than a "full sync" that finds nothing to do.
+ // We should almost certainly remove this and instead allow for a "time
+ // budget", after which we get interrupted. Later...
+ allowed_states: Vec<&'static str>,
+ sequence: Vec<&'static str>,
+ engine_updates: Option<&'a HashMap<String, bool>>,
+ interruptee: &'a dyn Interruptee,
+ pub(crate) changes_needed: Option<EngineChangesNeeded>,
+}
+
+impl<'a> SetupStateMachine<'a> {
+ /// Creates a state machine for a "classic" Sync 1.5 client that supports
+ /// all states, including uploading a fresh `meta/global` and `crypto/keys`
+ /// after a node reassignment.
+ pub fn for_full_sync(
+ client: &'a dyn SetupStorageClient,
+ root_key: &'a KeyBundle,
+ pgs: &'a mut PersistedGlobalState,
+ engine_updates: Option<&'a HashMap<String, bool>>,
+ interruptee: &'a dyn Interruptee,
+ ) -> SetupStateMachine<'a> {
+ SetupStateMachine::with_allowed_states(
+ client,
+ root_key,
+ pgs,
+ interruptee,
+ engine_updates,
+ vec![
+ "Initial",
+ "InitialWithConfig",
+ "InitialWithInfo",
+ "InitialWithMetaGlobal",
+ "Ready",
+ "FreshStartRequired",
+ "WithPreviousState",
+ ],
+ )
+ }
+
+ fn with_allowed_states(
+ client: &'a dyn SetupStorageClient,
+ root_key: &'a KeyBundle,
+ pgs: &'a mut PersistedGlobalState,
+ interruptee: &'a dyn Interruptee,
+ engine_updates: Option<&'a HashMap<String, bool>>,
+ allowed_states: Vec<&'static str>,
+ ) -> SetupStateMachine<'a> {
+ SetupStateMachine {
+ client,
+ root_key,
+ pgs,
+ sequence: Vec::new(),
+ allowed_states,
+ engine_updates,
+ interruptee,
+ changes_needed: None,
+ }
+ }
+
+ fn advance(&mut self, from: SetupState) -> error::Result<SetupState> {
+ match from {
+ // Fetch `info/configuration` with current server limits, and
+ // `info/collections` with collection last modified times.
+ Initial => {
+ let config = match self.client.fetch_info_configuration()? {
+ Sync15ClientResponse::Success { record, .. } => record,
+ Sync15ClientResponse::Error(ErrorResponse::NotFound { .. }) => {
+ InfoConfiguration::default()
+ }
+ other => return Err(other.create_storage_error()),
+ };
+ Ok(InitialWithConfig { config })
+ }
+
+ // XXX - we could consider combining these Initial* states, because we don't
+ // attempt to support filling in "missing" global state - *any* 404 in them
+ // means `FreshStart`.
+ // IOW, in all cases, they either `Err()`, move to `FreshStartRequired`, or
+ // advance to a specific next state.
+ InitialWithConfig { config } => {
+ match self.client.fetch_info_collections()? {
+ Sync15ClientResponse::Success {
+ record: collections,
+ ..
+ } => Ok(InitialWithInfo {
+ config,
+ collections,
+ }),
+ // If the server doesn't have a `crypto/keys`, start over
+ // and reupload our `meta/global` and `crypto/keys`.
+ Sync15ClientResponse::Error(ErrorResponse::NotFound { .. }) => {
+ Ok(FreshStartRequired { config })
+ }
+ other => Err(other.create_storage_error()),
+ }
+ }
+
+ InitialWithInfo {
+ config,
+ collections,
+ } => {
+ match self.client.fetch_meta_global()? {
+ Sync15ClientResponse::Success {
+ record: mut global,
+ last_modified: mut global_timestamp,
+ ..
+ } => {
+ // If the server has a newer storage version, we can't
+ // sync until our client is updated.
+ if global.storage_version > STORAGE_VERSION {
+ return Err(ErrorKind::ClientUpgradeRequired);
+ }
+
+ // If the server has an older storage version, wipe and
+ // reupload.
+ if global.storage_version < STORAGE_VERSION {
+ Ok(FreshStartRequired { config })
+ } else {
+ log::info!("Have info/collections and meta/global. Computing new engine states");
+ let initial_global_declined: HashSet<String> =
+ global.declined.iter().cloned().collect();
+ let result = compute_engine_states(EngineStateInput {
+ local_declined: self.pgs.get_declined().iter().cloned().collect(),
+ user_changes: self.engine_updates.cloned().unwrap_or_default(),
+ remote: Some(RemoteEngineState {
+ declined: initial_global_declined.clone(),
+ info_collections: collections.keys().cloned().collect(),
+ }),
+ });
+ // Persist the new declined.
+ self.pgs
+ .set_declined(result.declined.iter().cloned().collect());
+ // If the declined engines differ from remote, fix that.
+ let fixed_declined = if result.declined != initial_global_declined {
+ global.declined = result.declined.iter().cloned().collect();
+ log::info!(
+ "Uploading new declined {:?} to meta/global with timestamp {:?}",
+ global.declined,
+ global_timestamp,
+ );
+ true
+ } else {
+ false
+ };
+ // If there are missing syncIds, we need to fix those as well
+ let fixed_ids = if fixup_meta_global(&mut global) {
+ log::info!(
+ "Uploading corrected meta/global with timestamp {:?}",
+ global_timestamp,
+ );
+ true
+ } else {
+ false
+ };
+
+ if fixed_declined || fixed_ids {
+ global_timestamp =
+ self.client.put_meta_global(global_timestamp, &global)?;
+ log::debug!("new global_timestamp: {:?}", global_timestamp);
+ }
+ // Update the set of changes needed.
+ if self.changes_needed.is_some() {
+ // Should never happen (we prevent state machine
+ // loops elsewhere) but if it did, the info is stale
+ // anyway.
+ log::warn!("Already have a set of changes needed, Overwriting...");
+ }
+ self.changes_needed = Some(result.changes_needed);
+ Ok(InitialWithMetaGlobal {
+ config,
+ collections,
+ global,
+ global_timestamp,
+ })
+ }
+ }
+ Sync15ClientResponse::Error(ErrorResponse::NotFound { .. }) => {
+ Ok(FreshStartRequired { config })
+ }
+ other => Err(other.create_storage_error()),
+ }
+ }
+
+ InitialWithMetaGlobal {
+ config,
+ collections,
+ global,
+ global_timestamp,
+ } => {
+ // Now try and get keys etc - if we fresh-start we'll re-use declined.
+ match self.client.fetch_crypto_keys()? {
+ Sync15ClientResponse::Success {
+ record,
+ last_modified,
+ ..
+ } => {
+ // Note that collection/keys is itself a bso, so the
+ // json body also carries the timestamp. If they aren't
+ // identical something has screwed up and we should die.
+ assert_eq!(last_modified, record.envelope.modified);
+ let state = GlobalState {
+ config,
+ collections,
+ global,
+ global_timestamp,
+ keys: record.payload,
+ keys_timestamp: last_modified,
+ };
+ Ok(Ready { state })
+ }
+ // If the server doesn't have a `crypto/keys`, start over
+ // and reupload our `meta/global` and `crypto/keys`.
+ Sync15ClientResponse::Error(ErrorResponse::NotFound { .. }) => {
+ Ok(FreshStartRequired { config })
+ }
+ other => Err(other.create_storage_error()),
+ }
+ }
+
+ // We've got old state that's likely to be OK.
+ // We keep things simple here - if there's evidence of a new/missing
+ // meta/global or new/missing keys we just restart from scratch.
+ WithPreviousState { old_state } => match self.client.fetch_info_collections()? {
+ Sync15ClientResponse::Success {
+ record: collections,
+ ..
+ } => Ok(
+ if self.engine_updates.is_none()
+ && is_same_timestamp(old_state.global_timestamp, &collections, "meta")
+ && is_same_timestamp(old_state.keys_timestamp, &collections, "crypto")
+ {
+ Ready {
+ state: GlobalState {
+ collections,
+ ..old_state
+ },
+ }
+ } else {
+ InitialWithConfig {
+ config: old_state.config,
+ }
+ },
+ ),
+ _ => Ok(InitialWithConfig {
+ config: old_state.config,
+ }),
+ },
+
+ Ready { state } => Ok(Ready { state }),
+
+ FreshStartRequired { config } => {
+ // Wipe the server.
+ log::info!("Fresh start: wiping remote");
+ self.client.wipe_all_remote()?;
+
+ // Upload a fresh `meta/global`...
+ log::info!("Uploading meta/global");
+ let computed = compute_engine_states(EngineStateInput {
+ local_declined: self.pgs.get_declined().iter().cloned().collect(),
+ user_changes: self.engine_updates.cloned().unwrap_or_default(),
+ remote: None,
+ });
+ self.pgs
+ .set_declined(computed.declined.iter().cloned().collect());
+
+ self.changes_needed = Some(computed.changes_needed);
+
+ let new_global = new_global(self.pgs);
+
+ self.client
+ .put_meta_global(ServerTimestamp::default(), &new_global)?;
+
+ // ...And a fresh `crypto/keys`.
+ let new_keys = CollectionKeys::new_random()?.to_encrypted_payload(self.root_key)?;
+ let bso = OutgoingEncryptedBso::new(Guid::new("keys").into(), new_keys);
+ self.client
+ .put_crypto_keys(ServerTimestamp::default(), &bso)?;
+
+ // TODO(lina): Can we pass along server timestamps from the PUTs
+ // above, and avoid re-fetching the `m/g` and `c/k` we just
+ // uploaded?
+ // OTOH(mark): restarting the state machine keeps life simple and rare.
+ Ok(InitialWithConfig { config })
+ }
+ }
+ }
+
+ /// Runs through the state machine to the ready state.
+ pub fn run_to_ready(&mut self, state: Option<GlobalState>) -> error::Result<GlobalState> {
+ let mut s = match state {
+ Some(old_state) => WithPreviousState { old_state },
+ None => Initial,
+ };
+ loop {
+ self.interruptee.err_if_interrupted()?;
+ let label = &s.label();
+ log::trace!("global state: {:?}", label);
+ match s {
+ Ready { state } => {
+ self.sequence.push(label);
+ return Ok(state);
+ }
+ // If we already started over once before, we're likely in a
+ // cycle, and should try again later. Intermediate states
+ // aren't a problem, just the initial ones.
+ FreshStartRequired { .. } | WithPreviousState { .. } | Initial => {
+ if self.sequence.contains(label) {
+ // Is this really the correct error?
+ return Err(ErrorKind::SetupRace);
+ }
+ }
+ _ => {
+ if !self.allowed_states.contains(label) {
+ return Err(ErrorKind::SetupRequired);
+ }
+ }
+ };
+ self.sequence.push(label);
+ s = self.advance(s)?;
+ }
+ }
+}
+
+/// States in the remote setup process.
+/// TODO(lina): Add link once #56 is merged.
+#[derive(Debug)]
+#[allow(clippy::large_enum_variant)]
+enum SetupState {
+ // These "Initial" states are only ever used when starting from scratch.
+ Initial,
+ InitialWithConfig {
+ config: InfoConfiguration,
+ },
+ InitialWithInfo {
+ config: InfoConfiguration,
+ collections: InfoCollections,
+ },
+ InitialWithMetaGlobal {
+ config: InfoConfiguration,
+ collections: InfoCollections,
+ global: MetaGlobalRecord,
+ global_timestamp: ServerTimestamp,
+ },
+ WithPreviousState {
+ old_state: GlobalState,
+ },
+ Ready {
+ state: GlobalState,
+ },
+ FreshStartRequired {
+ config: InfoConfiguration,
+ },
+}
+
+impl SetupState {
+ fn label(&self) -> &'static str {
+ match self {
+ Initial { .. } => "Initial",
+ InitialWithConfig { .. } => "InitialWithConfig",
+ InitialWithInfo { .. } => "InitialWithInfo",
+ InitialWithMetaGlobal { .. } => "InitialWithMetaGlobal",
+ Ready { .. } => "Ready",
+ WithPreviousState { .. } => "WithPreviousState",
+ FreshStartRequired { .. } => "FreshStartRequired",
+ }
+ }
+}
+
+/// Whether we should skip fetching an item. Used when we already have timestamps
+/// and want to check if we should reuse our existing state. The state's fairly
+/// cheap to recreate and very bad to use if it is wrong, so we insist on the
+/// *exact* timestamp matching and not a simple "later than" check.
+fn is_same_timestamp(local: ServerTimestamp, collections: &InfoCollections, key: &str) -> bool {
+ collections.get(key).map_or(false, |ts| local == *ts)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::bso::{IncomingEncryptedBso, IncomingEnvelope};
+ use interrupt_support::NeverInterrupts;
+
+ struct InMemoryClient {
+ info_configuration: error::Result<Sync15ClientResponse<InfoConfiguration>>,
+ info_collections: error::Result<Sync15ClientResponse<InfoCollections>>,
+ meta_global: error::Result<Sync15ClientResponse<MetaGlobalRecord>>,
+ crypto_keys: error::Result<Sync15ClientResponse<IncomingEncryptedBso>>,
+ }
+
+ impl SetupStorageClient for InMemoryClient {
+ fn fetch_info_configuration(
+ &self,
+ ) -> error::Result<Sync15ClientResponse<InfoConfiguration>> {
+ match &self.info_configuration {
+ Ok(client_response) => Ok(client_response.clone()),
+ Err(_) => Ok(Sync15ClientResponse::Error(ErrorResponse::ServerError {
+ status: 500,
+ route: "test/path".into(),
+ })),
+ }
+ }
+
+ fn fetch_info_collections(&self) -> error::Result<Sync15ClientResponse<InfoCollections>> {
+ match &self.info_collections {
+ Ok(collections) => Ok(collections.clone()),
+ Err(_) => Ok(Sync15ClientResponse::Error(ErrorResponse::ServerError {
+ status: 500,
+ route: "test/path".into(),
+ })),
+ }
+ }
+
+ fn fetch_meta_global(&self) -> error::Result<Sync15ClientResponse<MetaGlobalRecord>> {
+ match &self.meta_global {
+ Ok(global) => Ok(global.clone()),
+ // TODO(lina): Special handling for 404s, we want to ensure we
+ // handle missing keys and other server errors correctly.
+ Err(_) => Ok(Sync15ClientResponse::Error(ErrorResponse::ServerError {
+ status: 500,
+ route: "test/path".into(),
+ })),
+ }
+ }
+
+ fn put_meta_global(
+ &self,
+ xius: ServerTimestamp,
+ global: &MetaGlobalRecord,
+ ) -> error::Result<ServerTimestamp> {
+ // Ensure that the meta/global record we uploaded is "fixed up"
+ assert!(DEFAULT_ENGINES
+ .iter()
+ .filter(|e| e.0 != "logins")
+ .all(|&(k, _v)| global.engines.contains_key(k)));
+ assert!(!global.engines.contains_key("logins"));
+ assert_eq!(global.declined, vec!["logins".to_string()]);
+ // return a different timestamp.
+ Ok(ServerTimestamp(xius.0 + 1))
+ }
+
+ fn fetch_crypto_keys(&self) -> error::Result<Sync15ClientResponse<IncomingEncryptedBso>> {
+ match &self.crypto_keys {
+ Ok(Sync15ClientResponse::Success {
+ status,
+ record,
+ last_modified,
+ route,
+ }) => Ok(Sync15ClientResponse::Success {
+ status: *status,
+ record: IncomingEncryptedBso::new(
+ record.envelope.clone(),
+ record.payload.clone(),
+ ),
+ last_modified: *last_modified,
+ route: route.clone(),
+ }),
+ // TODO(lina): Same as above, for 404s.
+ _ => Ok(Sync15ClientResponse::Error(ErrorResponse::ServerError {
+ status: 500,
+ route: "test/path".into(),
+ })),
+ }
+ }
+
+ fn put_crypto_keys(
+ &self,
+ xius: ServerTimestamp,
+ _keys: &OutgoingEncryptedBso,
+ ) -> error::Result<()> {
+ assert_eq!(xius, ServerTimestamp(888_800));
+ Err(ErrorKind::StorageHttpError(ErrorResponse::ServerError {
+ status: 500,
+ route: "crypto/keys".to_string(),
+ }))
+ }
+
+ fn wipe_all_remote(&self) -> error::Result<()> {
+ Ok(())
+ }
+ }
+
+ #[allow(clippy::unnecessary_wraps)]
+ fn mocked_success_ts<T>(t: T, ts: i64) -> error::Result<Sync15ClientResponse<T>> {
+ Ok(Sync15ClientResponse::Success {
+ status: 200,
+ record: t,
+ last_modified: ServerTimestamp(ts),
+ route: "test/path".into(),
+ })
+ }
+
+ fn mocked_success<T>(t: T) -> error::Result<Sync15ClientResponse<T>> {
+ mocked_success_ts(t, 0)
+ }
+
+ fn mocked_success_keys(
+ keys: CollectionKeys,
+ root_key: &KeyBundle,
+ ) -> error::Result<Sync15ClientResponse<IncomingEncryptedBso>> {
+ let timestamp = keys.timestamp;
+ let payload = keys.to_encrypted_payload(root_key).unwrap();
+ let bso = IncomingEncryptedBso::new(
+ IncomingEnvelope {
+ id: Guid::new("keys"),
+ modified: timestamp,
+ sortindex: None,
+ ttl: None,
+ },
+ payload,
+ );
+ Ok(Sync15ClientResponse::Success {
+ status: 200,
+ record: bso,
+ last_modified: timestamp,
+ route: "test/path".into(),
+ })
+ }
+
+ #[test]
+ fn test_state_machine_ready_from_empty() {
+ let _ = env_logger::try_init();
+ let root_key = KeyBundle::new_random().unwrap();
+ let keys = CollectionKeys {
+ timestamp: ServerTimestamp(123_400),
+ default: KeyBundle::new_random().unwrap(),
+ collections: HashMap::new(),
+ };
+ let mg = MetaGlobalRecord {
+ sync_id: "syncIDAAAAAA".into(),
+ storage_version: 5usize,
+ engines: vec![(
+ "bookmarks",
+ MetaGlobalEngine {
+ version: 1usize,
+ sync_id: "syncIDBBBBBB".into(),
+ },
+ )]
+ .into_iter()
+ .map(|(key, value)| (key.to_owned(), value))
+ .collect(),
+ // We ensure that the record we upload doesn't have a logins record.
+ declined: vec!["logins".to_string()],
+ };
+ let client = InMemoryClient {
+ info_configuration: mocked_success(InfoConfiguration::default()),
+ info_collections: mocked_success(InfoCollections::new(
+ vec![("meta", 123_456), ("crypto", 145_000)]
+ .into_iter()
+ .map(|(key, value)| (key.to_owned(), ServerTimestamp(value)))
+ .collect(),
+ )),
+ meta_global: mocked_success_ts(mg, 999_000),
+ crypto_keys: mocked_success_keys(keys, &root_key),
+ };
+ let mut pgs = PersistedGlobalState::V2 { declined: None };
+
+ let mut state_machine =
+ SetupStateMachine::for_full_sync(&client, &root_key, &mut pgs, None, &NeverInterrupts);
+ assert!(
+ state_machine.run_to_ready(None).is_ok(),
+ "Should drive state machine to ready"
+ );
+ assert_eq!(
+ state_machine.sequence,
+ vec![
+ "Initial",
+ "InitialWithConfig",
+ "InitialWithInfo",
+ "InitialWithMetaGlobal",
+ "Ready",
+ ],
+ "Should cycle through all states"
+ );
+ }
+
+ #[test]
+ fn test_from_previous_state_declined() {
+ let _ = env_logger::try_init();
+ // The state-machine sequence where we didn't use the previous state
+ // (ie, where the state machine restarted)
+ let sm_seq_restarted = vec![
+ "WithPreviousState",
+ "InitialWithConfig",
+ "InitialWithInfo",
+ "InitialWithMetaGlobal",
+ "Ready",
+ ];
+ // The state-machine sequence where we used the previous state.
+ let sm_seq_used_previous = vec!["WithPreviousState", "Ready"];
+
+ // do the actual test.
+ fn do_test(
+ client: &dyn SetupStorageClient,
+ root_key: &KeyBundle,
+ pgs: &mut PersistedGlobalState,
+ engine_updates: Option<&HashMap<String, bool>>,
+ old_state: GlobalState,
+ expected_states: &[&str],
+ ) {
+ let mut state_machine = SetupStateMachine::for_full_sync(
+ client,
+ root_key,
+ pgs,
+ engine_updates,
+ &NeverInterrupts,
+ );
+ assert!(
+ state_machine.run_to_ready(Some(old_state)).is_ok(),
+ "Should drive state machine to ready"
+ );
+ assert_eq!(state_machine.sequence, expected_states);
+ }
+
+ // and all the complicated setup...
+ let ts_metaglobal = 123_456;
+ let ts_keys = 145_000;
+ let root_key = KeyBundle::new_random().unwrap();
+ let keys = CollectionKeys {
+ timestamp: ServerTimestamp(ts_keys + 1),
+ default: KeyBundle::new_random().unwrap(),
+ collections: HashMap::new(),
+ };
+ let mg = MetaGlobalRecord {
+ sync_id: "syncIDAAAAAA".into(),
+ storage_version: 5usize,
+ engines: vec![(
+ "bookmarks",
+ MetaGlobalEngine {
+ version: 1usize,
+ sync_id: "syncIDBBBBBB".into(),
+ },
+ )]
+ .into_iter()
+ .map(|(key, value)| (key.to_owned(), value))
+ .collect(),
+ // We ensure that the record we upload doesn't have a logins record.
+ declined: vec!["logins".to_string()],
+ };
+ let collections = InfoCollections::new(
+ vec![("meta", ts_metaglobal), ("crypto", ts_keys)]
+ .into_iter()
+ .map(|(key, value)| (key.to_owned(), ServerTimestamp(value)))
+ .collect(),
+ );
+ let client = InMemoryClient {
+ info_configuration: mocked_success(InfoConfiguration::default()),
+ info_collections: mocked_success(collections.clone()),
+ meta_global: mocked_success_ts(mg.clone(), ts_metaglobal),
+ crypto_keys: mocked_success_keys(keys.clone(), &root_key),
+ };
+
+ // First a test where the "previous" global state is OK to reuse.
+ {
+ let mut pgs = PersistedGlobalState::V2 { declined: None };
+ // A "previous" global state.
+ let old_state = GlobalState {
+ config: InfoConfiguration::default(),
+ collections: collections.clone(),
+ global: mg.clone(),
+ global_timestamp: ServerTimestamp(ts_metaglobal),
+ keys: keys
+ .to_encrypted_payload(&root_key)
+ .expect("should always work in this test"),
+ keys_timestamp: ServerTimestamp(ts_keys),
+ };
+ do_test(
+ &client,
+ &root_key,
+ &mut pgs,
+ None,
+ old_state,
+ &sm_seq_used_previous,
+ );
+ }
+
+ // Now where the meta/global record on the server is later.
+ {
+ let mut pgs = PersistedGlobalState::V2 { declined: None };
+ // A "previous" global state.
+ let old_state = GlobalState {
+ config: InfoConfiguration::default(),
+ collections: collections.clone(),
+ global: mg.clone(),
+ global_timestamp: ServerTimestamp(999_999),
+ keys: keys
+ .to_encrypted_payload(&root_key)
+ .expect("should always work in this test"),
+ keys_timestamp: ServerTimestamp(ts_keys),
+ };
+ do_test(
+ &client,
+ &root_key,
+ &mut pgs,
+ None,
+ old_state,
+ &sm_seq_restarted,
+ );
+ }
+
+ // Where keys on the server is later.
+ {
+ let mut pgs = PersistedGlobalState::V2 { declined: None };
+ // A "previous" global state.
+ let old_state = GlobalState {
+ config: InfoConfiguration::default(),
+ collections: collections.clone(),
+ global: mg.clone(),
+ global_timestamp: ServerTimestamp(ts_metaglobal),
+ keys: keys
+ .to_encrypted_payload(&root_key)
+ .expect("should always work in this test"),
+ keys_timestamp: ServerTimestamp(999_999),
+ };
+ do_test(
+ &client,
+ &root_key,
+ &mut pgs,
+ None,
+ old_state,
+ &sm_seq_restarted,
+ );
+ }
+
+ // Where there are engine-state changes.
+ {
+ let mut pgs = PersistedGlobalState::V2 { declined: None };
+ // A "previous" global state.
+ let old_state = GlobalState {
+ config: InfoConfiguration::default(),
+ collections,
+ global: mg,
+ global_timestamp: ServerTimestamp(ts_metaglobal),
+ keys: keys
+ .to_encrypted_payload(&root_key)
+ .expect("should always work in this test"),
+ keys_timestamp: ServerTimestamp(ts_keys),
+ };
+ let mut engine_updates = HashMap::<String, bool>::new();
+ engine_updates.insert("logins".to_string(), false);
+ do_test(
+ &client,
+ &root_key,
+ &mut pgs,
+ Some(&engine_updates),
+ old_state,
+ &sm_seq_restarted,
+ );
+ let declined = match pgs {
+ PersistedGlobalState::V2 { declined: d } => d,
+ };
+ // and check we now consider logins as declined.
+ assert_eq!(declined, Some(vec!["logins".to_string()]));
+ }
+ }
+
+ fn string_set(s: &[&str]) -> HashSet<String> {
+ s.iter().map(ToString::to_string).collect()
+ }
+ fn string_map<T: Clone>(s: &[(&str, T)]) -> HashMap<String, T> {
+ s.iter().map(|v| (v.0.to_string(), v.1.clone())).collect()
+ }
+ #[test]
+ fn test_engine_states() {
+ assert_eq!(
+ compute_engine_states(EngineStateInput {
+ local_declined: string_set(&["foo", "bar"]),
+ remote: None,
+ user_changes: Default::default(),
+ }),
+ EngineStateOutput {
+ declined: string_set(&["foo", "bar"]),
+ // No wipes, no resets
+ changes_needed: Default::default(),
+ }
+ );
+ assert_eq!(
+ compute_engine_states(EngineStateInput {
+ local_declined: string_set(&["foo", "bar"]),
+ remote: Some(RemoteEngineState {
+ declined: string_set(&["foo"]),
+ info_collections: string_set(&["bar"])
+ }),
+ user_changes: Default::default(),
+ }),
+ EngineStateOutput {
+ // Now we have `foo`.
+ declined: string_set(&["foo"]),
+ // No wipes, no resets, should just be a local update.
+ changes_needed: Default::default(),
+ }
+ );
+ assert_eq!(
+ compute_engine_states(EngineStateInput {
+ local_declined: string_set(&["foo", "bar"]),
+ remote: Some(RemoteEngineState {
+ declined: string_set(&["foo", "bar", "quux"]),
+ info_collections: string_set(&[])
+ }),
+ user_changes: Default::default(),
+ }),
+ EngineStateOutput {
+ // Now we have `foo`.
+ declined: string_set(&["foo", "bar", "quux"]),
+ changes_needed: EngineChangesNeeded {
+ // Should reset `quux`.
+ local_resets: string_set(&["quux"]),
+ // No wipes, though.
+ remote_wipes: string_set(&[]),
+ }
+ }
+ );
+ assert_eq!(
+ compute_engine_states(EngineStateInput {
+ local_declined: string_set(&["bar", "baz"]),
+ remote: Some(RemoteEngineState {
+ declined: string_set(&["bar", "baz",]),
+ info_collections: string_set(&["quux"])
+ }),
+ // Change a declined engine to undeclined.
+ user_changes: string_map(&[("bar", true)]),
+ }),
+ EngineStateOutput {
+ declined: string_set(&["baz"]),
+ // No wipes, just undecline it.
+ changes_needed: Default::default()
+ }
+ );
+ assert_eq!(
+ compute_engine_states(EngineStateInput {
+ local_declined: string_set(&["bar", "baz"]),
+ remote: Some(RemoteEngineState {
+ declined: string_set(&["bar", "baz"]),
+ info_collections: string_set(&["foo"])
+ }),
+ // Change an engine which exists remotely to declined.
+ user_changes: string_map(&[("foo", false)]),
+ }),
+ EngineStateOutput {
+ declined: string_set(&["baz", "bar", "foo"]),
+ // No wipes, just undecline it.
+ changes_needed: EngineChangesNeeded {
+ // Should reset our local foo
+ local_resets: string_set(&["foo"]),
+ // And wipe the server.
+ remote_wipes: string_set(&["foo"]),
+ }
+ }
+ );
+ }
+}
diff --git a/third_party/rust/sync15/src/client/status.rs b/third_party/rust/sync15/src/client/status.rs
new file mode 100644
index 0000000000..407efeec12
--- /dev/null
+++ b/third_party/rust/sync15/src/client/status.rs
@@ -0,0 +1,106 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use crate::error::{Error, ErrorResponse};
+use crate::telemetry::SyncTelemetryPing;
+use std::collections::HashMap;
+use std::time::{Duration, SystemTime};
+
+/// The general status of sync - should probably be moved to the "sync manager"
+/// once we have one!
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum ServiceStatus {
+ /// Everything is fine.
+ Ok,
+ /// Some general network issue.
+ NetworkError,
+ /// Some apparent issue with the servers.
+ ServiceError,
+ /// Some external FxA action needs to be taken.
+ AuthenticationError,
+ /// We declined to do anything for backoff or rate-limiting reasons.
+ BackedOff,
+ /// We were interrupted.
+ Interrupted,
+ /// Something else - you need to check the logs for more details. May
+ /// or may not be transient, we really don't know.
+ OtherError,
+}
+
+impl ServiceStatus {
+ // This is a bit naive and probably will not survive in this form in the
+ // SyncManager - eg, we'll want to handle backoff etc.
+ pub fn from_err(err: &Error) -> ServiceStatus {
+ match err {
+ // HTTP based errors.
+ Error::TokenserverHttpError(status) => {
+ // bit of a shame the tokenserver is different to storage...
+ if *status == 401 {
+ ServiceStatus::AuthenticationError
+ } else {
+ ServiceStatus::ServiceError
+ }
+ }
+ // BackoffError is also from the tokenserver.
+ Error::BackoffError(_) => ServiceStatus::ServiceError,
+ Error::StorageHttpError(ref e) => match e {
+ ErrorResponse::Unauthorized { .. } => ServiceStatus::AuthenticationError,
+ _ => ServiceStatus::ServiceError,
+ },
+
+ // Network errors.
+ Error::RequestError(_) | Error::UnexpectedStatus(_) | Error::HawkError(_) => {
+ ServiceStatus::NetworkError
+ }
+
+ Error::Interrupted(_) => ServiceStatus::Interrupted,
+ _ => ServiceStatus::OtherError,
+ }
+ }
+}
+
+/// The result of a sync request. This too is from the "sync manager", but only
+/// has a fraction of the things it will have when we actually build that.
+#[derive(Debug)]
+pub struct SyncResult {
+ /// The general health.
+ pub service_status: ServiceStatus,
+
+ /// The set of declined engines, if we know them.
+ pub declined: Option<Vec<String>>,
+
+ /// The result of the sync.
+ pub result: Result<(), Error>,
+
+ /// The result for each engine.
+ /// Note that we expect the `String` to be replaced with an enum later.
+ pub engine_results: HashMap<String, Result<(), Error>>,
+
+ pub telemetry: SyncTelemetryPing,
+
+ pub next_sync_after: Option<std::time::SystemTime>,
+}
+
+// If `r` has a BackoffError, then returns the later backoff value.
+fn advance_backoff(cur_best: SystemTime, r: &Result<(), Error>) -> SystemTime {
+ if let Err(e) = r {
+ if let Some(time) = e.get_backoff() {
+ return std::cmp::max(time, cur_best);
+ }
+ }
+ cur_best
+}
+
+impl SyncResult {
+ pub(crate) fn set_sync_after(&mut self, backoff_duration: Duration) {
+ let now = SystemTime::now();
+ let toplevel = advance_backoff(now + backoff_duration, &self.result);
+ let sync_after = self.engine_results.values().fold(toplevel, advance_backoff);
+ if sync_after <= now {
+ self.next_sync_after = None;
+ } else {
+ self.next_sync_after = Some(sync_after);
+ }
+ }
+}
diff --git a/third_party/rust/sync15/src/client/storage_client.rs b/third_party/rust/sync15/src/client/storage_client.rs
new file mode 100644
index 0000000000..83dbbf294e
--- /dev/null
+++ b/third_party/rust/sync15/src/client/storage_client.rs
@@ -0,0 +1,587 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use super::request::{
+ BatchPoster, InfoCollections, InfoConfiguration, PostQueue, PostResponse, PostResponseHandler,
+};
+use super::token;
+use crate::bso::{IncomingBso, IncomingEncryptedBso, OutgoingBso, OutgoingEncryptedBso};
+use crate::engine::CollectionRequest;
+use crate::error::{self, Error, ErrorResponse};
+use crate::record_types::MetaGlobalRecord;
+use crate::{Guid, ServerTimestamp};
+use serde_json::Value;
+use std::str::FromStr;
+use std::sync::atomic::{AtomicU32, Ordering};
+use url::Url;
+use viaduct::{
+ header_names::{self, AUTHORIZATION},
+ Method, Request, Response,
+};
+
+/// A response from a GET request on a Sync15StorageClient, encapsulating all
+/// the variants users of this client needs to care about.
+#[derive(Debug, Clone)]
+pub enum Sync15ClientResponse<T> {
+ Success {
+ status: u16,
+ record: T,
+ last_modified: ServerTimestamp,
+ route: String,
+ },
+ Error(ErrorResponse),
+}
+
+fn parse_seconds(seconds_str: &str) -> Option<u32> {
+ let secs = seconds_str.parse::<f64>().ok()?.ceil();
+ // Note: u32 doesn't impl TryFrom<f64> :(
+ if !secs.is_finite() || secs < 0.0 || secs > f64::from(u32::max_value()) {
+ log::warn!("invalid backoff value: {}", secs);
+ None
+ } else {
+ Some(secs as u32)
+ }
+}
+
+impl<T> Sync15ClientResponse<T> {
+ pub fn from_response(resp: Response, backoff_listener: &BackoffListener) -> error::Result<Self>
+ where
+ for<'a> T: serde::de::Deserialize<'a>,
+ {
+ let route: String = resp.url.path().into();
+ // Android seems to respect retry_after even on success requests, so we
+ // will too if it's present. This also lets us handle both backoff-like
+ // properties in the same place.
+ let retry_after = resp
+ .headers
+ .get(header_names::RETRY_AFTER)
+ .and_then(parse_seconds);
+
+ let backoff = resp
+ .headers
+ .get(header_names::X_WEAVE_BACKOFF)
+ .and_then(parse_seconds);
+
+ if let Some(b) = backoff {
+ backoff_listener.note_backoff(b);
+ }
+ if let Some(ra) = retry_after {
+ backoff_listener.note_retry_after(ra);
+ }
+
+ Ok(if resp.is_success() {
+ let record: T = resp.json()?;
+ let last_modified = resp
+ .headers
+ .get(header_names::X_LAST_MODIFIED)
+ .and_then(|s| ServerTimestamp::from_str(s).ok())
+ .ok_or(Error::MissingServerTimestamp)?;
+ log::info!(
+ "Successful request to \"{}\", incoming x-last-modified={:?}",
+ route,
+ last_modified
+ );
+
+ Sync15ClientResponse::Success {
+ status: resp.status,
+ record,
+ last_modified,
+ route,
+ }
+ } else {
+ let status = resp.status;
+ log::info!("Request \"{}\" was an error (status={})", route, status);
+ match status {
+ 404 => Sync15ClientResponse::Error(ErrorResponse::NotFound { route }),
+ 401 => Sync15ClientResponse::Error(ErrorResponse::Unauthorized { route }),
+ 412 => Sync15ClientResponse::Error(ErrorResponse::PreconditionFailed { route }),
+ 500..=600 => {
+ Sync15ClientResponse::Error(ErrorResponse::ServerError { route, status })
+ }
+ _ => Sync15ClientResponse::Error(ErrorResponse::RequestFailed { route, status }),
+ }
+ })
+ }
+
+ pub fn create_storage_error(self) -> Error {
+ let inner = match self {
+ Sync15ClientResponse::Success { status, route, .. } => {
+ // This should never happen as callers are expected to have
+ // already special-cased this response, so warn if it does.
+ // (or maybe we could panic?)
+ log::warn!("Converting success response into an error");
+ ErrorResponse::RequestFailed { status, route }
+ }
+ Sync15ClientResponse::Error(e) => e,
+ };
+ Error::StorageHttpError(inner)
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
+pub struct Sync15StorageClientInit {
+ pub key_id: String,
+ pub access_token: String,
+ pub tokenserver_url: Url,
+}
+
+/// A trait containing the methods required to run through the setup state
+/// machine. This is factored out into a separate trait to make mocking
+/// easier.
+pub trait SetupStorageClient {
+ fn fetch_info_configuration(&self) -> error::Result<Sync15ClientResponse<InfoConfiguration>>;
+ fn fetch_info_collections(&self) -> error::Result<Sync15ClientResponse<InfoCollections>>;
+ fn fetch_meta_global(&self) -> error::Result<Sync15ClientResponse<MetaGlobalRecord>>;
+ fn fetch_crypto_keys(&self) -> error::Result<Sync15ClientResponse<IncomingEncryptedBso>>;
+
+ fn put_meta_global(
+ &self,
+ xius: ServerTimestamp,
+ global: &MetaGlobalRecord,
+ ) -> error::Result<ServerTimestamp>;
+ fn put_crypto_keys(
+ &self,
+ xius: ServerTimestamp,
+ keys: &OutgoingEncryptedBso,
+ ) -> error::Result<()>;
+ fn wipe_all_remote(&self) -> error::Result<()>;
+}
+
+#[derive(Debug, Default)]
+pub struct BackoffState {
+ pub backoff_secs: AtomicU32,
+ pub retry_after_secs: AtomicU32,
+}
+
+pub(crate) type BackoffListener = std::sync::Arc<BackoffState>;
+
+pub(crate) fn new_backoff_listener() -> BackoffListener {
+ std::sync::Arc::new(BackoffState::default())
+}
+
+impl BackoffState {
+ pub fn note_backoff(&self, noted: u32) {
+ super::util::atomic_update_max(&self.backoff_secs, noted)
+ }
+
+ pub fn note_retry_after(&self, noted: u32) {
+ super::util::atomic_update_max(&self.retry_after_secs, noted)
+ }
+
+ pub fn get_backoff_secs(&self) -> u32 {
+ self.backoff_secs.load(Ordering::SeqCst)
+ }
+
+ pub fn get_retry_after_secs(&self) -> u32 {
+ self.retry_after_secs.load(Ordering::SeqCst)
+ }
+
+ pub fn get_required_wait(&self, ignore_soft_backoff: bool) -> Option<std::time::Duration> {
+ let bo = self.get_backoff_secs();
+ let ra = self.get_retry_after_secs();
+ let secs = u64::from(if ignore_soft_backoff { ra } else { bo.max(ra) });
+ if secs > 0 {
+ Some(std::time::Duration::from_secs(secs))
+ } else {
+ None
+ }
+ }
+
+ pub fn reset(&self) {
+ self.backoff_secs.store(0, Ordering::SeqCst);
+ self.retry_after_secs.store(0, Ordering::SeqCst);
+ }
+}
+
+// meta/global is a clear-text Bso (ie, there's a String `payload` which has a MetaGlobalRecord)
+// We don't use the 'content' helpers here because we want json errors to be fatal here
+// (ie, we don't need tombstones and can't just skip a malformed record)
+type IncMetaGlobalBso = IncomingBso;
+type OutMetaGlobalBso = OutgoingBso;
+
+#[derive(Debug)]
+pub struct Sync15StorageClient {
+ tsc: token::TokenProvider,
+ pub(crate) backoff: BackoffListener,
+}
+
+impl SetupStorageClient for Sync15StorageClient {
+ fn fetch_info_configuration(&self) -> error::Result<Sync15ClientResponse<InfoConfiguration>> {
+ self.relative_storage_request(Method::Get, "info/configuration")
+ }
+
+ fn fetch_info_collections(&self) -> error::Result<Sync15ClientResponse<InfoCollections>> {
+ self.relative_storage_request(Method::Get, "info/collections")
+ }
+
+ fn fetch_meta_global(&self) -> error::Result<Sync15ClientResponse<MetaGlobalRecord>> {
+ let got: Sync15ClientResponse<IncMetaGlobalBso> =
+ self.relative_storage_request(Method::Get, "storage/meta/global")?;
+ Ok(match got {
+ Sync15ClientResponse::Success {
+ record,
+ last_modified,
+ route,
+ status,
+ } => {
+ log::debug!(
+ "Got meta global with modified = {}; last-modified = {}",
+ record.envelope.modified,
+ last_modified
+ );
+ Sync15ClientResponse::Success {
+ record: serde_json::from_str(&record.payload)?,
+ last_modified,
+ route,
+ status,
+ }
+ }
+ Sync15ClientResponse::Error(e) => Sync15ClientResponse::Error(e),
+ })
+ }
+
+ fn fetch_crypto_keys(&self) -> error::Result<Sync15ClientResponse<IncomingEncryptedBso>> {
+ self.relative_storage_request(Method::Get, "storage/crypto/keys")
+ }
+
+ fn put_meta_global(
+ &self,
+ xius: ServerTimestamp,
+ global: &MetaGlobalRecord,
+ ) -> error::Result<ServerTimestamp> {
+ let bso = OutMetaGlobalBso::new(Guid::new("global").into(), global)?;
+ self.put("storage/meta/global", xius, &bso)
+ }
+
+ fn put_crypto_keys(
+ &self,
+ xius: ServerTimestamp,
+ keys: &OutgoingEncryptedBso,
+ ) -> error::Result<()> {
+ self.put("storage/crypto/keys", xius, keys)?;
+ Ok(())
+ }
+
+ fn wipe_all_remote(&self) -> error::Result<()> {
+ let s = self.tsc.api_endpoint()?;
+ let url = Url::parse(&s)?;
+
+ let req = self.build_request(Method::Delete, url)?;
+ match self.exec_request::<Value>(req, false) {
+ Ok(Sync15ClientResponse::Error(ErrorResponse::NotFound { .. }))
+ | Ok(Sync15ClientResponse::Success { .. }) => Ok(()),
+ Ok(resp) => Err(resp.create_storage_error()),
+ Err(e) => Err(e),
+ }
+ }
+}
+
+impl Sync15StorageClient {
+ pub fn new(init_params: Sync15StorageClientInit) -> error::Result<Sync15StorageClient> {
+ rc_crypto::ensure_initialized();
+ let tsc = token::TokenProvider::new(
+ init_params.tokenserver_url,
+ init_params.access_token,
+ init_params.key_id,
+ )?;
+ Ok(Sync15StorageClient {
+ tsc,
+ backoff: new_backoff_listener(),
+ })
+ }
+
+ pub fn get_encrypted_records(
+ &self,
+ collection_request: &CollectionRequest,
+ ) -> error::Result<Sync15ClientResponse<Vec<IncomingEncryptedBso>>> {
+ self.collection_request(Method::Get, collection_request)
+ }
+
+ #[inline]
+ fn authorized(&self, req: Request) -> error::Result<Request> {
+ let hawk_header_value = self.tsc.authorization(&req)?;
+ Ok(req.header(AUTHORIZATION, hawk_header_value)?)
+ }
+
+ // TODO: probably want a builder-like API to do collection requests (e.g. something
+ // that occupies roughly the same conceptual role as the Collection class in desktop)
+ fn build_request(&self, method: Method, url: Url) -> error::Result<Request> {
+ self.authorized(Request::new(method, url).header(header_names::ACCEPT, "application/json")?)
+ }
+
+ fn relative_storage_request<P, T>(
+ &self,
+ method: Method,
+ relative_path: P,
+ ) -> error::Result<Sync15ClientResponse<T>>
+ where
+ P: AsRef<str>,
+ for<'a> T: serde::de::Deserialize<'a>,
+ {
+ let s = self.tsc.api_endpoint()? + "/";
+ let url = Url::parse(&s)?.join(relative_path.as_ref())?;
+ self.exec_request(self.build_request(method, url)?, false)
+ }
+
+ fn exec_request<T>(
+ &self,
+ req: Request,
+ require_success: bool,
+ ) -> error::Result<Sync15ClientResponse<T>>
+ where
+ for<'a> T: serde::de::Deserialize<'a>,
+ {
+ log::trace!(
+ "request: {} {} ({:?})",
+ req.method,
+ req.url.path(),
+ req.url.query()
+ );
+ let resp = req.send()?;
+
+ let result = Sync15ClientResponse::from_response(resp, &self.backoff)?;
+ match result {
+ Sync15ClientResponse::Success { .. } => Ok(result),
+ _ => {
+ if require_success {
+ Err(result.create_storage_error())
+ } else {
+ Ok(result)
+ }
+ }
+ }
+ }
+
+ fn collection_request<T>(
+ &self,
+ method: Method,
+ r: &CollectionRequest,
+ ) -> error::Result<Sync15ClientResponse<T>>
+ where
+ for<'a> T: serde::de::Deserialize<'a>,
+ {
+ let url = build_collection_request_url(Url::parse(&self.tsc.api_endpoint()?)?, r)?;
+ self.exec_request(self.build_request(method, url)?, false)
+ }
+
+ pub fn new_post_queue<'a, F: PostResponseHandler>(
+ &'a self,
+ coll: &str,
+ config: &InfoConfiguration,
+ ts: ServerTimestamp,
+ on_response: F,
+ ) -> error::Result<PostQueue<PostWrapper<'a>, F>> {
+ let pw = PostWrapper {
+ client: self,
+ coll: coll.into(),
+ };
+ Ok(PostQueue::new(config, ts, pw, on_response))
+ }
+
+ fn put<P, B>(
+ &self,
+ relative_path: P,
+ xius: ServerTimestamp,
+ body: &B,
+ ) -> error::Result<ServerTimestamp>
+ where
+ P: AsRef<str>,
+ B: serde::ser::Serialize,
+ {
+ let s = self.tsc.api_endpoint()? + "/";
+ let url = Url::parse(&s)?.join(relative_path.as_ref())?;
+
+ let req = self
+ .build_request(Method::Put, url)?
+ .json(body)
+ .header(header_names::X_IF_UNMODIFIED_SINCE, format!("{}", xius))?;
+
+ let resp = self.exec_request::<Value>(req, true)?;
+ // Note: we pass `true` for require_success, so this panic never happens.
+ if let Sync15ClientResponse::Success { last_modified, .. } = resp {
+ Ok(last_modified)
+ } else {
+ unreachable!("Error returned exec_request when `require_success` was true");
+ }
+ }
+
+ pub fn hashed_uid(&self) -> error::Result<String> {
+ self.tsc.hashed_uid()
+ }
+
+ pub(crate) fn wipe_remote_engine(&self, engine: &str) -> error::Result<()> {
+ let s = self.tsc.api_endpoint()? + "/";
+ let url = Url::parse(&s)?.join(&format!("storage/{}", engine))?;
+ log::debug!("Wiping: {:?}", url);
+ let req = self.build_request(Method::Delete, url)?;
+ match self.exec_request::<Value>(req, false) {
+ Ok(Sync15ClientResponse::Error(ErrorResponse::NotFound { .. }))
+ | Ok(Sync15ClientResponse::Success { .. }) => Ok(()),
+ Ok(resp) => Err(resp.create_storage_error()),
+ Err(e) => Err(e),
+ }
+ }
+}
+
+pub struct PostWrapper<'a> {
+ client: &'a Sync15StorageClient,
+ coll: String,
+}
+
+impl<'a> BatchPoster for PostWrapper<'a> {
+ fn post<T, O>(
+ &self,
+ bytes: Vec<u8>,
+ xius: ServerTimestamp,
+ batch: Option<String>,
+ commit: bool,
+ _: &PostQueue<T, O>,
+ ) -> error::Result<PostResponse> {
+ let r = CollectionRequest::new(self.coll.clone())
+ .batch(batch)
+ .commit(commit);
+ let url = build_collection_request_url(Url::parse(&self.client.tsc.api_endpoint()?)?, &r)?;
+
+ let req = self
+ .client
+ .build_request(Method::Post, url)?
+ .header(header_names::CONTENT_TYPE, "application/json")?
+ .header(header_names::X_IF_UNMODIFIED_SINCE, format!("{}", xius))?
+ .body(bytes);
+ self.client.exec_request(req, false)
+ }
+}
+
+fn build_collection_request_url(mut base_url: Url, r: &CollectionRequest) -> error::Result<Url> {
+ base_url
+ .path_segments_mut()
+ .map_err(|_| Error::UnacceptableUrl("Storage server URL is not a base".to_string()))?
+ .extend(&["storage", &r.collection]);
+
+ let mut pairs = base_url.query_pairs_mut();
+ if r.full {
+ pairs.append_pair("full", "1");
+ }
+ if r.limit > 0 {
+ pairs.append_pair("limit", &r.limit.to_string());
+ }
+ if let Some(ids) = &r.ids {
+ // Most ids are 12 characters, and we comma separate them, so 13.
+ let mut buf = String::with_capacity(ids.len() * 13);
+ for (i, id) in ids.iter().enumerate() {
+ if i > 0 {
+ buf.push(',');
+ }
+ buf.push_str(id.as_str());
+ }
+ pairs.append_pair("ids", &buf);
+ }
+ if let Some(batch) = &r.batch {
+ pairs.append_pair("batch", batch);
+ }
+ if r.commit {
+ pairs.append_pair("commit", "true");
+ }
+ if let Some(ts) = r.older {
+ pairs.append_pair("older", &ts.to_string());
+ }
+ if let Some(ts) = r.newer {
+ pairs.append_pair("newer", &ts.to_string());
+ }
+ if let Some(o) = r.order {
+ pairs.append_pair("sort", o.as_str());
+ }
+ pairs.finish();
+ drop(pairs);
+
+ // This is strange but just accessing query_pairs_mut makes you have
+ // a trailing question mark on your url. I don't think anything bad
+ // would happen here, but I don't know, and also, it looks dumb so
+ // I'd rather not have it.
+ if base_url.query() == Some("") {
+ base_url.set_query(None);
+ }
+ Ok(base_url)
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ #[test]
+ fn test_send() {
+ fn ensure_send<T: Send>() {}
+ // Compile will fail if not send.
+ ensure_send::<Sync15StorageClient>();
+ }
+
+ #[test]
+ fn test_parse_seconds() {
+ assert_eq!(parse_seconds("1"), Some(1));
+ assert_eq!(parse_seconds("1.4"), Some(2));
+ assert_eq!(parse_seconds("1.5"), Some(2));
+ assert_eq!(parse_seconds("3600.0"), Some(3600));
+ assert_eq!(parse_seconds("3600"), Some(3600));
+ assert_eq!(parse_seconds("-1"), None);
+ assert_eq!(parse_seconds("inf"), None);
+ assert_eq!(parse_seconds("-inf"), None);
+ assert_eq!(parse_seconds("one-thousand"), None);
+ assert_eq!(parse_seconds("4294967295"), Some(4294967295));
+ assert_eq!(parse_seconds("4294967296"), None);
+ }
+
+ #[test]
+ fn test_query_building() {
+ use crate::engine::RequestOrder;
+ let base = Url::parse("https://example.com/sync").unwrap();
+
+ let empty =
+ build_collection_request_url(base.clone(), &CollectionRequest::new("foo")).unwrap();
+ assert_eq!(empty.as_str(), "https://example.com/sync/storage/foo");
+ let batch_start = build_collection_request_url(
+ base.clone(),
+ &CollectionRequest::new("bar")
+ .batch(Some("true".into()))
+ .commit(false),
+ )
+ .unwrap();
+ assert_eq!(
+ batch_start.as_str(),
+ "https://example.com/sync/storage/bar?batch=true"
+ );
+ let batch_commit = build_collection_request_url(
+ base.clone(),
+ &CollectionRequest::new("asdf")
+ .batch(Some("1234abc".into()))
+ .commit(true),
+ )
+ .unwrap();
+ assert_eq!(
+ batch_commit.as_str(),
+ "https://example.com/sync/storage/asdf?batch=1234abc&commit=true"
+ );
+
+ let idreq = build_collection_request_url(
+ base.clone(),
+ &CollectionRequest::new("wutang").full().ids(&["rza", "gza"]),
+ )
+ .unwrap();
+ assert_eq!(
+ idreq.as_str(),
+ "https://example.com/sync/storage/wutang?full=1&ids=rza%2Cgza"
+ );
+
+ let complex = build_collection_request_url(
+ base,
+ &CollectionRequest::new("specific")
+ .full()
+ .limit(10)
+ .sort_by(RequestOrder::Oldest)
+ .older_than(ServerTimestamp(9_876_540))
+ .newer_than(ServerTimestamp(1_234_560)),
+ )
+ .unwrap();
+ assert_eq!(complex.as_str(),
+ "https://example.com/sync/storage/specific?full=1&limit=10&older=9876.54&newer=1234.56&sort=oldest");
+ }
+}
diff --git a/third_party/rust/sync15/src/client/sync.rs b/third_party/rust/sync15/src/client/sync.rs
new file mode 100644
index 0000000000..808dae9c79
--- /dev/null
+++ b/third_party/rust/sync15/src/client/sync.rs
@@ -0,0 +1,105 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use super::coll_state::LocalCollStateMachine;
+use super::coll_update::CollectionUpdate;
+use super::state::GlobalState;
+use super::storage_client::Sync15StorageClient;
+use crate::clients_engine;
+use crate::engine::{IncomingChangeset, SyncEngine};
+use crate::error::Error;
+use crate::telemetry;
+use crate::KeyBundle;
+use interrupt_support::Interruptee;
+
+#[allow(clippy::too_many_arguments)]
+pub fn synchronize_with_clients_engine(
+ client: &Sync15StorageClient,
+ global_state: &GlobalState,
+ root_sync_key: &KeyBundle,
+ clients: Option<&clients_engine::Engine<'_>>,
+ engine: &dyn SyncEngine,
+ fully_atomic: bool,
+ telem_engine: &mut telemetry::Engine,
+ interruptee: &dyn Interruptee,
+) -> Result<(), Error> {
+ let collection = engine.collection_name();
+ log::info!("Syncing collection {}", collection);
+
+ // our global state machine is ready - get the collection machine going.
+ let mut coll_state =
+ match LocalCollStateMachine::get_state(engine, global_state, root_sync_key)? {
+ Some(coll_state) => coll_state,
+ None => {
+ // XXX - this is either "error" or "declined".
+ log::warn!(
+ "can't setup for the {} collection - hopefully it works later",
+ collection
+ );
+ return Ok(());
+ }
+ };
+
+ if let Some(clients) = clients {
+ engine.prepare_for_sync(&|| clients.get_client_data())?;
+ }
+
+ let collection_requests = engine.get_collection_requests(coll_state.last_modified)?;
+ let incoming = if collection_requests.is_empty() {
+ log::info!("skipping incoming for {} - not needed.", collection);
+ vec![IncomingChangeset::new(collection, coll_state.last_modified)]
+ } else {
+ assert_eq!(collection_requests.last().unwrap().collection, collection);
+
+ let count = collection_requests.len();
+ collection_requests
+ .into_iter()
+ .enumerate()
+ .map(|(idx, collection_request)| {
+ interruptee.err_if_interrupted()?;
+ let incoming_changes =
+ super::fetch_incoming(client, &mut coll_state, &collection_request)?;
+
+ log::info!(
+ "Downloaded {} remote changes (request {} of {})",
+ incoming_changes.changes.len(),
+ idx,
+ count,
+ );
+ Ok(incoming_changes)
+ })
+ .collect::<Result<Vec<_>, Error>>()?
+ };
+
+ let new_timestamp = incoming.last().expect("must have >= 1").timestamp;
+ let mut outgoing = engine.apply_incoming(incoming, telem_engine)?;
+
+ interruptee.err_if_interrupted()?;
+ // Bump the timestamps now just incase the upload fails.
+ // xxx - duplication below smells wrong
+ outgoing.timestamp = new_timestamp;
+ coll_state.last_modified = new_timestamp;
+
+ log::info!("Uploading {} outgoing changes", outgoing.changes.len());
+ let upload_info =
+ CollectionUpdate::new_from_changeset(client, &coll_state, outgoing, fully_atomic)?
+ .upload()?;
+
+ log::info!(
+ "Upload success ({} records success, {} records failed)",
+ upload_info.successful_ids.len(),
+ upload_info.failed_ids.len()
+ );
+ // ideally we'd report this per-batch, but for now, let's just report it
+ // as a total.
+ let mut telem_outgoing = telemetry::EngineOutgoing::new();
+ telem_outgoing.sent(upload_info.successful_ids.len() + upload_info.failed_ids.len());
+ telem_outgoing.failed(upload_info.failed_ids.len());
+ telem_engine.outgoing(telem_outgoing);
+
+ engine.sync_finished(upload_info.modified_timestamp, upload_info.successful_ids)?;
+
+ log::info!("Sync finished!");
+ Ok(())
+}
diff --git a/third_party/rust/sync15/src/client/sync_multiple.rs b/third_party/rust/sync15/src/client/sync_multiple.rs
new file mode 100644
index 0000000000..79ddceff3c
--- /dev/null
+++ b/third_party/rust/sync15/src/client/sync_multiple.rs
@@ -0,0 +1,493 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+// This helps you perform a sync of multiple engines and helps you manage
+// global and local state between syncs.
+
+use super::state::{EngineChangesNeeded, GlobalState, PersistedGlobalState, SetupStateMachine};
+use super::status::{ServiceStatus, SyncResult};
+use super::storage_client::{BackoffListener, Sync15StorageClient, Sync15StorageClientInit};
+use crate::clients_engine::{self, CommandProcessor, CLIENTS_TTL_REFRESH};
+use crate::engine::{EngineSyncAssociation, SyncEngine};
+use crate::error::Error;
+use crate::telemetry;
+use crate::KeyBundle;
+use interrupt_support::Interruptee;
+use std::collections::HashMap;
+use std::mem;
+use std::result;
+use std::time::{Duration, SystemTime};
+
+/// Info about the client to use. We reuse the client unless
+/// we discover the client_init has changed, in which case we re-create one.
+#[derive(Debug)]
+struct ClientInfo {
+ // the client_init used to create `client`.
+ client_init: Sync15StorageClientInit,
+ // the client (our tokenserver state machine state, and our http library's state)
+ client: Sync15StorageClient,
+}
+
+impl ClientInfo {
+ fn new(ci: &Sync15StorageClientInit) -> Result<Self, Error> {
+ Ok(Self {
+ client_init: ci.clone(),
+ client: Sync15StorageClient::new(ci.clone())?,
+ })
+ }
+}
+
+/// Info we want callers to engine *in memory* for us so that subsequent
+/// syncs are faster. This should never be persisted to storage as it holds
+/// sensitive information, such as the sync decryption keys.
+#[derive(Debug, Default)]
+pub struct MemoryCachedState {
+ last_client_info: Option<ClientInfo>,
+ last_global_state: Option<GlobalState>,
+ // These are just engined in memory, as persisting an invalid value far in the
+ // future has the potential to break sync for good.
+ next_sync_after: Option<SystemTime>,
+ next_client_refresh_after: Option<SystemTime>,
+}
+
+impl MemoryCachedState {
+ // Called we notice the cached state is stale.
+ pub fn clear_sensitive_info(&mut self) {
+ self.last_client_info = None;
+ self.last_global_state = None;
+ // Leave the backoff time, as there's no reason to think it's not still
+ // true.
+ }
+ pub fn get_next_sync_after(&self) -> Option<SystemTime> {
+ self.next_sync_after
+ }
+ pub fn should_refresh_client(&self) -> bool {
+ match self.next_client_refresh_after {
+ Some(t) => SystemTime::now() > t,
+ None => true,
+ }
+ }
+ pub fn note_client_refresh(&mut self) {
+ self.next_client_refresh_after =
+ Some(SystemTime::now() + Duration::from_secs(CLIENTS_TTL_REFRESH));
+ }
+}
+
+/// Sync multiple engines
+/// * `engines` - The engines to sync
+/// * `persisted_global_state` - The global state to use, or None if never
+/// before provided. At the end of the sync, and even when the sync fails,
+/// the value in this cell should be persisted to permanent storage and
+/// provided next time the sync is called.
+/// * `last_client_info` - The client state to use, or None if never before
+/// provided. At the end of the sync, the value should be persisted
+/// *in memory only* - it should not be persisted to disk.
+/// * `storage_init` - Information about how the sync http client should be
+/// configured.
+/// * `root_sync_key` - The KeyBundle used for encryption.
+///
+/// Returns a map, keyed by name and holding an error value - if any engine
+/// fails, the sync will continue on to other engines, but the error will be
+/// places in this map. The absence of a name in the map implies the engine
+/// succeeded.
+pub fn sync_multiple(
+ engines: &[&dyn SyncEngine],
+ persisted_global_state: &mut Option<String>,
+ mem_cached_state: &mut MemoryCachedState,
+ storage_init: &Sync15StorageClientInit,
+ root_sync_key: &KeyBundle,
+ interruptee: &dyn Interruptee,
+ req_info: Option<SyncRequestInfo<'_>>,
+) -> SyncResult {
+ sync_multiple_with_command_processor(
+ None,
+ engines,
+ persisted_global_state,
+ mem_cached_state,
+ storage_init,
+ root_sync_key,
+ interruptee,
+ req_info,
+ )
+}
+
+/// Like `sync_multiple`, but specifies an optional command processor to handle
+/// commands from the clients collection. This function is called by the sync
+/// manager, which provides its own processor.
+#[allow(clippy::too_many_arguments)]
+pub fn sync_multiple_with_command_processor(
+ command_processor: Option<&dyn CommandProcessor>,
+ engines: &[&dyn SyncEngine],
+ persisted_global_state: &mut Option<String>,
+ mem_cached_state: &mut MemoryCachedState,
+ storage_init: &Sync15StorageClientInit,
+ root_sync_key: &KeyBundle,
+ interruptee: &dyn Interruptee,
+ req_info: Option<SyncRequestInfo<'_>>,
+) -> SyncResult {
+ log::info!("Syncing {} engines", engines.len());
+ let mut sync_result = SyncResult {
+ service_status: ServiceStatus::OtherError,
+ result: Ok(()),
+ declined: None,
+ next_sync_after: None,
+ engine_results: HashMap::with_capacity(engines.len()),
+ telemetry: telemetry::SyncTelemetryPing::new(),
+ };
+ let backoff = super::storage_client::new_backoff_listener();
+ let req_info = req_info.unwrap_or_default();
+ let driver = SyncMultipleDriver {
+ command_processor,
+ engines,
+ storage_init,
+ interruptee,
+ engines_to_state_change: req_info.engines_to_state_change,
+ backoff: backoff.clone(),
+ root_sync_key,
+ result: &mut sync_result,
+ persisted_global_state,
+ mem_cached_state,
+ saw_auth_error: false,
+ ignore_soft_backoff: req_info.is_user_action,
+ };
+ match driver.sync() {
+ Ok(()) => {
+ log::debug!(
+ "sync was successful, final status={:?}",
+ sync_result.service_status
+ );
+ }
+ Err(e) => {
+ log::warn!(
+ "sync failed: {}, final status={:?}",
+ e,
+ sync_result.service_status,
+ );
+ sync_result.result = Err(e);
+ }
+ }
+ // Respect `backoff` value when computing the next sync time even if we were
+ // ignoring it during the sync
+ sync_result.set_sync_after(backoff.get_required_wait(false).unwrap_or_default());
+ mem_cached_state.next_sync_after = sync_result.next_sync_after;
+ log::trace!("Sync result: {:?}", sync_result);
+ sync_result
+}
+
+/// This is essentially a bag of information that the sync manager knows, but
+/// otherwise we won't. It should probably be rethought if it gains many more
+/// fields.
+#[derive(Debug, Default)]
+pub struct SyncRequestInfo<'a> {
+ pub engines_to_state_change: Option<&'a HashMap<String, bool>>,
+ pub is_user_action: bool,
+}
+
+// The sync multiple driver
+struct SyncMultipleDriver<'info, 'res, 'pgs, 'mcs> {
+ command_processor: Option<&'info dyn CommandProcessor>,
+ engines: &'info [&'info dyn SyncEngine],
+ storage_init: &'info Sync15StorageClientInit,
+ root_sync_key: &'info KeyBundle,
+ interruptee: &'info dyn Interruptee,
+ backoff: BackoffListener,
+ engines_to_state_change: Option<&'info HashMap<String, bool>>,
+ result: &'res mut SyncResult,
+ persisted_global_state: &'pgs mut Option<String>,
+ mem_cached_state: &'mcs mut MemoryCachedState,
+ ignore_soft_backoff: bool,
+ saw_auth_error: bool,
+}
+
+impl<'info, 'res, 'pgs, 'mcs> SyncMultipleDriver<'info, 'res, 'pgs, 'mcs> {
+ /// The actual worker for sync_multiple.
+ fn sync(mut self) -> result::Result<(), Error> {
+ log::info!("Loading/initializing persisted state");
+ let mut pgs = self.prepare_persisted_state();
+
+ log::info!("Preparing client info");
+ let client_info = self.prepare_client_info()?;
+
+ if self.was_interrupted() {
+ return Ok(());
+ }
+
+ log::info!("Entering sync state machine");
+ // Advance the state machine to the point where it can perform a full
+ // sync. This may involve uploading meta/global, crypto/keys etc.
+ let mut global_state = self.run_state_machine(&client_info, &mut pgs)?;
+
+ if self.was_interrupted() {
+ return Ok(());
+ }
+
+ // Set the service status to OK here - we may adjust it based on an individual
+ // engine failing.
+ self.result.service_status = ServiceStatus::Ok;
+
+ let clients_engine = if let Some(command_processor) = self.command_processor {
+ log::info!("Synchronizing clients engine");
+ let should_refresh = self.mem_cached_state.should_refresh_client();
+ let mut engine = clients_engine::Engine::new(command_processor, self.interruptee);
+ if let Err(e) = engine.sync(
+ &client_info.client,
+ &global_state,
+ self.root_sync_key,
+ should_refresh,
+ ) {
+ // Record telemetry with the error just in case...
+ let mut telem_sync = telemetry::SyncTelemetry::new();
+ let mut telem_engine = telemetry::Engine::new("clients");
+ telem_engine.failure(&e);
+ telem_sync.engine(telem_engine);
+ self.result.service_status = ServiceStatus::from_err(&e);
+
+ // ...And bail, because a clients engine sync failure is fatal.
+ return Err(e);
+ }
+ // We don't record telemetry for successful clients engine
+ // syncs, since we only keep client records in memory, we
+ // expect the counts to be the same most times, and a
+ // failure aborts the entire sync.
+ if self.was_interrupted() {
+ return Ok(());
+ }
+ self.mem_cached_state.note_client_refresh();
+ Some(engine)
+ } else {
+ None
+ };
+
+ log::info!("Synchronizing engines");
+
+ let telem_sync =
+ self.sync_engines(&client_info, &mut global_state, clients_engine.as_ref());
+ self.result.telemetry.sync(telem_sync);
+
+ log::info!("Finished syncing engines.");
+
+ if !self.saw_auth_error {
+ log::trace!("Updating persisted global state");
+ self.mem_cached_state.last_client_info = Some(client_info);
+ self.mem_cached_state.last_global_state = Some(global_state);
+ }
+
+ Ok(())
+ }
+
+ fn was_interrupted(&mut self) -> bool {
+ if self.interruptee.was_interrupted() {
+ log::info!("Interrupted, bailing out");
+ self.result.service_status = ServiceStatus::Interrupted;
+ true
+ } else {
+ false
+ }
+ }
+
+ fn sync_engines(
+ &mut self,
+ client_info: &ClientInfo,
+ global_state: &mut GlobalState,
+ clients: Option<&clients_engine::Engine<'_>>,
+ ) -> telemetry::SyncTelemetry {
+ let mut telem_sync = telemetry::SyncTelemetry::new();
+ for engine in self.engines {
+ let name = engine.collection_name();
+ if self
+ .backoff
+ .get_required_wait(self.ignore_soft_backoff)
+ .is_some()
+ {
+ log::warn!("Got backoff, bailing out of sync early");
+ break;
+ }
+ if global_state.global.declined.iter().any(|e| e == &*name) {
+ log::info!("The {} engine is declined. Skipping", name);
+ continue;
+ }
+ log::info!("Syncing {} engine!", name);
+
+ let mut telem_engine = telemetry::Engine::new(&*name);
+ let result = super::sync::synchronize_with_clients_engine(
+ &client_info.client,
+ global_state,
+ self.root_sync_key,
+ clients,
+ *engine,
+ true,
+ &mut telem_engine,
+ self.interruptee,
+ );
+
+ match result {
+ Ok(()) => log::info!("Sync of {} was successful!", name),
+ Err(ref e) => {
+ log::warn!("Sync of {} failed! {:?}", name, e);
+ let this_status = ServiceStatus::from_err(e);
+ // The only error which forces us to discard our state is an
+ // auth error.
+ self.saw_auth_error =
+ self.saw_auth_error || this_status == ServiceStatus::AuthenticationError;
+ telem_engine.failure(e);
+ // If the failure from the engine looks like anything other than
+ // a "engine error" we don't bother trying the others.
+ if this_status != ServiceStatus::OtherError {
+ telem_sync.engine(telem_engine);
+ self.result.engine_results.insert(name.into(), result);
+ self.result.service_status = this_status;
+ break;
+ }
+ }
+ }
+ telem_sync.engine(telem_engine);
+ self.result.engine_results.insert(name.into(), result);
+ if self.was_interrupted() {
+ break;
+ }
+ }
+ telem_sync
+ }
+
+ fn run_state_machine(
+ &mut self,
+ client_info: &ClientInfo,
+ pgs: &mut PersistedGlobalState,
+ ) -> result::Result<GlobalState, Error> {
+ let last_state = mem::replace(&mut self.mem_cached_state.last_global_state, None);
+
+ let mut state_machine = SetupStateMachine::for_full_sync(
+ &client_info.client,
+ self.root_sync_key,
+ pgs,
+ self.engines_to_state_change,
+ self.interruptee,
+ );
+
+ log::info!("Advancing state machine to ready (full)");
+ let res = state_machine.run_to_ready(last_state);
+ // Grab this now even though we don't need it until later to avoid a
+ // lifetime issue
+ let changes = state_machine.changes_needed.take();
+ // The state machine might have updated our persisted_global_state, so
+ // update the caller's repr of it.
+ *self.persisted_global_state = Some(serde_json::to_string(&pgs)?);
+
+ // Now that we've gone through the state machine, engine the declined list in
+ // the sync_result
+ self.result.declined = Some(pgs.get_declined().to_vec());
+ log::debug!(
+ "Declined engines list after state machine set to: {:?}",
+ self.result.declined,
+ );
+
+ if let Some(c) = changes {
+ self.wipe_or_reset_engines(c, &client_info.client)?;
+ }
+ let state = match res {
+ Err(e) => {
+ self.result.service_status = ServiceStatus::from_err(&e);
+ return Err(e);
+ }
+ Ok(state) => state,
+ };
+ self.result.telemetry.uid(client_info.client.hashed_uid()?);
+ // As for client_info, put None back now so we start from scratch on error.
+ self.mem_cached_state.last_global_state = None;
+ Ok(state)
+ }
+
+ fn wipe_or_reset_engines(
+ &mut self,
+ changes: EngineChangesNeeded,
+ client: &Sync15StorageClient,
+ ) -> result::Result<(), Error> {
+ if changes.local_resets.is_empty() && changes.remote_wipes.is_empty() {
+ return Ok(());
+ }
+ for e in &changes.remote_wipes {
+ log::info!("Engine {:?} just got disabled locally, wiping server", e);
+ client.wipe_remote_engine(e)?;
+ }
+
+ for s in self.engines {
+ let name = s.collection_name();
+ if changes.local_resets.contains(&*name) {
+ log::info!("Resetting engine {}, as it was declined remotely", name);
+ s.reset(&EngineSyncAssociation::Disconnected)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ fn prepare_client_info(&mut self) -> result::Result<ClientInfo, Error> {
+ let mut client_info = match mem::replace(&mut self.mem_cached_state.last_client_info, None)
+ {
+ Some(client_info) => {
+ // if our storage_init has changed it probably means the user has
+ // changed, courtesy of the 'kid' in the structure. Thus, we can't
+ // reuse the client or the memory cached state. We do keep the disk
+ // state as currently that's only the declined list.
+ if client_info.client_init != *self.storage_init {
+ log::info!("Discarding all state as the account might have changed");
+ *self.mem_cached_state = MemoryCachedState::default();
+ ClientInfo::new(self.storage_init)?
+ } else {
+ log::debug!("Reusing memory-cached client_info");
+ // we can reuse it (which should be the common path)
+ client_info
+ }
+ }
+ None => {
+ log::debug!("mem_cached_state was stale or missing, need setup");
+ // We almost certainly have no other state here, but to be safe, we
+ // throw away any memory state we do have.
+ self.mem_cached_state.clear_sensitive_info();
+ ClientInfo::new(self.storage_init)?
+ }
+ };
+ // Ensure we use the correct listener here rather than on all the branches
+ // above, since it seems less error prone.
+ client_info.client.backoff = self.backoff.clone();
+ Ok(client_info)
+ }
+
+ fn prepare_persisted_state(&mut self) -> PersistedGlobalState {
+ // Note that any failure to use a persisted state means we also decline
+ // to use our memory cached state, so that we fully rebuild that
+ // persisted state for next time.
+ match self.persisted_global_state {
+ Some(persisted_string) if !persisted_string.is_empty() => {
+ match serde_json::from_str::<PersistedGlobalState>(persisted_string) {
+ Ok(state) => {
+ log::trace!("Read persisted state: {:?}", state);
+ // Note that we don't set `result.declined` from the
+ // data in state - it remains None, which explicitly
+ // indicates "we don't have updated info".
+ state
+ }
+ _ => {
+ // Don't log the error since it might contain sensitive
+ // info (although currently it only contains the declined engines list)
+ error_support::report_error!(
+ "sync15-prepare-persisted-state",
+ "Failed to parse PersistedGlobalState from JSON! Falling back to default"
+ );
+ *self.mem_cached_state = MemoryCachedState::default();
+ PersistedGlobalState::default()
+ }
+ }
+ }
+ _ => {
+ log::info!(
+ "The application didn't give us persisted state - \
+ this is only expected on the very first run for a given user."
+ );
+ *self.mem_cached_state = MemoryCachedState::default();
+ PersistedGlobalState::default()
+ }
+ }
+ }
+}
diff --git a/third_party/rust/sync15/src/client/token.rs b/third_party/rust/sync15/src/client/token.rs
new file mode 100644
index 0000000000..b416c0c12a
--- /dev/null
+++ b/third_party/rust/sync15/src/client/token.rs
@@ -0,0 +1,602 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use crate::error::{self, Error as ErrorKind, Result};
+use crate::ServerTimestamp;
+use rc_crypto::hawk;
+use serde_derive::*;
+use std::borrow::{Borrow, Cow};
+use std::cell::RefCell;
+use std::fmt;
+use std::time::{Duration, SystemTime};
+use url::Url;
+use viaduct::{header_names, Request};
+
+const RETRY_AFTER_DEFAULT_MS: u64 = 10000;
+
+// The TokenserverToken is the token as received directly from the token server
+// and deserialized from JSON.
+#[derive(Deserialize, Clone, PartialEq, Eq)]
+struct TokenserverToken {
+ id: String,
+ key: String,
+ api_endpoint: String,
+ uid: u64,
+ duration: u64,
+ hashed_fxa_uid: String,
+}
+
+impl std::fmt::Debug for TokenserverToken {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("TokenserverToken")
+ .field("api_endpoint", &self.api_endpoint)
+ .field("uid", &self.uid)
+ .field("duration", &self.duration)
+ .field("hashed_fxa_uid", &self.hashed_fxa_uid)
+ .finish()
+ }
+}
+
+// The struct returned by the TokenFetcher - the token itself and the
+// server timestamp.
+struct TokenFetchResult {
+ token: TokenserverToken,
+ server_timestamp: ServerTimestamp,
+}
+
+// The trait for fetching tokens - we'll provide a "real" implementation but
+// tests will re-implement it.
+trait TokenFetcher {
+ fn fetch_token(&self) -> crate::Result<TokenFetchResult>;
+ // We allow the trait to tell us what the time is so tests can get funky.
+ fn now(&self) -> SystemTime;
+}
+
+// Our "real" token fetcher, implementing the TokenFetcher trait, which hits
+// the token server
+#[derive(Debug)]
+struct TokenServerFetcher {
+ // The stuff needed to fetch a token.
+ server_url: Url,
+ access_token: String,
+ key_id: String,
+}
+
+fn fixup_server_url(mut url: Url) -> url::Url {
+ // The given `url` is the end-point as returned by .well-known/fxa-client-configuration,
+ // or as directly specified by self-hosters. As a result, it may or may not have
+ // the sync 1.5 suffix of "/1.0/sync/1.5", so add it on here if it does not.
+ if url.as_str().ends_with("1.0/sync/1.5") {
+ // ok!
+ } else if url.as_str().ends_with("1.0/sync/1.5/") {
+ // Shouldn't ever be Err() here, but the result is `Result<PathSegmentsMut, ()>`
+ // and I don't want to unwrap or add a new error type just for PathSegmentsMut failing.
+ if let Ok(mut path) = url.path_segments_mut() {
+ path.pop();
+ }
+ } else {
+ // We deliberately don't use `.join()` here in order to preserve all path components.
+ // For example, "http://example.com/token" should produce "http://example.com/token/1.0/sync/1.5"
+ // but using `.join()` would produce "http://example.com/1.0/sync/1.5".
+ if let Ok(mut path) = url.path_segments_mut() {
+ path.pop_if_empty();
+ path.extend(&["1.0", "sync", "1.5"]);
+ }
+ };
+ url
+}
+
+impl TokenServerFetcher {
+ fn new(base_url: Url, access_token: String, key_id: String) -> TokenServerFetcher {
+ TokenServerFetcher {
+ server_url: fixup_server_url(base_url),
+ access_token,
+ key_id,
+ }
+ }
+}
+
+impl TokenFetcher for TokenServerFetcher {
+ fn fetch_token(&self) -> Result<TokenFetchResult> {
+ log::debug!("Fetching token from {}", self.server_url);
+ let resp = Request::get(self.server_url.clone())
+ .header(
+ header_names::AUTHORIZATION,
+ format!("Bearer {}", self.access_token),
+ )?
+ .header(header_names::X_KEYID, self.key_id.clone())?
+ .send()?;
+
+ if !resp.is_success() {
+ log::warn!("Non-success status when fetching token: {}", resp.status);
+ // TODO: the body should be JSON and contain a status parameter we might need?
+ log::trace!(" Response body {}", resp.text());
+ // XXX - shouldn't we "chain" these errors - ie, a BackoffError could
+ // have a TokenserverHttpError as its cause?
+ if let Some(res) = resp.headers.get_as::<f64, _>(header_names::RETRY_AFTER) {
+ let ms = res
+ .ok()
+ .map_or(RETRY_AFTER_DEFAULT_MS, |f| (f * 1000f64) as u64);
+ let when = self.now() + Duration::from_millis(ms);
+ return Err(ErrorKind::BackoffError(when));
+ }
+ let status = resp.status;
+ return Err(ErrorKind::TokenserverHttpError(status));
+ }
+
+ let token: TokenserverToken = resp.json()?;
+ let server_timestamp = resp
+ .headers
+ .try_get::<ServerTimestamp, _>(header_names::X_TIMESTAMP)
+ .ok_or(ErrorKind::MissingServerTimestamp)?;
+ Ok(TokenFetchResult {
+ token,
+ server_timestamp,
+ })
+ }
+
+ fn now(&self) -> SystemTime {
+ SystemTime::now()
+ }
+}
+
+// The context stored by our TokenProvider when it has a TokenState::Token
+// state.
+struct TokenContext {
+ token: TokenserverToken,
+ credentials: hawk::Credentials,
+ server_timestamp: ServerTimestamp,
+ valid_until: SystemTime,
+}
+
+// hawk::Credentials doesn't implement debug -_-
+impl fmt::Debug for TokenContext {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> ::std::result::Result<(), fmt::Error> {
+ f.debug_struct("TokenContext")
+ .field("token", &self.token)
+ .field("credentials", &"(omitted)")
+ .field("server_timestamp", &self.server_timestamp)
+ .field("valid_until", &self.valid_until)
+ .finish()
+ }
+}
+
+impl TokenContext {
+ fn new(
+ token: TokenserverToken,
+ credentials: hawk::Credentials,
+ server_timestamp: ServerTimestamp,
+ valid_until: SystemTime,
+ ) -> Self {
+ Self {
+ token,
+ credentials,
+ server_timestamp,
+ valid_until,
+ }
+ }
+
+ fn is_valid(&self, now: SystemTime) -> bool {
+ // We could consider making the duration a little shorter - if it
+ // only has 1 second validity there seems a reasonable chance it will
+ // have expired by the time it gets presented to the remote that wants
+ // it.
+ // Either way though, we will eventually need to handle a token being
+ // rejected as a non-fatal error and recover, so maybe we don't care?
+ now < self.valid_until
+ }
+
+ fn authorization(&self, req: &Request) -> Result<String> {
+ let url = &req.url;
+
+ let path_and_query = match url.query() {
+ None => Cow::from(url.path()),
+ Some(qs) => Cow::from(format!("{}?{}", url.path(), qs)),
+ };
+
+ let host = url
+ .host_str()
+ .ok_or_else(|| ErrorKind::UnacceptableUrl("Storage URL has no host".into()))?;
+
+ // Known defaults exist for https? (among others), so this should be impossible
+ let port = url.port_or_known_default().ok_or_else(|| {
+ ErrorKind::UnacceptableUrl(
+ "Storage URL has no port and no default port is known for the protocol".into(),
+ )
+ })?;
+
+ let header =
+ hawk::RequestBuilder::new(req.method.as_str(), host, port, path_and_query.borrow())
+ .request()
+ .make_header(&self.credentials)?;
+
+ Ok(format!("Hawk {}", header))
+ }
+}
+
+// The state our TokenProvider holds to reflect the state of the token.
+#[derive(Debug)]
+enum TokenState {
+ // We've never fetched a token.
+ NoToken,
+ // Have a token and last we checked it remained valid.
+ Token(TokenContext),
+ // We failed to fetch a token. First elt is the error, second elt is
+ // the api_endpoint we had before we failed to fetch a new token (or
+ // None if the very first attempt at fetching a token failed)
+ Failed(Option<error::Error>, Option<String>),
+ // Previously failed and told to back-off for SystemTime duration. Second
+ // elt is the api_endpoint we had before we hit the backoff error.
+ // XXX - should we roll Backoff and Failed together?
+ Backoff(SystemTime, Option<String>),
+ // api_endpoint changed - we are never going to get a token nor move out
+ // of this state.
+ NodeReassigned,
+}
+
+/// The generic TokenProvider implementation - long lived and fetches tokens
+/// on demand (eg, when first needed, or when an existing one expires.)
+#[derive(Debug)]
+struct TokenProviderImpl<TF: TokenFetcher> {
+ fetcher: TF,
+ // Our token state (ie, whether we have a token, and if not, why not)
+ current_state: RefCell<TokenState>,
+}
+
+impl<TF: TokenFetcher> TokenProviderImpl<TF> {
+ fn new(fetcher: TF) -> Self {
+ // We check this at the real entrypoint of the application, but tests
+ // can/do bypass that, so check this here too.
+ rc_crypto::ensure_initialized();
+ TokenProviderImpl {
+ fetcher,
+ current_state: RefCell::new(TokenState::NoToken),
+ }
+ }
+
+ // Uses our fetcher to grab a new token and if successfull, derives other
+ // info from that token into a usable TokenContext.
+ fn fetch_context(&self) -> Result<TokenContext> {
+ let result = self.fetcher.fetch_token()?;
+ let token = result.token;
+ let valid_until = SystemTime::now() + Duration::from_secs(token.duration);
+
+ let credentials = hawk::Credentials {
+ id: token.id.clone(),
+ key: hawk::Key::new(token.key.as_bytes(), hawk::SHA256)?,
+ };
+
+ Ok(TokenContext::new(
+ token,
+ credentials,
+ result.server_timestamp,
+ valid_until,
+ ))
+ }
+
+ // Attempt to fetch a new token and return a new state reflecting that
+ // operation. If it worked a TokenState will be returned, but errors may
+ // cause other states.
+ fn fetch_token(&self, previous_endpoint: Option<&str>) -> TokenState {
+ match self.fetch_context() {
+ Ok(tc) => {
+ // We got a new token - check that the endpoint is the same
+ // as a previous endpoint we saw (if any)
+ match previous_endpoint {
+ Some(prev) => {
+ if prev == tc.token.api_endpoint {
+ TokenState::Token(tc)
+ } else {
+ log::warn!(
+ "api_endpoint changed from {} to {}",
+ prev,
+ tc.token.api_endpoint
+ );
+ TokenState::NodeReassigned
+ }
+ }
+ None => {
+ // Never had an api_endpoint in the past, so this is OK.
+ TokenState::Token(tc)
+ }
+ }
+ }
+ Err(e) => {
+ // Early to avoid nll issues...
+ if let ErrorKind::BackoffError(be) = e {
+ return TokenState::Backoff(be, previous_endpoint.map(ToString::to_string));
+ }
+ TokenState::Failed(Some(e), previous_endpoint.map(ToString::to_string))
+ }
+ }
+ }
+
+ // Given the state we are currently in, return a new current state.
+ // Returns None if the current state should be used (eg, if we are
+ // holding a token that remains valid) or Some() if the state has changed
+ // (which may have changed to a state with a token or an error state)
+ fn advance_state(&self, state: &TokenState) -> Option<TokenState> {
+ match state {
+ TokenState::NoToken => Some(self.fetch_token(None)),
+ TokenState::Failed(_, existing_endpoint) => {
+ Some(self.fetch_token(existing_endpoint.as_ref().map(String::as_str)))
+ }
+ TokenState::Token(existing_context) => {
+ if existing_context.is_valid(self.fetcher.now()) {
+ None
+ } else {
+ Some(self.fetch_token(Some(existing_context.token.api_endpoint.as_str())))
+ }
+ }
+ TokenState::Backoff(ref until, ref existing_endpoint) => {
+ if let Ok(remaining) = until.duration_since(self.fetcher.now()) {
+ log::debug!("enforcing existing backoff - {:?} remains", remaining);
+ None
+ } else {
+ // backoff period is over
+ Some(self.fetch_token(existing_endpoint.as_ref().map(String::as_str)))
+ }
+ }
+ TokenState::NodeReassigned => {
+ // We never leave this state.
+ None
+ }
+ }
+ }
+
+ fn with_token<T, F>(&self, func: F) -> Result<T>
+ where
+ F: FnOnce(&TokenContext) -> Result<T>,
+ {
+ // first get a mutable ref to our existing state, advance to the
+ // state we will use, then re-stash that state for next time.
+ let state: &mut TokenState = &mut self.current_state.borrow_mut();
+ if let Some(new_state) = self.advance_state(state) {
+ *state = new_state;
+ }
+
+ // Now re-fetch the state we should use for this call - if it's
+ // anything other than TokenState::Token we will fail.
+ match state {
+ TokenState::NoToken => {
+ // it should be impossible to get here.
+ panic!("Can't be in NoToken state after advancing");
+ }
+ TokenState::Token(ref token_context) => {
+ // make the call.
+ func(token_context)
+ }
+ TokenState::Failed(e, _) => {
+ // We swap the error out of the state enum and return it.
+ Err(e.take().unwrap())
+ }
+ TokenState::NodeReassigned => {
+ // this is unrecoverable.
+ Err(ErrorKind::StorageResetError)
+ }
+ TokenState::Backoff(ref remaining, _) => Err(ErrorKind::BackoffError(*remaining)),
+ }
+ }
+
+ fn hashed_uid(&self) -> Result<String> {
+ self.with_token(|ctx| Ok(ctx.token.hashed_fxa_uid.clone()))
+ }
+
+ fn authorization(&self, req: &Request) -> Result<String> {
+ self.with_token(|ctx| ctx.authorization(req))
+ }
+
+ fn api_endpoint(&self) -> Result<String> {
+ self.with_token(|ctx| Ok(ctx.token.api_endpoint.clone()))
+ }
+ // TODO: we probably want a "drop_token/context" type method so that when
+ // using a token with some validity fails the caller can force a new one
+ // (in which case the new token request will probably fail with a 401)
+}
+
+// The public concrete object exposed by this module
+#[derive(Debug)]
+pub struct TokenProvider {
+ imp: TokenProviderImpl<TokenServerFetcher>,
+}
+
+impl TokenProvider {
+ pub fn new(url: Url, access_token: String, key_id: String) -> Result<Self> {
+ let fetcher = TokenServerFetcher::new(url, access_token, key_id);
+ Ok(Self {
+ imp: TokenProviderImpl::new(fetcher),
+ })
+ }
+
+ pub fn hashed_uid(&self) -> Result<String> {
+ self.imp.hashed_uid()
+ }
+
+ pub fn authorization(&self, req: &Request) -> Result<String> {
+ self.imp.authorization(req)
+ }
+
+ pub fn api_endpoint(&self) -> Result<String> {
+ self.imp.api_endpoint()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::cell::Cell;
+
+ struct TestFetcher<FF, FN>
+ where
+ FF: Fn() -> Result<TokenFetchResult>,
+ FN: Fn() -> SystemTime,
+ {
+ fetch: FF,
+ now: FN,
+ }
+ impl<FF, FN> TokenFetcher for TestFetcher<FF, FN>
+ where
+ FF: Fn() -> Result<TokenFetchResult>,
+ FN: Fn() -> SystemTime,
+ {
+ fn fetch_token(&self) -> Result<TokenFetchResult> {
+ (self.fetch)()
+ }
+ fn now(&self) -> SystemTime {
+ (self.now)()
+ }
+ }
+
+ fn make_tsc<FF, FN>(fetch: FF, now: FN) -> TokenProviderImpl<TestFetcher<FF, FN>>
+ where
+ FF: Fn() -> Result<TokenFetchResult>,
+ FN: Fn() -> SystemTime,
+ {
+ let fetcher: TestFetcher<FF, FN> = TestFetcher { fetch, now };
+ TokenProviderImpl::new(fetcher)
+ }
+
+ #[test]
+ fn test_endpoint() {
+ // Use a cell to avoid the closure having a mutable ref to this scope.
+ let counter: Cell<u32> = Cell::new(0);
+ let fetch = || {
+ counter.set(counter.get() + 1);
+ Ok(TokenFetchResult {
+ token: TokenserverToken {
+ id: "id".to_string(),
+ key: "key".to_string(),
+ api_endpoint: "api_endpoint".to_string(),
+ uid: 1,
+ duration: 1000,
+ hashed_fxa_uid: "hash".to_string(),
+ },
+ server_timestamp: ServerTimestamp(0i64),
+ })
+ };
+
+ let tsc = make_tsc(fetch, SystemTime::now);
+
+ let e = tsc.api_endpoint().expect("should work");
+ assert_eq!(e, "api_endpoint".to_string());
+ assert_eq!(counter.get(), 1);
+
+ let e2 = tsc.api_endpoint().expect("should work");
+ assert_eq!(e2, "api_endpoint".to_string());
+ // should not have re-fetched.
+ assert_eq!(counter.get(), 1);
+ }
+
+ #[test]
+ fn test_backoff() {
+ let counter: Cell<u32> = Cell::new(0);
+ let fetch = || {
+ counter.set(counter.get() + 1);
+ let when = SystemTime::now() + Duration::from_millis(10000);
+ Err(ErrorKind::BackoffError(when))
+ };
+ let now: Cell<SystemTime> = Cell::new(SystemTime::now());
+ let tsc = make_tsc(fetch, || now.get());
+
+ tsc.api_endpoint().expect_err("should bail");
+ // XXX - check error type.
+ assert_eq!(counter.get(), 1);
+ // try and get another token - should not re-fetch as backoff is still
+ // in progress.
+ tsc.api_endpoint().expect_err("should bail");
+ assert_eq!(counter.get(), 1);
+
+ // Advance the clock.
+ now.set(now.get() + Duration::new(20, 0));
+
+ // Our token fetch mock is still returning a backoff error, so we
+ // still fail, but should have re-hit the fetch function.
+ tsc.api_endpoint().expect_err("should bail");
+ assert_eq!(counter.get(), 2);
+ }
+
+ #[test]
+ fn test_validity() {
+ let counter: Cell<u32> = Cell::new(0);
+ let fetch = || {
+ counter.set(counter.get() + 1);
+ Ok(TokenFetchResult {
+ token: TokenserverToken {
+ id: "id".to_string(),
+ key: "key".to_string(),
+ api_endpoint: "api_endpoint".to_string(),
+ uid: 1,
+ duration: 10,
+ hashed_fxa_uid: "hash".to_string(),
+ },
+ server_timestamp: ServerTimestamp(0i64),
+ })
+ };
+ let now: Cell<SystemTime> = Cell::new(SystemTime::now());
+ let tsc = make_tsc(fetch, || now.get());
+
+ tsc.api_endpoint().expect("should get a valid token");
+ assert_eq!(counter.get(), 1);
+
+ // try and get another token - should not re-fetch as the old one
+ // remains valid.
+ tsc.api_endpoint().expect("should reuse existing token");
+ assert_eq!(counter.get(), 1);
+
+ // Advance the clock.
+ now.set(now.get() + Duration::new(20, 0));
+
+ // We should discard our token and fetch a new one.
+ tsc.api_endpoint().expect("should re-fetch");
+ assert_eq!(counter.get(), 2);
+ }
+
+ #[test]
+ fn test_server_url() {
+ assert_eq!(
+ fixup_server_url(
+ Url::parse("https://token.services.mozilla.com/1.0/sync/1.5").unwrap()
+ )
+ .as_str(),
+ "https://token.services.mozilla.com/1.0/sync/1.5"
+ );
+ assert_eq!(
+ fixup_server_url(
+ Url::parse("https://token.services.mozilla.com/1.0/sync/1.5/").unwrap()
+ )
+ .as_str(),
+ "https://token.services.mozilla.com/1.0/sync/1.5"
+ );
+ assert_eq!(
+ fixup_server_url(Url::parse("https://token.services.mozilla.com").unwrap()).as_str(),
+ "https://token.services.mozilla.com/1.0/sync/1.5"
+ );
+ assert_eq!(
+ fixup_server_url(Url::parse("https://token.services.mozilla.com/").unwrap()).as_str(),
+ "https://token.services.mozilla.com/1.0/sync/1.5"
+ );
+ assert_eq!(
+ fixup_server_url(
+ Url::parse("https://selfhosted.example.com/token/1.0/sync/1.5").unwrap()
+ )
+ .as_str(),
+ "https://selfhosted.example.com/token/1.0/sync/1.5"
+ );
+ assert_eq!(
+ fixup_server_url(
+ Url::parse("https://selfhosted.example.com/token/1.0/sync/1.5/").unwrap()
+ )
+ .as_str(),
+ "https://selfhosted.example.com/token/1.0/sync/1.5"
+ );
+ assert_eq!(
+ fixup_server_url(Url::parse("https://selfhosted.example.com/token/").unwrap()).as_str(),
+ "https://selfhosted.example.com/token/1.0/sync/1.5"
+ );
+ assert_eq!(
+ fixup_server_url(Url::parse("https://selfhosted.example.com/token").unwrap()).as_str(),
+ "https://selfhosted.example.com/token/1.0/sync/1.5"
+ );
+ }
+}
diff --git a/third_party/rust/sync15/src/client/util.rs b/third_party/rust/sync15/src/client/util.rs
new file mode 100644
index 0000000000..01fff77afa
--- /dev/null
+++ b/third_party/rust/sync15/src/client/util.rs
@@ -0,0 +1,102 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use std::collections::{HashMap, HashSet};
+use std::sync::atomic::{AtomicU32, Ordering};
+
+/// Finds the maximum of the current value and the argument `val`, and sets the
+/// new value to the result.
+///
+/// Note: `AtomicFoo::fetch_max` is unstable, and can't really be implemented as
+/// a single atomic operation from outside the stdlib ;-;
+pub(crate) fn atomic_update_max(v: &AtomicU32, new: u32) {
+ // For loads (and the compare_exchange_weak second ordering argument) this
+ // is too strong, we could probably get away with Acquire (or maybe Relaxed
+ // because we don't need the result?). In either case, this fn isn't called
+ // from a hot spot so whatever.
+ let mut cur = v.load(Ordering::SeqCst);
+ while cur < new {
+ // we're already handling the failure case so there's no reason not to
+ // use _weak here.
+ match v.compare_exchange_weak(cur, new, Ordering::SeqCst, Ordering::SeqCst) {
+ Ok(_) => {
+ // Success.
+ break;
+ }
+ Err(new_cur) => {
+ // Interrupted, keep trying.
+ cur = new_cur
+ }
+ }
+ }
+}
+
+// Slight wrappers around the builtin methods for doing this.
+pub(crate) fn set_union(a: &HashSet<String>, b: &HashSet<String>) -> HashSet<String> {
+ a.union(b).cloned().collect()
+}
+
+pub(crate) fn set_difference(a: &HashSet<String>, b: &HashSet<String>) -> HashSet<String> {
+ a.difference(b).cloned().collect()
+}
+
+pub(crate) fn set_intersection(a: &HashSet<String>, b: &HashSet<String>) -> HashSet<String> {
+ a.intersection(b).cloned().collect()
+}
+
+pub(crate) fn partition_by_value(v: &HashMap<String, bool>) -> (HashSet<String>, HashSet<String>) {
+ let mut true_: HashSet<String> = HashSet::new();
+ let mut false_: HashSet<String> = HashSet::new();
+ for (s, val) in v {
+ if *val {
+ true_.insert(s.clone());
+ } else {
+ false_.insert(s.clone());
+ }
+ }
+ (true_, false_)
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_set_ops() {
+ fn hash_set(s: &[&str]) -> HashSet<String> {
+ s.iter()
+ .copied()
+ .map(ToOwned::to_owned)
+ .collect::<HashSet<_>>()
+ }
+
+ assert_eq!(
+ set_union(&hash_set(&["a", "b", "c"]), &hash_set(&["b", "d"])),
+ hash_set(&["a", "b", "c", "d"]),
+ );
+
+ assert_eq!(
+ set_difference(&hash_set(&["a", "b", "c"]), &hash_set(&["b", "d"])),
+ hash_set(&["a", "c"]),
+ );
+ assert_eq!(
+ set_intersection(&hash_set(&["a", "b", "c"]), &hash_set(&["b", "d"])),
+ hash_set(&["b"]),
+ );
+ let m: HashMap<String, bool> = [
+ ("foo", true),
+ ("bar", true),
+ ("baz", false),
+ ("quux", false),
+ ]
+ .iter()
+ .copied()
+ .map(|(a, b)| (a.to_owned(), b))
+ .collect();
+ assert_eq!(
+ partition_by_value(&m),
+ (hash_set(&["foo", "bar"]), hash_set(&["baz", "quux"])),
+ );
+ }
+}