diff options
Diffstat (limited to 'third_party/rust/sql-support/src/open_database.rs')
-rw-r--r-- | third_party/rust/sql-support/src/open_database.rs | 204 |
1 files changed, 126 insertions, 78 deletions
diff --git a/third_party/rust/sql-support/src/open_database.rs b/third_party/rust/sql-support/src/open_database.rs index d92a94a9ed..9096b796a3 100644 --- a/third_party/rust/sql-support/src/open_database.rs +++ b/third_party/rust/sql-support/src/open_database.rs @@ -46,6 +46,8 @@ pub enum Error { SqlError(rusqlite::Error), #[error("Failed to recover a corrupt database due to an error deleting the file: {0}")] RecoveryError(std::io::Error), + #[error("In shutdown mode")] + Shutdown, } impl From<rusqlite::Error> for Error { @@ -241,99 +243,28 @@ fn set_schema_version(conn: &Connection, version: u32) -> Result<()> { // our other crates. pub mod test_utils { use super::*; - use std::path::PathBuf; + use std::{cell::RefCell, collections::HashSet, path::PathBuf}; use tempfile::TempDir; - // Database file that we can programatically run upgrades on - // - // We purposefully don't keep a connection to the database around to force upgrades to always - // run against a newly opened DB, like they would in the real world. See #4106 for - // details. - pub struct MigratedDatabaseFile<CI: ConnectionInitializer> { - // Keep around a TempDir to ensure the database file stays around until this struct is - // dropped - _tempdir: TempDir, - pub connection_initializer: CI, - pub path: PathBuf, + pub struct TestConnectionInitializer { + pub calls: RefCell<Vec<&'static str>>, + pub buggy_v3_upgrade: bool, } - impl<CI: ConnectionInitializer> MigratedDatabaseFile<CI> { - pub fn new(connection_initializer: CI, init_sql: &str) -> Self { - Self::new_with_flags(connection_initializer, init_sql, OpenFlags::default()) - } - - pub fn new_with_flags( - connection_initializer: CI, - init_sql: &str, - open_flags: OpenFlags, - ) -> Self { - let tempdir = tempfile::tempdir().unwrap(); - let path = tempdir.path().join(Path::new("db.sql")); - let conn = Connection::open_with_flags(&path, open_flags).unwrap(); - conn.execute_batch(init_sql).unwrap(); - Self { - _tempdir: tempdir, - connection_initializer, - path, - } - } - - /// Attempt to run all upgrades up to a specific version. - /// - /// This will result in a panic if an upgrade fails to run. - pub fn upgrade_to(&self, version: u32) { - let mut conn = self.open(); - let tx = conn.transaction().unwrap(); - let mut current_version = get_schema_version(&tx).unwrap(); - while current_version < version { - self.connection_initializer - .upgrade_from(&tx, current_version) - .unwrap(); - current_version += 1; - } - set_schema_version(&tx, current_version).unwrap(); - self.connection_initializer.finish(&tx).unwrap(); - tx.commit().unwrap(); - } - - /// Attempt to run all upgrades - /// - /// This will result in a panic if an upgrade fails to run. - pub fn run_all_upgrades(&self) { - let current_version = get_schema_version(&self.open()).unwrap(); - for version in current_version..CI::END_VERSION { - self.upgrade_to(version + 1); - } - } - - pub fn open(&self) -> Connection { - Connection::open(&self.path).unwrap() + impl Default for TestConnectionInitializer { + fn default() -> Self { + Self::new() } } -} - -#[cfg(test)] -mod test { - use super::test_utils::MigratedDatabaseFile; - use super::*; - use std::cell::RefCell; - use std::io::Write; - - struct TestConnectionInitializer { - pub calls: RefCell<Vec<&'static str>>, - pub buggy_v3_upgrade: bool, - } impl TestConnectionInitializer { pub fn new() -> Self { - let _ = env_logger::try_init(); Self { calls: RefCell::new(Vec::new()), buggy_v3_upgrade: false, } } pub fn new_with_buggy_logic() -> Self { - let _ = env_logger::try_init(); Self { calls: RefCell::new(Vec::new()), buggy_v3_upgrade: true, @@ -427,6 +358,123 @@ mod test { } } + // Database file that we can programatically run upgrades on + // + // We purposefully don't keep a connection to the database around to force upgrades to always + // run against a newly opened DB, like they would in the real world. See #4106 for + // details. + pub struct MigratedDatabaseFile<CI: ConnectionInitializer> { + // Keep around a TempDir to ensure the database file stays around until this struct is + // dropped + _tempdir: TempDir, + pub connection_initializer: CI, + pub path: PathBuf, + } + + impl<CI: ConnectionInitializer> MigratedDatabaseFile<CI> { + pub fn new(connection_initializer: CI, init_sql: &str) -> Self { + Self::new_with_flags(connection_initializer, init_sql, OpenFlags::default()) + } + + pub fn new_with_flags( + connection_initializer: CI, + init_sql: &str, + open_flags: OpenFlags, + ) -> Self { + let tempdir = tempfile::tempdir().unwrap(); + let path = tempdir.path().join(Path::new("db.sql")); + let conn = Connection::open_with_flags(&path, open_flags).unwrap(); + conn.execute_batch(init_sql).unwrap(); + Self { + _tempdir: tempdir, + connection_initializer, + path, + } + } + + /// Attempt to run all upgrades up to a specific version. + /// + /// This will result in a panic if an upgrade fails to run. + pub fn upgrade_to(&self, version: u32) { + let mut conn = self.open(); + let tx = conn.transaction().unwrap(); + let mut current_version = get_schema_version(&tx).unwrap(); + while current_version < version { + self.connection_initializer + .upgrade_from(&tx, current_version) + .unwrap(); + current_version += 1; + } + set_schema_version(&tx, current_version).unwrap(); + self.connection_initializer.finish(&tx).unwrap(); + tx.commit().unwrap(); + } + + /// Attempt to run all upgrades + /// + /// This will result in a panic if an upgrade fails to run. + pub fn run_all_upgrades(&self) { + let current_version = get_schema_version(&self.open()).unwrap(); + for version in current_version..CI::END_VERSION { + self.upgrade_to(version + 1); + } + } + + pub fn assert_schema_matches_new_database(&self) { + let db = self.open(); + let new_db = open_memory_database(&self.connection_initializer).unwrap(); + let table_names = get_table_names(&db); + let new_db_table_names = get_table_names(&new_db); + let extra_tables = Vec::from_iter(table_names.difference(&new_db_table_names)); + if !extra_tables.is_empty() { + panic!("Extra tables not present in new database: {extra_tables:?}"); + } + let new_db_extra_tables = Vec::from_iter(new_db_table_names.difference(&table_names)); + if !new_db_extra_tables.is_empty() { + panic!("Extra tables only present in new database: {new_db_extra_tables:?}"); + } + + for table_name in table_names { + assert_eq!( + get_table_sql(&db, &table_name), + get_table_sql(&new_db, &table_name), + "sql differs for table: {table_name}", + ); + } + } + + pub fn open(&self) -> Connection { + Connection::open(&self.path).unwrap() + } + } + + fn get_table_names(conn: &Connection) -> HashSet<String> { + conn.query_rows_and_then( + "SELECT name FROM sqlite_master WHERE type='table'", + (), + |row| row.get(0), + ) + .unwrap() + .into_iter() + .collect() + } + + fn get_table_sql(conn: &Connection, table_name: &str) -> String { + conn.query_row_and_then( + "SELECT sql FROM sqlite_master WHERE name = ? AND type='table'", + (&table_name,), + |row| row.get::<_, String>(0), + ) + .unwrap() + } +} + +#[cfg(test)] +mod test { + use super::test_utils::{MigratedDatabaseFile, TestConnectionInitializer}; + use super::*; + use std::io::Write; + // A special schema used to test the upgrade that forces the database to be // replaced. static INIT_V1: &str = " |