summaryrefslogtreecommitdiffstats
path: root/third_party/rust/sql-support/src
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 09:22:09 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 09:22:09 +0000
commit43a97878ce14b72f0981164f87f2e35e14151312 (patch)
tree620249daf56c0258faa40cbdcf9cfba06de2a846 /third_party/rust/sql-support/src
parentInitial commit. (diff)
downloadfirefox-43a97878ce14b72f0981164f87f2e35e14151312.tar.xz
firefox-43a97878ce14b72f0981164f87f2e35e14151312.zip
Adding upstream version 110.0.1.upstream/110.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/sql-support/src')
-rw-r--r--third_party/rust/sql-support/src/conn_ext.rs402
-rw-r--r--third_party/rust/sql-support/src/each_chunk.rs311
-rw-r--r--third_party/rust/sql-support/src/lib.rs36
-rw-r--r--third_party/rust/sql-support/src/maybe_cached.rs64
-rw-r--r--third_party/rust/sql-support/src/open_database.rs536
-rw-r--r--third_party/rust/sql-support/src/repeat.rs113
6 files changed, 1462 insertions, 0 deletions
diff --git a/third_party/rust/sql-support/src/conn_ext.rs b/third_party/rust/sql-support/src/conn_ext.rs
new file mode 100644
index 0000000000..2efec41b7c
--- /dev/null
+++ b/third_party/rust/sql-support/src/conn_ext.rs
@@ -0,0 +1,402 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use rusqlite::{
+ self,
+ types::{FromSql, ToSql},
+ Connection, Params, Result as SqlResult, Row, Savepoint, Transaction, TransactionBehavior,
+};
+use std::iter::FromIterator;
+use std::ops::Deref;
+use std::time::Instant;
+
+use crate::maybe_cached::MaybeCached;
+
+pub struct Conn(rusqlite::Connection);
+
+/// This trait exists so that we can use these helpers on `rusqlite::{Transaction, Connection}`.
+/// Note that you must import ConnExt in order to call these methods on anything.
+pub trait ConnExt {
+ /// The method you need to implement to opt in to all of this.
+ fn conn(&self) -> &Connection;
+
+ /// Set the value of the pragma on the main database. Returns the same object, for chaining.
+ fn set_pragma<T>(&self, pragma_name: &str, pragma_value: T) -> SqlResult<&Self>
+ where
+ T: ToSql,
+ Self: Sized,
+ {
+ // None == Schema name, e.g. `PRAGMA some_attached_db.something = blah`
+ self.conn()
+ .pragma_update(None, pragma_name, &pragma_value)?;
+ Ok(self)
+ }
+
+ /// Get a cached or uncached statement based on a flag.
+ fn prepare_maybe_cached<'conn>(
+ &'conn self,
+ sql: &str,
+ cache: bool,
+ ) -> SqlResult<MaybeCached<'conn>> {
+ MaybeCached::prepare(self.conn(), sql, cache)
+ }
+
+ /// Execute all the provided statements.
+ fn execute_all(&self, stmts: &[&str]) -> SqlResult<()> {
+ let conn = self.conn();
+ for sql in stmts {
+ let r = conn.execute(sql, []);
+ match r {
+ Ok(_) => {}
+ // Ignore ExecuteReturnedResults error because they're pointless
+ // and annoying.
+ Err(rusqlite::Error::ExecuteReturnedResults) => {}
+ Err(e) => return Err(e),
+ }
+ }
+ Ok(())
+ }
+
+ /// Execute a single statement.
+ fn execute_one(&self, stmt: &str) -> SqlResult<()> {
+ self.execute_all(&[stmt])
+ }
+
+ /// Equivalent to `Connection::execute` but caches the statement so that subsequent
+ /// calls to `execute_cached` will have improved performance.
+ fn execute_cached<P: Params>(&self, sql: &str, params: P) -> SqlResult<usize> {
+ let mut stmt = self.conn().prepare_cached(sql)?;
+ stmt.execute(params)
+ }
+
+ /// Execute a query that returns a single result column, and return that result.
+ fn query_one<T: FromSql>(&self, sql: &str) -> SqlResult<T> {
+ let res: T = self.conn().query_row_and_then(sql, [], |row| row.get(0))?;
+ Ok(res)
+ }
+
+ /// Execute a query that returns 0 or 1 result columns, returning None
+ /// if there were no rows, or if the only result was NULL.
+ fn try_query_one<T: FromSql, P: Params>(
+ &self,
+ sql: &str,
+ params: P,
+ cache: bool,
+ ) -> SqlResult<Option<T>>
+ where
+ Self: Sized,
+ {
+ use rusqlite::OptionalExtension;
+ // The outer option is if we got rows, the inner option is
+ // if the first row was null.
+ let res: Option<Option<T>> = self
+ .conn()
+ .query_row_and_then_cachable(sql, params, |row| row.get(0), cache)
+ .optional()?;
+ // go from Option<Option<T>> to Option<T>
+ Ok(res.unwrap_or_default())
+ }
+
+ /// Equivalent to `rusqlite::Connection::query_row_and_then` but allows
+ /// passing a flag to indicate that it's cached.
+ fn query_row_and_then_cachable<T, E, P, F>(
+ &self,
+ sql: &str,
+ params: P,
+ mapper: F,
+ cache: bool,
+ ) -> Result<T, E>
+ where
+ Self: Sized,
+ P: Params,
+ E: From<rusqlite::Error>,
+ F: FnOnce(&Row<'_>) -> Result<T, E>,
+ {
+ Ok(self
+ .try_query_row(sql, params, mapper, cache)?
+ .ok_or(rusqlite::Error::QueryReturnedNoRows)?)
+ }
+
+ /// Helper for when you'd like to get a Vec<T> of all the rows returned by a
+ /// query that takes named arguments. See also
+ /// `query_rows_and_then_cached`.
+ fn query_rows_and_then<T, E, P, F>(&self, sql: &str, params: P, mapper: F) -> Result<Vec<T>, E>
+ where
+ Self: Sized,
+ P: Params,
+ E: From<rusqlite::Error>,
+ F: FnMut(&Row<'_>) -> Result<T, E>,
+ {
+ query_rows_and_then_cachable(self.conn(), sql, params, mapper, false)
+ }
+
+ /// Helper for when you'd like to get a Vec<T> of all the rows returned by a
+ /// query that takes named arguments.
+ fn query_rows_and_then_cached<T, E, P, F>(
+ &self,
+ sql: &str,
+ params: P,
+ mapper: F,
+ ) -> Result<Vec<T>, E>
+ where
+ Self: Sized,
+ P: Params,
+ E: From<rusqlite::Error>,
+ F: FnMut(&Row<'_>) -> Result<T, E>,
+ {
+ query_rows_and_then_cachable(self.conn(), sql, params, mapper, true)
+ }
+
+ /// Like `query_rows_and_then_cachable`, but works if you want a non-Vec as a result.
+ /// # Example:
+ /// ```rust,no_run
+ /// # use std::collections::HashSet;
+ /// # use sql_support::ConnExt;
+ /// # use rusqlite::Connection;
+ /// fn get_visit_tombstones(conn: &Connection, id: i64) -> rusqlite::Result<HashSet<i64>> {
+ /// Ok(conn.query_rows_into(
+ /// "SELECT visit_date FROM moz_historyvisit_tombstones
+ /// WHERE place_id = :place_id",
+ /// &[(":place_id", &id)],
+ /// |row| row.get::<_, i64>(0))?)
+ /// }
+ /// ```
+ /// Note if the type isn't inferred, you'll have to do something gross like
+ /// `conn.query_rows_into::<HashSet<_>, _, _, _>(...)`.
+ fn query_rows_into<Coll, T, E, P, F>(&self, sql: &str, params: P, mapper: F) -> Result<Coll, E>
+ where
+ Self: Sized,
+ E: From<rusqlite::Error>,
+ F: FnMut(&Row<'_>) -> Result<T, E>,
+ Coll: FromIterator<T>,
+ P: Params,
+ {
+ query_rows_and_then_cachable(self.conn(), sql, params, mapper, false)
+ }
+
+ /// Same as `query_rows_into`, but caches the stmt if possible.
+ fn query_rows_into_cached<Coll, T, E, P, F>(
+ &self,
+ sql: &str,
+ params: P,
+ mapper: F,
+ ) -> Result<Coll, E>
+ where
+ Self: Sized,
+ P: Params,
+ E: From<rusqlite::Error>,
+ F: FnMut(&Row<'_>) -> Result<T, E>,
+ Coll: FromIterator<T>,
+ {
+ query_rows_and_then_cachable(self.conn(), sql, params, mapper, true)
+ }
+
+ // This should probably have a longer name...
+ /// Like `query_row_and_then_cachable` but returns None instead of erroring
+ /// if no such row exists.
+ fn try_query_row<T, E, P, F>(
+ &self,
+ sql: &str,
+ params: P,
+ mapper: F,
+ cache: bool,
+ ) -> Result<Option<T>, E>
+ where
+ Self: Sized,
+ P: Params,
+ E: From<rusqlite::Error>,
+ F: FnOnce(&Row<'_>) -> Result<T, E>,
+ {
+ let conn = self.conn();
+ let mut stmt = MaybeCached::prepare(conn, sql, cache)?;
+ let mut rows = stmt.query(params)?;
+ rows.next()?.map(mapper).transpose()
+ }
+
+ /// Caveat: This won't actually get used most of the time, and calls will
+ /// usually invoke rusqlite's method with the same name. See comment on
+ /// `UncheckedTransaction` for details (generally you probably don't need to
+ /// care)
+ fn unchecked_transaction(&self) -> SqlResult<UncheckedTransaction<'_>> {
+ UncheckedTransaction::new(self.conn(), TransactionBehavior::Deferred)
+ }
+
+ /// Begin `unchecked_transaction` with `TransactionBehavior::Immediate`. Use
+ /// when the first operation will be a read operation, that further writes
+ /// depend on for correctness.
+ fn unchecked_transaction_imm(&self) -> SqlResult<UncheckedTransaction<'_>> {
+ UncheckedTransaction::new(self.conn(), TransactionBehavior::Immediate)
+ }
+
+ /// Get the DB size in bytes
+ fn get_db_size(&self) -> Result<u32, rusqlite::Error> {
+ let page_count: u32 = self.query_one("SELECT * from pragma_page_count()")?;
+ let page_size: u32 = self.query_one("SELECT * from pragma_page_size()")?;
+ let freelist_count: u32 = self.query_one("SELECT * from pragma_freelist_count()")?;
+
+ Ok((page_count - freelist_count) * page_size)
+ }
+}
+
+impl ConnExt for Connection {
+ #[inline]
+ fn conn(&self) -> &Connection {
+ self
+ }
+}
+
+impl<'conn> ConnExt for Transaction<'conn> {
+ #[inline]
+ fn conn(&self) -> &Connection {
+ self
+ }
+}
+
+impl<'conn> ConnExt for Savepoint<'conn> {
+ #[inline]
+ fn conn(&self) -> &Connection {
+ self
+ }
+}
+
+/// rusqlite, in an attempt to save us from ourselves, needs a mutable ref to a
+/// connection to start a transaction. That is a bit of a PITA in some cases, so
+/// we offer this as an alternative - but the responsibility of ensuring there
+/// are no concurrent transactions is on our head.
+///
+/// This is very similar to the rusqlite `Transaction` - it doesn't prevent
+/// against nested transactions but does allow you to use an immutable
+/// `Connection`.
+///
+/// FIXME: This currently won't actually be used most of the time, because
+/// `rusqlite` added [`Connection::unchecked_transaction`] (and
+/// `Transaction::new_unchecked`, which can be used to reimplement
+/// `unchecked_transaction_imm`), which will be preferred in a call to
+/// `c.unchecked_transaction()`, because inherent methods have precedence over
+/// methods on extension traits. The exception here is that this will still be
+/// used by code which takes `&impl ConnExt` (I believe it would also be used if
+/// you attempted to call `unchecked_transaction()` on a non-Connection that
+/// implements ConnExt, such as a `Safepoint`, `UncheckedTransaction`, or
+/// `Transaction` itself, but such code is clearly broken, so is not worth
+/// considering).
+///
+/// The difference is that `rusqlite`'s version returns a normal
+/// `rusqlite::Transaction`, rather than the `UncheckedTransaction` from this
+/// crate. Aside from type's name and location (and the fact that `rusqlite`'s
+/// detects slightly more misuse at compile time, and has more features), the
+/// main difference is: `rusqlite`'s does not track when a transaction began,
+/// which unfortunatly seems to be used by the coop-transaction management in
+/// places in some fashion.
+///
+/// There are at least two options for how to fix this:
+/// 1. Decide we don't need this version, and delete it, and moving the
+/// transaction timing into the coop-transaction code directly (or something
+/// like this).
+/// 2. Decide this difference *is* important, and rename
+/// `ConnExt::unchecked_transaction` to something like
+/// `ConnExt::transaction_unchecked`.
+pub struct UncheckedTransaction<'conn> {
+ pub conn: &'conn Connection,
+ pub started_at: Instant,
+ pub finished: bool,
+ // we could add drop_behavior etc too, but we don't need it yet - we
+ // always rollback.
+}
+
+impl<'conn> UncheckedTransaction<'conn> {
+ /// Begin a new unchecked transaction. Cannot be nested, but this is not
+ /// enforced by Rust (hence 'unchecked') - however, it is enforced by
+ /// SQLite; use a rusqlite `savepoint` for nested transactions.
+ pub fn new(conn: &'conn Connection, behavior: TransactionBehavior) -> SqlResult<Self> {
+ let query = match behavior {
+ TransactionBehavior::Deferred => "BEGIN DEFERRED",
+ TransactionBehavior::Immediate => "BEGIN IMMEDIATE",
+ TransactionBehavior::Exclusive => "BEGIN EXCLUSIVE",
+ _ => unreachable!(),
+ };
+ conn.execute_batch(query)
+ .map(move |_| UncheckedTransaction {
+ conn,
+ started_at: Instant::now(),
+ finished: false,
+ })
+ }
+
+ /// Consumes and commits an unchecked transaction.
+ pub fn commit(mut self) -> SqlResult<()> {
+ if self.finished {
+ log::warn!("ignoring request to commit an already finished transaction");
+ return Ok(());
+ }
+ self.finished = true;
+ self.conn.execute_batch("COMMIT")?;
+ log::debug!("Transaction commited after {:?}", self.started_at.elapsed());
+ Ok(())
+ }
+
+ /// Consumes and rolls back an unchecked transaction.
+ pub fn rollback(mut self) -> SqlResult<()> {
+ if self.finished {
+ log::warn!("ignoring request to rollback an already finished transaction");
+ return Ok(());
+ }
+ self.rollback_()
+ }
+
+ fn rollback_(&mut self) -> SqlResult<()> {
+ self.finished = true;
+ self.conn.execute_batch("ROLLBACK")?;
+ Ok(())
+ }
+
+ fn finish_(&mut self) -> SqlResult<()> {
+ if self.finished || self.conn.is_autocommit() {
+ return Ok(());
+ }
+ self.rollback_()?;
+ Ok(())
+ }
+}
+
+impl<'conn> Deref for UncheckedTransaction<'conn> {
+ type Target = Connection;
+
+ #[inline]
+ fn deref(&self) -> &Connection {
+ self.conn
+ }
+}
+
+impl<'conn> Drop for UncheckedTransaction<'conn> {
+ fn drop(&mut self) {
+ if let Err(e) = self.finish_() {
+ log::warn!("Error dropping an unchecked transaction: {}", e);
+ }
+ }
+}
+
+impl<'conn> ConnExt for UncheckedTransaction<'conn> {
+ #[inline]
+ fn conn(&self) -> &Connection {
+ self
+ }
+}
+
+fn query_rows_and_then_cachable<Coll, T, E, P, F>(
+ conn: &Connection,
+ sql: &str,
+ params: P,
+ mapper: F,
+ cache: bool,
+) -> Result<Coll, E>
+where
+ E: From<rusqlite::Error>,
+ F: FnMut(&Row<'_>) -> Result<T, E>,
+ Coll: FromIterator<T>,
+ P: Params,
+{
+ let mut stmt = conn.prepare_maybe_cached(sql, cache)?;
+ let iter = stmt.query_and_then(params, mapper)?;
+ iter.collect::<Result<Coll, E>>()
+}
diff --git a/third_party/rust/sql-support/src/each_chunk.rs b/third_party/rust/sql-support/src/each_chunk.rs
new file mode 100644
index 0000000000..2d738bcb37
--- /dev/null
+++ b/third_party/rust/sql-support/src/each_chunk.rs
@@ -0,0 +1,311 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use lazy_static::lazy_static;
+use rusqlite::{self, limits::Limit, types::ToSql};
+use std::iter::Map;
+use std::slice::Iter;
+
+/// Returns SQLITE_LIMIT_VARIABLE_NUMBER as read from an in-memory connection and cached.
+/// connection and cached. That means this will return the wrong value if it's set to a lower
+/// value for a connection using this will return the wrong thing, but doing so is rare enough
+/// that we explicitly don't support it (why would you want to lower this at runtime?).
+///
+/// If you call this and the actual value was set to a negative number or zero (nothing prevents
+/// this beyond a warning in the SQLite documentation), we panic. However, it's unlikely you can
+/// run useful queries if this happened anyway.
+pub fn default_max_variable_number() -> usize {
+ lazy_static! {
+ static ref MAX_VARIABLE_NUMBER: usize = {
+ let conn = rusqlite::Connection::open_in_memory()
+ .expect("Failed to initialize in-memory connection (out of memory?)");
+
+ let limit = conn.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER);
+ assert!(
+ limit > 0,
+ "Illegal value for SQLITE_LIMIT_VARIABLE_NUMBER (must be > 0) {}",
+ limit
+ );
+ limit as usize
+ };
+ }
+ *MAX_VARIABLE_NUMBER
+}
+
+/// Helper for the case where you have a `&[impl ToSql]` of arbitrary length, but need one
+/// of no more than the connection's `MAX_VARIABLE_NUMBER` (rather,
+/// `default_max_variable_number()`). This is useful when performing batched updates.
+///
+/// The `do_chunk` callback is called with a slice of no more than `default_max_variable_number()`
+/// items as it's first argument, and the offset from the start as it's second.
+///
+/// See `each_chunk_mapped` for the case where `T` doesn't implement `ToSql`, but can be
+/// converted to something that does.
+pub fn each_chunk<'a, T, E, F>(items: &'a [T], do_chunk: F) -> Result<(), E>
+where
+ T: 'a,
+ F: FnMut(&'a [T], usize) -> Result<(), E>,
+{
+ each_sized_chunk(items, default_max_variable_number(), do_chunk)
+}
+
+/// A version of `each_chunk` for the case when the conversion to `to_sql` requires an custom
+/// intermediate step. For example, you might want to grab a property off of an arrray of records
+pub fn each_chunk_mapped<'a, T, U, E, Mapper, DoChunk>(
+ items: &'a [T],
+ to_sql: Mapper,
+ do_chunk: DoChunk,
+) -> Result<(), E>
+where
+ T: 'a,
+ U: ToSql + 'a,
+ Mapper: Fn(&'a T) -> U,
+ DoChunk: FnMut(Map<Iter<'a, T>, &'_ Mapper>, usize) -> Result<(), E>,
+{
+ each_sized_chunk_mapped(items, default_max_variable_number(), to_sql, do_chunk)
+}
+
+// Split out for testing. Separate so that we can pass an actual slice
+// to the callback if they don't need mapping. We could probably unify
+// this with each_sized_chunk_mapped with a lot of type system trickery,
+// but one of the benefits to each_chunk over the mapped versions is
+// that the declaration is simpler.
+pub fn each_sized_chunk<'a, T, E, F>(
+ items: &'a [T],
+ chunk_size: usize,
+ mut do_chunk: F,
+) -> Result<(), E>
+where
+ T: 'a,
+ F: FnMut(&'a [T], usize) -> Result<(), E>,
+{
+ if items.is_empty() {
+ return Ok(());
+ }
+ let mut offset = 0;
+ for chunk in items.chunks(chunk_size) {
+ do_chunk(chunk, offset)?;
+ offset += chunk.len();
+ }
+ Ok(())
+}
+
+/// Utility to help perform batched updates, inserts, queries, etc. This is the low-level version
+/// of this utility which is wrapped by `each_chunk` and `each_chunk_mapped`, and it allows you to
+/// provide both the mapping function, and the chunk size.
+///
+/// Note: `mapped` basically just refers to the translating of `T` to some `U` where `U: ToSql`
+/// using the `to_sql` function. This is useful for e.g. inserting the IDs of a large list
+/// of records.
+pub fn each_sized_chunk_mapped<'a, T, U, E, Mapper, DoChunk>(
+ items: &'a [T],
+ chunk_size: usize,
+ to_sql: Mapper,
+ mut do_chunk: DoChunk,
+) -> Result<(), E>
+where
+ T: 'a,
+ U: ToSql + 'a,
+ Mapper: Fn(&'a T) -> U,
+ DoChunk: FnMut(Map<Iter<'a, T>, &'_ Mapper>, usize) -> Result<(), E>,
+{
+ if items.is_empty() {
+ return Ok(());
+ }
+ let mut offset = 0;
+ for chunk in items.chunks(chunk_size) {
+ let mapped = chunk.iter().map(&to_sql);
+ do_chunk(mapped, offset)?;
+ offset += chunk.len();
+ }
+ Ok(())
+}
+
+#[cfg(test)]
+fn check_chunk<T, C>(items: C, expect: &[T], desc: &str)
+where
+ C: IntoIterator,
+ <C as IntoIterator>::Item: ToSql,
+ T: ToSql,
+{
+ let items = items.into_iter().collect::<Vec<_>>();
+ assert_eq!(items.len(), expect.len());
+ // Can't quite make the borrowing work out here w/o a loop, oh well.
+ for (idx, (got, want)) in items.iter().zip(expect.iter()).enumerate() {
+ assert_eq!(
+ got.to_sql().unwrap(),
+ want.to_sql().unwrap(),
+ // ToSqlOutput::Owned(Value::Integer(*num)),
+ "{}: Bad value at index {}",
+ desc,
+ idx
+ );
+ }
+}
+
+#[cfg(test)]
+mod test_mapped {
+ use super::*;
+
+ #[test]
+ fn test_separate() {
+ let mut iteration = 0;
+ each_sized_chunk_mapped(
+ &[1, 2, 3, 4, 5],
+ 3,
+ |item| item as &dyn ToSql,
+ |chunk, offset| {
+ match offset {
+ 0 => {
+ assert_eq!(iteration, 0);
+ check_chunk(chunk, &[1, 2, 3], "first chunk");
+ }
+ 3 => {
+ assert_eq!(iteration, 1);
+ check_chunk(chunk, &[4, 5], "second chunk");
+ }
+ n => {
+ panic!("Unexpected offset {}", n);
+ }
+ }
+ iteration += 1;
+ Ok::<(), ()>(())
+ },
+ )
+ .unwrap();
+ }
+
+ #[test]
+ fn test_leq_chunk_size() {
+ for &check_size in &[5, 6] {
+ let mut iteration = 0;
+ each_sized_chunk_mapped(
+ &[1, 2, 3, 4, 5],
+ check_size,
+ |item| item as &dyn ToSql,
+ |chunk, offset| {
+ assert_eq!(iteration, 0);
+ iteration += 1;
+ assert_eq!(offset, 0);
+ check_chunk(chunk, &[1, 2, 3, 4, 5], "only iteration");
+ Ok::<(), ()>(())
+ },
+ )
+ .unwrap();
+ }
+ }
+
+ #[test]
+ fn test_empty_chunk() {
+ let items: &[i64] = &[];
+ each_sized_chunk_mapped::<_, _, (), _, _>(
+ items,
+ 100,
+ |item| item as &dyn ToSql,
+ |_, _| {
+ panic!("Should never be called");
+ },
+ )
+ .unwrap();
+ }
+
+ #[test]
+ fn test_error() {
+ let mut iteration = 0;
+ let e = each_sized_chunk_mapped(
+ &[1, 2, 3, 4, 5, 6, 7],
+ 3,
+ |item| item as &dyn ToSql,
+ |_, offset| {
+ if offset == 0 {
+ assert_eq!(iteration, 0);
+ iteration += 1;
+ Ok(())
+ } else if offset == 3 {
+ assert_eq!(iteration, 1);
+ iteration += 1;
+ Err("testing".to_string())
+ } else {
+ // Make sure we stopped after the error.
+ panic!("Shouldn't get called with offset of {}", offset);
+ }
+ },
+ )
+ .expect_err("Should be an error");
+ assert_eq!(e, "testing");
+ }
+}
+
+#[cfg(test)]
+mod test_unmapped {
+ use super::*;
+
+ #[test]
+ fn test_separate() {
+ let mut iteration = 0;
+ each_sized_chunk(&[1, 2, 3, 4, 5], 3, |chunk, offset| {
+ match offset {
+ 0 => {
+ assert_eq!(iteration, 0);
+ check_chunk(chunk, &[1, 2, 3], "first chunk");
+ }
+ 3 => {
+ assert_eq!(iteration, 1);
+ check_chunk(chunk, &[4, 5], "second chunk");
+ }
+ n => {
+ panic!("Unexpected offset {}", n);
+ }
+ }
+ iteration += 1;
+ Ok::<(), ()>(())
+ })
+ .unwrap();
+ }
+
+ #[test]
+ fn test_leq_chunk_size() {
+ for &check_size in &[5, 6] {
+ let mut iteration = 0;
+ each_sized_chunk(&[1, 2, 3, 4, 5], check_size, |chunk, offset| {
+ assert_eq!(iteration, 0);
+ iteration += 1;
+ assert_eq!(offset, 0);
+ check_chunk(chunk, &[1, 2, 3, 4, 5], "only iteration");
+ Ok::<(), ()>(())
+ })
+ .unwrap();
+ }
+ }
+
+ #[test]
+ fn test_empty_chunk() {
+ let items: &[i64] = &[];
+ each_sized_chunk::<_, (), _>(items, 100, |_, _| {
+ panic!("Should never be called");
+ })
+ .unwrap();
+ }
+
+ #[test]
+ fn test_error() {
+ let mut iteration = 0;
+ let e = each_sized_chunk(&[1, 2, 3, 4, 5, 6, 7], 3, |_, offset| {
+ if offset == 0 {
+ assert_eq!(iteration, 0);
+ iteration += 1;
+ Ok(())
+ } else if offset == 3 {
+ assert_eq!(iteration, 1);
+ iteration += 1;
+ Err("testing".to_string())
+ } else {
+ // Make sure we stopped after the error.
+ panic!("Shouldn't get called with offset of {}", offset);
+ }
+ })
+ .expect_err("Should be an error");
+ assert_eq!(e, "testing");
+ }
+}
diff --git a/third_party/rust/sql-support/src/lib.rs b/third_party/rust/sql-support/src/lib.rs
new file mode 100644
index 0000000000..17068dcfea
--- /dev/null
+++ b/third_party/rust/sql-support/src/lib.rs
@@ -0,0 +1,36 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#![allow(unknown_lints)]
+#![warn(rust_2018_idioms)]
+
+mod conn_ext;
+mod each_chunk;
+mod maybe_cached;
+pub mod open_database;
+mod repeat;
+
+pub use crate::conn_ext::*;
+pub use crate::each_chunk::*;
+pub use crate::maybe_cached::*;
+pub use crate::repeat::*;
+
+/// In PRAGMA foo='bar', `'bar'` must be a constant string (it cannot be a
+/// bound parameter), so we need to escape manually. According to
+/// https://www.sqlite.org/faq.html, the only character that must be escaped is
+/// the single quote, which is escaped by placing two single quotes in a row.
+pub fn escape_string_for_pragma(s: &str) -> String {
+ s.replace('\'', "''")
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ #[test]
+ fn test_escape_string_for_pragma() {
+ assert_eq!(escape_string_for_pragma("foobar"), "foobar");
+ assert_eq!(escape_string_for_pragma("'foo'bar'"), "''foo''bar''");
+ assert_eq!(escape_string_for_pragma("''"), "''''");
+ }
+}
diff --git a/third_party/rust/sql-support/src/maybe_cached.rs b/third_party/rust/sql-support/src/maybe_cached.rs
new file mode 100644
index 0000000000..96f99f490c
--- /dev/null
+++ b/third_party/rust/sql-support/src/maybe_cached.rs
@@ -0,0 +1,64 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use rusqlite::{self, CachedStatement, Connection, Statement};
+
+use std::ops::{Deref, DerefMut};
+
+/// MaybeCached is a type that can be used to help abstract
+/// over cached and uncached rusqlite statements in a transparent manner.
+pub enum MaybeCached<'conn> {
+ Uncached(Statement<'conn>),
+ Cached(CachedStatement<'conn>),
+}
+
+impl<'conn> Deref for MaybeCached<'conn> {
+ type Target = Statement<'conn>;
+ #[inline]
+ fn deref(&self) -> &Statement<'conn> {
+ match self {
+ MaybeCached::Cached(cached) => Deref::deref(cached),
+ MaybeCached::Uncached(uncached) => uncached,
+ }
+ }
+}
+
+impl<'conn> DerefMut for MaybeCached<'conn> {
+ #[inline]
+ fn deref_mut(&mut self) -> &mut Statement<'conn> {
+ match self {
+ MaybeCached::Cached(cached) => DerefMut::deref_mut(cached),
+ MaybeCached::Uncached(uncached) => uncached,
+ }
+ }
+}
+
+impl<'conn> From<Statement<'conn>> for MaybeCached<'conn> {
+ #[inline]
+ fn from(stmt: Statement<'conn>) -> Self {
+ MaybeCached::Uncached(stmt)
+ }
+}
+
+impl<'conn> From<CachedStatement<'conn>> for MaybeCached<'conn> {
+ #[inline]
+ fn from(stmt: CachedStatement<'conn>) -> Self {
+ MaybeCached::Cached(stmt)
+ }
+}
+
+impl<'conn> MaybeCached<'conn> {
+ #[inline]
+ pub fn prepare(
+ conn: &'conn Connection,
+ sql: &str,
+ cached: bool,
+ ) -> rusqlite::Result<MaybeCached<'conn>> {
+ if cached {
+ Ok(MaybeCached::Cached(conn.prepare_cached(sql)?))
+ } else {
+ Ok(MaybeCached::Uncached(conn.prepare(sql)?))
+ }
+ }
+}
diff --git a/third_party/rust/sql-support/src/open_database.rs b/third_party/rust/sql-support/src/open_database.rs
new file mode 100644
index 0000000000..43c0d3f30d
--- /dev/null
+++ b/third_party/rust/sql-support/src/open_database.rs
@@ -0,0 +1,536 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+/// Use this module to open a new SQLite database connection.
+///
+/// Usage:
+/// - Define a struct that implements ConnectionInitializer. This handles:
+/// - Initializing the schema for a new database
+/// - Upgrading the schema for an existing database
+/// - Extra preparation/finishing steps, for example setting up SQLite functions
+///
+/// - Call open_database() in your database constructor:
+/// - The first method called is `prepare()`. This is executed outside of a transaction
+/// and is suitable for executing pragmas (eg, `PRAGMA journal_mode=wal`), defining
+/// functions, etc.
+/// - If the database file is not present and the connection is writable, open_database()
+/// will create a new DB and call init(), then finish(). If the connection is not
+/// writable it will panic, meaning that if you support ReadOnly connections, they must
+/// be created after a writable connection is open.
+/// - If the database file exists and the connection is writable, open_database() will open
+/// it and call prepare(), upgrade_from() for each upgrade that needs to be applied, then
+/// finish(). As above, a read-only connection will panic if upgrades are necessary, so
+/// you should ensure the first connection opened is writable.
+/// - If the connection is not writable, `finish()` will be called (ie, `finish()`, like
+/// `prepare()`, is called for all connections)
+///
+/// See the autofill DB code for an example.
+///
+use crate::ConnExt;
+use rusqlite::{
+ Connection, Error as RusqliteError, ErrorCode, OpenFlags, Transaction, TransactionBehavior,
+};
+use std::path::Path;
+use thiserror::Error;
+
+#[derive(Error, Debug)]
+pub enum Error {
+ #[error("Incompatible database version: {0}")]
+ IncompatibleVersion(u32),
+ #[error("Error executing SQL: {0}")]
+ SqlError(#[from] rusqlite::Error),
+ // `.0` is the original `Error` in string form.
+ #[error("Failed to recover a corrupt database ('{0}') due to an error deleting the file: {1}")]
+ RecoveryError(String, std::io::Error),
+}
+
+pub type Result<T> = std::result::Result<T, Error>;
+
+pub trait ConnectionInitializer {
+ // Name to display in the logs
+ const NAME: &'static str;
+
+ // The version that the last upgrade function upgrades to.
+ const END_VERSION: u32;
+
+ // Functions called only for writable connections all take a Transaction
+ // Initialize a newly created database to END_VERSION
+ fn init(&self, tx: &Transaction<'_>) -> Result<()>;
+
+ // Upgrade schema from version -> version + 1
+ fn upgrade_from(&self, conn: &Transaction<'_>, version: u32) -> Result<()>;
+
+ // Runs immediately after creation for all types of connections. If writable,
+ // will *not* be in the transaction created for the "only writable" functions above.
+ fn prepare(&self, _conn: &Connection, _db_empty: bool) -> Result<()> {
+ Ok(())
+ }
+
+ // Runs for all types of connections. If a writable connection is being
+ // initialized, this will be called after all initialization functions,
+ // but inside their transaction.
+ fn finish(&self, _conn: &Connection) -> Result<()> {
+ Ok(())
+ }
+}
+
+pub fn open_database<CI: ConnectionInitializer, P: AsRef<Path>>(
+ path: P,
+ connection_initializer: &CI,
+) -> Result<Connection> {
+ open_database_with_flags(path, OpenFlags::default(), connection_initializer)
+}
+
+pub fn open_memory_database<CI: ConnectionInitializer>(
+ conn_initializer: &CI,
+) -> Result<Connection> {
+ open_memory_database_with_flags(OpenFlags::default(), conn_initializer)
+}
+
+pub fn open_database_with_flags<CI: ConnectionInitializer, P: AsRef<Path>>(
+ path: P,
+ open_flags: OpenFlags,
+ connection_initializer: &CI,
+) -> Result<Connection> {
+ do_open_database_with_flags(&path, open_flags, connection_initializer).or_else(|e| {
+ // See if we can recover from the error and try a second time
+ try_handle_db_failure(&path, open_flags, connection_initializer, e)?;
+ do_open_database_with_flags(&path, open_flags, connection_initializer)
+ })
+}
+
+fn do_open_database_with_flags<CI: ConnectionInitializer, P: AsRef<Path>>(
+ path: P,
+ open_flags: OpenFlags,
+ connection_initializer: &CI,
+) -> Result<Connection> {
+ // Try running the migration logic with an existing file
+ log::debug!("{}: opening database", CI::NAME);
+ let mut conn = Connection::open_with_flags(path, open_flags)?;
+ log::debug!("{}: checking if initialization is necessary", CI::NAME);
+ let db_empty = is_db_empty(&conn)?;
+
+ log::debug!("{}: preparing", CI::NAME);
+ connection_initializer.prepare(&conn, db_empty)?;
+
+ if open_flags.contains(OpenFlags::SQLITE_OPEN_READ_WRITE) {
+ let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate)?;
+ if db_empty {
+ log::debug!("{}: initializing new database", CI::NAME);
+ connection_initializer.init(&tx)?;
+ } else {
+ let mut current_version = get_schema_version(&tx)?;
+ if current_version > CI::END_VERSION {
+ return Err(Error::IncompatibleVersion(current_version));
+ }
+ while current_version < CI::END_VERSION {
+ log::debug!(
+ "{}: upgrading database to {}",
+ CI::NAME,
+ current_version + 1
+ );
+ connection_initializer.upgrade_from(&tx, current_version)?;
+ current_version += 1;
+ }
+ }
+ log::debug!("{}: finishing writable database open", CI::NAME);
+ connection_initializer.finish(&tx)?;
+ set_schema_version(&tx, CI::END_VERSION)?;
+ tx.commit()?;
+ } else {
+ // There's an implied requirement that the first connection to a DB is
+ // writable, so read-only connections do much less, but panic if stuff is wrong
+ assert!(!db_empty, "existing writer must have initialized");
+ assert!(
+ get_schema_version(&conn)? == CI::END_VERSION,
+ "existing writer must have migrated"
+ );
+ log::debug!("{}: finishing readonly database open", CI::NAME);
+ connection_initializer.finish(&conn)?;
+ }
+ log::debug!("{}: database open successful", CI::NAME);
+ Ok(conn)
+}
+
+pub fn open_memory_database_with_flags<CI: ConnectionInitializer>(
+ flags: OpenFlags,
+ conn_initializer: &CI,
+) -> Result<Connection> {
+ open_database_with_flags(":memory:", flags, conn_initializer)
+}
+
+// Attempt to handle failure when opening the database.
+//
+// Returns:
+// - Ok(()) the failure is potentially handled and we should make a second open attempt
+// - Err(e) the failure couldn't be handled and we should return this error
+fn try_handle_db_failure<CI: ConnectionInitializer, P: AsRef<Path>>(
+ path: P,
+ open_flags: OpenFlags,
+ _connection_initializer: &CI,
+ err: Error,
+) -> Result<()> {
+ if !open_flags.contains(OpenFlags::SQLITE_OPEN_CREATE)
+ && matches!(err, Error::SqlError(rusqlite::Error::SqliteFailure(code, _)) if code.code == rusqlite::ErrorCode::CannotOpen)
+ {
+ log::info!(
+ "{}: database doesn't exist, but we weren't requested to create it",
+ CI::NAME
+ );
+ return Err(err);
+ }
+ log::warn!("{}: database operation failed: {}", CI::NAME, err);
+ if !open_flags.contains(OpenFlags::SQLITE_OPEN_READ_WRITE) {
+ log::warn!(
+ "{}: not attempting recovery as this is a read-only connection request",
+ CI::NAME
+ );
+ return Err(err);
+ }
+
+ let delete = match err {
+ Error::SqlError(RusqliteError::SqliteFailure(e, _)) => {
+ matches!(e.code, ErrorCode::DatabaseCorrupt | ErrorCode::NotADatabase)
+ }
+ _ => false,
+ };
+ if delete {
+ log::info!(
+ "{}: the database is fatally damaged; deleting and starting fresh",
+ CI::NAME
+ );
+ // Note we explicitly decline to move the path to, say ".corrupt", as it's difficult to
+ // identify any value there - actually getting our hands on the file from a mobile device
+ // is tricky and it would just take up disk space forever.
+ if let Err(io_err) = std::fs::remove_file(path) {
+ return Err(Error::RecoveryError(err.to_string(), io_err));
+ }
+ Ok(())
+ } else {
+ Err(err)
+ }
+}
+
+fn is_db_empty(conn: &Connection) -> Result<bool> {
+ Ok(conn.query_one::<u32>("SELECT COUNT(*) FROM sqlite_master")? == 0)
+}
+
+fn get_schema_version(conn: &Connection) -> Result<u32> {
+ let version = conn.query_row_and_then("PRAGMA user_version", [], |row| row.get(0))?;
+ Ok(version)
+}
+
+fn set_schema_version(conn: &Connection, version: u32) -> Result<()> {
+ conn.set_pragma("user_version", version)?;
+ Ok(())
+}
+
+// It would be nice for this to be #[cfg(test)], but that doesn't allow it to be used in tests for
+// our other crates.
+pub mod test_utils {
+ use super::*;
+ use std::path::PathBuf;
+ use tempfile::TempDir;
+
+ // Database file that we can programatically run upgrades on
+ //
+ // We purposefully don't keep a connection to the database around to force upgrades to always
+ // run against a newly opened DB, like they would in the real world. See #4106 for
+ // details.
+ pub struct MigratedDatabaseFile<CI: ConnectionInitializer> {
+ // Keep around a TempDir to ensure the database file stays around until this struct is
+ // dropped
+ _tempdir: TempDir,
+ pub connection_initializer: CI,
+ pub path: PathBuf,
+ }
+
+ impl<CI: ConnectionInitializer> MigratedDatabaseFile<CI> {
+ pub fn new(connection_initializer: CI, init_sql: &str) -> Self {
+ Self::new_with_flags(connection_initializer, init_sql, OpenFlags::default())
+ }
+
+ pub fn new_with_flags(
+ connection_initializer: CI,
+ init_sql: &str,
+ open_flags: OpenFlags,
+ ) -> Self {
+ let tempdir = tempfile::tempdir().unwrap();
+ let path = tempdir.path().join(Path::new("db.sql"));
+ let conn = Connection::open_with_flags(&path, open_flags).unwrap();
+ conn.execute_batch(init_sql).unwrap();
+ Self {
+ _tempdir: tempdir,
+ connection_initializer,
+ path,
+ }
+ }
+
+ pub fn upgrade_to(&self, version: u32) {
+ let mut conn = self.open();
+ let tx = conn.transaction().unwrap();
+ let mut current_version = get_schema_version(&tx).unwrap();
+ while current_version < version {
+ self.connection_initializer
+ .upgrade_from(&tx, current_version)
+ .unwrap();
+ current_version += 1;
+ }
+ set_schema_version(&tx, current_version).unwrap();
+ self.connection_initializer.finish(&tx).unwrap();
+ tx.commit().unwrap();
+ }
+
+ pub fn run_all_upgrades(&self) {
+ let current_version = get_schema_version(&self.open()).unwrap();
+ for version in current_version..CI::END_VERSION {
+ self.upgrade_to(version + 1);
+ }
+ }
+
+ pub fn open(&self) -> Connection {
+ Connection::open(&self.path).unwrap()
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::test_utils::MigratedDatabaseFile;
+ use super::*;
+ use std::cell::RefCell;
+ use std::io::Write;
+
+ struct TestConnectionInitializer {
+ pub calls: RefCell<Vec<&'static str>>,
+ pub buggy_v3_upgrade: bool,
+ }
+
+ impl TestConnectionInitializer {
+ pub fn new() -> Self {
+ let _ = env_logger::try_init();
+ Self {
+ calls: RefCell::new(Vec::new()),
+ buggy_v3_upgrade: false,
+ }
+ }
+ pub fn new_with_buggy_logic() -> Self {
+ let _ = env_logger::try_init();
+ Self {
+ calls: RefCell::new(Vec::new()),
+ buggy_v3_upgrade: true,
+ }
+ }
+
+ pub fn clear_calls(&self) {
+ self.calls.borrow_mut().clear();
+ }
+
+ pub fn push_call(&self, call: &'static str) {
+ self.calls.borrow_mut().push(call);
+ }
+
+ pub fn check_calls(&self, expected: Vec<&'static str>) {
+ assert_eq!(*self.calls.borrow(), expected);
+ }
+ }
+
+ impl ConnectionInitializer for TestConnectionInitializer {
+ const NAME: &'static str = "test db";
+ const END_VERSION: u32 = 4;
+
+ fn prepare(&self, conn: &Connection, _: bool) -> Result<()> {
+ self.push_call("prep");
+ conn.execute_batch(
+ "
+ PRAGMA journal_mode = wal;
+ ",
+ )?;
+ Ok(())
+ }
+
+ fn init(&self, conn: &Transaction<'_>) -> Result<()> {
+ self.push_call("init");
+ conn.execute_batch(
+ "
+ CREATE TABLE prep_table(col);
+ INSERT INTO prep_table(col) VALUES ('correct-value');
+ CREATE TABLE my_table(col);
+ ",
+ )
+ .map_err(|e| e.into())
+ }
+
+ fn upgrade_from(&self, conn: &Transaction<'_>, version: u32) -> Result<()> {
+ match version {
+ 2 => {
+ self.push_call("upgrade_from_v2");
+ conn.execute_batch(
+ "
+ ALTER TABLE my_old_table_name RENAME TO my_table;
+ ",
+ )?;
+ Ok(())
+ }
+ 3 => {
+ self.push_call("upgrade_from_v3");
+
+ if self.buggy_v3_upgrade {
+ conn.execute_batch("ILLEGAL_SQL_CODE")?;
+ }
+
+ conn.execute_batch(
+ "
+ ALTER TABLE my_table RENAME COLUMN old_col to col;
+ ",
+ )?;
+ Ok(())
+ }
+ _ => {
+ panic!("Unexpected version: {}", version);
+ }
+ }
+ }
+
+ fn finish(&self, conn: &Connection) -> Result<()> {
+ self.push_call("finish");
+ conn.execute_batch(
+ "
+ INSERT INTO my_table(col) SELECT col FROM prep_table;
+ ",
+ )?;
+ Ok(())
+ }
+ }
+
+ // Initialize the database to v2 to test upgrading from there
+ static INIT_V2: &str = "
+ CREATE TABLE prep_table(col);
+ INSERT INTO prep_table(col) VALUES ('correct-value');
+ CREATE TABLE my_old_table_name(old_col);
+ PRAGMA user_version=2;
+ ";
+
+ fn check_final_data(conn: &Connection) {
+ let value: String = conn
+ .query_row("SELECT col FROM my_table", [], |r| r.get(0))
+ .unwrap();
+ assert_eq!(value, "correct-value");
+ assert_eq!(get_schema_version(conn).unwrap(), 4);
+ }
+
+ #[test]
+ fn test_init() {
+ let connection_initializer = TestConnectionInitializer::new();
+ let conn = open_memory_database(&connection_initializer).unwrap();
+ check_final_data(&conn);
+ connection_initializer.check_calls(vec!["prep", "init", "finish"]);
+ }
+
+ #[test]
+ fn test_upgrades() {
+ let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
+ let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
+ check_final_data(&conn);
+ db_file.connection_initializer.check_calls(vec![
+ "prep",
+ "upgrade_from_v2",
+ "upgrade_from_v3",
+ "finish",
+ ]);
+ }
+
+ #[test]
+ fn test_open_current_version() {
+ let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
+ db_file.upgrade_to(4);
+ db_file.connection_initializer.clear_calls();
+ let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
+ check_final_data(&conn);
+ db_file
+ .connection_initializer
+ .check_calls(vec!["prep", "finish"]);
+ }
+
+ #[test]
+ fn test_pragmas() {
+ let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
+ let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
+ assert_eq!(
+ conn.query_one::<String>("PRAGMA journal_mode").unwrap(),
+ "wal"
+ );
+ }
+
+ #[test]
+ fn test_migration_error() {
+ let db_file =
+ MigratedDatabaseFile::new(TestConnectionInitializer::new_with_buggy_logic(), INIT_V2);
+ db_file
+ .open()
+ .execute(
+ "INSERT INTO my_old_table_name(old_col) VALUES ('I should not be deleted')",
+ [],
+ )
+ .unwrap();
+
+ open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap_err();
+ // Even though the upgrades failed, the data should still be there. The changes that
+ // upgrade_to_v3 made should have been rolled back.
+ assert_eq!(
+ db_file
+ .open()
+ .query_one::<i32>("SELECT COUNT(*) FROM my_old_table_name")
+ .unwrap(),
+ 1
+ );
+ }
+
+ #[test]
+ fn test_version_too_new() {
+ let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
+ set_schema_version(&db_file.open(), 5).unwrap();
+
+ db_file
+ .open()
+ .execute(
+ "INSERT INTO my_old_table_name(old_col) VALUES ('I should not be deleted')",
+ [],
+ )
+ .unwrap();
+
+ assert!(matches!(
+ open_database(db_file.path.clone(), &db_file.connection_initializer,),
+ Err(Error::IncompatibleVersion(5))
+ ));
+ // Make sure that even when DeleteAndRecreate is specified, we don't delete the database
+ // file when the schema is newer
+ assert_eq!(
+ db_file
+ .open()
+ .query_one::<i32>("SELECT COUNT(*) FROM my_old_table_name")
+ .unwrap(),
+ 1
+ );
+ }
+
+ #[test]
+ fn test_corrupt_db() {
+ let tempdir = tempfile::tempdir().unwrap();
+ let path = tempdir.path().join(Path::new("corrupt-db.sql"));
+ let mut file = std::fs::File::create(path.clone()).unwrap();
+ // interestingly, sqlite seems to treat a 0-byte file as a missing one.
+ // Note that this will exercise the `ErrorCode::NotADatabase` error code. It's not clear
+ // how we could hit `ErrorCode::DatabaseCorrupt`, but even if we could, there's not much
+ // value as this test can't really observe which one it was.
+ file.write_all(b"not sql").unwrap();
+ let metadata = std::fs::metadata(path.clone()).unwrap();
+ assert_eq!(metadata.len(), 7);
+ drop(file);
+ open_database(path.clone(), &TestConnectionInitializer::new()).unwrap();
+ let metadata = std::fs::metadata(path).unwrap();
+ // just check the file is no longer what it was before.
+ assert_ne!(metadata.len(), 7);
+ }
+}
diff --git a/third_party/rust/sql-support/src/repeat.rs b/third_party/rust/sql-support/src/repeat.rs
new file mode 100644
index 0000000000..40b582ec14
--- /dev/null
+++ b/third_party/rust/sql-support/src/repeat.rs
@@ -0,0 +1,113 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+use std::fmt;
+
+/// Helper type for printing repeated strings more efficiently. You should use
+/// [`repeat_display`](sql_support::repeat_display), or one of the `repeat_sql_*` helpers to
+/// construct it.
+#[derive(Debug, Clone)]
+pub struct RepeatDisplay<'a, F> {
+ count: usize,
+ sep: &'a str,
+ fmt_one: F,
+}
+
+impl<'a, F> fmt::Display for RepeatDisplay<'a, F>
+where
+ F: Fn(usize, &mut fmt::Formatter<'_>) -> fmt::Result,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ for i in 0..self.count {
+ if i != 0 {
+ f.write_str(self.sep)?;
+ }
+ (self.fmt_one)(i, f)?;
+ }
+ Ok(())
+ }
+}
+
+/// Construct a RepeatDisplay that will repeatedly call `fmt_one` with a formatter `count` times,
+/// separated by `sep`.
+///
+/// # Example
+///
+/// ```rust
+/// # use sql_support::repeat_display;
+/// assert_eq!(format!("{}", repeat_display(1, ",", |i, f| write!(f, "({},?)", i))),
+/// "(0,?)");
+/// assert_eq!(format!("{}", repeat_display(2, ",", |i, f| write!(f, "({},?)", i))),
+/// "(0,?),(1,?)");
+/// assert_eq!(format!("{}", repeat_display(3, ",", |i, f| write!(f, "({},?)", i))),
+/// "(0,?),(1,?),(2,?)");
+/// ```
+#[inline]
+pub fn repeat_display<F>(count: usize, sep: &str, fmt_one: F) -> RepeatDisplay<'_, F>
+where
+ F: Fn(usize, &mut fmt::Formatter<'_>) -> fmt::Result,
+{
+ RepeatDisplay {
+ count,
+ sep,
+ fmt_one,
+ }
+}
+
+/// Returns a value that formats as `count` instances of `?` separated by commas.
+///
+/// # Example
+///
+/// ```rust
+/// # use sql_support::repeat_sql_vars;
+/// assert_eq!(format!("{}", repeat_sql_vars(0)), "");
+/// assert_eq!(format!("{}", repeat_sql_vars(1)), "?");
+/// assert_eq!(format!("{}", repeat_sql_vars(2)), "?,?");
+/// assert_eq!(format!("{}", repeat_sql_vars(3)), "?,?,?");
+/// ```
+pub fn repeat_sql_vars(count: usize) -> impl fmt::Display {
+ repeat_display(count, ",", |_, f| write!(f, "?"))
+}
+
+/// Returns a value that formats as `count` instances of `(?)` separated by commas.
+///
+/// # Example
+///
+/// ```rust
+/// # use sql_support::repeat_sql_values;
+/// assert_eq!(format!("{}", repeat_sql_values(0)), "");
+/// assert_eq!(format!("{}", repeat_sql_values(1)), "(?)");
+/// assert_eq!(format!("{}", repeat_sql_values(2)), "(?),(?)");
+/// assert_eq!(format!("{}", repeat_sql_values(3)), "(?),(?),(?)");
+/// ```
+///
+pub fn repeat_sql_values(count: usize) -> impl fmt::Display {
+ // We could also implement this as `repeat_sql_multi_values(count, 1)`,
+ // but this is faster and no less clear IMO.
+ repeat_display(count, ",", |_, f| write!(f, "(?)"))
+}
+
+/// Returns a value that formats as `num_values` instances of `(?,?,?,...)` (where there are
+/// `vars_per_value` question marks separated by commas in between the `?`s).
+///
+/// Panics if `vars_per_value` is zero (however, `num_values` is allowed to be zero).
+///
+/// # Example
+///
+/// ```rust
+/// # use sql_support::repeat_multi_values;
+/// assert_eq!(format!("{}", repeat_multi_values(0, 2)), "");
+/// assert_eq!(format!("{}", repeat_multi_values(1, 5)), "(?,?,?,?,?)");
+/// assert_eq!(format!("{}", repeat_multi_values(2, 3)), "(?,?,?),(?,?,?)");
+/// assert_eq!(format!("{}", repeat_multi_values(3, 1)), "(?),(?),(?)");
+/// ```
+pub fn repeat_multi_values(num_values: usize, vars_per_value: usize) -> impl fmt::Display {
+ assert_ne!(
+ vars_per_value, 0,
+ "Illegal value for `vars_per_value`, must not be zero"
+ );
+ repeat_display(num_values, ",", move |_, f| {
+ write!(f, "({})", repeat_sql_vars(vars_per_value))
+ })
+}