diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/sql-support/src | |
parent | Initial commit. (diff) | |
download | firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/sql-support/src')
-rw-r--r-- | third_party/rust/sql-support/src/conn_ext.rs | 402 | ||||
-rw-r--r-- | third_party/rust/sql-support/src/each_chunk.rs | 311 | ||||
-rw-r--r-- | third_party/rust/sql-support/src/lib.rs | 36 | ||||
-rw-r--r-- | third_party/rust/sql-support/src/maybe_cached.rs | 64 | ||||
-rw-r--r-- | third_party/rust/sql-support/src/open_database.rs | 536 | ||||
-rw-r--r-- | third_party/rust/sql-support/src/repeat.rs | 113 |
6 files changed, 1462 insertions, 0 deletions
diff --git a/third_party/rust/sql-support/src/conn_ext.rs b/third_party/rust/sql-support/src/conn_ext.rs new file mode 100644 index 0000000000..2efec41b7c --- /dev/null +++ b/third_party/rust/sql-support/src/conn_ext.rs @@ -0,0 +1,402 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use rusqlite::{ + self, + types::{FromSql, ToSql}, + Connection, Params, Result as SqlResult, Row, Savepoint, Transaction, TransactionBehavior, +}; +use std::iter::FromIterator; +use std::ops::Deref; +use std::time::Instant; + +use crate::maybe_cached::MaybeCached; + +pub struct Conn(rusqlite::Connection); + +/// This trait exists so that we can use these helpers on `rusqlite::{Transaction, Connection}`. +/// Note that you must import ConnExt in order to call these methods on anything. +pub trait ConnExt { + /// The method you need to implement to opt in to all of this. + fn conn(&self) -> &Connection; + + /// Set the value of the pragma on the main database. Returns the same object, for chaining. + fn set_pragma<T>(&self, pragma_name: &str, pragma_value: T) -> SqlResult<&Self> + where + T: ToSql, + Self: Sized, + { + // None == Schema name, e.g. `PRAGMA some_attached_db.something = blah` + self.conn() + .pragma_update(None, pragma_name, &pragma_value)?; + Ok(self) + } + + /// Get a cached or uncached statement based on a flag. + fn prepare_maybe_cached<'conn>( + &'conn self, + sql: &str, + cache: bool, + ) -> SqlResult<MaybeCached<'conn>> { + MaybeCached::prepare(self.conn(), sql, cache) + } + + /// Execute all the provided statements. + fn execute_all(&self, stmts: &[&str]) -> SqlResult<()> { + let conn = self.conn(); + for sql in stmts { + let r = conn.execute(sql, []); + match r { + Ok(_) => {} + // Ignore ExecuteReturnedResults error because they're pointless + // and annoying. + Err(rusqlite::Error::ExecuteReturnedResults) => {} + Err(e) => return Err(e), + } + } + Ok(()) + } + + /// Execute a single statement. + fn execute_one(&self, stmt: &str) -> SqlResult<()> { + self.execute_all(&[stmt]) + } + + /// Equivalent to `Connection::execute` but caches the statement so that subsequent + /// calls to `execute_cached` will have improved performance. + fn execute_cached<P: Params>(&self, sql: &str, params: P) -> SqlResult<usize> { + let mut stmt = self.conn().prepare_cached(sql)?; + stmt.execute(params) + } + + /// Execute a query that returns a single result column, and return that result. + fn query_one<T: FromSql>(&self, sql: &str) -> SqlResult<T> { + let res: T = self.conn().query_row_and_then(sql, [], |row| row.get(0))?; + Ok(res) + } + + /// Execute a query that returns 0 or 1 result columns, returning None + /// if there were no rows, or if the only result was NULL. + fn try_query_one<T: FromSql, P: Params>( + &self, + sql: &str, + params: P, + cache: bool, + ) -> SqlResult<Option<T>> + where + Self: Sized, + { + use rusqlite::OptionalExtension; + // The outer option is if we got rows, the inner option is + // if the first row was null. + let res: Option<Option<T>> = self + .conn() + .query_row_and_then_cachable(sql, params, |row| row.get(0), cache) + .optional()?; + // go from Option<Option<T>> to Option<T> + Ok(res.unwrap_or_default()) + } + + /// Equivalent to `rusqlite::Connection::query_row_and_then` but allows + /// passing a flag to indicate that it's cached. + fn query_row_and_then_cachable<T, E, P, F>( + &self, + sql: &str, + params: P, + mapper: F, + cache: bool, + ) -> Result<T, E> + where + Self: Sized, + P: Params, + E: From<rusqlite::Error>, + F: FnOnce(&Row<'_>) -> Result<T, E>, + { + Ok(self + .try_query_row(sql, params, mapper, cache)? + .ok_or(rusqlite::Error::QueryReturnedNoRows)?) + } + + /// Helper for when you'd like to get a Vec<T> of all the rows returned by a + /// query that takes named arguments. See also + /// `query_rows_and_then_cached`. + fn query_rows_and_then<T, E, P, F>(&self, sql: &str, params: P, mapper: F) -> Result<Vec<T>, E> + where + Self: Sized, + P: Params, + E: From<rusqlite::Error>, + F: FnMut(&Row<'_>) -> Result<T, E>, + { + query_rows_and_then_cachable(self.conn(), sql, params, mapper, false) + } + + /// Helper for when you'd like to get a Vec<T> of all the rows returned by a + /// query that takes named arguments. + fn query_rows_and_then_cached<T, E, P, F>( + &self, + sql: &str, + params: P, + mapper: F, + ) -> Result<Vec<T>, E> + where + Self: Sized, + P: Params, + E: From<rusqlite::Error>, + F: FnMut(&Row<'_>) -> Result<T, E>, + { + query_rows_and_then_cachable(self.conn(), sql, params, mapper, true) + } + + /// Like `query_rows_and_then_cachable`, but works if you want a non-Vec as a result. + /// # Example: + /// ```rust,no_run + /// # use std::collections::HashSet; + /// # use sql_support::ConnExt; + /// # use rusqlite::Connection; + /// fn get_visit_tombstones(conn: &Connection, id: i64) -> rusqlite::Result<HashSet<i64>> { + /// Ok(conn.query_rows_into( + /// "SELECT visit_date FROM moz_historyvisit_tombstones + /// WHERE place_id = :place_id", + /// &[(":place_id", &id)], + /// |row| row.get::<_, i64>(0))?) + /// } + /// ``` + /// Note if the type isn't inferred, you'll have to do something gross like + /// `conn.query_rows_into::<HashSet<_>, _, _, _>(...)`. + fn query_rows_into<Coll, T, E, P, F>(&self, sql: &str, params: P, mapper: F) -> Result<Coll, E> + where + Self: Sized, + E: From<rusqlite::Error>, + F: FnMut(&Row<'_>) -> Result<T, E>, + Coll: FromIterator<T>, + P: Params, + { + query_rows_and_then_cachable(self.conn(), sql, params, mapper, false) + } + + /// Same as `query_rows_into`, but caches the stmt if possible. + fn query_rows_into_cached<Coll, T, E, P, F>( + &self, + sql: &str, + params: P, + mapper: F, + ) -> Result<Coll, E> + where + Self: Sized, + P: Params, + E: From<rusqlite::Error>, + F: FnMut(&Row<'_>) -> Result<T, E>, + Coll: FromIterator<T>, + { + query_rows_and_then_cachable(self.conn(), sql, params, mapper, true) + } + + // This should probably have a longer name... + /// Like `query_row_and_then_cachable` but returns None instead of erroring + /// if no such row exists. + fn try_query_row<T, E, P, F>( + &self, + sql: &str, + params: P, + mapper: F, + cache: bool, + ) -> Result<Option<T>, E> + where + Self: Sized, + P: Params, + E: From<rusqlite::Error>, + F: FnOnce(&Row<'_>) -> Result<T, E>, + { + let conn = self.conn(); + let mut stmt = MaybeCached::prepare(conn, sql, cache)?; + let mut rows = stmt.query(params)?; + rows.next()?.map(mapper).transpose() + } + + /// Caveat: This won't actually get used most of the time, and calls will + /// usually invoke rusqlite's method with the same name. See comment on + /// `UncheckedTransaction` for details (generally you probably don't need to + /// care) + fn unchecked_transaction(&self) -> SqlResult<UncheckedTransaction<'_>> { + UncheckedTransaction::new(self.conn(), TransactionBehavior::Deferred) + } + + /// Begin `unchecked_transaction` with `TransactionBehavior::Immediate`. Use + /// when the first operation will be a read operation, that further writes + /// depend on for correctness. + fn unchecked_transaction_imm(&self) -> SqlResult<UncheckedTransaction<'_>> { + UncheckedTransaction::new(self.conn(), TransactionBehavior::Immediate) + } + + /// Get the DB size in bytes + fn get_db_size(&self) -> Result<u32, rusqlite::Error> { + let page_count: u32 = self.query_one("SELECT * from pragma_page_count()")?; + let page_size: u32 = self.query_one("SELECT * from pragma_page_size()")?; + let freelist_count: u32 = self.query_one("SELECT * from pragma_freelist_count()")?; + + Ok((page_count - freelist_count) * page_size) + } +} + +impl ConnExt for Connection { + #[inline] + fn conn(&self) -> &Connection { + self + } +} + +impl<'conn> ConnExt for Transaction<'conn> { + #[inline] + fn conn(&self) -> &Connection { + self + } +} + +impl<'conn> ConnExt for Savepoint<'conn> { + #[inline] + fn conn(&self) -> &Connection { + self + } +} + +/// rusqlite, in an attempt to save us from ourselves, needs a mutable ref to a +/// connection to start a transaction. That is a bit of a PITA in some cases, so +/// we offer this as an alternative - but the responsibility of ensuring there +/// are no concurrent transactions is on our head. +/// +/// This is very similar to the rusqlite `Transaction` - it doesn't prevent +/// against nested transactions but does allow you to use an immutable +/// `Connection`. +/// +/// FIXME: This currently won't actually be used most of the time, because +/// `rusqlite` added [`Connection::unchecked_transaction`] (and +/// `Transaction::new_unchecked`, which can be used to reimplement +/// `unchecked_transaction_imm`), which will be preferred in a call to +/// `c.unchecked_transaction()`, because inherent methods have precedence over +/// methods on extension traits. The exception here is that this will still be +/// used by code which takes `&impl ConnExt` (I believe it would also be used if +/// you attempted to call `unchecked_transaction()` on a non-Connection that +/// implements ConnExt, such as a `Safepoint`, `UncheckedTransaction`, or +/// `Transaction` itself, but such code is clearly broken, so is not worth +/// considering). +/// +/// The difference is that `rusqlite`'s version returns a normal +/// `rusqlite::Transaction`, rather than the `UncheckedTransaction` from this +/// crate. Aside from type's name and location (and the fact that `rusqlite`'s +/// detects slightly more misuse at compile time, and has more features), the +/// main difference is: `rusqlite`'s does not track when a transaction began, +/// which unfortunatly seems to be used by the coop-transaction management in +/// places in some fashion. +/// +/// There are at least two options for how to fix this: +/// 1. Decide we don't need this version, and delete it, and moving the +/// transaction timing into the coop-transaction code directly (or something +/// like this). +/// 2. Decide this difference *is* important, and rename +/// `ConnExt::unchecked_transaction` to something like +/// `ConnExt::transaction_unchecked`. +pub struct UncheckedTransaction<'conn> { + pub conn: &'conn Connection, + pub started_at: Instant, + pub finished: bool, + // we could add drop_behavior etc too, but we don't need it yet - we + // always rollback. +} + +impl<'conn> UncheckedTransaction<'conn> { + /// Begin a new unchecked transaction. Cannot be nested, but this is not + /// enforced by Rust (hence 'unchecked') - however, it is enforced by + /// SQLite; use a rusqlite `savepoint` for nested transactions. + pub fn new(conn: &'conn Connection, behavior: TransactionBehavior) -> SqlResult<Self> { + let query = match behavior { + TransactionBehavior::Deferred => "BEGIN DEFERRED", + TransactionBehavior::Immediate => "BEGIN IMMEDIATE", + TransactionBehavior::Exclusive => "BEGIN EXCLUSIVE", + _ => unreachable!(), + }; + conn.execute_batch(query) + .map(move |_| UncheckedTransaction { + conn, + started_at: Instant::now(), + finished: false, + }) + } + + /// Consumes and commits an unchecked transaction. + pub fn commit(mut self) -> SqlResult<()> { + if self.finished { + log::warn!("ignoring request to commit an already finished transaction"); + return Ok(()); + } + self.finished = true; + self.conn.execute_batch("COMMIT")?; + log::debug!("Transaction commited after {:?}", self.started_at.elapsed()); + Ok(()) + } + + /// Consumes and rolls back an unchecked transaction. + pub fn rollback(mut self) -> SqlResult<()> { + if self.finished { + log::warn!("ignoring request to rollback an already finished transaction"); + return Ok(()); + } + self.rollback_() + } + + fn rollback_(&mut self) -> SqlResult<()> { + self.finished = true; + self.conn.execute_batch("ROLLBACK")?; + Ok(()) + } + + fn finish_(&mut self) -> SqlResult<()> { + if self.finished || self.conn.is_autocommit() { + return Ok(()); + } + self.rollback_()?; + Ok(()) + } +} + +impl<'conn> Deref for UncheckedTransaction<'conn> { + type Target = Connection; + + #[inline] + fn deref(&self) -> &Connection { + self.conn + } +} + +impl<'conn> Drop for UncheckedTransaction<'conn> { + fn drop(&mut self) { + if let Err(e) = self.finish_() { + log::warn!("Error dropping an unchecked transaction: {}", e); + } + } +} + +impl<'conn> ConnExt for UncheckedTransaction<'conn> { + #[inline] + fn conn(&self) -> &Connection { + self + } +} + +fn query_rows_and_then_cachable<Coll, T, E, P, F>( + conn: &Connection, + sql: &str, + params: P, + mapper: F, + cache: bool, +) -> Result<Coll, E> +where + E: From<rusqlite::Error>, + F: FnMut(&Row<'_>) -> Result<T, E>, + Coll: FromIterator<T>, + P: Params, +{ + let mut stmt = conn.prepare_maybe_cached(sql, cache)?; + let iter = stmt.query_and_then(params, mapper)?; + iter.collect::<Result<Coll, E>>() +} diff --git a/third_party/rust/sql-support/src/each_chunk.rs b/third_party/rust/sql-support/src/each_chunk.rs new file mode 100644 index 0000000000..2d738bcb37 --- /dev/null +++ b/third_party/rust/sql-support/src/each_chunk.rs @@ -0,0 +1,311 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use lazy_static::lazy_static; +use rusqlite::{self, limits::Limit, types::ToSql}; +use std::iter::Map; +use std::slice::Iter; + +/// Returns SQLITE_LIMIT_VARIABLE_NUMBER as read from an in-memory connection and cached. +/// connection and cached. That means this will return the wrong value if it's set to a lower +/// value for a connection using this will return the wrong thing, but doing so is rare enough +/// that we explicitly don't support it (why would you want to lower this at runtime?). +/// +/// If you call this and the actual value was set to a negative number or zero (nothing prevents +/// this beyond a warning in the SQLite documentation), we panic. However, it's unlikely you can +/// run useful queries if this happened anyway. +pub fn default_max_variable_number() -> usize { + lazy_static! { + static ref MAX_VARIABLE_NUMBER: usize = { + let conn = rusqlite::Connection::open_in_memory() + .expect("Failed to initialize in-memory connection (out of memory?)"); + + let limit = conn.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER); + assert!( + limit > 0, + "Illegal value for SQLITE_LIMIT_VARIABLE_NUMBER (must be > 0) {}", + limit + ); + limit as usize + }; + } + *MAX_VARIABLE_NUMBER +} + +/// Helper for the case where you have a `&[impl ToSql]` of arbitrary length, but need one +/// of no more than the connection's `MAX_VARIABLE_NUMBER` (rather, +/// `default_max_variable_number()`). This is useful when performing batched updates. +/// +/// The `do_chunk` callback is called with a slice of no more than `default_max_variable_number()` +/// items as it's first argument, and the offset from the start as it's second. +/// +/// See `each_chunk_mapped` for the case where `T` doesn't implement `ToSql`, but can be +/// converted to something that does. +pub fn each_chunk<'a, T, E, F>(items: &'a [T], do_chunk: F) -> Result<(), E> +where + T: 'a, + F: FnMut(&'a [T], usize) -> Result<(), E>, +{ + each_sized_chunk(items, default_max_variable_number(), do_chunk) +} + +/// A version of `each_chunk` for the case when the conversion to `to_sql` requires an custom +/// intermediate step. For example, you might want to grab a property off of an arrray of records +pub fn each_chunk_mapped<'a, T, U, E, Mapper, DoChunk>( + items: &'a [T], + to_sql: Mapper, + do_chunk: DoChunk, +) -> Result<(), E> +where + T: 'a, + U: ToSql + 'a, + Mapper: Fn(&'a T) -> U, + DoChunk: FnMut(Map<Iter<'a, T>, &'_ Mapper>, usize) -> Result<(), E>, +{ + each_sized_chunk_mapped(items, default_max_variable_number(), to_sql, do_chunk) +} + +// Split out for testing. Separate so that we can pass an actual slice +// to the callback if they don't need mapping. We could probably unify +// this with each_sized_chunk_mapped with a lot of type system trickery, +// but one of the benefits to each_chunk over the mapped versions is +// that the declaration is simpler. +pub fn each_sized_chunk<'a, T, E, F>( + items: &'a [T], + chunk_size: usize, + mut do_chunk: F, +) -> Result<(), E> +where + T: 'a, + F: FnMut(&'a [T], usize) -> Result<(), E>, +{ + if items.is_empty() { + return Ok(()); + } + let mut offset = 0; + for chunk in items.chunks(chunk_size) { + do_chunk(chunk, offset)?; + offset += chunk.len(); + } + Ok(()) +} + +/// Utility to help perform batched updates, inserts, queries, etc. This is the low-level version +/// of this utility which is wrapped by `each_chunk` and `each_chunk_mapped`, and it allows you to +/// provide both the mapping function, and the chunk size. +/// +/// Note: `mapped` basically just refers to the translating of `T` to some `U` where `U: ToSql` +/// using the `to_sql` function. This is useful for e.g. inserting the IDs of a large list +/// of records. +pub fn each_sized_chunk_mapped<'a, T, U, E, Mapper, DoChunk>( + items: &'a [T], + chunk_size: usize, + to_sql: Mapper, + mut do_chunk: DoChunk, +) -> Result<(), E> +where + T: 'a, + U: ToSql + 'a, + Mapper: Fn(&'a T) -> U, + DoChunk: FnMut(Map<Iter<'a, T>, &'_ Mapper>, usize) -> Result<(), E>, +{ + if items.is_empty() { + return Ok(()); + } + let mut offset = 0; + for chunk in items.chunks(chunk_size) { + let mapped = chunk.iter().map(&to_sql); + do_chunk(mapped, offset)?; + offset += chunk.len(); + } + Ok(()) +} + +#[cfg(test)] +fn check_chunk<T, C>(items: C, expect: &[T], desc: &str) +where + C: IntoIterator, + <C as IntoIterator>::Item: ToSql, + T: ToSql, +{ + let items = items.into_iter().collect::<Vec<_>>(); + assert_eq!(items.len(), expect.len()); + // Can't quite make the borrowing work out here w/o a loop, oh well. + for (idx, (got, want)) in items.iter().zip(expect.iter()).enumerate() { + assert_eq!( + got.to_sql().unwrap(), + want.to_sql().unwrap(), + // ToSqlOutput::Owned(Value::Integer(*num)), + "{}: Bad value at index {}", + desc, + idx + ); + } +} + +#[cfg(test)] +mod test_mapped { + use super::*; + + #[test] + fn test_separate() { + let mut iteration = 0; + each_sized_chunk_mapped( + &[1, 2, 3, 4, 5], + 3, + |item| item as &dyn ToSql, + |chunk, offset| { + match offset { + 0 => { + assert_eq!(iteration, 0); + check_chunk(chunk, &[1, 2, 3], "first chunk"); + } + 3 => { + assert_eq!(iteration, 1); + check_chunk(chunk, &[4, 5], "second chunk"); + } + n => { + panic!("Unexpected offset {}", n); + } + } + iteration += 1; + Ok::<(), ()>(()) + }, + ) + .unwrap(); + } + + #[test] + fn test_leq_chunk_size() { + for &check_size in &[5, 6] { + let mut iteration = 0; + each_sized_chunk_mapped( + &[1, 2, 3, 4, 5], + check_size, + |item| item as &dyn ToSql, + |chunk, offset| { + assert_eq!(iteration, 0); + iteration += 1; + assert_eq!(offset, 0); + check_chunk(chunk, &[1, 2, 3, 4, 5], "only iteration"); + Ok::<(), ()>(()) + }, + ) + .unwrap(); + } + } + + #[test] + fn test_empty_chunk() { + let items: &[i64] = &[]; + each_sized_chunk_mapped::<_, _, (), _, _>( + items, + 100, + |item| item as &dyn ToSql, + |_, _| { + panic!("Should never be called"); + }, + ) + .unwrap(); + } + + #[test] + fn test_error() { + let mut iteration = 0; + let e = each_sized_chunk_mapped( + &[1, 2, 3, 4, 5, 6, 7], + 3, + |item| item as &dyn ToSql, + |_, offset| { + if offset == 0 { + assert_eq!(iteration, 0); + iteration += 1; + Ok(()) + } else if offset == 3 { + assert_eq!(iteration, 1); + iteration += 1; + Err("testing".to_string()) + } else { + // Make sure we stopped after the error. + panic!("Shouldn't get called with offset of {}", offset); + } + }, + ) + .expect_err("Should be an error"); + assert_eq!(e, "testing"); + } +} + +#[cfg(test)] +mod test_unmapped { + use super::*; + + #[test] + fn test_separate() { + let mut iteration = 0; + each_sized_chunk(&[1, 2, 3, 4, 5], 3, |chunk, offset| { + match offset { + 0 => { + assert_eq!(iteration, 0); + check_chunk(chunk, &[1, 2, 3], "first chunk"); + } + 3 => { + assert_eq!(iteration, 1); + check_chunk(chunk, &[4, 5], "second chunk"); + } + n => { + panic!("Unexpected offset {}", n); + } + } + iteration += 1; + Ok::<(), ()>(()) + }) + .unwrap(); + } + + #[test] + fn test_leq_chunk_size() { + for &check_size in &[5, 6] { + let mut iteration = 0; + each_sized_chunk(&[1, 2, 3, 4, 5], check_size, |chunk, offset| { + assert_eq!(iteration, 0); + iteration += 1; + assert_eq!(offset, 0); + check_chunk(chunk, &[1, 2, 3, 4, 5], "only iteration"); + Ok::<(), ()>(()) + }) + .unwrap(); + } + } + + #[test] + fn test_empty_chunk() { + let items: &[i64] = &[]; + each_sized_chunk::<_, (), _>(items, 100, |_, _| { + panic!("Should never be called"); + }) + .unwrap(); + } + + #[test] + fn test_error() { + let mut iteration = 0; + let e = each_sized_chunk(&[1, 2, 3, 4, 5, 6, 7], 3, |_, offset| { + if offset == 0 { + assert_eq!(iteration, 0); + iteration += 1; + Ok(()) + } else if offset == 3 { + assert_eq!(iteration, 1); + iteration += 1; + Err("testing".to_string()) + } else { + // Make sure we stopped after the error. + panic!("Shouldn't get called with offset of {}", offset); + } + }) + .expect_err("Should be an error"); + assert_eq!(e, "testing"); + } +} diff --git a/third_party/rust/sql-support/src/lib.rs b/third_party/rust/sql-support/src/lib.rs new file mode 100644 index 0000000000..17068dcfea --- /dev/null +++ b/third_party/rust/sql-support/src/lib.rs @@ -0,0 +1,36 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#![allow(unknown_lints)] +#![warn(rust_2018_idioms)] + +mod conn_ext; +mod each_chunk; +mod maybe_cached; +pub mod open_database; +mod repeat; + +pub use crate::conn_ext::*; +pub use crate::each_chunk::*; +pub use crate::maybe_cached::*; +pub use crate::repeat::*; + +/// In PRAGMA foo='bar', `'bar'` must be a constant string (it cannot be a +/// bound parameter), so we need to escape manually. According to +/// https://www.sqlite.org/faq.html, the only character that must be escaped is +/// the single quote, which is escaped by placing two single quotes in a row. +pub fn escape_string_for_pragma(s: &str) -> String { + s.replace('\'', "''") +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_escape_string_for_pragma() { + assert_eq!(escape_string_for_pragma("foobar"), "foobar"); + assert_eq!(escape_string_for_pragma("'foo'bar'"), "''foo''bar''"); + assert_eq!(escape_string_for_pragma("''"), "''''"); + } +} diff --git a/third_party/rust/sql-support/src/maybe_cached.rs b/third_party/rust/sql-support/src/maybe_cached.rs new file mode 100644 index 0000000000..96f99f490c --- /dev/null +++ b/third_party/rust/sql-support/src/maybe_cached.rs @@ -0,0 +1,64 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use rusqlite::{self, CachedStatement, Connection, Statement}; + +use std::ops::{Deref, DerefMut}; + +/// MaybeCached is a type that can be used to help abstract +/// over cached and uncached rusqlite statements in a transparent manner. +pub enum MaybeCached<'conn> { + Uncached(Statement<'conn>), + Cached(CachedStatement<'conn>), +} + +impl<'conn> Deref for MaybeCached<'conn> { + type Target = Statement<'conn>; + #[inline] + fn deref(&self) -> &Statement<'conn> { + match self { + MaybeCached::Cached(cached) => Deref::deref(cached), + MaybeCached::Uncached(uncached) => uncached, + } + } +} + +impl<'conn> DerefMut for MaybeCached<'conn> { + #[inline] + fn deref_mut(&mut self) -> &mut Statement<'conn> { + match self { + MaybeCached::Cached(cached) => DerefMut::deref_mut(cached), + MaybeCached::Uncached(uncached) => uncached, + } + } +} + +impl<'conn> From<Statement<'conn>> for MaybeCached<'conn> { + #[inline] + fn from(stmt: Statement<'conn>) -> Self { + MaybeCached::Uncached(stmt) + } +} + +impl<'conn> From<CachedStatement<'conn>> for MaybeCached<'conn> { + #[inline] + fn from(stmt: CachedStatement<'conn>) -> Self { + MaybeCached::Cached(stmt) + } +} + +impl<'conn> MaybeCached<'conn> { + #[inline] + pub fn prepare( + conn: &'conn Connection, + sql: &str, + cached: bool, + ) -> rusqlite::Result<MaybeCached<'conn>> { + if cached { + Ok(MaybeCached::Cached(conn.prepare_cached(sql)?)) + } else { + Ok(MaybeCached::Uncached(conn.prepare(sql)?)) + } + } +} diff --git a/third_party/rust/sql-support/src/open_database.rs b/third_party/rust/sql-support/src/open_database.rs new file mode 100644 index 0000000000..43c0d3f30d --- /dev/null +++ b/third_party/rust/sql-support/src/open_database.rs @@ -0,0 +1,536 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +/// Use this module to open a new SQLite database connection. +/// +/// Usage: +/// - Define a struct that implements ConnectionInitializer. This handles: +/// - Initializing the schema for a new database +/// - Upgrading the schema for an existing database +/// - Extra preparation/finishing steps, for example setting up SQLite functions +/// +/// - Call open_database() in your database constructor: +/// - The first method called is `prepare()`. This is executed outside of a transaction +/// and is suitable for executing pragmas (eg, `PRAGMA journal_mode=wal`), defining +/// functions, etc. +/// - If the database file is not present and the connection is writable, open_database() +/// will create a new DB and call init(), then finish(). If the connection is not +/// writable it will panic, meaning that if you support ReadOnly connections, they must +/// be created after a writable connection is open. +/// - If the database file exists and the connection is writable, open_database() will open +/// it and call prepare(), upgrade_from() for each upgrade that needs to be applied, then +/// finish(). As above, a read-only connection will panic if upgrades are necessary, so +/// you should ensure the first connection opened is writable. +/// - If the connection is not writable, `finish()` will be called (ie, `finish()`, like +/// `prepare()`, is called for all connections) +/// +/// See the autofill DB code for an example. +/// +use crate::ConnExt; +use rusqlite::{ + Connection, Error as RusqliteError, ErrorCode, OpenFlags, Transaction, TransactionBehavior, +}; +use std::path::Path; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum Error { + #[error("Incompatible database version: {0}")] + IncompatibleVersion(u32), + #[error("Error executing SQL: {0}")] + SqlError(#[from] rusqlite::Error), + // `.0` is the original `Error` in string form. + #[error("Failed to recover a corrupt database ('{0}') due to an error deleting the file: {1}")] + RecoveryError(String, std::io::Error), +} + +pub type Result<T> = std::result::Result<T, Error>; + +pub trait ConnectionInitializer { + // Name to display in the logs + const NAME: &'static str; + + // The version that the last upgrade function upgrades to. + const END_VERSION: u32; + + // Functions called only for writable connections all take a Transaction + // Initialize a newly created database to END_VERSION + fn init(&self, tx: &Transaction<'_>) -> Result<()>; + + // Upgrade schema from version -> version + 1 + fn upgrade_from(&self, conn: &Transaction<'_>, version: u32) -> Result<()>; + + // Runs immediately after creation for all types of connections. If writable, + // will *not* be in the transaction created for the "only writable" functions above. + fn prepare(&self, _conn: &Connection, _db_empty: bool) -> Result<()> { + Ok(()) + } + + // Runs for all types of connections. If a writable connection is being + // initialized, this will be called after all initialization functions, + // but inside their transaction. + fn finish(&self, _conn: &Connection) -> Result<()> { + Ok(()) + } +} + +pub fn open_database<CI: ConnectionInitializer, P: AsRef<Path>>( + path: P, + connection_initializer: &CI, +) -> Result<Connection> { + open_database_with_flags(path, OpenFlags::default(), connection_initializer) +} + +pub fn open_memory_database<CI: ConnectionInitializer>( + conn_initializer: &CI, +) -> Result<Connection> { + open_memory_database_with_flags(OpenFlags::default(), conn_initializer) +} + +pub fn open_database_with_flags<CI: ConnectionInitializer, P: AsRef<Path>>( + path: P, + open_flags: OpenFlags, + connection_initializer: &CI, +) -> Result<Connection> { + do_open_database_with_flags(&path, open_flags, connection_initializer).or_else(|e| { + // See if we can recover from the error and try a second time + try_handle_db_failure(&path, open_flags, connection_initializer, e)?; + do_open_database_with_flags(&path, open_flags, connection_initializer) + }) +} + +fn do_open_database_with_flags<CI: ConnectionInitializer, P: AsRef<Path>>( + path: P, + open_flags: OpenFlags, + connection_initializer: &CI, +) -> Result<Connection> { + // Try running the migration logic with an existing file + log::debug!("{}: opening database", CI::NAME); + let mut conn = Connection::open_with_flags(path, open_flags)?; + log::debug!("{}: checking if initialization is necessary", CI::NAME); + let db_empty = is_db_empty(&conn)?; + + log::debug!("{}: preparing", CI::NAME); + connection_initializer.prepare(&conn, db_empty)?; + + if open_flags.contains(OpenFlags::SQLITE_OPEN_READ_WRITE) { + let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate)?; + if db_empty { + log::debug!("{}: initializing new database", CI::NAME); + connection_initializer.init(&tx)?; + } else { + let mut current_version = get_schema_version(&tx)?; + if current_version > CI::END_VERSION { + return Err(Error::IncompatibleVersion(current_version)); + } + while current_version < CI::END_VERSION { + log::debug!( + "{}: upgrading database to {}", + CI::NAME, + current_version + 1 + ); + connection_initializer.upgrade_from(&tx, current_version)?; + current_version += 1; + } + } + log::debug!("{}: finishing writable database open", CI::NAME); + connection_initializer.finish(&tx)?; + set_schema_version(&tx, CI::END_VERSION)?; + tx.commit()?; + } else { + // There's an implied requirement that the first connection to a DB is + // writable, so read-only connections do much less, but panic if stuff is wrong + assert!(!db_empty, "existing writer must have initialized"); + assert!( + get_schema_version(&conn)? == CI::END_VERSION, + "existing writer must have migrated" + ); + log::debug!("{}: finishing readonly database open", CI::NAME); + connection_initializer.finish(&conn)?; + } + log::debug!("{}: database open successful", CI::NAME); + Ok(conn) +} + +pub fn open_memory_database_with_flags<CI: ConnectionInitializer>( + flags: OpenFlags, + conn_initializer: &CI, +) -> Result<Connection> { + open_database_with_flags(":memory:", flags, conn_initializer) +} + +// Attempt to handle failure when opening the database. +// +// Returns: +// - Ok(()) the failure is potentially handled and we should make a second open attempt +// - Err(e) the failure couldn't be handled and we should return this error +fn try_handle_db_failure<CI: ConnectionInitializer, P: AsRef<Path>>( + path: P, + open_flags: OpenFlags, + _connection_initializer: &CI, + err: Error, +) -> Result<()> { + if !open_flags.contains(OpenFlags::SQLITE_OPEN_CREATE) + && matches!(err, Error::SqlError(rusqlite::Error::SqliteFailure(code, _)) if code.code == rusqlite::ErrorCode::CannotOpen) + { + log::info!( + "{}: database doesn't exist, but we weren't requested to create it", + CI::NAME + ); + return Err(err); + } + log::warn!("{}: database operation failed: {}", CI::NAME, err); + if !open_flags.contains(OpenFlags::SQLITE_OPEN_READ_WRITE) { + log::warn!( + "{}: not attempting recovery as this is a read-only connection request", + CI::NAME + ); + return Err(err); + } + + let delete = match err { + Error::SqlError(RusqliteError::SqliteFailure(e, _)) => { + matches!(e.code, ErrorCode::DatabaseCorrupt | ErrorCode::NotADatabase) + } + _ => false, + }; + if delete { + log::info!( + "{}: the database is fatally damaged; deleting and starting fresh", + CI::NAME + ); + // Note we explicitly decline to move the path to, say ".corrupt", as it's difficult to + // identify any value there - actually getting our hands on the file from a mobile device + // is tricky and it would just take up disk space forever. + if let Err(io_err) = std::fs::remove_file(path) { + return Err(Error::RecoveryError(err.to_string(), io_err)); + } + Ok(()) + } else { + Err(err) + } +} + +fn is_db_empty(conn: &Connection) -> Result<bool> { + Ok(conn.query_one::<u32>("SELECT COUNT(*) FROM sqlite_master")? == 0) +} + +fn get_schema_version(conn: &Connection) -> Result<u32> { + let version = conn.query_row_and_then("PRAGMA user_version", [], |row| row.get(0))?; + Ok(version) +} + +fn set_schema_version(conn: &Connection, version: u32) -> Result<()> { + conn.set_pragma("user_version", version)?; + Ok(()) +} + +// It would be nice for this to be #[cfg(test)], but that doesn't allow it to be used in tests for +// our other crates. +pub mod test_utils { + use super::*; + use std::path::PathBuf; + use tempfile::TempDir; + + // Database file that we can programatically run upgrades on + // + // We purposefully don't keep a connection to the database around to force upgrades to always + // run against a newly opened DB, like they would in the real world. See #4106 for + // details. + pub struct MigratedDatabaseFile<CI: ConnectionInitializer> { + // Keep around a TempDir to ensure the database file stays around until this struct is + // dropped + _tempdir: TempDir, + pub connection_initializer: CI, + pub path: PathBuf, + } + + impl<CI: ConnectionInitializer> MigratedDatabaseFile<CI> { + pub fn new(connection_initializer: CI, init_sql: &str) -> Self { + Self::new_with_flags(connection_initializer, init_sql, OpenFlags::default()) + } + + pub fn new_with_flags( + connection_initializer: CI, + init_sql: &str, + open_flags: OpenFlags, + ) -> Self { + let tempdir = tempfile::tempdir().unwrap(); + let path = tempdir.path().join(Path::new("db.sql")); + let conn = Connection::open_with_flags(&path, open_flags).unwrap(); + conn.execute_batch(init_sql).unwrap(); + Self { + _tempdir: tempdir, + connection_initializer, + path, + } + } + + pub fn upgrade_to(&self, version: u32) { + let mut conn = self.open(); + let tx = conn.transaction().unwrap(); + let mut current_version = get_schema_version(&tx).unwrap(); + while current_version < version { + self.connection_initializer + .upgrade_from(&tx, current_version) + .unwrap(); + current_version += 1; + } + set_schema_version(&tx, current_version).unwrap(); + self.connection_initializer.finish(&tx).unwrap(); + tx.commit().unwrap(); + } + + pub fn run_all_upgrades(&self) { + let current_version = get_schema_version(&self.open()).unwrap(); + for version in current_version..CI::END_VERSION { + self.upgrade_to(version + 1); + } + } + + pub fn open(&self) -> Connection { + Connection::open(&self.path).unwrap() + } + } +} + +#[cfg(test)] +mod test { + use super::test_utils::MigratedDatabaseFile; + use super::*; + use std::cell::RefCell; + use std::io::Write; + + struct TestConnectionInitializer { + pub calls: RefCell<Vec<&'static str>>, + pub buggy_v3_upgrade: bool, + } + + impl TestConnectionInitializer { + pub fn new() -> Self { + let _ = env_logger::try_init(); + Self { + calls: RefCell::new(Vec::new()), + buggy_v3_upgrade: false, + } + } + pub fn new_with_buggy_logic() -> Self { + let _ = env_logger::try_init(); + Self { + calls: RefCell::new(Vec::new()), + buggy_v3_upgrade: true, + } + } + + pub fn clear_calls(&self) { + self.calls.borrow_mut().clear(); + } + + pub fn push_call(&self, call: &'static str) { + self.calls.borrow_mut().push(call); + } + + pub fn check_calls(&self, expected: Vec<&'static str>) { + assert_eq!(*self.calls.borrow(), expected); + } + } + + impl ConnectionInitializer for TestConnectionInitializer { + const NAME: &'static str = "test db"; + const END_VERSION: u32 = 4; + + fn prepare(&self, conn: &Connection, _: bool) -> Result<()> { + self.push_call("prep"); + conn.execute_batch( + " + PRAGMA journal_mode = wal; + ", + )?; + Ok(()) + } + + fn init(&self, conn: &Transaction<'_>) -> Result<()> { + self.push_call("init"); + conn.execute_batch( + " + CREATE TABLE prep_table(col); + INSERT INTO prep_table(col) VALUES ('correct-value'); + CREATE TABLE my_table(col); + ", + ) + .map_err(|e| e.into()) + } + + fn upgrade_from(&self, conn: &Transaction<'_>, version: u32) -> Result<()> { + match version { + 2 => { + self.push_call("upgrade_from_v2"); + conn.execute_batch( + " + ALTER TABLE my_old_table_name RENAME TO my_table; + ", + )?; + Ok(()) + } + 3 => { + self.push_call("upgrade_from_v3"); + + if self.buggy_v3_upgrade { + conn.execute_batch("ILLEGAL_SQL_CODE")?; + } + + conn.execute_batch( + " + ALTER TABLE my_table RENAME COLUMN old_col to col; + ", + )?; + Ok(()) + } + _ => { + panic!("Unexpected version: {}", version); + } + } + } + + fn finish(&self, conn: &Connection) -> Result<()> { + self.push_call("finish"); + conn.execute_batch( + " + INSERT INTO my_table(col) SELECT col FROM prep_table; + ", + )?; + Ok(()) + } + } + + // Initialize the database to v2 to test upgrading from there + static INIT_V2: &str = " + CREATE TABLE prep_table(col); + INSERT INTO prep_table(col) VALUES ('correct-value'); + CREATE TABLE my_old_table_name(old_col); + PRAGMA user_version=2; + "; + + fn check_final_data(conn: &Connection) { + let value: String = conn + .query_row("SELECT col FROM my_table", [], |r| r.get(0)) + .unwrap(); + assert_eq!(value, "correct-value"); + assert_eq!(get_schema_version(conn).unwrap(), 4); + } + + #[test] + fn test_init() { + let connection_initializer = TestConnectionInitializer::new(); + let conn = open_memory_database(&connection_initializer).unwrap(); + check_final_data(&conn); + connection_initializer.check_calls(vec!["prep", "init", "finish"]); + } + + #[test] + fn test_upgrades() { + let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2); + let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap(); + check_final_data(&conn); + db_file.connection_initializer.check_calls(vec![ + "prep", + "upgrade_from_v2", + "upgrade_from_v3", + "finish", + ]); + } + + #[test] + fn test_open_current_version() { + let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2); + db_file.upgrade_to(4); + db_file.connection_initializer.clear_calls(); + let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap(); + check_final_data(&conn); + db_file + .connection_initializer + .check_calls(vec!["prep", "finish"]); + } + + #[test] + fn test_pragmas() { + let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2); + let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap(); + assert_eq!( + conn.query_one::<String>("PRAGMA journal_mode").unwrap(), + "wal" + ); + } + + #[test] + fn test_migration_error() { + let db_file = + MigratedDatabaseFile::new(TestConnectionInitializer::new_with_buggy_logic(), INIT_V2); + db_file + .open() + .execute( + "INSERT INTO my_old_table_name(old_col) VALUES ('I should not be deleted')", + [], + ) + .unwrap(); + + open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap_err(); + // Even though the upgrades failed, the data should still be there. The changes that + // upgrade_to_v3 made should have been rolled back. + assert_eq!( + db_file + .open() + .query_one::<i32>("SELECT COUNT(*) FROM my_old_table_name") + .unwrap(), + 1 + ); + } + + #[test] + fn test_version_too_new() { + let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2); + set_schema_version(&db_file.open(), 5).unwrap(); + + db_file + .open() + .execute( + "INSERT INTO my_old_table_name(old_col) VALUES ('I should not be deleted')", + [], + ) + .unwrap(); + + assert!(matches!( + open_database(db_file.path.clone(), &db_file.connection_initializer,), + Err(Error::IncompatibleVersion(5)) + )); + // Make sure that even when DeleteAndRecreate is specified, we don't delete the database + // file when the schema is newer + assert_eq!( + db_file + .open() + .query_one::<i32>("SELECT COUNT(*) FROM my_old_table_name") + .unwrap(), + 1 + ); + } + + #[test] + fn test_corrupt_db() { + let tempdir = tempfile::tempdir().unwrap(); + let path = tempdir.path().join(Path::new("corrupt-db.sql")); + let mut file = std::fs::File::create(path.clone()).unwrap(); + // interestingly, sqlite seems to treat a 0-byte file as a missing one. + // Note that this will exercise the `ErrorCode::NotADatabase` error code. It's not clear + // how we could hit `ErrorCode::DatabaseCorrupt`, but even if we could, there's not much + // value as this test can't really observe which one it was. + file.write_all(b"not sql").unwrap(); + let metadata = std::fs::metadata(path.clone()).unwrap(); + assert_eq!(metadata.len(), 7); + drop(file); + open_database(path.clone(), &TestConnectionInitializer::new()).unwrap(); + let metadata = std::fs::metadata(path).unwrap(); + // just check the file is no longer what it was before. + assert_ne!(metadata.len(), 7); + } +} diff --git a/third_party/rust/sql-support/src/repeat.rs b/third_party/rust/sql-support/src/repeat.rs new file mode 100644 index 0000000000..40b582ec14 --- /dev/null +++ b/third_party/rust/sql-support/src/repeat.rs @@ -0,0 +1,113 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use std::fmt; + +/// Helper type for printing repeated strings more efficiently. You should use +/// [`repeat_display`](sql_support::repeat_display), or one of the `repeat_sql_*` helpers to +/// construct it. +#[derive(Debug, Clone)] +pub struct RepeatDisplay<'a, F> { + count: usize, + sep: &'a str, + fmt_one: F, +} + +impl<'a, F> fmt::Display for RepeatDisplay<'a, F> +where + F: Fn(usize, &mut fmt::Formatter<'_>) -> fmt::Result, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for i in 0..self.count { + if i != 0 { + f.write_str(self.sep)?; + } + (self.fmt_one)(i, f)?; + } + Ok(()) + } +} + +/// Construct a RepeatDisplay that will repeatedly call `fmt_one` with a formatter `count` times, +/// separated by `sep`. +/// +/// # Example +/// +/// ```rust +/// # use sql_support::repeat_display; +/// assert_eq!(format!("{}", repeat_display(1, ",", |i, f| write!(f, "({},?)", i))), +/// "(0,?)"); +/// assert_eq!(format!("{}", repeat_display(2, ",", |i, f| write!(f, "({},?)", i))), +/// "(0,?),(1,?)"); +/// assert_eq!(format!("{}", repeat_display(3, ",", |i, f| write!(f, "({},?)", i))), +/// "(0,?),(1,?),(2,?)"); +/// ``` +#[inline] +pub fn repeat_display<F>(count: usize, sep: &str, fmt_one: F) -> RepeatDisplay<'_, F> +where + F: Fn(usize, &mut fmt::Formatter<'_>) -> fmt::Result, +{ + RepeatDisplay { + count, + sep, + fmt_one, + } +} + +/// Returns a value that formats as `count` instances of `?` separated by commas. +/// +/// # Example +/// +/// ```rust +/// # use sql_support::repeat_sql_vars; +/// assert_eq!(format!("{}", repeat_sql_vars(0)), ""); +/// assert_eq!(format!("{}", repeat_sql_vars(1)), "?"); +/// assert_eq!(format!("{}", repeat_sql_vars(2)), "?,?"); +/// assert_eq!(format!("{}", repeat_sql_vars(3)), "?,?,?"); +/// ``` +pub fn repeat_sql_vars(count: usize) -> impl fmt::Display { + repeat_display(count, ",", |_, f| write!(f, "?")) +} + +/// Returns a value that formats as `count` instances of `(?)` separated by commas. +/// +/// # Example +/// +/// ```rust +/// # use sql_support::repeat_sql_values; +/// assert_eq!(format!("{}", repeat_sql_values(0)), ""); +/// assert_eq!(format!("{}", repeat_sql_values(1)), "(?)"); +/// assert_eq!(format!("{}", repeat_sql_values(2)), "(?),(?)"); +/// assert_eq!(format!("{}", repeat_sql_values(3)), "(?),(?),(?)"); +/// ``` +/// +pub fn repeat_sql_values(count: usize) -> impl fmt::Display { + // We could also implement this as `repeat_sql_multi_values(count, 1)`, + // but this is faster and no less clear IMO. + repeat_display(count, ",", |_, f| write!(f, "(?)")) +} + +/// Returns a value that formats as `num_values` instances of `(?,?,?,...)` (where there are +/// `vars_per_value` question marks separated by commas in between the `?`s). +/// +/// Panics if `vars_per_value` is zero (however, `num_values` is allowed to be zero). +/// +/// # Example +/// +/// ```rust +/// # use sql_support::repeat_multi_values; +/// assert_eq!(format!("{}", repeat_multi_values(0, 2)), ""); +/// assert_eq!(format!("{}", repeat_multi_values(1, 5)), "(?,?,?,?,?)"); +/// assert_eq!(format!("{}", repeat_multi_values(2, 3)), "(?,?,?),(?,?,?)"); +/// assert_eq!(format!("{}", repeat_multi_values(3, 1)), "(?),(?),(?)"); +/// ``` +pub fn repeat_multi_values(num_values: usize, vars_per_value: usize) -> impl fmt::Display { + assert_ne!( + vars_per_value, 0, + "Illegal value for `vars_per_value`, must not be zero" + ); + repeat_display(num_values, ",", move |_, f| { + write!(f, "({})", repeat_sql_vars(vars_per_value)) + }) +} |