diff options
Diffstat (limited to 'third_party/rust/sql-support')
-rw-r--r-- | third_party/rust/sql-support/.cargo-checksum.json | 2 | ||||
-rw-r--r-- | third_party/rust/sql-support/Cargo.toml | 1 | ||||
-rw-r--r-- | third_party/rust/sql-support/src/lazy.rs | 151 | ||||
-rw-r--r-- | third_party/rust/sql-support/src/lib.rs | 10 | ||||
-rw-r--r-- | third_party/rust/sql-support/src/open_database.rs | 204 |
5 files changed, 285 insertions, 83 deletions
diff --git a/third_party/rust/sql-support/.cargo-checksum.json b/third_party/rust/sql-support/.cargo-checksum.json index 93cef6e94d..ecf062beaf 100644 --- a/third_party/rust/sql-support/.cargo-checksum.json +++ b/third_party/rust/sql-support/.cargo-checksum.json @@ -1 +1 @@ -{"files":{"Cargo.toml":"812811e5a8e00abe3ec345cd8fd435e27fec7cb8f2e45a0e93e5becf564c46ad","src/conn_ext.rs":"e48e862e47c000c545dcc766fc1889498a8709bee00e240ed68d247b0fbef577","src/debug_tools.rs":"bece2bc3d35379b81ea2f942a0a3e909e0ab0553656505904745548eacaf402a","src/each_chunk.rs":"8aaba842e43b002fbc0fee95d14ce08faa7187b1979c765b2e270cd4802607a5","src/lib.rs":"af704ec04beb6c2c388d4566710e1167b18fb64acb248ccf37a67679daffddb6","src/maybe_cached.rs":"0b18425595055883a98807fbd62ff27a79c18af34e7cb3439f8c3438463ef2dd","src/open_database.rs":"ba290bfb39468e96f9b3ea865e0c13c2cc5a731ea8877a9feb6b1de4f7d666c4","src/repeat.rs":"b4c5ff5d083afba7f9f153f54aba2e6859b78b85c82d48dbd6bd58f67da9e6b9"},"package":null}
\ No newline at end of file +{"files":{"Cargo.toml":"2a0d414052d959098dcb3c22fce0eb008710ab594a6d0e5c58056b2dd497a359","src/conn_ext.rs":"e48e862e47c000c545dcc766fc1889498a8709bee00e240ed68d247b0fbef577","src/debug_tools.rs":"bece2bc3d35379b81ea2f942a0a3e909e0ab0553656505904745548eacaf402a","src/each_chunk.rs":"8aaba842e43b002fbc0fee95d14ce08faa7187b1979c765b2e270cd4802607a5","src/lazy.rs":"a96b4f4ec572538b49cdfa8fee981dcf5143a5f51163fb8a573d3ac128df70f9","src/lib.rs":"b2c120db4928c3e4abdd96405fd4c1016255699bdbc38c8cd60dbd3431fc0a12","src/maybe_cached.rs":"0b18425595055883a98807fbd62ff27a79c18af34e7cb3439f8c3438463ef2dd","src/open_database.rs":"dfc6f68354bf35ee1fc235986e5563e9f8c5cf7920dfe77a9a3d3ad4cfd3723f","src/repeat.rs":"b4c5ff5d083afba7f9f153f54aba2e6859b78b85c82d48dbd6bd58f67da9e6b9"},"package":null}
\ No newline at end of file diff --git a/third_party/rust/sql-support/Cargo.toml b/third_party/rust/sql-support/Cargo.toml index 0e6137ddbf..c09933b165 100644 --- a/third_party/rust/sql-support/Cargo.toml +++ b/third_party/rust/sql-support/Cargo.toml @@ -20,6 +20,7 @@ license = "MPL-2.0" ffi-support = "0.4" lazy_static = "1.4" log = "0.4" +parking_lot = ">=0.11,<=0.12" tempfile = "3.1.0" thiserror = "1.0" diff --git a/third_party/rust/sql-support/src/lazy.rs b/third_party/rust/sql-support/src/lazy.rs new file mode 100644 index 0000000000..b22d9c39e3 --- /dev/null +++ b/third_party/rust/sql-support/src/lazy.rs @@ -0,0 +1,151 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use crate::open_database::{open_database_with_flags, ConnectionInitializer, Error}; +use interrupt_support::{register_interrupt, SqlInterruptHandle, SqlInterruptScope}; +use parking_lot::{MappedMutexGuard, Mutex, MutexGuard}; +use rusqlite::{Connection, OpenFlags}; +use std::{ + path::{Path, PathBuf}, + sync::{Arc, Weak}, +}; + +/// Lazily-loaded database with interruption support +/// +/// In addition to the [Self::interrupt] method, LazyDb also calls +/// [interrupt_support::register_interrupt] on any opened database. This means that if +/// [interrupt_support::shutdown] is called it will interrupt this database if it's open and +/// in-use. +pub struct LazyDb<CI> { + path: PathBuf, + open_flags: OpenFlags, + connection_initializer: CI, + // Note: if you're going to lock both mutexes at once, make sure to lock the connection mutex + // first. Otherwise, you risk creating a deadlock where two threads each hold one of the locks + // and is waiting for the other. + connection: Mutex<Option<Connection>>, + // It's important to use a separate mutex for the interrupt handle, since the whole point is to + // interrupt while another thread holds the connection mutex. Since the only mutation is + // setting/unsetting the Arc, maybe this could be sped up by using something like + // `arc_swap::ArcSwap`, but that seems like overkill for our purposes. This mutex should rarely + // be contested and interrupt operations execute quickly. + interrupt_handle: Mutex<Option<Arc<SqlInterruptHandle>>>, +} + +impl<CI: ConnectionInitializer> LazyDb<CI> { + /// Create a new LazyDb + /// + /// This does not open the connection and is non-blocking + pub fn new(path: &Path, open_flags: OpenFlags, connection_initializer: CI) -> Self { + Self { + path: path.to_owned(), + open_flags, + connection_initializer, + connection: Mutex::new(None), + interrupt_handle: Mutex::new(None), + } + } + + /// Lock the database mutex and get a connection and interrupt scope. + /// + /// If the connection is closed, it will be opened. + /// + /// Calling `lock` again, or calling `close`, from the same thread while the mutex guard is + /// still alive will cause a deadlock. + pub fn lock(&self) -> Result<(MappedMutexGuard<'_, Connection>, SqlInterruptScope), Error> { + // Call get_conn first, then get_scope to ensure we acquire the locks in the correct order + let conn = self.get_conn()?; + let scope = self.get_scope(&conn)?; + Ok((conn, scope)) + } + + fn get_conn(&self) -> Result<MappedMutexGuard<'_, Connection>, Error> { + let mut guard = self.connection.lock(); + // Open the database if it wasn't opened before. Do this outside of the MutexGuard::map call to simplify the error handling + if guard.is_none() { + *guard = Some(open_database_with_flags( + &self.path, + self.open_flags, + &self.connection_initializer, + )?); + }; + // Use MutexGuard::map to get a Connection rather than Option<Connection>. The unwrap() + // call can't fail because of the previous code. + Ok(MutexGuard::map(guard, |conn_option| { + conn_option.as_mut().unwrap() + })) + } + + fn get_scope(&self, conn: &Connection) -> Result<SqlInterruptScope, Error> { + let mut handle_guard = self.interrupt_handle.lock(); + let result = match handle_guard.as_ref() { + Some(handle) => handle.begin_interrupt_scope(), + None => { + let handle = Arc::new(SqlInterruptHandle::new(conn)); + register_interrupt( + Arc::downgrade(&handle) as Weak<dyn AsRef<SqlInterruptHandle> + Send + Sync> + ); + handle_guard.insert(handle).begin_interrupt_scope() + } + }; + // If we see an Interrupted error when beginning the scope, it means that we're in shutdown + // mode. + result.map_err(|_| Error::Shutdown) + } + + /// Close the database if it's open + /// + /// Pass interrupt=true to interrupt any in-progress queries before closing the database. + /// + /// Do not call `close` if you already have a lock on the database in the current thread, as + /// this will cause a deadlock. + pub fn close(&self, interrupt: bool) { + let mut interrupt_handle = self.interrupt_handle.lock(); + if let Some(handle) = interrupt_handle.as_ref() { + if interrupt { + handle.interrupt(); + } + *interrupt_handle = None; + } + // Drop the interrupt handle lock to avoid holding both locks at once. + drop(interrupt_handle); + *self.connection.lock() = None; + } + + /// Interrupt any in-progress queries + pub fn interrupt(&self) { + if let Some(handle) = self.interrupt_handle.lock().as_ref() { + handle.interrupt(); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::open_database::test_utils::TestConnectionInitializer; + + fn open_test_db() -> LazyDb<TestConnectionInitializer> { + LazyDb::new( + Path::new(":memory:"), + OpenFlags::default(), + TestConnectionInitializer::new(), + ) + } + + #[test] + fn test_interrupt() { + let lazy_db = open_test_db(); + let (_, scope) = lazy_db.lock().unwrap(); + assert!(!scope.was_interrupted()); + lazy_db.interrupt(); + assert!(scope.was_interrupted()); + } + + #[test] + fn interrupt_before_db_is_opened_should_not_fail() { + let lazy_db = open_test_db(); + lazy_db.interrupt(); + } +} diff --git a/third_party/rust/sql-support/src/lib.rs b/third_party/rust/sql-support/src/lib.rs index 2ece560b4d..5e8dfcea29 100644 --- a/third_party/rust/sql-support/src/lib.rs +++ b/third_party/rust/sql-support/src/lib.rs @@ -10,14 +10,16 @@ mod conn_ext; pub mod debug_tools; mod each_chunk; +mod lazy; mod maybe_cached; pub mod open_database; mod repeat; -pub use crate::conn_ext::*; -pub use crate::each_chunk::*; -pub use crate::maybe_cached::*; -pub use crate::repeat::*; +pub use conn_ext::*; +pub use each_chunk::*; +pub use lazy::*; +pub use maybe_cached::*; +pub use repeat::*; /// In PRAGMA foo='bar', `'bar'` must be a constant string (it cannot be a /// bound parameter), so we need to escape manually. According to diff --git a/third_party/rust/sql-support/src/open_database.rs b/third_party/rust/sql-support/src/open_database.rs index d92a94a9ed..9096b796a3 100644 --- a/third_party/rust/sql-support/src/open_database.rs +++ b/third_party/rust/sql-support/src/open_database.rs @@ -46,6 +46,8 @@ pub enum Error { SqlError(rusqlite::Error), #[error("Failed to recover a corrupt database due to an error deleting the file: {0}")] RecoveryError(std::io::Error), + #[error("In shutdown mode")] + Shutdown, } impl From<rusqlite::Error> for Error { @@ -241,99 +243,28 @@ fn set_schema_version(conn: &Connection, version: u32) -> Result<()> { // our other crates. pub mod test_utils { use super::*; - use std::path::PathBuf; + use std::{cell::RefCell, collections::HashSet, path::PathBuf}; use tempfile::TempDir; - // Database file that we can programatically run upgrades on - // - // We purposefully don't keep a connection to the database around to force upgrades to always - // run against a newly opened DB, like they would in the real world. See #4106 for - // details. - pub struct MigratedDatabaseFile<CI: ConnectionInitializer> { - // Keep around a TempDir to ensure the database file stays around until this struct is - // dropped - _tempdir: TempDir, - pub connection_initializer: CI, - pub path: PathBuf, + pub struct TestConnectionInitializer { + pub calls: RefCell<Vec<&'static str>>, + pub buggy_v3_upgrade: bool, } - impl<CI: ConnectionInitializer> MigratedDatabaseFile<CI> { - pub fn new(connection_initializer: CI, init_sql: &str) -> Self { - Self::new_with_flags(connection_initializer, init_sql, OpenFlags::default()) - } - - pub fn new_with_flags( - connection_initializer: CI, - init_sql: &str, - open_flags: OpenFlags, - ) -> Self { - let tempdir = tempfile::tempdir().unwrap(); - let path = tempdir.path().join(Path::new("db.sql")); - let conn = Connection::open_with_flags(&path, open_flags).unwrap(); - conn.execute_batch(init_sql).unwrap(); - Self { - _tempdir: tempdir, - connection_initializer, - path, - } - } - - /// Attempt to run all upgrades up to a specific version. - /// - /// This will result in a panic if an upgrade fails to run. - pub fn upgrade_to(&self, version: u32) { - let mut conn = self.open(); - let tx = conn.transaction().unwrap(); - let mut current_version = get_schema_version(&tx).unwrap(); - while current_version < version { - self.connection_initializer - .upgrade_from(&tx, current_version) - .unwrap(); - current_version += 1; - } - set_schema_version(&tx, current_version).unwrap(); - self.connection_initializer.finish(&tx).unwrap(); - tx.commit().unwrap(); - } - - /// Attempt to run all upgrades - /// - /// This will result in a panic if an upgrade fails to run. - pub fn run_all_upgrades(&self) { - let current_version = get_schema_version(&self.open()).unwrap(); - for version in current_version..CI::END_VERSION { - self.upgrade_to(version + 1); - } - } - - pub fn open(&self) -> Connection { - Connection::open(&self.path).unwrap() + impl Default for TestConnectionInitializer { + fn default() -> Self { + Self::new() } } -} - -#[cfg(test)] -mod test { - use super::test_utils::MigratedDatabaseFile; - use super::*; - use std::cell::RefCell; - use std::io::Write; - - struct TestConnectionInitializer { - pub calls: RefCell<Vec<&'static str>>, - pub buggy_v3_upgrade: bool, - } impl TestConnectionInitializer { pub fn new() -> Self { - let _ = env_logger::try_init(); Self { calls: RefCell::new(Vec::new()), buggy_v3_upgrade: false, } } pub fn new_with_buggy_logic() -> Self { - let _ = env_logger::try_init(); Self { calls: RefCell::new(Vec::new()), buggy_v3_upgrade: true, @@ -427,6 +358,123 @@ mod test { } } + // Database file that we can programatically run upgrades on + // + // We purposefully don't keep a connection to the database around to force upgrades to always + // run against a newly opened DB, like they would in the real world. See #4106 for + // details. + pub struct MigratedDatabaseFile<CI: ConnectionInitializer> { + // Keep around a TempDir to ensure the database file stays around until this struct is + // dropped + _tempdir: TempDir, + pub connection_initializer: CI, + pub path: PathBuf, + } + + impl<CI: ConnectionInitializer> MigratedDatabaseFile<CI> { + pub fn new(connection_initializer: CI, init_sql: &str) -> Self { + Self::new_with_flags(connection_initializer, init_sql, OpenFlags::default()) + } + + pub fn new_with_flags( + connection_initializer: CI, + init_sql: &str, + open_flags: OpenFlags, + ) -> Self { + let tempdir = tempfile::tempdir().unwrap(); + let path = tempdir.path().join(Path::new("db.sql")); + let conn = Connection::open_with_flags(&path, open_flags).unwrap(); + conn.execute_batch(init_sql).unwrap(); + Self { + _tempdir: tempdir, + connection_initializer, + path, + } + } + + /// Attempt to run all upgrades up to a specific version. + /// + /// This will result in a panic if an upgrade fails to run. + pub fn upgrade_to(&self, version: u32) { + let mut conn = self.open(); + let tx = conn.transaction().unwrap(); + let mut current_version = get_schema_version(&tx).unwrap(); + while current_version < version { + self.connection_initializer + .upgrade_from(&tx, current_version) + .unwrap(); + current_version += 1; + } + set_schema_version(&tx, current_version).unwrap(); + self.connection_initializer.finish(&tx).unwrap(); + tx.commit().unwrap(); + } + + /// Attempt to run all upgrades + /// + /// This will result in a panic if an upgrade fails to run. + pub fn run_all_upgrades(&self) { + let current_version = get_schema_version(&self.open()).unwrap(); + for version in current_version..CI::END_VERSION { + self.upgrade_to(version + 1); + } + } + + pub fn assert_schema_matches_new_database(&self) { + let db = self.open(); + let new_db = open_memory_database(&self.connection_initializer).unwrap(); + let table_names = get_table_names(&db); + let new_db_table_names = get_table_names(&new_db); + let extra_tables = Vec::from_iter(table_names.difference(&new_db_table_names)); + if !extra_tables.is_empty() { + panic!("Extra tables not present in new database: {extra_tables:?}"); + } + let new_db_extra_tables = Vec::from_iter(new_db_table_names.difference(&table_names)); + if !new_db_extra_tables.is_empty() { + panic!("Extra tables only present in new database: {new_db_extra_tables:?}"); + } + + for table_name in table_names { + assert_eq!( + get_table_sql(&db, &table_name), + get_table_sql(&new_db, &table_name), + "sql differs for table: {table_name}", + ); + } + } + + pub fn open(&self) -> Connection { + Connection::open(&self.path).unwrap() + } + } + + fn get_table_names(conn: &Connection) -> HashSet<String> { + conn.query_rows_and_then( + "SELECT name FROM sqlite_master WHERE type='table'", + (), + |row| row.get(0), + ) + .unwrap() + .into_iter() + .collect() + } + + fn get_table_sql(conn: &Connection, table_name: &str) -> String { + conn.query_row_and_then( + "SELECT sql FROM sqlite_master WHERE name = ? AND type='table'", + (&table_name,), + |row| row.get::<_, String>(0), + ) + .unwrap() + } +} + +#[cfg(test)] +mod test { + use super::test_utils::{MigratedDatabaseFile, TestConnectionInitializer}; + use super::*; + use std::io::Write; + // A special schema used to test the upgrade that forces the database to be // replaced. static INIT_V1: &str = " |