summaryrefslogtreecommitdiffstats
path: root/third_party/rust/sql-support/src
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/sql-support/src')
-rw-r--r--third_party/rust/sql-support/src/lazy.rs151
-rw-r--r--third_party/rust/sql-support/src/lib.rs10
-rw-r--r--third_party/rust/sql-support/src/open_database.rs204
3 files changed, 283 insertions, 82 deletions
diff --git a/third_party/rust/sql-support/src/lazy.rs b/third_party/rust/sql-support/src/lazy.rs
new file mode 100644
index 0000000000..b22d9c39e3
--- /dev/null
+++ b/third_party/rust/sql-support/src/lazy.rs
@@ -0,0 +1,151 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use crate::open_database::{open_database_with_flags, ConnectionInitializer, Error};
+use interrupt_support::{register_interrupt, SqlInterruptHandle, SqlInterruptScope};
+use parking_lot::{MappedMutexGuard, Mutex, MutexGuard};
+use rusqlite::{Connection, OpenFlags};
+use std::{
+ path::{Path, PathBuf},
+ sync::{Arc, Weak},
+};
+
+/// Lazily-loaded database with interruption support
+///
+/// In addition to the [Self::interrupt] method, LazyDb also calls
+/// [interrupt_support::register_interrupt] on any opened database. This means that if
+/// [interrupt_support::shutdown] is called it will interrupt this database if it's open and
+/// in-use.
+pub struct LazyDb<CI> {
+ path: PathBuf,
+ open_flags: OpenFlags,
+ connection_initializer: CI,
+ // Note: if you're going to lock both mutexes at once, make sure to lock the connection mutex
+ // first. Otherwise, you risk creating a deadlock where two threads each hold one of the locks
+ // and is waiting for the other.
+ connection: Mutex<Option<Connection>>,
+ // It's important to use a separate mutex for the interrupt handle, since the whole point is to
+ // interrupt while another thread holds the connection mutex. Since the only mutation is
+ // setting/unsetting the Arc, maybe this could be sped up by using something like
+ // `arc_swap::ArcSwap`, but that seems like overkill for our purposes. This mutex should rarely
+ // be contested and interrupt operations execute quickly.
+ interrupt_handle: Mutex<Option<Arc<SqlInterruptHandle>>>,
+}
+
+impl<CI: ConnectionInitializer> LazyDb<CI> {
+ /// Create a new LazyDb
+ ///
+ /// This does not open the connection and is non-blocking
+ pub fn new(path: &Path, open_flags: OpenFlags, connection_initializer: CI) -> Self {
+ Self {
+ path: path.to_owned(),
+ open_flags,
+ connection_initializer,
+ connection: Mutex::new(None),
+ interrupt_handle: Mutex::new(None),
+ }
+ }
+
+ /// Lock the database mutex and get a connection and interrupt scope.
+ ///
+ /// If the connection is closed, it will be opened.
+ ///
+ /// Calling `lock` again, or calling `close`, from the same thread while the mutex guard is
+ /// still alive will cause a deadlock.
+ pub fn lock(&self) -> Result<(MappedMutexGuard<'_, Connection>, SqlInterruptScope), Error> {
+ // Call get_conn first, then get_scope to ensure we acquire the locks in the correct order
+ let conn = self.get_conn()?;
+ let scope = self.get_scope(&conn)?;
+ Ok((conn, scope))
+ }
+
+ fn get_conn(&self) -> Result<MappedMutexGuard<'_, Connection>, Error> {
+ let mut guard = self.connection.lock();
+ // Open the database if it wasn't opened before. Do this outside of the MutexGuard::map call to simplify the error handling
+ if guard.is_none() {
+ *guard = Some(open_database_with_flags(
+ &self.path,
+ self.open_flags,
+ &self.connection_initializer,
+ )?);
+ };
+ // Use MutexGuard::map to get a Connection rather than Option<Connection>. The unwrap()
+ // call can't fail because of the previous code.
+ Ok(MutexGuard::map(guard, |conn_option| {
+ conn_option.as_mut().unwrap()
+ }))
+ }
+
+ fn get_scope(&self, conn: &Connection) -> Result<SqlInterruptScope, Error> {
+ let mut handle_guard = self.interrupt_handle.lock();
+ let result = match handle_guard.as_ref() {
+ Some(handle) => handle.begin_interrupt_scope(),
+ None => {
+ let handle = Arc::new(SqlInterruptHandle::new(conn));
+ register_interrupt(
+ Arc::downgrade(&handle) as Weak<dyn AsRef<SqlInterruptHandle> + Send + Sync>
+ );
+ handle_guard.insert(handle).begin_interrupt_scope()
+ }
+ };
+ // If we see an Interrupted error when beginning the scope, it means that we're in shutdown
+ // mode.
+ result.map_err(|_| Error::Shutdown)
+ }
+
+ /// Close the database if it's open
+ ///
+ /// Pass interrupt=true to interrupt any in-progress queries before closing the database.
+ ///
+ /// Do not call `close` if you already have a lock on the database in the current thread, as
+ /// this will cause a deadlock.
+ pub fn close(&self, interrupt: bool) {
+ let mut interrupt_handle = self.interrupt_handle.lock();
+ if let Some(handle) = interrupt_handle.as_ref() {
+ if interrupt {
+ handle.interrupt();
+ }
+ *interrupt_handle = None;
+ }
+ // Drop the interrupt handle lock to avoid holding both locks at once.
+ drop(interrupt_handle);
+ *self.connection.lock() = None;
+ }
+
+ /// Interrupt any in-progress queries
+ pub fn interrupt(&self) {
+ if let Some(handle) = self.interrupt_handle.lock().as_ref() {
+ handle.interrupt();
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::open_database::test_utils::TestConnectionInitializer;
+
+ fn open_test_db() -> LazyDb<TestConnectionInitializer> {
+ LazyDb::new(
+ Path::new(":memory:"),
+ OpenFlags::default(),
+ TestConnectionInitializer::new(),
+ )
+ }
+
+ #[test]
+ fn test_interrupt() {
+ let lazy_db = open_test_db();
+ let (_, scope) = lazy_db.lock().unwrap();
+ assert!(!scope.was_interrupted());
+ lazy_db.interrupt();
+ assert!(scope.was_interrupted());
+ }
+
+ #[test]
+ fn interrupt_before_db_is_opened_should_not_fail() {
+ let lazy_db = open_test_db();
+ lazy_db.interrupt();
+ }
+}
diff --git a/third_party/rust/sql-support/src/lib.rs b/third_party/rust/sql-support/src/lib.rs
index 2ece560b4d..5e8dfcea29 100644
--- a/third_party/rust/sql-support/src/lib.rs
+++ b/third_party/rust/sql-support/src/lib.rs
@@ -10,14 +10,16 @@
mod conn_ext;
pub mod debug_tools;
mod each_chunk;
+mod lazy;
mod maybe_cached;
pub mod open_database;
mod repeat;
-pub use crate::conn_ext::*;
-pub use crate::each_chunk::*;
-pub use crate::maybe_cached::*;
-pub use crate::repeat::*;
+pub use conn_ext::*;
+pub use each_chunk::*;
+pub use lazy::*;
+pub use maybe_cached::*;
+pub use repeat::*;
/// In PRAGMA foo='bar', `'bar'` must be a constant string (it cannot be a
/// bound parameter), so we need to escape manually. According to
diff --git a/third_party/rust/sql-support/src/open_database.rs b/third_party/rust/sql-support/src/open_database.rs
index d92a94a9ed..9096b796a3 100644
--- a/third_party/rust/sql-support/src/open_database.rs
+++ b/third_party/rust/sql-support/src/open_database.rs
@@ -46,6 +46,8 @@ pub enum Error {
SqlError(rusqlite::Error),
#[error("Failed to recover a corrupt database due to an error deleting the file: {0}")]
RecoveryError(std::io::Error),
+ #[error("In shutdown mode")]
+ Shutdown,
}
impl From<rusqlite::Error> for Error {
@@ -241,99 +243,28 @@ fn set_schema_version(conn: &Connection, version: u32) -> Result<()> {
// our other crates.
pub mod test_utils {
use super::*;
- use std::path::PathBuf;
+ use std::{cell::RefCell, collections::HashSet, path::PathBuf};
use tempfile::TempDir;
- // Database file that we can programatically run upgrades on
- //
- // We purposefully don't keep a connection to the database around to force upgrades to always
- // run against a newly opened DB, like they would in the real world. See #4106 for
- // details.
- pub struct MigratedDatabaseFile<CI: ConnectionInitializer> {
- // Keep around a TempDir to ensure the database file stays around until this struct is
- // dropped
- _tempdir: TempDir,
- pub connection_initializer: CI,
- pub path: PathBuf,
+ pub struct TestConnectionInitializer {
+ pub calls: RefCell<Vec<&'static str>>,
+ pub buggy_v3_upgrade: bool,
}
- impl<CI: ConnectionInitializer> MigratedDatabaseFile<CI> {
- pub fn new(connection_initializer: CI, init_sql: &str) -> Self {
- Self::new_with_flags(connection_initializer, init_sql, OpenFlags::default())
- }
-
- pub fn new_with_flags(
- connection_initializer: CI,
- init_sql: &str,
- open_flags: OpenFlags,
- ) -> Self {
- let tempdir = tempfile::tempdir().unwrap();
- let path = tempdir.path().join(Path::new("db.sql"));
- let conn = Connection::open_with_flags(&path, open_flags).unwrap();
- conn.execute_batch(init_sql).unwrap();
- Self {
- _tempdir: tempdir,
- connection_initializer,
- path,
- }
- }
-
- /// Attempt to run all upgrades up to a specific version.
- ///
- /// This will result in a panic if an upgrade fails to run.
- pub fn upgrade_to(&self, version: u32) {
- let mut conn = self.open();
- let tx = conn.transaction().unwrap();
- let mut current_version = get_schema_version(&tx).unwrap();
- while current_version < version {
- self.connection_initializer
- .upgrade_from(&tx, current_version)
- .unwrap();
- current_version += 1;
- }
- set_schema_version(&tx, current_version).unwrap();
- self.connection_initializer.finish(&tx).unwrap();
- tx.commit().unwrap();
- }
-
- /// Attempt to run all upgrades
- ///
- /// This will result in a panic if an upgrade fails to run.
- pub fn run_all_upgrades(&self) {
- let current_version = get_schema_version(&self.open()).unwrap();
- for version in current_version..CI::END_VERSION {
- self.upgrade_to(version + 1);
- }
- }
-
- pub fn open(&self) -> Connection {
- Connection::open(&self.path).unwrap()
+ impl Default for TestConnectionInitializer {
+ fn default() -> Self {
+ Self::new()
}
}
-}
-
-#[cfg(test)]
-mod test {
- use super::test_utils::MigratedDatabaseFile;
- use super::*;
- use std::cell::RefCell;
- use std::io::Write;
-
- struct TestConnectionInitializer {
- pub calls: RefCell<Vec<&'static str>>,
- pub buggy_v3_upgrade: bool,
- }
impl TestConnectionInitializer {
pub fn new() -> Self {
- let _ = env_logger::try_init();
Self {
calls: RefCell::new(Vec::new()),
buggy_v3_upgrade: false,
}
}
pub fn new_with_buggy_logic() -> Self {
- let _ = env_logger::try_init();
Self {
calls: RefCell::new(Vec::new()),
buggy_v3_upgrade: true,
@@ -427,6 +358,123 @@ mod test {
}
}
+ // Database file that we can programatically run upgrades on
+ //
+ // We purposefully don't keep a connection to the database around to force upgrades to always
+ // run against a newly opened DB, like they would in the real world. See #4106 for
+ // details.
+ pub struct MigratedDatabaseFile<CI: ConnectionInitializer> {
+ // Keep around a TempDir to ensure the database file stays around until this struct is
+ // dropped
+ _tempdir: TempDir,
+ pub connection_initializer: CI,
+ pub path: PathBuf,
+ }
+
+ impl<CI: ConnectionInitializer> MigratedDatabaseFile<CI> {
+ pub fn new(connection_initializer: CI, init_sql: &str) -> Self {
+ Self::new_with_flags(connection_initializer, init_sql, OpenFlags::default())
+ }
+
+ pub fn new_with_flags(
+ connection_initializer: CI,
+ init_sql: &str,
+ open_flags: OpenFlags,
+ ) -> Self {
+ let tempdir = tempfile::tempdir().unwrap();
+ let path = tempdir.path().join(Path::new("db.sql"));
+ let conn = Connection::open_with_flags(&path, open_flags).unwrap();
+ conn.execute_batch(init_sql).unwrap();
+ Self {
+ _tempdir: tempdir,
+ connection_initializer,
+ path,
+ }
+ }
+
+ /// Attempt to run all upgrades up to a specific version.
+ ///
+ /// This will result in a panic if an upgrade fails to run.
+ pub fn upgrade_to(&self, version: u32) {
+ let mut conn = self.open();
+ let tx = conn.transaction().unwrap();
+ let mut current_version = get_schema_version(&tx).unwrap();
+ while current_version < version {
+ self.connection_initializer
+ .upgrade_from(&tx, current_version)
+ .unwrap();
+ current_version += 1;
+ }
+ set_schema_version(&tx, current_version).unwrap();
+ self.connection_initializer.finish(&tx).unwrap();
+ tx.commit().unwrap();
+ }
+
+ /// Attempt to run all upgrades
+ ///
+ /// This will result in a panic if an upgrade fails to run.
+ pub fn run_all_upgrades(&self) {
+ let current_version = get_schema_version(&self.open()).unwrap();
+ for version in current_version..CI::END_VERSION {
+ self.upgrade_to(version + 1);
+ }
+ }
+
+ pub fn assert_schema_matches_new_database(&self) {
+ let db = self.open();
+ let new_db = open_memory_database(&self.connection_initializer).unwrap();
+ let table_names = get_table_names(&db);
+ let new_db_table_names = get_table_names(&new_db);
+ let extra_tables = Vec::from_iter(table_names.difference(&new_db_table_names));
+ if !extra_tables.is_empty() {
+ panic!("Extra tables not present in new database: {extra_tables:?}");
+ }
+ let new_db_extra_tables = Vec::from_iter(new_db_table_names.difference(&table_names));
+ if !new_db_extra_tables.is_empty() {
+ panic!("Extra tables only present in new database: {new_db_extra_tables:?}");
+ }
+
+ for table_name in table_names {
+ assert_eq!(
+ get_table_sql(&db, &table_name),
+ get_table_sql(&new_db, &table_name),
+ "sql differs for table: {table_name}",
+ );
+ }
+ }
+
+ pub fn open(&self) -> Connection {
+ Connection::open(&self.path).unwrap()
+ }
+ }
+
+ fn get_table_names(conn: &Connection) -> HashSet<String> {
+ conn.query_rows_and_then(
+ "SELECT name FROM sqlite_master WHERE type='table'",
+ (),
+ |row| row.get(0),
+ )
+ .unwrap()
+ .into_iter()
+ .collect()
+ }
+
+ fn get_table_sql(conn: &Connection, table_name: &str) -> String {
+ conn.query_row_and_then(
+ "SELECT sql FROM sqlite_master WHERE name = ? AND type='table'",
+ (&table_name,),
+ |row| row.get::<_, String>(0),
+ )
+ .unwrap()
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::test_utils::{MigratedDatabaseFile, TestConnectionInitializer};
+ use super::*;
+ use std::io::Write;
+
// A special schema used to test the upgrade that forces the database to be
// replaced.
static INIT_V1: &str = "