summaryrefslogtreecommitdiffstats
path: root/third_party/rust/sql-support/src/open_database.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/sql-support/src/open_database.rs')
-rw-r--r--third_party/rust/sql-support/src/open_database.rs204
1 files changed, 126 insertions, 78 deletions
diff --git a/third_party/rust/sql-support/src/open_database.rs b/third_party/rust/sql-support/src/open_database.rs
index d92a94a9ed..9096b796a3 100644
--- a/third_party/rust/sql-support/src/open_database.rs
+++ b/third_party/rust/sql-support/src/open_database.rs
@@ -46,6 +46,8 @@ pub enum Error {
SqlError(rusqlite::Error),
#[error("Failed to recover a corrupt database due to an error deleting the file: {0}")]
RecoveryError(std::io::Error),
+ #[error("In shutdown mode")]
+ Shutdown,
}
impl From<rusqlite::Error> for Error {
@@ -241,99 +243,28 @@ fn set_schema_version(conn: &Connection, version: u32) -> Result<()> {
// our other crates.
pub mod test_utils {
use super::*;
- use std::path::PathBuf;
+ use std::{cell::RefCell, collections::HashSet, path::PathBuf};
use tempfile::TempDir;
- // Database file that we can programatically run upgrades on
- //
- // We purposefully don't keep a connection to the database around to force upgrades to always
- // run against a newly opened DB, like they would in the real world. See #4106 for
- // details.
- pub struct MigratedDatabaseFile<CI: ConnectionInitializer> {
- // Keep around a TempDir to ensure the database file stays around until this struct is
- // dropped
- _tempdir: TempDir,
- pub connection_initializer: CI,
- pub path: PathBuf,
+ pub struct TestConnectionInitializer {
+ pub calls: RefCell<Vec<&'static str>>,
+ pub buggy_v3_upgrade: bool,
}
- impl<CI: ConnectionInitializer> MigratedDatabaseFile<CI> {
- pub fn new(connection_initializer: CI, init_sql: &str) -> Self {
- Self::new_with_flags(connection_initializer, init_sql, OpenFlags::default())
- }
-
- pub fn new_with_flags(
- connection_initializer: CI,
- init_sql: &str,
- open_flags: OpenFlags,
- ) -> Self {
- let tempdir = tempfile::tempdir().unwrap();
- let path = tempdir.path().join(Path::new("db.sql"));
- let conn = Connection::open_with_flags(&path, open_flags).unwrap();
- conn.execute_batch(init_sql).unwrap();
- Self {
- _tempdir: tempdir,
- connection_initializer,
- path,
- }
- }
-
- /// Attempt to run all upgrades up to a specific version.
- ///
- /// This will result in a panic if an upgrade fails to run.
- pub fn upgrade_to(&self, version: u32) {
- let mut conn = self.open();
- let tx = conn.transaction().unwrap();
- let mut current_version = get_schema_version(&tx).unwrap();
- while current_version < version {
- self.connection_initializer
- .upgrade_from(&tx, current_version)
- .unwrap();
- current_version += 1;
- }
- set_schema_version(&tx, current_version).unwrap();
- self.connection_initializer.finish(&tx).unwrap();
- tx.commit().unwrap();
- }
-
- /// Attempt to run all upgrades
- ///
- /// This will result in a panic if an upgrade fails to run.
- pub fn run_all_upgrades(&self) {
- let current_version = get_schema_version(&self.open()).unwrap();
- for version in current_version..CI::END_VERSION {
- self.upgrade_to(version + 1);
- }
- }
-
- pub fn open(&self) -> Connection {
- Connection::open(&self.path).unwrap()
+ impl Default for TestConnectionInitializer {
+ fn default() -> Self {
+ Self::new()
}
}
-}
-
-#[cfg(test)]
-mod test {
- use super::test_utils::MigratedDatabaseFile;
- use super::*;
- use std::cell::RefCell;
- use std::io::Write;
-
- struct TestConnectionInitializer {
- pub calls: RefCell<Vec<&'static str>>,
- pub buggy_v3_upgrade: bool,
- }
impl TestConnectionInitializer {
pub fn new() -> Self {
- let _ = env_logger::try_init();
Self {
calls: RefCell::new(Vec::new()),
buggy_v3_upgrade: false,
}
}
pub fn new_with_buggy_logic() -> Self {
- let _ = env_logger::try_init();
Self {
calls: RefCell::new(Vec::new()),
buggy_v3_upgrade: true,
@@ -427,6 +358,123 @@ mod test {
}
}
+ // Database file that we can programatically run upgrades on
+ //
+ // We purposefully don't keep a connection to the database around to force upgrades to always
+ // run against a newly opened DB, like they would in the real world. See #4106 for
+ // details.
+ pub struct MigratedDatabaseFile<CI: ConnectionInitializer> {
+ // Keep around a TempDir to ensure the database file stays around until this struct is
+ // dropped
+ _tempdir: TempDir,
+ pub connection_initializer: CI,
+ pub path: PathBuf,
+ }
+
+ impl<CI: ConnectionInitializer> MigratedDatabaseFile<CI> {
+ pub fn new(connection_initializer: CI, init_sql: &str) -> Self {
+ Self::new_with_flags(connection_initializer, init_sql, OpenFlags::default())
+ }
+
+ pub fn new_with_flags(
+ connection_initializer: CI,
+ init_sql: &str,
+ open_flags: OpenFlags,
+ ) -> Self {
+ let tempdir = tempfile::tempdir().unwrap();
+ let path = tempdir.path().join(Path::new("db.sql"));
+ let conn = Connection::open_with_flags(&path, open_flags).unwrap();
+ conn.execute_batch(init_sql).unwrap();
+ Self {
+ _tempdir: tempdir,
+ connection_initializer,
+ path,
+ }
+ }
+
+ /// Attempt to run all upgrades up to a specific version.
+ ///
+ /// This will result in a panic if an upgrade fails to run.
+ pub fn upgrade_to(&self, version: u32) {
+ let mut conn = self.open();
+ let tx = conn.transaction().unwrap();
+ let mut current_version = get_schema_version(&tx).unwrap();
+ while current_version < version {
+ self.connection_initializer
+ .upgrade_from(&tx, current_version)
+ .unwrap();
+ current_version += 1;
+ }
+ set_schema_version(&tx, current_version).unwrap();
+ self.connection_initializer.finish(&tx).unwrap();
+ tx.commit().unwrap();
+ }
+
+ /// Attempt to run all upgrades
+ ///
+ /// This will result in a panic if an upgrade fails to run.
+ pub fn run_all_upgrades(&self) {
+ let current_version = get_schema_version(&self.open()).unwrap();
+ for version in current_version..CI::END_VERSION {
+ self.upgrade_to(version + 1);
+ }
+ }
+
+ pub fn assert_schema_matches_new_database(&self) {
+ let db = self.open();
+ let new_db = open_memory_database(&self.connection_initializer).unwrap();
+ let table_names = get_table_names(&db);
+ let new_db_table_names = get_table_names(&new_db);
+ let extra_tables = Vec::from_iter(table_names.difference(&new_db_table_names));
+ if !extra_tables.is_empty() {
+ panic!("Extra tables not present in new database: {extra_tables:?}");
+ }
+ let new_db_extra_tables = Vec::from_iter(new_db_table_names.difference(&table_names));
+ if !new_db_extra_tables.is_empty() {
+ panic!("Extra tables only present in new database: {new_db_extra_tables:?}");
+ }
+
+ for table_name in table_names {
+ assert_eq!(
+ get_table_sql(&db, &table_name),
+ get_table_sql(&new_db, &table_name),
+ "sql differs for table: {table_name}",
+ );
+ }
+ }
+
+ pub fn open(&self) -> Connection {
+ Connection::open(&self.path).unwrap()
+ }
+ }
+
+ fn get_table_names(conn: &Connection) -> HashSet<String> {
+ conn.query_rows_and_then(
+ "SELECT name FROM sqlite_master WHERE type='table'",
+ (),
+ |row| row.get(0),
+ )
+ .unwrap()
+ .into_iter()
+ .collect()
+ }
+
+ fn get_table_sql(conn: &Connection, table_name: &str) -> String {
+ conn.query_row_and_then(
+ "SELECT sql FROM sqlite_master WHERE name = ? AND type='table'",
+ (&table_name,),
+ |row| row.get::<_, String>(0),
+ )
+ .unwrap()
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::test_utils::{MigratedDatabaseFile, TestConnectionInitializer};
+ use super::*;
+ use std::io::Write;
+
// A special schema used to test the upgrade that forces the database to be
// replaced.
static INIT_V1: &str = "