diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
commit | 26a029d407be480d791972afb5975cf62c9360a6 (patch) | |
tree | f435a8308119effd964b339f76abb83a57c29483 /third_party/rust/rusqlite/src | |
parent | Initial commit. (diff) | |
download | firefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz firefox-26a029d407be480d791972afb5975cf62c9360a6.zip |
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/rusqlite/src')
45 files changed, 18773 insertions, 0 deletions
diff --git a/third_party/rust/rusqlite/src/backup.rs b/third_party/rust/rusqlite/src/backup.rs new file mode 100644 index 0000000000..f28ae9a8ac --- /dev/null +++ b/third_party/rust/rusqlite/src/backup.rs @@ -0,0 +1,428 @@ +//! Online SQLite backup API. +//! +//! To create a [`Backup`], you must have two distinct [`Connection`]s - one +//! for the source (which can be used while the backup is running) and one for +//! the destination (which cannot). A [`Backup`] handle exposes three methods: +//! [`step`](Backup::step) will attempt to back up a specified number of pages, +//! [`progress`](Backup::progress) gets the current progress of the backup as of +//! the last call to [`step`](Backup::step), and +//! [`run_to_completion`](Backup::run_to_completion) will attempt to back up the +//! entire source database, allowing you to specify how many pages are backed up +//! at a time and how long the thread should sleep between chunks of pages. +//! +//! The following example is equivalent to "Example 2: Online Backup of a +//! Running Database" from [SQLite's Online Backup API +//! documentation](https://www.sqlite.org/backup.html). +//! +//! ```rust,no_run +//! # use rusqlite::{backup, Connection, Result}; +//! # use std::path::Path; +//! # use std::time; +//! +//! fn backup_db<P: AsRef<Path>>( +//! src: &Connection, +//! dst: P, +//! progress: fn(backup::Progress), +//! ) -> Result<()> { +//! let mut dst = Connection::open(dst)?; +//! let backup = backup::Backup::new(src, &mut dst)?; +//! backup.run_to_completion(5, time::Duration::from_millis(250), Some(progress)) +//! } +//! ``` + +use std::marker::PhantomData; +use std::path::Path; +use std::ptr; + +use std::os::raw::c_int; +use std::thread; +use std::time::Duration; + +use crate::ffi; + +use crate::error::error_from_handle; +use crate::{Connection, DatabaseName, Result}; + +impl Connection { + /// Back up the `name` database to the given + /// destination path. + /// + /// If `progress` is not `None`, it will be called periodically + /// until the backup completes. + /// + /// For more fine-grained control over the backup process (e.g., + /// to sleep periodically during the backup or to back up to an + /// already-open database connection), see the `backup` module. + /// + /// # Failure + /// + /// Will return `Err` if the destination path cannot be opened + /// or if the backup fails. + pub fn backup<P: AsRef<Path>>( + &self, + name: DatabaseName<'_>, + dst_path: P, + progress: Option<fn(Progress)>, + ) -> Result<()> { + use self::StepResult::{Busy, Done, Locked, More}; + let mut dst = Connection::open(dst_path)?; + let backup = Backup::new_with_names(self, name, &mut dst, DatabaseName::Main)?; + + let mut r = More; + while r == More { + r = backup.step(100)?; + if let Some(f) = progress { + f(backup.progress()); + } + } + + match r { + Done => Ok(()), + Busy => Err(unsafe { error_from_handle(ptr::null_mut(), ffi::SQLITE_BUSY) }), + Locked => Err(unsafe { error_from_handle(ptr::null_mut(), ffi::SQLITE_LOCKED) }), + More => unreachable!(), + } + } + + /// Restore the given source path into the + /// `name` database. If `progress` is not `None`, it will be + /// called periodically until the restore completes. + /// + /// For more fine-grained control over the restore process (e.g., + /// to sleep periodically during the restore or to restore from an + /// already-open database connection), see the `backup` module. + /// + /// # Failure + /// + /// Will return `Err` if the destination path cannot be opened + /// or if the restore fails. + pub fn restore<P: AsRef<Path>, F: Fn(Progress)>( + &mut self, + name: DatabaseName<'_>, + src_path: P, + progress: Option<F>, + ) -> Result<()> { + use self::StepResult::{Busy, Done, Locked, More}; + let src = Connection::open(src_path)?; + let restore = Backup::new_with_names(&src, DatabaseName::Main, self, name)?; + + let mut r = More; + let mut busy_count = 0_i32; + 'restore_loop: while r == More || r == Busy { + r = restore.step(100)?; + if let Some(ref f) = progress { + f(restore.progress()); + } + if r == Busy { + busy_count += 1; + if busy_count >= 3 { + break 'restore_loop; + } + thread::sleep(Duration::from_millis(100)); + } + } + + match r { + Done => Ok(()), + Busy => Err(unsafe { error_from_handle(ptr::null_mut(), ffi::SQLITE_BUSY) }), + Locked => Err(unsafe { error_from_handle(ptr::null_mut(), ffi::SQLITE_LOCKED) }), + More => unreachable!(), + } + } +} + +/// Possible successful results of calling +/// [`Backup::step`]. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum StepResult { + /// The backup is complete. + Done, + + /// The step was successful but there are still more pages that need to be + /// backed up. + More, + + /// The step failed because appropriate locks could not be acquired. This is + /// not a fatal error - the step can be retried. + Busy, + + /// The step failed because the source connection was writing to the + /// database. This is not a fatal error - the step can be retried. + Locked, +} + +/// Struct specifying the progress of a backup. The +/// percentage completion can be calculated as `(pagecount - remaining) / +/// pagecount`. The progress of a backup is as of the last call to +/// [`step`](Backup::step) - if the source database is modified after a call to +/// [`step`](Backup::step), the progress value will become outdated and +/// potentially incorrect. +#[derive(Copy, Clone, Debug)] +pub struct Progress { + /// Number of pages in the source database that still need to be backed up. + pub remaining: c_int, + /// Total number of pages in the source database. + pub pagecount: c_int, +} + +/// A handle to an online backup. +pub struct Backup<'a, 'b> { + phantom_from: PhantomData<&'a Connection>, + to: &'b Connection, + b: *mut ffi::sqlite3_backup, +} + +impl Backup<'_, '_> { + /// Attempt to create a new handle that will allow backups from `from` to + /// `to`. Note that `to` is a `&mut` - this is because SQLite forbids any + /// API calls on the destination of a backup while the backup is taking + /// place. + /// + /// # Failure + /// + /// Will return `Err` if the underlying `sqlite3_backup_init` call returns + /// `NULL`. + #[inline] + pub fn new<'a, 'b>(from: &'a Connection, to: &'b mut Connection) -> Result<Backup<'a, 'b>> { + Backup::new_with_names(from, DatabaseName::Main, to, DatabaseName::Main) + } + + /// Attempt to create a new handle that will allow backups from the + /// `from_name` database of `from` to the `to_name` database of `to`. Note + /// that `to` is a `&mut` - this is because SQLite forbids any API calls on + /// the destination of a backup while the backup is taking place. + /// + /// # Failure + /// + /// Will return `Err` if the underlying `sqlite3_backup_init` call returns + /// `NULL`. + pub fn new_with_names<'a, 'b>( + from: &'a Connection, + from_name: DatabaseName<'_>, + to: &'b mut Connection, + to_name: DatabaseName<'_>, + ) -> Result<Backup<'a, 'b>> { + let to_name = to_name.as_cstring()?; + let from_name = from_name.as_cstring()?; + + let to_db = to.db.borrow_mut().db; + + let b = unsafe { + let b = ffi::sqlite3_backup_init( + to_db, + to_name.as_ptr(), + from.db.borrow_mut().db, + from_name.as_ptr(), + ); + if b.is_null() { + return Err(error_from_handle(to_db, ffi::sqlite3_errcode(to_db))); + } + b + }; + + Ok(Backup { + phantom_from: PhantomData, + to, + b, + }) + } + + /// Gets the progress of the backup as of the last call to + /// [`step`](Backup::step). + #[inline] + #[must_use] + pub fn progress(&self) -> Progress { + unsafe { + Progress { + remaining: ffi::sqlite3_backup_remaining(self.b), + pagecount: ffi::sqlite3_backup_pagecount(self.b), + } + } + } + + /// Attempts to back up the given number of pages. If `num_pages` is + /// negative, will attempt to back up all remaining pages. This will hold a + /// lock on the source database for the duration, so it is probably not + /// what you want for databases that are currently active (see + /// [`run_to_completion`](Backup::run_to_completion) for a better + /// alternative). + /// + /// # Failure + /// + /// Will return `Err` if the underlying `sqlite3_backup_step` call returns + /// an error code other than `DONE`, `OK`, `BUSY`, or `LOCKED`. `BUSY` and + /// `LOCKED` are transient errors and are therefore returned as possible + /// `Ok` values. + #[inline] + pub fn step(&self, num_pages: c_int) -> Result<StepResult> { + use self::StepResult::{Busy, Done, Locked, More}; + + let rc = unsafe { ffi::sqlite3_backup_step(self.b, num_pages) }; + match rc { + ffi::SQLITE_DONE => Ok(Done), + ffi::SQLITE_OK => Ok(More), + ffi::SQLITE_BUSY => Ok(Busy), + ffi::SQLITE_LOCKED => Ok(Locked), + _ => self.to.decode_result(rc).map(|_| More), + } + } + + /// Attempts to run the entire backup. Will call + /// [`step(pages_per_step)`](Backup::step) as many times as necessary, + /// sleeping for `pause_between_pages` between each call to give the + /// source database time to process any pending queries. This is a + /// direct implementation of "Example 2: Online Backup of a Running + /// Database" from [SQLite's Online Backup API documentation](https://www.sqlite.org/backup.html). + /// + /// If `progress` is not `None`, it will be called after each step with the + /// current progress of the backup. Note that is possible the progress may + /// not change if the step returns `Busy` or `Locked` even though the + /// backup is still running. + /// + /// # Failure + /// + /// Will return `Err` if any of the calls to [`step`](Backup::step) return + /// `Err`. + pub fn run_to_completion( + &self, + pages_per_step: c_int, + pause_between_pages: Duration, + progress: Option<fn(Progress)>, + ) -> Result<()> { + use self::StepResult::{Busy, Done, Locked, More}; + + assert!(pages_per_step > 0, "pages_per_step must be positive"); + + loop { + let r = self.step(pages_per_step)?; + if let Some(progress) = progress { + progress(self.progress()); + } + match r { + More | Busy | Locked => thread::sleep(pause_between_pages), + Done => return Ok(()), + } + } + } +} + +impl Drop for Backup<'_, '_> { + #[inline] + fn drop(&mut self) { + unsafe { ffi::sqlite3_backup_finish(self.b) }; + } +} + +#[cfg(test)] +mod test { + use super::Backup; + use crate::{Connection, DatabaseName, Result}; + use std::time::Duration; + + #[test] + fn test_backup() -> Result<()> { + let src = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(42); + END;"; + src.execute_batch(sql)?; + + let mut dst = Connection::open_in_memory()?; + + { + let backup = Backup::new(&src, &mut dst)?; + backup.step(-1)?; + } + + let the_answer: i64 = dst.one_column("SELECT x FROM foo")?; + assert_eq!(42, the_answer); + + src.execute_batch("INSERT INTO foo VALUES(43)")?; + + { + let backup = Backup::new(&src, &mut dst)?; + backup.run_to_completion(5, Duration::from_millis(250), None)?; + } + + let the_answer: i64 = dst.one_column("SELECT SUM(x) FROM foo")?; + assert_eq!(42 + 43, the_answer); + Ok(()) + } + + #[test] + fn test_backup_temp() -> Result<()> { + let src = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TEMPORARY TABLE foo(x INTEGER); + INSERT INTO foo VALUES(42); + END;"; + src.execute_batch(sql)?; + + let mut dst = Connection::open_in_memory()?; + + { + let backup = + Backup::new_with_names(&src, DatabaseName::Temp, &mut dst, DatabaseName::Main)?; + backup.step(-1)?; + } + + let the_answer: i64 = dst.one_column("SELECT x FROM foo")?; + assert_eq!(42, the_answer); + + src.execute_batch("INSERT INTO foo VALUES(43)")?; + + { + let backup = + Backup::new_with_names(&src, DatabaseName::Temp, &mut dst, DatabaseName::Main)?; + backup.run_to_completion(5, Duration::from_millis(250), None)?; + } + + let the_answer: i64 = dst.one_column("SELECT SUM(x) FROM foo")?; + assert_eq!(42 + 43, the_answer); + Ok(()) + } + + #[test] + fn test_backup_attached() -> Result<()> { + let src = Connection::open_in_memory()?; + let sql = "ATTACH DATABASE ':memory:' AS my_attached; + BEGIN; + CREATE TABLE my_attached.foo(x INTEGER); + INSERT INTO my_attached.foo VALUES(42); + END;"; + src.execute_batch(sql)?; + + let mut dst = Connection::open_in_memory()?; + + { + let backup = Backup::new_with_names( + &src, + DatabaseName::Attached("my_attached"), + &mut dst, + DatabaseName::Main, + )?; + backup.step(-1)?; + } + + let the_answer: i64 = dst.one_column("SELECT x FROM foo")?; + assert_eq!(42, the_answer); + + src.execute_batch("INSERT INTO foo VALUES(43)")?; + + { + let backup = Backup::new_with_names( + &src, + DatabaseName::Attached("my_attached"), + &mut dst, + DatabaseName::Main, + )?; + backup.run_to_completion(5, Duration::from_millis(250), None)?; + } + + let the_answer: i64 = dst.one_column("SELECT SUM(x) FROM foo")?; + assert_eq!(42 + 43, the_answer); + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/blob/mod.rs b/third_party/rust/rusqlite/src/blob/mod.rs new file mode 100644 index 0000000000..c9e797bc0f --- /dev/null +++ b/third_party/rust/rusqlite/src/blob/mod.rs @@ -0,0 +1,550 @@ +//! Incremental BLOB I/O. +//! +//! Note that SQLite does not provide API-level access to change the size of a +//! BLOB; that must be performed through SQL statements. +//! +//! There are two choices for how to perform IO on a [`Blob`]. +//! +//! 1. The implementations it provides of the `std::io::Read`, `std::io::Write`, +//! and `std::io::Seek` traits. +//! +//! 2. A positional IO API, e.g. [`Blob::read_at`], [`Blob::write_at`] and +//! similar. +//! +//! Documenting these in order: +//! +//! ## 1. `std::io` trait implementations. +//! +//! `Blob` conforms to `std::io::Read`, `std::io::Write`, and `std::io::Seek`, +//! so it plays nicely with other types that build on these (such as +//! `std::io::BufReader` and `std::io::BufWriter`). However, you must be careful +//! with the size of the blob. For example, when using a `BufWriter`, the +//! `BufWriter` will accept more data than the `Blob` will allow, so make sure +//! to call `flush` and check for errors. (See the unit tests in this module for +//! an example.) +//! +//! ## 2. Positional IO +//! +//! `Blob`s also offer a `pread` / `pwrite`-style positional IO api in the form +//! of [`Blob::read_at`], [`Blob::write_at`], [`Blob::raw_read_at`], +//! [`Blob::read_at_exact`], and [`Blob::raw_read_at_exact`]. +//! +//! These APIs all take the position to read from or write to from as a +//! parameter, instead of using an internal `pos` value. +//! +//! ### Positional IO Read Variants +//! +//! For the `read` functions, there are several functions provided: +//! +//! - [`Blob::read_at`] +//! - [`Blob::raw_read_at`] +//! - [`Blob::read_at_exact`] +//! - [`Blob::raw_read_at_exact`] +//! +//! These can be divided along two axes: raw/not raw, and exact/inexact: +//! +//! 1. Raw/not raw refers to the type of the destination buffer. The raw +//! functions take a `&mut [MaybeUninit<u8>]` as the destination buffer, +//! where the "normal" functions take a `&mut [u8]`. +//! +//! Using `MaybeUninit` here can be more efficient in some cases, but is +//! often inconvenient, so both are provided. +//! +//! 2. Exact/inexact refers to to whether or not the entire buffer must be +//! filled in order for the call to be considered a success. +//! +//! The "exact" functions require the provided buffer be entirely filled, or +//! they return an error, whereas the "inexact" functions read as much out of +//! the blob as is available, and return how much they were able to read. +//! +//! The inexact functions are preferable if you do not know the size of the +//! blob already, and the exact functions are preferable if you do. +//! +//! ### Comparison to using the `std::io` traits: +//! +//! In general, the positional methods offer the following Pro/Cons compared to +//! using the implementation `std::io::{Read, Write, Seek}` we provide for +//! `Blob`: +//! +//! 1. (Pro) There is no need to first seek to a position in order to perform IO +//! on it as the position is a parameter. +//! +//! 2. (Pro) `Blob`'s positional read functions don't mutate the blob in any +//! way, and take `&self`. No `&mut` access required. +//! +//! 3. (Pro) Positional IO functions return `Err(rusqlite::Error)` on failure, +//! rather than `Err(std::io::Error)`. Returning `rusqlite::Error` is more +//! accurate and convenient. +//! +//! Note that for the `std::io` API, no data is lost however, and it can be +//! recovered with `io_err.downcast::<rusqlite::Error>()` (this can be easy +//! to forget, though). +//! +//! 4. (Pro, for now). A `raw` version of the read API exists which can allow +//! reading into a `&mut [MaybeUninit<u8>]` buffer, which avoids a potential +//! costly initialization step. (However, `std::io` traits will certainly +//! gain this someday, which is why this is only a "Pro, for now"). +//! +//! 5. (Con) The set of functions is more bare-bones than what is offered in +//! `std::io`, which has a number of adapters, handy algorithms, further +//! traits. +//! +//! 6. (Con) No meaningful interoperability with other crates, so if you need +//! that you must use `std::io`. +//! +//! To generalize: the `std::io` traits are useful because they conform to a +//! standard interface that a lot of code knows how to handle, however that +//! interface is not a perfect fit for [`Blob`], so another small set of +//! functions is provided as well. +//! +//! # Example (`std::io`) +//! +//! ```rust +//! # use rusqlite::blob::ZeroBlob; +//! # use rusqlite::{Connection, DatabaseName}; +//! # use std::error::Error; +//! # use std::io::{Read, Seek, SeekFrom, Write}; +//! # fn main() -> Result<(), Box<dyn Error>> { +//! let db = Connection::open_in_memory()?; +//! db.execute_batch("CREATE TABLE test_table (content BLOB);")?; +//! +//! // Insert a BLOB into the `content` column of `test_table`. Note that the Blob +//! // I/O API provides no way of inserting or resizing BLOBs in the DB -- this +//! // must be done via SQL. +//! db.execute("INSERT INTO test_table (content) VALUES (ZEROBLOB(10))", [])?; +//! +//! // Get the row id off the BLOB we just inserted. +//! let rowid = db.last_insert_rowid(); +//! // Open the BLOB we just inserted for IO. +//! let mut blob = db.blob_open(DatabaseName::Main, "test_table", "content", rowid, false)?; +//! +//! // Write some data into the blob. Make sure to test that the number of bytes +//! // written matches what you expect; if you try to write too much, the data +//! // will be truncated to the size of the BLOB. +//! let bytes_written = blob.write(b"01234567")?; +//! assert_eq!(bytes_written, 8); +//! +//! // Move back to the start and read into a local buffer. +//! // Same guidance - make sure you check the number of bytes read! +//! blob.seek(SeekFrom::Start(0))?; +//! let mut buf = [0u8; 20]; +//! let bytes_read = blob.read(&mut buf[..])?; +//! assert_eq!(bytes_read, 10); // note we read 10 bytes because the blob has size 10 +//! +//! // Insert another BLOB, this time using a parameter passed in from +//! // rust (potentially with a dynamic size). +//! db.execute( +//! "INSERT INTO test_table (content) VALUES (?1)", +//! [ZeroBlob(64)], +//! )?; +//! +//! // given a new row ID, we can reopen the blob on that row +//! let rowid = db.last_insert_rowid(); +//! blob.reopen(rowid)?; +//! // Just check that the size is right. +//! assert_eq!(blob.len(), 64); +//! # Ok(()) +//! # } +//! ``` +//! +//! # Example (Positional) +//! +//! ```rust +//! # use rusqlite::blob::ZeroBlob; +//! # use rusqlite::{Connection, DatabaseName}; +//! # use std::error::Error; +//! # fn main() -> Result<(), Box<dyn Error>> { +//! let db = Connection::open_in_memory()?; +//! db.execute_batch("CREATE TABLE test_table (content BLOB);")?; +//! // Insert a blob into the `content` column of `test_table`. Note that the Blob +//! // I/O API provides no way of inserting or resizing blobs in the DB -- this +//! // must be done via SQL. +//! db.execute("INSERT INTO test_table (content) VALUES (ZEROBLOB(10))", [])?; +//! // Get the row id off the blob we just inserted. +//! let rowid = db.last_insert_rowid(); +//! // Open the blob we just inserted for IO. +//! let mut blob = db.blob_open(DatabaseName::Main, "test_table", "content", rowid, false)?; +//! // Write some data into the blob. +//! blob.write_at(b"ABCDEF", 2)?; +//! +//! // Read the whole blob into a local buffer. +//! let mut buf = [0u8; 10]; +//! blob.read_at_exact(&mut buf, 0)?; +//! assert_eq!(&buf, b"\0\0ABCDEF\0\0"); +//! +//! // Insert another blob, this time using a parameter passed in from +//! // rust (potentially with a dynamic size). +//! db.execute( +//! "INSERT INTO test_table (content) VALUES (?1)", +//! [ZeroBlob(64)], +//! )?; +//! +//! // given a new row ID, we can reopen the blob on that row +//! let rowid = db.last_insert_rowid(); +//! blob.reopen(rowid)?; +//! assert_eq!(blob.len(), 64); +//! # Ok(()) +//! # } +//! ``` +use std::cmp::min; +use std::io; +use std::ptr; + +use super::ffi; +use super::types::{ToSql, ToSqlOutput}; +use crate::{Connection, DatabaseName, Result}; + +mod pos_io; + +/// Handle to an open BLOB. See +/// [`rusqlite::blob`](crate::blob) documentation for in-depth discussion. +pub struct Blob<'conn> { + conn: &'conn Connection, + blob: *mut ffi::sqlite3_blob, + // used by std::io implementations, + pos: i32, +} + +impl Connection { + /// Open a handle to the BLOB located in `row_id`, + /// `column`, `table` in database `db`. + /// + /// # Failure + /// + /// Will return `Err` if `db`/`table`/`column` cannot be converted to a + /// C-compatible string or if the underlying SQLite BLOB open call + /// fails. + #[inline] + pub fn blob_open<'a>( + &'a self, + db: DatabaseName<'_>, + table: &str, + column: &str, + row_id: i64, + read_only: bool, + ) -> Result<Blob<'a>> { + let c = self.db.borrow_mut(); + let mut blob = ptr::null_mut(); + let db = db.as_cstring()?; + let table = super::str_to_cstring(table)?; + let column = super::str_to_cstring(column)?; + let rc = unsafe { + ffi::sqlite3_blob_open( + c.db(), + db.as_ptr(), + table.as_ptr(), + column.as_ptr(), + row_id, + !read_only as std::os::raw::c_int, + &mut blob, + ) + }; + c.decode_result(rc).map(|_| Blob { + conn: self, + blob, + pos: 0, + }) + } +} + +impl Blob<'_> { + /// Move a BLOB handle to a new row. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite BLOB reopen call fails. + #[inline] + pub fn reopen(&mut self, row: i64) -> Result<()> { + let rc = unsafe { ffi::sqlite3_blob_reopen(self.blob, row) }; + if rc != ffi::SQLITE_OK { + return self.conn.decode_result(rc); + } + self.pos = 0; + Ok(()) + } + + /// Return the size in bytes of the BLOB. + #[inline] + #[must_use] + pub fn size(&self) -> i32 { + unsafe { ffi::sqlite3_blob_bytes(self.blob) } + } + + /// Return the current size in bytes of the BLOB. + #[inline] + #[must_use] + pub fn len(&self) -> usize { + self.size().try_into().unwrap() + } + + /// Return true if the BLOB is empty. + #[inline] + #[must_use] + pub fn is_empty(&self) -> bool { + self.size() == 0 + } + + /// Close a BLOB handle. + /// + /// Calling `close` explicitly is not required (the BLOB will be closed + /// when the `Blob` is dropped), but it is available so you can get any + /// errors that occur. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite close call fails. + #[inline] + pub fn close(mut self) -> Result<()> { + self.close_() + } + + #[inline] + fn close_(&mut self) -> Result<()> { + let rc = unsafe { ffi::sqlite3_blob_close(self.blob) }; + self.blob = ptr::null_mut(); + self.conn.decode_result(rc) + } +} + +impl io::Read for Blob<'_> { + /// Read data from a BLOB incrementally. Will return Ok(0) if the end of + /// the blob has been reached. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite read call fails. + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + let max_allowed_len = (self.size() - self.pos) as usize; + let n = min(buf.len(), max_allowed_len) as i32; + if n <= 0 { + return Ok(0); + } + let rc = unsafe { ffi::sqlite3_blob_read(self.blob, buf.as_mut_ptr().cast(), n, self.pos) }; + self.conn + .decode_result(rc) + .map(|_| { + self.pos += n; + n as usize + }) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err)) + } +} + +impl io::Write for Blob<'_> { + /// Write data into a BLOB incrementally. Will return `Ok(0)` if the end of + /// the blob has been reached; consider using `Write::write_all(buf)` + /// if you want to get an error if the entirety of the buffer cannot be + /// written. + /// + /// This function may only modify the contents of the BLOB; it is not + /// possible to increase the size of a BLOB using this API. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite write call fails. + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + let max_allowed_len = (self.size() - self.pos) as usize; + let n = min(buf.len(), max_allowed_len) as i32; + if n <= 0 { + return Ok(0); + } + let rc = unsafe { ffi::sqlite3_blob_write(self.blob, buf.as_ptr() as *mut _, n, self.pos) }; + self.conn + .decode_result(rc) + .map(|_| { + self.pos += n; + n as usize + }) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err)) + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl io::Seek for Blob<'_> { + /// Seek to an offset, in bytes, in BLOB. + #[inline] + fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> { + let pos = match pos { + io::SeekFrom::Start(offset) => offset as i64, + io::SeekFrom::Current(offset) => i64::from(self.pos) + offset, + io::SeekFrom::End(offset) => i64::from(self.size()) + offset, + }; + + if pos < 0 { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid seek to negative position", + )) + } else if pos > i64::from(self.size()) { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid seek to position past end of blob", + )) + } else { + self.pos = pos as i32; + Ok(pos as u64) + } + } +} + +#[allow(unused_must_use)] +impl Drop for Blob<'_> { + #[inline] + fn drop(&mut self) { + self.close_(); + } +} + +/// BLOB of length N that is filled with zeroes. +/// +/// Zeroblobs are intended to serve as placeholders for BLOBs whose content is +/// later written using incremental BLOB I/O routines. +/// +/// A negative value for the zeroblob results in a zero-length BLOB. +#[derive(Copy, Clone)] +pub struct ZeroBlob(pub i32); + +impl ToSql for ZeroBlob { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + let ZeroBlob(length) = *self; + Ok(ToSqlOutput::ZeroBlob(length)) + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, DatabaseName, Result}; + use std::io::{BufRead, BufReader, BufWriter, Read, Seek, SeekFrom, Write}; + + fn db_with_test_blob() -> Result<(Connection, i64)> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE test (content BLOB); + INSERT INTO test VALUES (ZEROBLOB(10)); + END;"; + db.execute_batch(sql)?; + let rowid = db.last_insert_rowid(); + Ok((db, rowid)) + } + + #[test] + fn test_blob() -> Result<()> { + let (db, rowid) = db_with_test_blob()?; + + let mut blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false)?; + assert_eq!(4, blob.write(b"Clob").unwrap()); + assert_eq!(6, blob.write(b"567890xxxxxx").unwrap()); // cannot write past 10 + assert_eq!(0, blob.write(b"5678").unwrap()); // still cannot write past 10 + + blob.reopen(rowid)?; + blob.close()?; + + blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, true)?; + let mut bytes = [0u8; 5]; + assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"Clob5"); + assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"67890"); + assert_eq!(0, blob.read(&mut bytes[..]).unwrap()); + + blob.seek(SeekFrom::Start(2)).unwrap(); + assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"ob567"); + + // only first 4 bytes of `bytes` should be read into + blob.seek(SeekFrom::Current(-1)).unwrap(); + assert_eq!(4, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"78907"); + + blob.seek(SeekFrom::End(-6)).unwrap(); + assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"56789"); + + blob.reopen(rowid)?; + assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"Clob5"); + + // should not be able to seek negative or past end + blob.seek(SeekFrom::Current(-20)).unwrap_err(); + blob.seek(SeekFrom::End(0)).unwrap(); + blob.seek(SeekFrom::Current(1)).unwrap_err(); + + // write_all should detect when we return Ok(0) because there is no space left, + // and return a write error + blob.reopen(rowid)?; + blob.write_all(b"0123456789x").unwrap_err(); + Ok(()) + } + + #[test] + fn test_blob_in_bufreader() -> Result<()> { + let (db, rowid) = db_with_test_blob()?; + + let mut blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false)?; + assert_eq!(8, blob.write(b"one\ntwo\n").unwrap()); + + blob.reopen(rowid)?; + let mut reader = BufReader::new(blob); + + let mut line = String::new(); + assert_eq!(4, reader.read_line(&mut line).unwrap()); + assert_eq!("one\n", line); + + line.truncate(0); + assert_eq!(4, reader.read_line(&mut line).unwrap()); + assert_eq!("two\n", line); + + line.truncate(0); + assert_eq!(2, reader.read_line(&mut line).unwrap()); + assert_eq!("\0\0", line); + Ok(()) + } + + #[test] + fn test_blob_in_bufwriter() -> Result<()> { + let (db, rowid) = db_with_test_blob()?; + + { + let blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false)?; + let mut writer = BufWriter::new(blob); + + // trying to write too much and then flush should fail + assert_eq!(8, writer.write(b"01234567").unwrap()); + assert_eq!(8, writer.write(b"01234567").unwrap()); + writer.flush().unwrap_err(); + } + + { + // ... but it should've written the first 10 bytes + let mut blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false)?; + let mut bytes = [0u8; 10]; + assert_eq!(10, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(b"0123456701", &bytes); + } + + { + let blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false)?; + let mut writer = BufWriter::new(blob); + + // trying to write_all too much should fail + writer.write_all(b"aaaaaaaaaabbbbb").unwrap(); + writer.flush().unwrap_err(); + } + + { + // ... but it should've written the first 10 bytes + let mut blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false)?; + let mut bytes = [0u8; 10]; + assert_eq!(10, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(b"aaaaaaaaaa", &bytes); + Ok(()) + } + } +} diff --git a/third_party/rust/rusqlite/src/blob/pos_io.rs b/third_party/rust/rusqlite/src/blob/pos_io.rs new file mode 100644 index 0000000000..d970ab735b --- /dev/null +++ b/third_party/rust/rusqlite/src/blob/pos_io.rs @@ -0,0 +1,272 @@ +use super::Blob; + +use std::convert::TryFrom; +use std::mem::MaybeUninit; +use std::slice::from_raw_parts_mut; + +use crate::ffi; +use crate::{Error, Result}; + +impl<'conn> Blob<'conn> { + /// Write `buf` to `self` starting at `write_start`, returning an error if + /// `write_start + buf.len()` is past the end of the blob. + /// + /// If an error is returned, no data is written. + /// + /// Note: the blob cannot be resized using this function -- that must be + /// done using SQL (for example, an `UPDATE` statement). + /// + /// Note: This is part of the positional I/O API, and thus takes an absolute + /// position write to, instead of using the internal position that can be + /// manipulated by the `std::io` traits. + /// + /// Unlike the similarly named [`FileExt::write_at`][fext_write_at] function + /// (from `std::os::unix`), it's always an error to perform a "short write". + /// + /// [fext_write_at]: https://doc.rust-lang.org/std/os/unix/fs/trait.FileExt.html#tymethod.write_at + #[inline] + pub fn write_at(&mut self, buf: &[u8], write_start: usize) -> Result<()> { + let len = self.len(); + + if buf.len().saturating_add(write_start) > len { + return Err(Error::BlobSizeError); + } + // We know `len` fits in an `i32`, so either: + // + // 1. `buf.len() + write_start` overflows, in which case we'd hit the + // return above (courtesy of `saturating_add`). + // + // 2. `buf.len() + write_start` doesn't overflow but is larger than len, + // in which case ditto. + // + // 3. `buf.len() + write_start` doesn't overflow but is less than len. + // This means that both `buf.len()` and `write_start` can also be + // losslessly converted to i32, since `len` came from an i32. + // Sanity check the above. + debug_assert!(i32::try_from(write_start).is_ok() && i32::try_from(buf.len()).is_ok()); + self.conn.decode_result(unsafe { + ffi::sqlite3_blob_write( + self.blob, + buf.as_ptr().cast(), + buf.len() as i32, + write_start as i32, + ) + }) + } + + /// An alias for `write_at` provided for compatibility with the conceptually + /// equivalent [`std::os::unix::FileExt::write_all_at`][write_all_at] + /// function from libstd: + /// + /// [write_all_at]: https://doc.rust-lang.org/std/os/unix/fs/trait.FileExt.html#method.write_all_at + #[inline] + pub fn write_all_at(&mut self, buf: &[u8], write_start: usize) -> Result<()> { + self.write_at(buf, write_start) + } + + /// Read as much as possible from `offset` to `offset + buf.len()` out of + /// `self`, writing into `buf`. On success, returns the number of bytes + /// written. + /// + /// If there's insufficient data in `self`, then the returned value will be + /// less than `buf.len()`. + /// + /// See also [`Blob::raw_read_at`], which can take an uninitialized buffer, + /// or [`Blob::read_at_exact`] which returns an error if the entire `buf` is + /// not read. + /// + /// Note: This is part of the positional I/O API, and thus takes an absolute + /// position to read from, instead of using the internal position that can + /// be manipulated by the `std::io` traits. Consequently, it does not change + /// that value either. + #[inline] + pub fn read_at(&self, buf: &mut [u8], read_start: usize) -> Result<usize> { + // Safety: this is safe because `raw_read_at` never stores uninitialized + // data into `as_uninit`. + let as_uninit: &mut [MaybeUninit<u8>] = + unsafe { from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) }; + self.raw_read_at(as_uninit, read_start).map(|s| s.len()) + } + + /// Read as much as possible from `offset` to `offset + buf.len()` out of + /// `self`, writing into `buf`. On success, returns the portion of `buf` + /// which was initialized by this call. + /// + /// If there's insufficient data in `self`, then the returned value will be + /// shorter than `buf`. + /// + /// See also [`Blob::read_at`], which takes a `&mut [u8]` buffer instead of + /// a slice of `MaybeUninit<u8>`. + /// + /// Note: This is part of the positional I/O API, and thus takes an absolute + /// position to read from, instead of using the internal position that can + /// be manipulated by the `std::io` traits. Consequently, it does not change + /// that value either. + #[inline] + pub fn raw_read_at<'a>( + &self, + buf: &'a mut [MaybeUninit<u8>], + read_start: usize, + ) -> Result<&'a mut [u8]> { + let len = self.len(); + + let read_len = match len.checked_sub(read_start) { + None | Some(0) => 0, + Some(v) => v.min(buf.len()), + }; + + if read_len == 0 { + // We could return `Ok(&mut [])`, but it seems confusing that the + // pointers don't match, so fabricate a empty slice of u8 with the + // same base pointer as `buf`. + let empty = unsafe { from_raw_parts_mut(buf.as_mut_ptr().cast::<u8>(), 0) }; + return Ok(empty); + } + + // At this point we believe `read_start as i32` is lossless because: + // + // 1. `len as i32` is known to be lossless, since it comes from a SQLite + // api returning an i32. + // + // 2. If we got here, `len.checked_sub(read_start)` was Some (or else + // we'd have hit the `if read_len == 0` early return), so `len` must + // be larger than `read_start`, and so it must fit in i32 as well. + debug_assert!(i32::try_from(read_start).is_ok()); + + // We also believe that `read_start + read_len <= len` because: + // + // 1. This is equivalent to `read_len <= len - read_start` via algebra. + // 2. We know that `read_len` is `min(len - read_start, buf.len())` + // 3. Expanding, this is `min(len - read_start, buf.len()) <= len - read_start`, + // or `min(A, B) <= A` which is clearly true. + // + // Note that this stuff is in debug_assert so no need to use checked_add + // and such -- we'll always panic on overflow in debug builds. + debug_assert!(read_start + read_len <= len); + + // These follow naturally. + debug_assert!(buf.len() >= read_len); + debug_assert!(i32::try_from(buf.len()).is_ok()); + debug_assert!(i32::try_from(read_len).is_ok()); + + unsafe { + self.conn.decode_result(ffi::sqlite3_blob_read( + self.blob, + buf.as_mut_ptr().cast(), + read_len as i32, + read_start as i32, + ))?; + + Ok(from_raw_parts_mut(buf.as_mut_ptr().cast::<u8>(), read_len)) + } + } + + /// Equivalent to [`Blob::read_at`], but returns a `BlobSizeError` if `buf` + /// is not fully initialized. + #[inline] + pub fn read_at_exact(&self, buf: &mut [u8], read_start: usize) -> Result<()> { + let n = self.read_at(buf, read_start)?; + if n != buf.len() { + Err(Error::BlobSizeError) + } else { + Ok(()) + } + } + + /// Equivalent to [`Blob::raw_read_at`], but returns a `BlobSizeError` if + /// `buf` is not fully initialized. + #[inline] + pub fn raw_read_at_exact<'a>( + &self, + buf: &'a mut [MaybeUninit<u8>], + read_start: usize, + ) -> Result<&'a mut [u8]> { + let buflen = buf.len(); + let initted = self.raw_read_at(buf, read_start)?; + if initted.len() != buflen { + Err(Error::BlobSizeError) + } else { + Ok(initted) + } + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, DatabaseName, Result}; + // to ensure we don't modify seek pos + use std::io::Seek as _; + + #[test] + fn test_pos_io() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE test_table(content BLOB);")?; + db.execute("INSERT INTO test_table(content) VALUES (ZEROBLOB(10))", [])?; + + let rowid = db.last_insert_rowid(); + let mut blob = db.blob_open(DatabaseName::Main, "test_table", "content", rowid, false)?; + // modify the seek pos to ensure we aren't using it or modifying it. + blob.seek(std::io::SeekFrom::Start(1)).unwrap(); + + let one2ten: [u8; 10] = [1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + blob.write_at(&one2ten, 0).unwrap(); + + let mut s = [0u8; 10]; + blob.read_at_exact(&mut s, 0).unwrap(); + assert_eq!(&s, &one2ten, "write should go through"); + blob.read_at_exact(&mut s, 1).unwrap_err(); + + blob.read_at_exact(&mut s, 0).unwrap(); + assert_eq!(&s, &one2ten, "should be unchanged"); + + let mut fives = [0u8; 5]; + blob.read_at_exact(&mut fives, 0).unwrap(); + assert_eq!(&fives, &[1u8, 2, 3, 4, 5]); + + blob.read_at_exact(&mut fives, 5).unwrap(); + assert_eq!(&fives, &[6u8, 7, 8, 9, 10]); + blob.read_at_exact(&mut fives, 7).unwrap_err(); + blob.read_at_exact(&mut fives, 12).unwrap_err(); + blob.read_at_exact(&mut fives, 10).unwrap_err(); + blob.read_at_exact(&mut fives, i32::MAX as usize) + .unwrap_err(); + blob.read_at_exact(&mut fives, i32::MAX as usize + 1) + .unwrap_err(); + + // zero length writes are fine if in bounds + blob.read_at_exact(&mut [], 10).unwrap(); + blob.read_at_exact(&mut [], 0).unwrap(); + blob.read_at_exact(&mut [], 5).unwrap(); + + blob.write_all_at(&[16, 17, 18, 19, 20], 5).unwrap(); + blob.read_at_exact(&mut s, 0).unwrap(); + assert_eq!(&s, &[1u8, 2, 3, 4, 5, 16, 17, 18, 19, 20]); + + blob.write_at(&[100, 99, 98, 97, 96], 6).unwrap_err(); + blob.write_at(&[100, 99, 98, 97, 96], i32::MAX as usize) + .unwrap_err(); + blob.write_at(&[100, 99, 98, 97, 96], i32::MAX as usize + 1) + .unwrap_err(); + + blob.read_at_exact(&mut s, 0).unwrap(); + assert_eq!(&s, &[1u8, 2, 3, 4, 5, 16, 17, 18, 19, 20]); + + let mut s2: [std::mem::MaybeUninit<u8>; 10] = [std::mem::MaybeUninit::uninit(); 10]; + { + let read = blob.raw_read_at_exact(&mut s2, 0).unwrap(); + assert_eq!(read, &s); + assert!(std::ptr::eq(read.as_ptr(), s2.as_ptr().cast())); + } + + let mut empty = []; + assert!(std::ptr::eq( + blob.raw_read_at_exact(&mut empty, 0).unwrap().as_ptr(), + empty.as_ptr().cast(), + )); + blob.raw_read_at_exact(&mut s2, 5).unwrap_err(); + + let end_pos = blob.stream_position().unwrap(); + assert_eq!(end_pos, 1); + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/busy.rs b/third_party/rust/rusqlite/src/busy.rs new file mode 100644 index 0000000000..18fa7e2907 --- /dev/null +++ b/third_party/rust/rusqlite/src/busy.rs @@ -0,0 +1,170 @@ +//! Busy handler (when the database is locked) +use std::convert::TryInto; +use std::mem; +use std::os::raw::{c_int, c_void}; +use std::panic::catch_unwind; +use std::ptr; +use std::time::Duration; + +use crate::ffi; +use crate::{Connection, InnerConnection, Result}; + +impl Connection { + /// Set a busy handler that sleeps for a specified amount of time when a + /// table is locked. The handler will sleep multiple times until at + /// least "ms" milliseconds of sleeping have accumulated. + /// + /// Calling this routine with an argument equal to zero turns off all busy + /// handlers. + /// + /// There can only be a single busy handler for a particular database + /// connection at any given moment. If another busy handler was defined + /// (using [`busy_handler`](Connection::busy_handler)) prior to calling this + /// routine, that other busy handler is cleared. + /// + /// Newly created connections currently have a default busy timeout of + /// 5000ms, but this may be subject to change. + pub fn busy_timeout(&self, timeout: Duration) -> Result<()> { + let ms: i32 = timeout + .as_secs() + .checked_mul(1000) + .and_then(|t| t.checked_add(timeout.subsec_millis().into())) + .and_then(|t| t.try_into().ok()) + .expect("too big"); + self.db.borrow_mut().busy_timeout(ms) + } + + /// Register a callback to handle `SQLITE_BUSY` errors. + /// + /// If the busy callback is `None`, then `SQLITE_BUSY` is returned + /// immediately upon encountering the lock. The argument to the busy + /// handler callback is the number of times that the + /// busy handler has been invoked previously for the + /// same locking event. If the busy callback returns `false`, then no + /// additional attempts are made to access the + /// database and `SQLITE_BUSY` is returned to the + /// application. If the callback returns `true`, then another attempt + /// is made to access the database and the cycle repeats. + /// + /// There can only be a single busy handler defined for each database + /// connection. Setting a new busy handler clears any previously set + /// handler. Note that calling [`busy_timeout()`](Connection::busy_timeout) + /// or evaluating `PRAGMA busy_timeout=N` will change the busy handler + /// and thus clear any previously set busy handler. + /// + /// Newly created connections default to a + /// [`busy_timeout()`](Connection::busy_timeout) handler with a timeout + /// of 5000ms, although this is subject to change. + pub fn busy_handler(&self, callback: Option<fn(i32) -> bool>) -> Result<()> { + unsafe extern "C" fn busy_handler_callback(p_arg: *mut c_void, count: c_int) -> c_int { + let handler_fn: fn(i32) -> bool = mem::transmute(p_arg); + c_int::from(catch_unwind(|| handler_fn(count)).unwrap_or_default()) + } + let c = self.db.borrow_mut(); + let r = match callback { + Some(f) => unsafe { + ffi::sqlite3_busy_handler(c.db(), Some(busy_handler_callback), f as *mut c_void) + }, + None => unsafe { ffi::sqlite3_busy_handler(c.db(), None, ptr::null_mut()) }, + }; + c.decode_result(r) + } +} + +impl InnerConnection { + #[inline] + fn busy_timeout(&mut self, timeout: c_int) -> Result<()> { + let r = unsafe { ffi::sqlite3_busy_timeout(self.db, timeout) }; + self.decode_result(r) + } +} + +#[cfg(test)] +mod test { + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::mpsc::sync_channel; + use std::thread; + use std::time::Duration; + + use crate::{Connection, ErrorCode, Result, TransactionBehavior}; + + #[test] + fn test_default_busy() -> Result<()> { + let temp_dir = tempfile::tempdir().unwrap(); + let path = temp_dir.path().join("test.db3"); + + let mut db1 = Connection::open(&path)?; + let tx1 = db1.transaction_with_behavior(TransactionBehavior::Exclusive)?; + let db2 = Connection::open(&path)?; + let r: Result<()> = db2.query_row("PRAGMA schema_version", [], |_| unreachable!()); + assert_eq!( + r.unwrap_err().sqlite_error_code(), + Some(ErrorCode::DatabaseBusy) + ); + tx1.rollback() + } + + #[test] + #[ignore] // FIXME: unstable + fn test_busy_timeout() { + let temp_dir = tempfile::tempdir().unwrap(); + let path = temp_dir.path().join("test.db3"); + + let db2 = Connection::open(&path).unwrap(); + db2.busy_timeout(Duration::from_secs(1)).unwrap(); + + let (rx, tx) = sync_channel(0); + let child = thread::spawn(move || { + let mut db1 = Connection::open(&path).unwrap(); + let tx1 = db1 + .transaction_with_behavior(TransactionBehavior::Exclusive) + .unwrap(); + rx.send(1).unwrap(); + thread::sleep(Duration::from_millis(100)); + tx1.rollback().unwrap(); + }); + + assert_eq!(tx.recv().unwrap(), 1); + let _ = db2 + .query_row("PRAGMA schema_version", [], |row| row.get::<_, i32>(0)) + .expect("unexpected error"); + + child.join().unwrap(); + } + + #[test] + #[ignore] // FIXME: unstable + fn test_busy_handler() { + static CALLED: AtomicBool = AtomicBool::new(false); + fn busy_handler(_: i32) -> bool { + CALLED.store(true, Ordering::Relaxed); + thread::sleep(Duration::from_millis(100)); + true + } + + let temp_dir = tempfile::tempdir().unwrap(); + let path = temp_dir.path().join("test.db3"); + + let db2 = Connection::open(&path).unwrap(); + db2.busy_handler(Some(busy_handler)).unwrap(); + + let (rx, tx) = sync_channel(0); + let child = thread::spawn(move || { + let mut db1 = Connection::open(&path).unwrap(); + let tx1 = db1 + .transaction_with_behavior(TransactionBehavior::Exclusive) + .unwrap(); + rx.send(1).unwrap(); + thread::sleep(Duration::from_millis(100)); + tx1.rollback().unwrap(); + }); + + assert_eq!(tx.recv().unwrap(), 1); + let _ = db2 + .query_row("PRAGMA schema_version", [], |row| row.get::<_, i32>(0)) + .expect("unexpected error"); + assert!(CALLED.load(Ordering::Relaxed)); + + child.join().unwrap(); + } +} diff --git a/third_party/rust/rusqlite/src/cache.rs b/third_party/rust/rusqlite/src/cache.rs new file mode 100644 index 0000000000..05ddb875f9 --- /dev/null +++ b/third_party/rust/rusqlite/src/cache.rs @@ -0,0 +1,350 @@ +//! Prepared statements cache for faster execution. + +use crate::raw_statement::RawStatement; +use crate::{Connection, PrepFlags, Result, Statement}; +use hashlink::LruCache; +use std::cell::RefCell; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; + +impl Connection { + /// Prepare a SQL statement for execution, returning a previously prepared + /// (but not currently in-use) statement if one is available. The + /// returned statement will be cached for reuse by future calls to + /// [`prepare_cached`](Connection::prepare_cached) once it is dropped. + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn insert_new_people(conn: &Connection) -> Result<()> { + /// { + /// let mut stmt = conn.prepare_cached("INSERT INTO People (name) VALUES (?1)")?; + /// stmt.execute(["Joe Smith"])?; + /// } + /// { + /// // This will return the same underlying SQLite statement handle without + /// // having to prepare it again. + /// let mut stmt = conn.prepare_cached("INSERT INTO People (name) VALUES (?1)")?; + /// stmt.execute(["Bob Jones"])?; + /// } + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + #[inline] + pub fn prepare_cached(&self, sql: &str) -> Result<CachedStatement<'_>> { + self.cache.get(self, sql) + } + + /// Set the maximum number of cached prepared statements this connection + /// will hold. By default, a connection will hold a relatively small + /// number of cached statements. If you need more, or know that you + /// will not use cached statements, you + /// can set the capacity manually using this method. + #[inline] + pub fn set_prepared_statement_cache_capacity(&self, capacity: usize) { + self.cache.set_capacity(capacity); + } + + /// Remove/finalize all prepared statements currently in the cache. + #[inline] + pub fn flush_prepared_statement_cache(&self) { + self.cache.flush(); + } +} + +/// Prepared statements LRU cache. +// #[derive(Debug)] // FIXME: https://github.com/kyren/hashlink/pull/4 +pub struct StatementCache(RefCell<LruCache<Arc<str>, RawStatement>>); + +#[allow(clippy::non_send_fields_in_send_ty)] +unsafe impl Send for StatementCache {} + +/// Cacheable statement. +/// +/// Statement will return automatically to the cache by default. +/// If you want the statement to be discarded, call +/// [`discard()`](CachedStatement::discard) on it. +pub struct CachedStatement<'conn> { + stmt: Option<Statement<'conn>>, + cache: &'conn StatementCache, +} + +impl<'conn> Deref for CachedStatement<'conn> { + type Target = Statement<'conn>; + + #[inline] + fn deref(&self) -> &Statement<'conn> { + self.stmt.as_ref().unwrap() + } +} + +impl<'conn> DerefMut for CachedStatement<'conn> { + #[inline] + fn deref_mut(&mut self) -> &mut Statement<'conn> { + self.stmt.as_mut().unwrap() + } +} + +impl Drop for CachedStatement<'_> { + #[allow(unused_must_use)] + #[inline] + fn drop(&mut self) { + if let Some(stmt) = self.stmt.take() { + self.cache.cache_stmt(unsafe { stmt.into_raw() }); + } + } +} + +impl CachedStatement<'_> { + #[inline] + fn new<'conn>(stmt: Statement<'conn>, cache: &'conn StatementCache) -> CachedStatement<'conn> { + CachedStatement { + stmt: Some(stmt), + cache, + } + } + + /// Discard the statement, preventing it from being returned to its + /// [`Connection`]'s collection of cached statements. + #[inline] + pub fn discard(mut self) { + self.stmt = None; + } +} + +impl StatementCache { + /// Create a statement cache. + #[inline] + pub fn with_capacity(capacity: usize) -> StatementCache { + StatementCache(RefCell::new(LruCache::new(capacity))) + } + + #[inline] + fn set_capacity(&self, capacity: usize) { + self.0.borrow_mut().set_capacity(capacity); + } + + // Search the cache for a prepared-statement object that implements `sql`. + // If no such prepared-statement can be found, allocate and prepare a new one. + // + // # Failure + // + // Will return `Err` if no cached statement can be found and the underlying + // SQLite prepare call fails. + fn get<'conn>( + &'conn self, + conn: &'conn Connection, + sql: &str, + ) -> Result<CachedStatement<'conn>> { + let trimmed = sql.trim(); + let mut cache = self.0.borrow_mut(); + let stmt = match cache.remove(trimmed) { + Some(raw_stmt) => Ok(Statement::new(conn, raw_stmt)), + None => conn.prepare_with_flags(trimmed, PrepFlags::SQLITE_PREPARE_PERSISTENT), + }; + stmt.map(|mut stmt| { + stmt.stmt.set_statement_cache_key(trimmed); + CachedStatement::new(stmt, self) + }) + } + + // Return a statement to the cache. + fn cache_stmt(&self, stmt: RawStatement) { + if stmt.is_null() { + return; + } + let mut cache = self.0.borrow_mut(); + stmt.clear_bindings(); + if let Some(sql) = stmt.statement_cache_key() { + cache.insert(sql, stmt); + } else { + debug_assert!( + false, + "bug in statement cache code, statement returned to cache that without key" + ); + } + } + + #[inline] + fn flush(&self) { + let mut cache = self.0.borrow_mut(); + cache.clear(); + } +} + +#[cfg(test)] +mod test { + use super::StatementCache; + use crate::{Connection, Result}; + use fallible_iterator::FallibleIterator; + + impl StatementCache { + fn clear(&self) { + self.0.borrow_mut().clear(); + } + + fn len(&self) -> usize { + self.0.borrow().len() + } + + fn capacity(&self) -> usize { + self.0.borrow().capacity() + } + } + + #[test] + fn test_cache() -> Result<()> { + let db = Connection::open_in_memory()?; + let cache = &db.cache; + let initial_capacity = cache.capacity(); + assert_eq!(0, cache.len()); + assert!(initial_capacity > 0); + + let sql = "PRAGMA schema_version"; + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(0, cache.len()); + assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?); + } + assert_eq!(1, cache.len()); + + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(0, cache.len()); + assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?); + } + assert_eq!(1, cache.len()); + + cache.clear(); + assert_eq!(0, cache.len()); + assert_eq!(initial_capacity, cache.capacity()); + Ok(()) + } + + #[test] + fn test_set_capacity() -> Result<()> { + let db = Connection::open_in_memory()?; + let cache = &db.cache; + + let sql = "PRAGMA schema_version"; + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(0, cache.len()); + assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?); + } + assert_eq!(1, cache.len()); + + db.set_prepared_statement_cache_capacity(0); + assert_eq!(0, cache.len()); + + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(0, cache.len()); + assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?); + } + assert_eq!(0, cache.len()); + + db.set_prepared_statement_cache_capacity(8); + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(0, cache.len()); + assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?); + } + assert_eq!(1, cache.len()); + Ok(()) + } + + #[test] + fn test_discard() -> Result<()> { + let db = Connection::open_in_memory()?; + let cache = &db.cache; + + let sql = "PRAGMA schema_version"; + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(0, cache.len()); + assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?); + stmt.discard(); + } + assert_eq!(0, cache.len()); + Ok(()) + } + + #[test] + fn test_ddl() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch( + r#" + CREATE TABLE foo (x INT); + INSERT INTO foo VALUES (1); + "#, + )?; + + let sql = "SELECT * FROM foo"; + + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(Ok(Some(1i32)), stmt.query([])?.map(|r| r.get(0)).next()); + } + + db.execute_batch( + r#" + ALTER TABLE foo ADD COLUMN y INT; + UPDATE foo SET y = 2; + "#, + )?; + + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!( + Ok(Some((1i32, 2i32))), + stmt.query([])?.map(|r| Ok((r.get(0)?, r.get(1)?))).next() + ); + } + Ok(()) + } + + #[test] + fn test_connection_close() -> Result<()> { + let conn = Connection::open_in_memory()?; + conn.prepare_cached("SELECT * FROM sqlite_master;")?; + + conn.close().expect("connection not closed"); + Ok(()) + } + + #[test] + fn test_cache_key() -> Result<()> { + let db = Connection::open_in_memory()?; + let cache = &db.cache; + assert_eq!(0, cache.len()); + + //let sql = " PRAGMA schema_version; -- comment"; + let sql = "PRAGMA schema_version; "; + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(0, cache.len()); + assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?); + } + assert_eq!(1, cache.len()); + + { + let mut stmt = db.prepare_cached(sql)?; + assert_eq!(0, cache.len()); + assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?); + } + assert_eq!(1, cache.len()); + Ok(()) + } + + #[test] + fn test_empty_stmt() -> Result<()> { + let conn = Connection::open_in_memory()?; + conn.prepare_cached("")?; + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/collation.rs b/third_party/rust/rusqlite/src/collation.rs new file mode 100644 index 0000000000..c1fe3f7837 --- /dev/null +++ b/third_party/rust/rusqlite/src/collation.rs @@ -0,0 +1,215 @@ +//! Add, remove, or modify a collation +use std::cmp::Ordering; +use std::os::raw::{c_char, c_int, c_void}; +use std::panic::{catch_unwind, UnwindSafe}; +use std::ptr; +use std::slice; + +use crate::ffi; +use crate::{str_to_cstring, Connection, InnerConnection, Result}; + +// FIXME copy/paste from function.rs +unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) { + drop(Box::from_raw(p.cast::<T>())); +} + +impl Connection { + /// Add or modify a collation. + #[inline] + pub fn create_collation<C>(&self, collation_name: &str, x_compare: C) -> Result<()> + where + C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static, + { + self.db + .borrow_mut() + .create_collation(collation_name, x_compare) + } + + /// Collation needed callback + #[inline] + pub fn collation_needed( + &self, + x_coll_needed: fn(&Connection, &str) -> Result<()>, + ) -> Result<()> { + self.db.borrow_mut().collation_needed(x_coll_needed) + } + + /// Remove collation. + #[inline] + pub fn remove_collation(&self, collation_name: &str) -> Result<()> { + self.db.borrow_mut().remove_collation(collation_name) + } +} + +impl InnerConnection { + fn create_collation<C>(&mut self, collation_name: &str, x_compare: C) -> Result<()> + where + C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static, + { + unsafe extern "C" fn call_boxed_closure<C>( + arg1: *mut c_void, + arg2: c_int, + arg3: *const c_void, + arg4: c_int, + arg5: *const c_void, + ) -> c_int + where + C: Fn(&str, &str) -> Ordering, + { + let r = catch_unwind(|| { + let boxed_f: *mut C = arg1.cast::<C>(); + assert!(!boxed_f.is_null(), "Internal error - null function pointer"); + let s1 = { + let c_slice = slice::from_raw_parts(arg3.cast::<u8>(), arg2 as usize); + String::from_utf8_lossy(c_slice) + }; + let s2 = { + let c_slice = slice::from_raw_parts(arg5.cast::<u8>(), arg4 as usize); + String::from_utf8_lossy(c_slice) + }; + (*boxed_f)(s1.as_ref(), s2.as_ref()) + }); + let t = match r { + Err(_) => { + return -1; // FIXME How ? + } + Ok(r) => r, + }; + + match t { + Ordering::Less => -1, + Ordering::Equal => 0, + Ordering::Greater => 1, + } + } + + let boxed_f: *mut C = Box::into_raw(Box::new(x_compare)); + let c_name = str_to_cstring(collation_name)?; + let flags = ffi::SQLITE_UTF8; + let r = unsafe { + ffi::sqlite3_create_collation_v2( + self.db(), + c_name.as_ptr(), + flags, + boxed_f.cast::<c_void>(), + Some(call_boxed_closure::<C>), + Some(free_boxed_value::<C>), + ) + }; + let res = self.decode_result(r); + // The xDestroy callback is not called if the sqlite3_create_collation_v2() + // function fails. + if res.is_err() { + drop(unsafe { Box::from_raw(boxed_f) }); + } + res + } + + fn collation_needed( + &mut self, + x_coll_needed: fn(&Connection, &str) -> Result<()>, + ) -> Result<()> { + use std::mem; + #[allow(clippy::needless_return)] + unsafe extern "C" fn collation_needed_callback( + arg1: *mut c_void, + arg2: *mut ffi::sqlite3, + e_text_rep: c_int, + arg3: *const c_char, + ) { + use std::ffi::CStr; + use std::str; + + if e_text_rep != ffi::SQLITE_UTF8 { + // TODO: validate + return; + } + + let callback: fn(&Connection, &str) -> Result<()> = mem::transmute(arg1); + let res = catch_unwind(|| { + let conn = Connection::from_handle(arg2).unwrap(); + let collation_name = { + let c_slice = CStr::from_ptr(arg3).to_bytes(); + str::from_utf8(c_slice).expect("illegal collation sequence name") + }; + callback(&conn, collation_name) + }); + if res.is_err() { + return; // FIXME How ? + } + } + + let r = unsafe { + ffi::sqlite3_collation_needed( + self.db(), + x_coll_needed as *mut c_void, + Some(collation_needed_callback), + ) + }; + self.decode_result(r) + } + + #[inline] + fn remove_collation(&mut self, collation_name: &str) -> Result<()> { + let c_name = str_to_cstring(collation_name)?; + let r = unsafe { + ffi::sqlite3_create_collation_v2( + self.db(), + c_name.as_ptr(), + ffi::SQLITE_UTF8, + ptr::null_mut(), + None, + None, + ) + }; + self.decode_result(r) + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, Result}; + use fallible_streaming_iterator::FallibleStreamingIterator; + use std::cmp::Ordering; + use unicase::UniCase; + + fn unicase_compare(s1: &str, s2: &str) -> Ordering { + UniCase::new(s1).cmp(&UniCase::new(s2)) + } + + #[test] + fn test_unicase() -> Result<()> { + let db = Connection::open_in_memory()?; + + db.create_collation("unicase", unicase_compare)?; + + collate(db) + } + + fn collate(db: Connection) -> Result<()> { + db.execute_batch( + "CREATE TABLE foo (bar); + INSERT INTO foo (bar) VALUES ('Maße'); + INSERT INTO foo (bar) VALUES ('MASSE');", + )?; + let mut stmt = db.prepare("SELECT DISTINCT bar COLLATE unicase FROM foo ORDER BY 1")?; + let rows = stmt.query([])?; + assert_eq!(rows.count()?, 1); + Ok(()) + } + + fn collation_needed(db: &Connection, collation_name: &str) -> Result<()> { + if "unicase" == collation_name { + db.create_collation(collation_name, unicase_compare) + } else { + Ok(()) + } + } + + #[test] + fn test_collation_needed() -> Result<()> { + let db = Connection::open_in_memory()?; + db.collation_needed(collation_needed)?; + collate(db) + } +} diff --git a/third_party/rust/rusqlite/src/column.rs b/third_party/rust/rusqlite/src/column.rs new file mode 100644 index 0000000000..b18eb8a912 --- /dev/null +++ b/third_party/rust/rusqlite/src/column.rs @@ -0,0 +1,248 @@ +use std::str; + +use crate::{Error, Result, Statement}; + +/// Information about a column of a SQLite query. +#[cfg(feature = "column_decltype")] +#[cfg_attr(docsrs, doc(cfg(feature = "column_decltype")))] +#[derive(Debug)] +pub struct Column<'stmt> { + name: &'stmt str, + decl_type: Option<&'stmt str>, +} + +#[cfg(feature = "column_decltype")] +#[cfg_attr(docsrs, doc(cfg(feature = "column_decltype")))] +impl Column<'_> { + /// Returns the name of the column. + #[inline] + #[must_use] + pub fn name(&self) -> &str { + self.name + } + + /// Returns the type of the column (`None` for expression). + #[inline] + #[must_use] + pub fn decl_type(&self) -> Option<&str> { + self.decl_type + } +} + +impl Statement<'_> { + /// Get all the column names in the result set of the prepared statement. + /// + /// If associated DB schema can be altered concurrently, you should make + /// sure that current statement has already been stepped once before + /// calling this method. + pub fn column_names(&self) -> Vec<&str> { + let n = self.column_count(); + let mut cols = Vec::with_capacity(n); + for i in 0..n { + let s = self.column_name_unwrap(i); + cols.push(s); + } + cols + } + + /// Return the number of columns in the result set returned by the prepared + /// statement. + /// + /// If associated DB schema can be altered concurrently, you should make + /// sure that current statement has already been stepped once before + /// calling this method. + #[inline] + pub fn column_count(&self) -> usize { + self.stmt.column_count() + } + + /// Check that column name reference lifetime is limited: + /// https://www.sqlite.org/c3ref/column_name.html + /// > The returned string pointer is valid... + /// + /// `column_name` reference can become invalid if `stmt` is reprepared + /// (because of schema change) when `query_row` is called. So we assert + /// that a compilation error happens if this reference is kept alive: + /// ```compile_fail + /// use rusqlite::{Connection, Result}; + /// fn main() -> Result<()> { + /// let db = Connection::open_in_memory()?; + /// let mut stmt = db.prepare("SELECT 1 as x")?; + /// let column_name = stmt.column_name(0)?; + /// let x = stmt.query_row([], |r| r.get::<_, i64>(0))?; // E0502 + /// assert_eq!(1, x); + /// assert_eq!("x", column_name); + /// Ok(()) + /// } + /// ``` + #[inline] + pub(super) fn column_name_unwrap(&self, col: usize) -> &str { + // Just panic if the bounds are wrong for now, we never call this + // without checking first. + self.column_name(col).expect("Column out of bounds") + } + + /// Returns the name assigned to a particular column in the result set + /// returned by the prepared statement. + /// + /// If associated DB schema can be altered concurrently, you should make + /// sure that current statement has already been stepped once before + /// calling this method. + /// + /// ## Failure + /// + /// Returns an `Error::InvalidColumnIndex` if `idx` is outside the valid + /// column range for this row. + /// + /// # Panics + /// + /// Panics when column name is not valid UTF-8. + #[inline] + pub fn column_name(&self, col: usize) -> Result<&str> { + self.stmt + .column_name(col) + // clippy::or_fun_call (nightly) vs clippy::unnecessary-lazy-evaluations (stable) + .ok_or(Error::InvalidColumnIndex(col)) + .map(|slice| { + str::from_utf8(slice.to_bytes()).expect("Invalid UTF-8 sequence in column name") + }) + } + + /// Returns the column index in the result set for a given column name. + /// + /// If there is no AS clause then the name of the column is unspecified and + /// may change from one release of SQLite to the next. + /// + /// If associated DB schema can be altered concurrently, you should make + /// sure that current statement has already been stepped once before + /// calling this method. + /// + /// # Failure + /// + /// Will return an `Error::InvalidColumnName` when there is no column with + /// the specified `name`. + #[inline] + pub fn column_index(&self, name: &str) -> Result<usize> { + let bytes = name.as_bytes(); + let n = self.column_count(); + for i in 0..n { + // Note: `column_name` is only fallible if `i` is out of bounds, + // which we've already checked. + if bytes.eq_ignore_ascii_case(self.stmt.column_name(i).unwrap().to_bytes()) { + return Ok(i); + } + } + Err(Error::InvalidColumnName(String::from(name))) + } + + /// Returns a slice describing the columns of the result of the query. + /// + /// If associated DB schema can be altered concurrently, you should make + /// sure that current statement has already been stepped once before + /// calling this method. + #[cfg(feature = "column_decltype")] + #[cfg_attr(docsrs, doc(cfg(feature = "column_decltype")))] + pub fn columns(&self) -> Vec<Column> { + let n = self.column_count(); + let mut cols = Vec::with_capacity(n); + for i in 0..n { + let name = self.column_name_unwrap(i); + let slice = self.stmt.column_decltype(i); + let decl_type = slice.map(|s| { + str::from_utf8(s.to_bytes()).expect("Invalid UTF-8 sequence in column declaration") + }); + cols.push(Column { name, decl_type }); + } + cols + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, Result}; + + #[test] + #[cfg(feature = "column_decltype")] + fn test_columns() -> Result<()> { + use super::Column; + + let db = Connection::open_in_memory()?; + let query = db.prepare("SELECT * FROM sqlite_master")?; + let columns = query.columns(); + let column_names: Vec<&str> = columns.iter().map(Column::name).collect(); + assert_eq!( + column_names.as_slice(), + &["type", "name", "tbl_name", "rootpage", "sql"] + ); + let column_types: Vec<Option<String>> = columns + .iter() + .map(|col| col.decl_type().map(str::to_lowercase)) + .collect(); + assert_eq!( + &column_types[..3], + &[ + Some("text".to_owned()), + Some("text".to_owned()), + Some("text".to_owned()), + ] + ); + Ok(()) + } + + #[test] + fn test_column_name_in_error() -> Result<()> { + use crate::{types::Type, Error}; + let db = Connection::open_in_memory()?; + db.execute_batch( + "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, NULL); + END;", + )?; + let mut stmt = db.prepare("SELECT x as renamed, y FROM foo")?; + let mut rows = stmt.query([])?; + let row = rows.next()?.unwrap(); + match row.get::<_, String>(0).unwrap_err() { + Error::InvalidColumnType(idx, name, ty) => { + assert_eq!(idx, 0); + assert_eq!(name, "renamed"); + assert_eq!(ty, Type::Integer); + } + e => { + panic!("Unexpected error type: {e:?}"); + } + } + match row.get::<_, String>("y").unwrap_err() { + Error::InvalidColumnType(idx, name, ty) => { + assert_eq!(idx, 1); + assert_eq!(name, "y"); + assert_eq!(ty, Type::Null); + } + e => { + panic!("Unexpected error type: {e:?}"); + } + } + Ok(()) + } + + /// `column_name` reference should stay valid until `stmt` is reprepared (or + /// reset) even if DB schema is altered (SQLite documentation is + /// ambiguous here because it says reference "is valid until (...) the next + /// call to sqlite3_column_name() or sqlite3_column_name16() on the same + /// column.". We assume that reference is valid if only + /// `sqlite3_column_name()` is used): + #[test] + #[cfg(feature = "modern_sqlite")] + fn test_column_name_reference() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE y (x);")?; + let stmt = db.prepare("SELECT x FROM y;")?; + let column_name = stmt.column_name(0)?; + assert_eq!("x", column_name); + db.execute_batch("ALTER TABLE y RENAME COLUMN x TO z;")?; + // column name is not refreshed until statement is re-prepared + let same_column_name = stmt.column_name(0)?; + assert_eq!(same_column_name, column_name); + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/config.rs b/third_party/rust/rusqlite/src/config.rs new file mode 100644 index 0000000000..194ee59639 --- /dev/null +++ b/third_party/rust/rusqlite/src/config.rs @@ -0,0 +1,163 @@ +//! Configure database connections + +use std::os::raw::c_int; + +use crate::error::check; +use crate::ffi; +use crate::{Connection, Result}; + +/// Database Connection Configuration Options +/// See [Database Connection Configuration Options](https://sqlite.org/c3ref/c_dbconfig_enable_fkey.html) for details. +#[repr(i32)] +#[allow(non_snake_case, non_camel_case_types)] +#[non_exhaustive] +#[allow(clippy::upper_case_acronyms)] +pub enum DbConfig { + //SQLITE_DBCONFIG_MAINDBNAME = 1000, /* const char* */ + //SQLITE_DBCONFIG_LOOKASIDE = 1001, /* void* int int */ + /// Enable or disable the enforcement of foreign key constraints. + SQLITE_DBCONFIG_ENABLE_FKEY = ffi::SQLITE_DBCONFIG_ENABLE_FKEY, + /// Enable or disable triggers. + SQLITE_DBCONFIG_ENABLE_TRIGGER = ffi::SQLITE_DBCONFIG_ENABLE_TRIGGER, + /// Enable or disable the fts3_tokenizer() function which is part of the + /// FTS3 full-text search engine extension. + SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER = ffi::SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER, // 3.12.0 + //SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION = 1005, + /// In WAL mode, enable or disable the checkpoint operation before closing + /// the connection. + SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE = 1006, // 3.16.2 + /// Activates or deactivates the query planner stability guarantee (QPSG). + SQLITE_DBCONFIG_ENABLE_QPSG = 1007, // 3.20.0 + /// Includes or excludes output for any operations performed by trigger + /// programs from the output of EXPLAIN QUERY PLAN commands. + SQLITE_DBCONFIG_TRIGGER_EQP = 1008, // 3.22.0 + /// Activates or deactivates the "reset" flag for a database connection. + /// Run VACUUM with this flag set to reset the database. + SQLITE_DBCONFIG_RESET_DATABASE = 1009, // 3.24.0 + /// Activates or deactivates the "defensive" flag for a database connection. + SQLITE_DBCONFIG_DEFENSIVE = 1010, // 3.26.0 + /// Activates or deactivates the "writable_schema" flag. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_WRITABLE_SCHEMA = 1011, // 3.28.0 + /// Activates or deactivates the legacy behavior of the ALTER TABLE RENAME + /// command. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_LEGACY_ALTER_TABLE = 1012, // 3.29 + /// Activates or deactivates the legacy double-quoted string literal + /// misfeature for DML statements only. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_DQS_DML = 1013, // 3.29.0 + /// Activates or deactivates the legacy double-quoted string literal + /// misfeature for DDL statements. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_DQS_DDL = 1014, // 3.29.0 + /// Enable or disable views. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_ENABLE_VIEW = 1015, // 3.30.0 + /// Activates or deactivates the legacy file format flag. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_LEGACY_FILE_FORMAT = 1016, // 3.31.0 + /// Tells SQLite to assume that database schemas (the contents of the + /// sqlite_master tables) are untainted by malicious content. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_TRUSTED_SCHEMA = 1017, // 3.31.0 + /// Sets or clears a flag that enables collection of the + /// sqlite3_stmt_scanstatus_v2() statistics + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_STMT_SCANSTATUS = 1018, // 3.42.0 + /// Changes the default order in which tables and indexes are scanned + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_REVERSE_SCANORDER = 1019, // 3.42.0 +} + +impl Connection { + /// Returns the current value of a `config`. + /// + /// - `SQLITE_DBCONFIG_ENABLE_FKEY`: return `false` or `true` to indicate + /// whether FK enforcement is off or on + /// - `SQLITE_DBCONFIG_ENABLE_TRIGGER`: return `false` or `true` to indicate + /// whether triggers are disabled or enabled + /// - `SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER`: return `false` or `true` to + /// indicate whether `fts3_tokenizer` are disabled or enabled + /// - `SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE`: return `false` to indicate + /// checkpoints-on-close are not disabled or `true` if they are + /// - `SQLITE_DBCONFIG_ENABLE_QPSG`: return `false` or `true` to indicate + /// whether the QPSG is disabled or enabled + /// - `SQLITE_DBCONFIG_TRIGGER_EQP`: return `false` to indicate + /// output-for-trigger are not disabled or `true` if it is + #[inline] + pub fn db_config(&self, config: DbConfig) -> Result<bool> { + let c = self.db.borrow(); + unsafe { + let mut val = 0; + check(ffi::sqlite3_db_config( + c.db(), + config as c_int, + -1, + &mut val, + ))?; + Ok(val != 0) + } + } + + /// Make configuration changes to a database connection + /// + /// - `SQLITE_DBCONFIG_ENABLE_FKEY`: `false` to disable FK enforcement, + /// `true` to enable FK enforcement + /// - `SQLITE_DBCONFIG_ENABLE_TRIGGER`: `false` to disable triggers, `true` + /// to enable triggers + /// - `SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER`: `false` to disable + /// `fts3_tokenizer()`, `true` to enable `fts3_tokenizer()` + /// - `SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE`: `false` (the default) to enable + /// checkpoints-on-close, `true` to disable them + /// - `SQLITE_DBCONFIG_ENABLE_QPSG`: `false` to disable the QPSG, `true` to + /// enable QPSG + /// - `SQLITE_DBCONFIG_TRIGGER_EQP`: `false` to disable output for trigger + /// programs, `true` to enable it + #[inline] + pub fn set_db_config(&self, config: DbConfig, new_val: bool) -> Result<bool> { + let c = self.db.borrow_mut(); + unsafe { + let mut val = 0; + check(ffi::sqlite3_db_config( + c.db(), + config as c_int, + new_val as c_int, + &mut val, + ))?; + Ok(val != 0) + } + } +} + +#[cfg(test)] +mod test { + use super::DbConfig; + use crate::{Connection, Result}; + + #[test] + fn test_db_config() -> Result<()> { + let db = Connection::open_in_memory()?; + + let opposite = !db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_FKEY)?; + assert_eq!( + db.set_db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_FKEY, opposite), + Ok(opposite) + ); + assert_eq!( + db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_FKEY), + Ok(opposite) + ); + + let opposite = !db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_TRIGGER)?; + assert_eq!( + db.set_db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_TRIGGER, opposite), + Ok(opposite) + ); + assert_eq!( + db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_TRIGGER), + Ok(opposite) + ); + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/context.rs b/third_party/rust/rusqlite/src/context.rs new file mode 100644 index 0000000000..bcaefc9395 --- /dev/null +++ b/third_party/rust/rusqlite/src/context.rs @@ -0,0 +1,75 @@ +//! Code related to `sqlite3_context` common to `functions` and `vtab` modules. + +use std::os::raw::{c_int, c_void}; +#[cfg(feature = "array")] +use std::rc::Rc; + +use crate::ffi; +use crate::ffi::sqlite3_context; + +use crate::str_for_sqlite; +use crate::types::{ToSqlOutput, ValueRef}; +#[cfg(feature = "array")] +use crate::vtab::array::{free_array, ARRAY_TYPE}; + +// This function is inline despite it's size because what's in the ToSqlOutput +// is often known to the compiler, and thus const prop/DCE can substantially +// simplify the function. +#[inline] +pub(super) unsafe fn set_result(ctx: *mut sqlite3_context, result: &ToSqlOutput<'_>) { + let value = match *result { + ToSqlOutput::Borrowed(v) => v, + ToSqlOutput::Owned(ref v) => ValueRef::from(v), + + #[cfg(feature = "blob")] + ToSqlOutput::ZeroBlob(len) => { + // TODO sqlite3_result_zeroblob64 // 3.8.11 + return ffi::sqlite3_result_zeroblob(ctx, len); + } + #[cfg(feature = "array")] + ToSqlOutput::Array(ref a) => { + return ffi::sqlite3_result_pointer( + ctx, + Rc::into_raw(a.clone()) as *mut c_void, + ARRAY_TYPE, + Some(free_array), + ); + } + }; + + match value { + ValueRef::Null => ffi::sqlite3_result_null(ctx), + ValueRef::Integer(i) => ffi::sqlite3_result_int64(ctx, i), + ValueRef::Real(r) => ffi::sqlite3_result_double(ctx, r), + ValueRef::Text(s) => { + let length = s.len(); + if length > c_int::MAX as usize { + ffi::sqlite3_result_error_toobig(ctx); + } else { + let (c_str, len, destructor) = match str_for_sqlite(s) { + Ok(c_str) => c_str, + // TODO sqlite3_result_error + Err(_) => return ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_MISUSE), + }; + // TODO sqlite3_result_text64 // 3.8.7 + ffi::sqlite3_result_text(ctx, c_str, len, destructor); + } + } + ValueRef::Blob(b) => { + let length = b.len(); + if length > c_int::MAX as usize { + ffi::sqlite3_result_error_toobig(ctx); + } else if length == 0 { + ffi::sqlite3_result_zeroblob(ctx, 0); + } else { + // TODO sqlite3_result_blob64 // 3.8.7 + ffi::sqlite3_result_blob( + ctx, + b.as_ptr().cast::<c_void>(), + length as c_int, + ffi::SQLITE_TRANSIENT(), + ); + } + } + } +} diff --git a/third_party/rust/rusqlite/src/error.rs b/third_party/rust/rusqlite/src/error.rs new file mode 100644 index 0000000000..69aaf2ef21 --- /dev/null +++ b/third_party/rust/rusqlite/src/error.rs @@ -0,0 +1,476 @@ +use crate::types::FromSqlError; +use crate::types::Type; +use crate::{errmsg_to_string, ffi, Result}; +use std::error; +use std::fmt; +use std::os::raw::c_int; +use std::path::PathBuf; +use std::str; + +/// Enum listing possible errors from rusqlite. +#[derive(Debug)] +#[allow(clippy::enum_variant_names)] +#[non_exhaustive] +pub enum Error { + /// An error from an underlying SQLite call. + SqliteFailure(ffi::Error, Option<String>), + + /// Error reported when attempting to open a connection when SQLite was + /// configured to allow single-threaded use only. + SqliteSingleThreadedMode, + + /// Error when the value of a particular column is requested, but it cannot + /// be converted to the requested Rust type. + FromSqlConversionFailure(usize, Type, Box<dyn error::Error + Send + Sync + 'static>), + + /// Error when SQLite gives us an integral value outside the range of the + /// requested type (e.g., trying to get the value 1000 into a `u8`). + /// The associated `usize` is the column index, + /// and the associated `i64` is the value returned by SQLite. + IntegralValueOutOfRange(usize, i64), + + /// Error converting a string to UTF-8. + Utf8Error(str::Utf8Error), + + /// Error converting a string to a C-compatible string because it contained + /// an embedded nul. + NulError(std::ffi::NulError), + + /// Error when using SQL named parameters and passing a parameter name not + /// present in the SQL. + InvalidParameterName(String), + + /// Error converting a file path to a string. + InvalidPath(PathBuf), + + /// Error returned when an [`execute`](crate::Connection::execute) call + /// returns rows. + ExecuteReturnedResults, + + /// Error when a query that was expected to return at least one row (e.g., + /// for [`query_row`](crate::Connection::query_row)) did not return any. + QueryReturnedNoRows, + + /// Error when the value of a particular column is requested, but the index + /// is out of range for the statement. + InvalidColumnIndex(usize), + + /// Error when the value of a named column is requested, but no column + /// matches the name for the statement. + InvalidColumnName(String), + + /// Error when the value of a particular column is requested, but the type + /// of the result in that column cannot be converted to the requested + /// Rust type. + InvalidColumnType(usize, String, Type), + + /// Error when a query that was expected to insert one row did not insert + /// any or insert many. + StatementChangedRows(usize), + + /// Error returned by + /// [`functions::Context::get`](crate::functions::Context::get) when the + /// function argument cannot be converted to the requested type. + #[cfg(feature = "functions")] + #[cfg_attr(docsrs, doc(cfg(feature = "functions")))] + InvalidFunctionParameterType(usize, Type), + /// Error returned by [`vtab::Values::get`](crate::vtab::Values::get) when + /// the filter argument cannot be converted to the requested type. + #[cfg(feature = "vtab")] + #[cfg_attr(docsrs, doc(cfg(feature = "vtab")))] + InvalidFilterParameterType(usize, Type), + + /// An error case available for implementors of custom user functions (e.g., + /// [`create_scalar_function`](crate::Connection::create_scalar_function)). + #[cfg(feature = "functions")] + #[cfg_attr(docsrs, doc(cfg(feature = "functions")))] + #[allow(dead_code)] + UserFunctionError(Box<dyn error::Error + Send + Sync + 'static>), + + /// Error available for the implementors of the + /// [`ToSql`](crate::types::ToSql) trait. + ToSqlConversionFailure(Box<dyn error::Error + Send + Sync + 'static>), + + /// Error when the SQL is not a `SELECT`, is not read-only. + InvalidQuery, + + /// An error case available for implementors of custom modules (e.g., + /// [`create_module`](crate::Connection::create_module)). + #[cfg(feature = "vtab")] + #[cfg_attr(docsrs, doc(cfg(feature = "vtab")))] + #[allow(dead_code)] + ModuleError(String), + + /// An unwinding panic occurs in an UDF (user-defined function). + #[cfg(feature = "functions")] + #[cfg_attr(docsrs, doc(cfg(feature = "functions")))] + UnwindingPanic, + + /// An error returned when + /// [`Context::get_aux`](crate::functions::Context::get_aux) attempts to + /// retrieve data of a different type than what had been stored using + /// [`Context::set_aux`](crate::functions::Context::set_aux). + #[cfg(feature = "functions")] + #[cfg_attr(docsrs, doc(cfg(feature = "functions")))] + GetAuxWrongType, + + /// Error when the SQL contains multiple statements. + MultipleStatement, + /// Error when the number of bound parameters does not match the number of + /// parameters in the query. The first `usize` is how many parameters were + /// given, the 2nd is how many were expected. + InvalidParameterCount(usize, usize), + + /// Returned from various functions in the Blob IO positional API. For + /// example, + /// [`Blob::raw_read_at_exact`](crate::blob::Blob::raw_read_at_exact) will + /// return it if the blob has insufficient data. + #[cfg(feature = "blob")] + #[cfg_attr(docsrs, doc(cfg(feature = "blob")))] + BlobSizeError, + /// Error referencing a specific token in the input SQL + #[cfg(feature = "modern_sqlite")] // 3.38.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + SqlInputError { + /// error code + error: ffi::Error, + /// error message + msg: String, + /// SQL input + sql: String, + /// byte offset of the start of invalid token + offset: c_int, + }, + /// Loadable extension initialization error + #[cfg(feature = "loadable_extension")] + #[cfg_attr(docsrs, doc(cfg(feature = "loadable_extension")))] + InitError(ffi::InitError), +} + +impl PartialEq for Error { + fn eq(&self, other: &Error) -> bool { + match (self, other) { + (Error::SqliteFailure(e1, s1), Error::SqliteFailure(e2, s2)) => e1 == e2 && s1 == s2, + (Error::SqliteSingleThreadedMode, Error::SqliteSingleThreadedMode) => true, + (Error::IntegralValueOutOfRange(i1, n1), Error::IntegralValueOutOfRange(i2, n2)) => { + i1 == i2 && n1 == n2 + } + (Error::Utf8Error(e1), Error::Utf8Error(e2)) => e1 == e2, + (Error::NulError(e1), Error::NulError(e2)) => e1 == e2, + (Error::InvalidParameterName(n1), Error::InvalidParameterName(n2)) => n1 == n2, + (Error::InvalidPath(p1), Error::InvalidPath(p2)) => p1 == p2, + (Error::ExecuteReturnedResults, Error::ExecuteReturnedResults) => true, + (Error::QueryReturnedNoRows, Error::QueryReturnedNoRows) => true, + (Error::InvalidColumnIndex(i1), Error::InvalidColumnIndex(i2)) => i1 == i2, + (Error::InvalidColumnName(n1), Error::InvalidColumnName(n2)) => n1 == n2, + (Error::InvalidColumnType(i1, n1, t1), Error::InvalidColumnType(i2, n2, t2)) => { + i1 == i2 && t1 == t2 && n1 == n2 + } + (Error::StatementChangedRows(n1), Error::StatementChangedRows(n2)) => n1 == n2, + #[cfg(feature = "functions")] + ( + Error::InvalidFunctionParameterType(i1, t1), + Error::InvalidFunctionParameterType(i2, t2), + ) => i1 == i2 && t1 == t2, + #[cfg(feature = "vtab")] + ( + Error::InvalidFilterParameterType(i1, t1), + Error::InvalidFilterParameterType(i2, t2), + ) => i1 == i2 && t1 == t2, + (Error::InvalidQuery, Error::InvalidQuery) => true, + #[cfg(feature = "vtab")] + (Error::ModuleError(s1), Error::ModuleError(s2)) => s1 == s2, + #[cfg(feature = "functions")] + (Error::UnwindingPanic, Error::UnwindingPanic) => true, + #[cfg(feature = "functions")] + (Error::GetAuxWrongType, Error::GetAuxWrongType) => true, + (Error::InvalidParameterCount(i1, n1), Error::InvalidParameterCount(i2, n2)) => { + i1 == i2 && n1 == n2 + } + #[cfg(feature = "blob")] + (Error::BlobSizeError, Error::BlobSizeError) => true, + #[cfg(feature = "modern_sqlite")] + ( + Error::SqlInputError { + error: e1, + msg: m1, + sql: s1, + offset: o1, + }, + Error::SqlInputError { + error: e2, + msg: m2, + sql: s2, + offset: o2, + }, + ) => e1 == e2 && m1 == m2 && s1 == s2 && o1 == o2, + #[cfg(feature = "loadable_extension")] + (Error::InitError(e1), Error::InitError(e2)) => e1 == e2, + (..) => false, + } + } +} + +impl From<str::Utf8Error> for Error { + #[cold] + fn from(err: str::Utf8Error) -> Error { + Error::Utf8Error(err) + } +} + +impl From<std::ffi::NulError> for Error { + #[cold] + fn from(err: std::ffi::NulError) -> Error { + Error::NulError(err) + } +} + +const UNKNOWN_COLUMN: usize = usize::MAX; + +/// The conversion isn't precise, but it's convenient to have it +/// to allow use of `get_raw(…).as_…()?` in callbacks that take `Error`. +impl From<FromSqlError> for Error { + #[cold] + fn from(err: FromSqlError) -> Error { + // The error type requires index and type fields, but they aren't known in this + // context. + match err { + FromSqlError::OutOfRange(val) => Error::IntegralValueOutOfRange(UNKNOWN_COLUMN, val), + FromSqlError::InvalidBlobSize { .. } => { + Error::FromSqlConversionFailure(UNKNOWN_COLUMN, Type::Blob, Box::new(err)) + } + FromSqlError::Other(source) => { + Error::FromSqlConversionFailure(UNKNOWN_COLUMN, Type::Null, source) + } + _ => Error::FromSqlConversionFailure(UNKNOWN_COLUMN, Type::Null, Box::new(err)), + } + } +} + +#[cfg(feature = "loadable_extension")] +impl From<ffi::InitError> for Error { + #[cold] + fn from(err: ffi::InitError) -> Error { + Error::InitError(err) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Error::SqliteFailure(ref err, None) => err.fmt(f), + Error::SqliteFailure(_, Some(ref s)) => write!(f, "{s}"), + Error::SqliteSingleThreadedMode => write!( + f, + "SQLite was compiled or configured for single-threaded use only" + ), + Error::FromSqlConversionFailure(i, ref t, ref err) => { + if i != UNKNOWN_COLUMN { + write!(f, "Conversion error from type {t} at index: {i}, {err}") + } else { + err.fmt(f) + } + } + Error::IntegralValueOutOfRange(col, val) => { + if col != UNKNOWN_COLUMN { + write!(f, "Integer {val} out of range at index {col}") + } else { + write!(f, "Integer {val} out of range") + } + } + Error::Utf8Error(ref err) => err.fmt(f), + Error::NulError(ref err) => err.fmt(f), + Error::InvalidParameterName(ref name) => write!(f, "Invalid parameter name: {name}"), + Error::InvalidPath(ref p) => write!(f, "Invalid path: {}", p.to_string_lossy()), + Error::ExecuteReturnedResults => { + write!(f, "Execute returned results - did you mean to call query?") + } + Error::QueryReturnedNoRows => write!(f, "Query returned no rows"), + Error::InvalidColumnIndex(i) => write!(f, "Invalid column index: {i}"), + Error::InvalidColumnName(ref name) => write!(f, "Invalid column name: {name}"), + Error::InvalidColumnType(i, ref name, ref t) => { + write!(f, "Invalid column type {t} at index: {i}, name: {name}") + } + Error::InvalidParameterCount(i1, n1) => write!( + f, + "Wrong number of parameters passed to query. Got {i1}, needed {n1}" + ), + Error::StatementChangedRows(i) => write!(f, "Query changed {i} rows"), + + #[cfg(feature = "functions")] + Error::InvalidFunctionParameterType(i, ref t) => { + write!(f, "Invalid function parameter type {t} at index {i}") + } + #[cfg(feature = "vtab")] + Error::InvalidFilterParameterType(i, ref t) => { + write!(f, "Invalid filter parameter type {t} at index {i}") + } + #[cfg(feature = "functions")] + Error::UserFunctionError(ref err) => err.fmt(f), + Error::ToSqlConversionFailure(ref err) => err.fmt(f), + Error::InvalidQuery => write!(f, "Query is not read-only"), + #[cfg(feature = "vtab")] + Error::ModuleError(ref desc) => write!(f, "{desc}"), + #[cfg(feature = "functions")] + Error::UnwindingPanic => write!(f, "unwinding panic"), + #[cfg(feature = "functions")] + Error::GetAuxWrongType => write!(f, "get_aux called with wrong type"), + Error::MultipleStatement => write!(f, "Multiple statements provided"), + #[cfg(feature = "blob")] + Error::BlobSizeError => "Blob size is insufficient".fmt(f), + #[cfg(feature = "modern_sqlite")] + Error::SqlInputError { + ref msg, + offset, + ref sql, + .. + } => write!(f, "{msg} in {sql} at offset {offset}"), + #[cfg(feature = "loadable_extension")] + Error::InitError(ref err) => err.fmt(f), + } + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match *self { + Error::SqliteFailure(ref err, _) => Some(err), + Error::Utf8Error(ref err) => Some(err), + Error::NulError(ref err) => Some(err), + + Error::IntegralValueOutOfRange(..) + | Error::SqliteSingleThreadedMode + | Error::InvalidParameterName(_) + | Error::ExecuteReturnedResults + | Error::QueryReturnedNoRows + | Error::InvalidColumnIndex(_) + | Error::InvalidColumnName(_) + | Error::InvalidColumnType(..) + | Error::InvalidPath(_) + | Error::InvalidParameterCount(..) + | Error::StatementChangedRows(_) + | Error::InvalidQuery + | Error::MultipleStatement => None, + + #[cfg(feature = "functions")] + Error::InvalidFunctionParameterType(..) => None, + #[cfg(feature = "vtab")] + Error::InvalidFilterParameterType(..) => None, + + #[cfg(feature = "functions")] + Error::UserFunctionError(ref err) => Some(&**err), + + Error::FromSqlConversionFailure(_, _, ref err) + | Error::ToSqlConversionFailure(ref err) => Some(&**err), + + #[cfg(feature = "vtab")] + Error::ModuleError(_) => None, + + #[cfg(feature = "functions")] + Error::UnwindingPanic => None, + + #[cfg(feature = "functions")] + Error::GetAuxWrongType => None, + + #[cfg(feature = "blob")] + Error::BlobSizeError => None, + #[cfg(feature = "modern_sqlite")] + Error::SqlInputError { ref error, .. } => Some(error), + #[cfg(feature = "loadable_extension")] + Error::InitError(ref err) => Some(err), + } + } +} + +impl Error { + /// Returns the underlying SQLite error if this is [`Error::SqliteFailure`]. + #[inline] + #[must_use] + pub fn sqlite_error(&self) -> Option<&ffi::Error> { + match self { + Self::SqliteFailure(error, _) => Some(error), + _ => None, + } + } + + /// Returns the underlying SQLite error code if this is + /// [`Error::SqliteFailure`]. + #[inline] + #[must_use] + pub fn sqlite_error_code(&self) -> Option<ffi::ErrorCode> { + self.sqlite_error().map(|error| error.code) + } +} + +// These are public but not re-exported by lib.rs, so only visible within crate. + +#[cold] +pub fn error_from_sqlite_code(code: c_int, message: Option<String>) -> Error { + Error::SqliteFailure(ffi::Error::new(code), message) +} + +#[cold] +pub unsafe fn error_from_handle(db: *mut ffi::sqlite3, code: c_int) -> Error { + let message = if db.is_null() { + None + } else { + Some(errmsg_to_string(ffi::sqlite3_errmsg(db))) + }; + error_from_sqlite_code(code, message) +} + +#[cold] +#[cfg(not(feature = "modern_sqlite"))] // SQLite >= 3.38.0 +pub unsafe fn error_with_offset(db: *mut ffi::sqlite3, code: c_int, _sql: &str) -> Error { + error_from_handle(db, code) +} + +#[cold] +#[cfg(feature = "modern_sqlite")] // SQLite >= 3.38.0 +pub unsafe fn error_with_offset(db: *mut ffi::sqlite3, code: c_int, sql: &str) -> Error { + if db.is_null() { + error_from_sqlite_code(code, None) + } else { + let error = ffi::Error::new(code); + let msg = errmsg_to_string(ffi::sqlite3_errmsg(db)); + if ffi::ErrorCode::Unknown == error.code { + let offset = ffi::sqlite3_error_offset(db); + if offset >= 0 { + return Error::SqlInputError { + error, + msg, + sql: sql.to_owned(), + offset, + }; + } + } + Error::SqliteFailure(error, Some(msg)) + } +} + +pub fn check(code: c_int) -> Result<()> { + if code != crate::ffi::SQLITE_OK { + Err(error_from_sqlite_code(code, None)) + } else { + Ok(()) + } +} + +/// Transform Rust error to SQLite error (message and code). +/// # Safety +/// This function is unsafe because it uses raw pointer +pub unsafe fn to_sqlite_error(e: &Error, err_msg: *mut *mut std::os::raw::c_char) -> c_int { + use crate::util::alloc; + match e { + Error::SqliteFailure(err, s) => { + if let Some(s) = s { + *err_msg = alloc(s); + } + err.extended_code + } + err => { + *err_msg = alloc(&err.to_string()); + ffi::SQLITE_ERROR + } + } +} diff --git a/third_party/rust/rusqlite/src/functions.rs b/third_party/rust/rusqlite/src/functions.rs new file mode 100644 index 0000000000..522f1167d0 --- /dev/null +++ b/third_party/rust/rusqlite/src/functions.rs @@ -0,0 +1,1071 @@ +//! Create or redefine SQL functions. +//! +//! # Example +//! +//! Adding a `regexp` function to a connection in which compiled regular +//! expressions are cached in a `HashMap`. For an alternative implementation +//! that uses SQLite's [Function Auxiliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface +//! to avoid recompiling regular expressions, see the unit tests for this +//! module. +//! +//! ```rust +//! use regex::Regex; +//! use rusqlite::functions::FunctionFlags; +//! use rusqlite::{Connection, Error, Result}; +//! use std::sync::Arc; +//! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>; +//! +//! fn add_regexp_function(db: &Connection) -> Result<()> { +//! db.create_scalar_function( +//! "regexp", +//! 2, +//! FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, +//! move |ctx| { +//! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); +//! let regexp: Arc<Regex> = ctx.get_or_create_aux(0, |vr| -> Result<_, BoxError> { +//! Ok(Regex::new(vr.as_str()?)?) +//! })?; +//! let is_match = { +//! let text = ctx +//! .get_raw(1) +//! .as_str() +//! .map_err(|e| Error::UserFunctionError(e.into()))?; +//! +//! regexp.is_match(text) +//! }; +//! +//! Ok(is_match) +//! }, +//! ) +//! } +//! +//! fn main() -> Result<()> { +//! let db = Connection::open_in_memory()?; +//! add_regexp_function(&db)?; +//! +//! let is_match: bool = +//! db.query_row("SELECT regexp('[aeiou]*', 'aaaaeeeiii')", [], |row| { +//! row.get(0) +//! })?; +//! +//! assert!(is_match); +//! Ok(()) +//! } +//! ``` +use std::any::Any; +use std::marker::PhantomData; +use std::ops::Deref; +use std::os::raw::{c_int, c_void}; +use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe}; +use std::ptr; +use std::slice; +use std::sync::Arc; + +use crate::ffi; +use crate::ffi::sqlite3_context; +use crate::ffi::sqlite3_value; + +use crate::context::set_result; +use crate::types::{FromSql, FromSqlError, ToSql, ValueRef}; + +use crate::{str_to_cstring, Connection, Error, InnerConnection, Result}; + +unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) { + if let Error::SqliteFailure(ref err, ref s) = *err { + ffi::sqlite3_result_error_code(ctx, err.extended_code); + if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } else { + ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); + if let Ok(cstr) = str_to_cstring(&err.to_string()) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } +} + +unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) { + drop(Box::from_raw(p.cast::<T>())); +} + +/// Context is a wrapper for the SQLite function +/// evaluation context. +pub struct Context<'a> { + ctx: *mut sqlite3_context, + args: &'a [*mut sqlite3_value], +} + +impl Context<'_> { + /// Returns the number of arguments to the function. + #[inline] + #[must_use] + pub fn len(&self) -> usize { + self.args.len() + } + + /// Returns `true` when there is no argument. + #[inline] + #[must_use] + pub fn is_empty(&self) -> bool { + self.args.is_empty() + } + + /// Returns the `idx`th argument as a `T`. + /// + /// # Failure + /// + /// Will panic if `idx` is greater than or equal to + /// [`self.len()`](Context::len). + /// + /// Will return Err if the underlying SQLite type cannot be converted to a + /// `T`. + pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> { + let arg = self.args[idx]; + let value = unsafe { ValueRef::from_value(arg) }; + FromSql::column_result(value).map_err(|err| match err { + FromSqlError::InvalidType => { + Error::InvalidFunctionParameterType(idx, value.data_type()) + } + FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i), + FromSqlError::Other(err) => { + Error::FromSqlConversionFailure(idx, value.data_type(), err) + } + FromSqlError::InvalidBlobSize { .. } => { + Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err)) + } + }) + } + + /// Returns the `idx`th argument as a `ValueRef`. + /// + /// # Failure + /// + /// Will panic if `idx` is greater than or equal to + /// [`self.len()`](Context::len). + #[inline] + #[must_use] + pub fn get_raw(&self, idx: usize) -> ValueRef<'_> { + let arg = self.args[idx]; + unsafe { ValueRef::from_value(arg) } + } + + /// Returns the subtype of `idx`th argument. + /// + /// # Failure + /// + /// Will panic if `idx` is greater than or equal to + /// [`self.len()`](Context::len). + pub fn get_subtype(&self, idx: usize) -> std::os::raw::c_uint { + let arg = self.args[idx]; + unsafe { ffi::sqlite3_value_subtype(arg) } + } + + /// Fetch or insert the auxiliary data associated with a particular + /// parameter. This is intended to be an easier-to-use way of fetching it + /// compared to calling [`get_aux`](Context::get_aux) and + /// [`set_aux`](Context::set_aux) separately. + /// + /// See `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of + /// this feature, or the unit tests of this module for an example. + pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>> + where + T: Send + Sync + 'static, + E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>, + F: FnOnce(ValueRef<'_>) -> Result<T, E>, + { + if let Some(v) = self.get_aux(arg)? { + Ok(v) + } else { + let vr = self.get_raw(arg as usize); + self.set_aux( + arg, + func(vr).map_err(|e| Error::UserFunctionError(e.into()))?, + ) + } + } + + /// Sets the auxiliary data associated with a particular parameter. See + /// `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of + /// this feature, or the unit tests of this module for an example. + pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> { + let orig: Arc<T> = Arc::new(value); + let inner: AuxInner = orig.clone(); + let outer = Box::new(inner); + let raw: *mut AuxInner = Box::into_raw(outer); + unsafe { + ffi::sqlite3_set_auxdata( + self.ctx, + arg, + raw.cast(), + Some(free_boxed_value::<AuxInner>), + ); + }; + Ok(orig) + } + + /// Gets the auxiliary data that was associated with a given parameter via + /// [`set_aux`](Context::set_aux). Returns `Ok(None)` if no data has been + /// associated, and Ok(Some(v)) if it has. Returns an error if the + /// requested type does not match. + pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> { + let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner }; + if p.is_null() { + Ok(None) + } else { + let v: AuxInner = AuxInner::clone(unsafe { &*p }); + v.downcast::<T>() + .map(Some) + .map_err(|_| Error::GetAuxWrongType) + } + } + + /// Get the db connection handle via [sqlite3_context_db_handle](https://www.sqlite.org/c3ref/context_db_handle.html) + /// + /// # Safety + /// + /// This function is marked unsafe because there is a potential for other + /// references to the connection to be sent across threads, [see this comment](https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213). + pub unsafe fn get_connection(&self) -> Result<ConnectionRef<'_>> { + let handle = ffi::sqlite3_context_db_handle(self.ctx); + Ok(ConnectionRef { + conn: Connection::from_handle(handle)?, + phantom: PhantomData, + }) + } + + /// Set the Subtype of an SQL function + pub fn set_result_subtype(&self, sub_type: std::os::raw::c_uint) { + unsafe { ffi::sqlite3_result_subtype(self.ctx, sub_type) }; + } +} + +/// A reference to a connection handle with a lifetime bound to something. +pub struct ConnectionRef<'ctx> { + // comes from Connection::from_handle(sqlite3_context_db_handle(...)) + // and is non-owning + conn: Connection, + phantom: PhantomData<&'ctx Context<'ctx>>, +} + +impl Deref for ConnectionRef<'_> { + type Target = Connection; + + #[inline] + fn deref(&self) -> &Connection { + &self.conn + } +} + +type AuxInner = Arc<dyn Any + Send + Sync + 'static>; + +/// Aggregate is the callback interface for user-defined +/// aggregate function. +/// +/// `A` is the type of the aggregation context and `T` is the type of the final +/// result. Implementations should be stateless. +pub trait Aggregate<A, T> +where + A: RefUnwindSafe + UnwindSafe, + T: ToSql, +{ + /// Initializes the aggregation context. Will be called prior to the first + /// call to [`step()`](Aggregate::step) to set up the context for an + /// invocation of the function. (Note: `init()` will not be called if + /// there are no rows.) + fn init(&self, ctx: &mut Context<'_>) -> Result<A>; + + /// "step" function called once for each row in an aggregate group. May be + /// called 0 times if there are no rows. + fn step(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>; + + /// Computes and returns the final result. Will be called exactly once for + /// each invocation of the function. If [`step()`](Aggregate::step) was + /// called at least once, will be given `Some(A)` (the same `A` as was + /// created by [`init`](Aggregate::init) and given to + /// [`step`](Aggregate::step)); if [`step()`](Aggregate::step) was not + /// called (because the function is running against 0 rows), will be + /// given `None`. + /// + /// The passed context will have no arguments. + fn finalize(&self, ctx: &mut Context<'_>, acc: Option<A>) -> Result<T>; +} + +/// `WindowAggregate` is the callback interface for +/// user-defined aggregate window function. +#[cfg(feature = "window")] +#[cfg_attr(docsrs, doc(cfg(feature = "window")))] +pub trait WindowAggregate<A, T>: Aggregate<A, T> +where + A: RefUnwindSafe + UnwindSafe, + T: ToSql, +{ + /// Returns the current value of the aggregate. Unlike xFinal, the + /// implementation should not delete any context. + fn value(&self, acc: Option<&mut A>) -> Result<T>; + + /// Removes a row from the current window. + fn inverse(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>; +} + +bitflags::bitflags! { + /// Function Flags. + /// See [sqlite3_create_function](https://sqlite.org/c3ref/create_function.html) + /// and [Function Flags](https://sqlite.org/c3ref/c_deterministic.html) for details. + #[repr(C)] + pub struct FunctionFlags: ::std::os::raw::c_int { + /// Specifies UTF-8 as the text encoding this SQL function prefers for its parameters. + const SQLITE_UTF8 = ffi::SQLITE_UTF8; + /// Specifies UTF-16 using little-endian byte order as the text encoding this SQL function prefers for its parameters. + const SQLITE_UTF16LE = ffi::SQLITE_UTF16LE; + /// Specifies UTF-16 using big-endian byte order as the text encoding this SQL function prefers for its parameters. + const SQLITE_UTF16BE = ffi::SQLITE_UTF16BE; + /// Specifies UTF-16 using native byte order as the text encoding this SQL function prefers for its parameters. + const SQLITE_UTF16 = ffi::SQLITE_UTF16; + /// Means that the function always gives the same output when the input parameters are the same. + const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC; // 3.8.3 + /// Means that the function may only be invoked from top-level SQL. + const SQLITE_DIRECTONLY = 0x0000_0008_0000; // 3.30.0 + /// Indicates to SQLite that a function may call `sqlite3_value_subtype()` to inspect the sub-types of its arguments. + const SQLITE_SUBTYPE = 0x0000_0010_0000; // 3.30.0 + /// Means that the function is unlikely to cause problems even if misused. + const SQLITE_INNOCUOUS = 0x0000_0020_0000; // 3.31.0 + } +} + +impl Default for FunctionFlags { + #[inline] + fn default() -> FunctionFlags { + FunctionFlags::SQLITE_UTF8 + } +} + +impl Connection { + /// Attach a user-defined scalar function to + /// this database connection. + /// + /// `fn_name` is the name the function will be accessible from SQL. + /// `n_arg` is the number of arguments to the function. Use `-1` for a + /// variable number. If the function always returns the same value + /// given the same input, `deterministic` should be `true`. + /// + /// The function will remain available until the connection is closed or + /// until it is explicitly removed via + /// [`remove_function`](Connection::remove_function). + /// + /// # Example + /// + /// ```rust + /// # use rusqlite::{Connection, Result}; + /// # use rusqlite::functions::FunctionFlags; + /// fn scalar_function_example(db: Connection) -> Result<()> { + /// db.create_scalar_function( + /// "halve", + /// 1, + /// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + /// |ctx| { + /// let value = ctx.get::<f64>(0)?; + /// Ok(value / 2f64) + /// }, + /// )?; + /// + /// let six_halved: f64 = db.query_row("SELECT halve(6)", [], |r| r.get(0))?; + /// assert_eq!(six_halved, 3f64); + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return Err if the function could not be attached to the connection. + #[inline] + pub fn create_scalar_function<F, T>( + &self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + x_func: F, + ) -> Result<()> + where + F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, + T: ToSql, + { + self.db + .borrow_mut() + .create_scalar_function(fn_name, n_arg, flags, x_func) + } + + /// Attach a user-defined aggregate function to this + /// database connection. + /// + /// # Failure + /// + /// Will return Err if the function could not be attached to the connection. + #[inline] + pub fn create_aggregate_function<A, D, T>( + &self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + aggr: D, + ) -> Result<()> + where + A: RefUnwindSafe + UnwindSafe, + D: Aggregate<A, T> + 'static, + T: ToSql, + { + self.db + .borrow_mut() + .create_aggregate_function(fn_name, n_arg, flags, aggr) + } + + /// Attach a user-defined aggregate window function to + /// this database connection. + /// + /// See `https://sqlite.org/windowfunctions.html#udfwinfunc` for more + /// information. + #[cfg(feature = "window")] + #[cfg_attr(docsrs, doc(cfg(feature = "window")))] + #[inline] + pub fn create_window_function<A, W, T>( + &self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + aggr: W, + ) -> Result<()> + where + A: RefUnwindSafe + UnwindSafe, + W: WindowAggregate<A, T> + 'static, + T: ToSql, + { + self.db + .borrow_mut() + .create_window_function(fn_name, n_arg, flags, aggr) + } + + /// Removes a user-defined function from this + /// database connection. + /// + /// `fn_name` and `n_arg` should match the name and number of arguments + /// given to [`create_scalar_function`](Connection::create_scalar_function) + /// or [`create_aggregate_function`](Connection::create_aggregate_function). + /// + /// # Failure + /// + /// Will return Err if the function could not be removed. + #[inline] + pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()> { + self.db.borrow_mut().remove_function(fn_name, n_arg) + } +} + +impl InnerConnection { + fn create_scalar_function<F, T>( + &mut self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + x_func: F, + ) -> Result<()> + where + F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, + T: ToSql, + { + unsafe extern "C" fn call_boxed_closure<F, T>( + ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, + ) where + F: FnMut(&Context<'_>) -> Result<T>, + T: ToSql, + { + let r = catch_unwind(|| { + let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::<F>(); + assert!(!boxed_f.is_null(), "Internal error - null function pointer"); + let ctx = Context { + ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + (*boxed_f)(&ctx) + }); + let t = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + let t = t.as_ref().map(|t| ToSql::to_sql(t)); + + match t { + Ok(Ok(ref value)) => set_result(ctx, value), + Ok(Err(err)) => report_error(ctx, &err), + Err(err) => report_error(ctx, err), + } + } + + let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); + let c_name = str_to_cstring(fn_name)?; + let r = unsafe { + ffi::sqlite3_create_function_v2( + self.db(), + c_name.as_ptr(), + n_arg, + flags.bits(), + boxed_f.cast::<c_void>(), + Some(call_boxed_closure::<F, T>), + None, + None, + Some(free_boxed_value::<F>), + ) + }; + self.decode_result(r) + } + + fn create_aggregate_function<A, D, T>( + &mut self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + aggr: D, + ) -> Result<()> + where + A: RefUnwindSafe + UnwindSafe, + D: Aggregate<A, T> + 'static, + T: ToSql, + { + let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); + let c_name = str_to_cstring(fn_name)?; + let r = unsafe { + ffi::sqlite3_create_function_v2( + self.db(), + c_name.as_ptr(), + n_arg, + flags.bits(), + boxed_aggr.cast::<c_void>(), + None, + Some(call_boxed_step::<A, D, T>), + Some(call_boxed_final::<A, D, T>), + Some(free_boxed_value::<D>), + ) + }; + self.decode_result(r) + } + + #[cfg(feature = "window")] + fn create_window_function<A, W, T>( + &mut self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + aggr: W, + ) -> Result<()> + where + A: RefUnwindSafe + UnwindSafe, + W: WindowAggregate<A, T> + 'static, + T: ToSql, + { + let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr)); + let c_name = str_to_cstring(fn_name)?; + let r = unsafe { + ffi::sqlite3_create_window_function( + self.db(), + c_name.as_ptr(), + n_arg, + flags.bits(), + boxed_aggr.cast::<c_void>(), + Some(call_boxed_step::<A, W, T>), + Some(call_boxed_final::<A, W, T>), + Some(call_boxed_value::<A, W, T>), + Some(call_boxed_inverse::<A, W, T>), + Some(free_boxed_value::<W>), + ) + }; + self.decode_result(r) + } + + fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> { + let c_name = str_to_cstring(fn_name)?; + let r = unsafe { + ffi::sqlite3_create_function_v2( + self.db(), + c_name.as_ptr(), + n_arg, + ffi::SQLITE_UTF8, + ptr::null_mut(), + None, + None, + None, + None, + ) + }; + self.decode_result(r) + } +} + +unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> { + let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A; + if pac.is_null() { + return None; + } + Some(pac) +} + +unsafe extern "C" fn call_boxed_step<A, D, T>( + ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, +) where + A: RefUnwindSafe + UnwindSafe, + D: Aggregate<A, T>, + T: ToSql, +{ + let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) { + pac + } else { + ffi::sqlite3_result_error_nomem(ctx); + return; + }; + + let r = catch_unwind(|| { + let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>(); + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + let mut ctx = Context { + ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + + #[allow(clippy::unnecessary_cast)] + if (*pac as *mut A).is_null() { + *pac = Box::into_raw(Box::new((*boxed_aggr).init(&mut ctx)?)); + } + + (*boxed_aggr).step(&mut ctx, &mut **pac) + }); + let r = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + match r { + Ok(_) => {} + Err(err) => report_error(ctx, &err), + }; +} + +#[cfg(feature = "window")] +unsafe extern "C" fn call_boxed_inverse<A, W, T>( + ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, +) where + A: RefUnwindSafe + UnwindSafe, + W: WindowAggregate<A, T>, + T: ToSql, +{ + let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) { + pac + } else { + ffi::sqlite3_result_error_nomem(ctx); + return; + }; + + let r = catch_unwind(|| { + let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>(); + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + let mut ctx = Context { + ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + (*boxed_aggr).inverse(&mut ctx, &mut **pac) + }); + let r = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + match r { + Ok(_) => {} + Err(err) => report_error(ctx, &err), + }; +} + +unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context) +where + A: RefUnwindSafe + UnwindSafe, + D: Aggregate<A, T>, + T: ToSql, +{ + // Within the xFinal callback, it is customary to set N=0 in calls to + // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. + let a: Option<A> = match aggregate_context(ctx, 0) { + Some(pac) => + { + #[allow(clippy::unnecessary_cast)] + if (*pac as *mut A).is_null() { + None + } else { + let a = Box::from_raw(*pac); + Some(*a) + } + } + None => None, + }; + + let r = catch_unwind(|| { + let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>(); + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + let mut ctx = Context { ctx, args: &mut [] }; + (*boxed_aggr).finalize(&mut ctx, a) + }); + let t = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + let t = t.as_ref().map(|t| ToSql::to_sql(t)); + match t { + Ok(Ok(ref value)) => set_result(ctx, value), + Ok(Err(err)) => report_error(ctx, &err), + Err(err) => report_error(ctx, err), + } +} + +#[cfg(feature = "window")] +unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context) +where + A: RefUnwindSafe + UnwindSafe, + W: WindowAggregate<A, T>, + T: ToSql, +{ + // Within the xValue callback, it is customary to set N=0 in calls to + // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. + let pac = aggregate_context(ctx, 0).filter(|&pac| { + #[allow(clippy::unnecessary_cast)] + !(*pac as *mut A).is_null() + }); + + let r = catch_unwind(|| { + let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>(); + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + (*boxed_aggr).value(pac.map(|pac| &mut **pac)) + }); + let t = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + let t = t.as_ref().map(|t| ToSql::to_sql(t)); + match t { + Ok(Ok(ref value)) => set_result(ctx, value), + Ok(Err(err)) => report_error(ctx, &err), + Err(err) => report_error(ctx, err), + } +} + +#[cfg(test)] +mod test { + use regex::Regex; + use std::os::raw::c_double; + + #[cfg(feature = "window")] + use crate::functions::WindowAggregate; + use crate::functions::{Aggregate, Context, FunctionFlags}; + use crate::{Connection, Error, Result}; + + fn half(ctx: &Context<'_>) -> Result<c_double> { + assert_eq!(ctx.len(), 1, "called with unexpected number of arguments"); + let value = ctx.get::<c_double>(0)?; + Ok(value / 2f64) + } + + #[test] + fn test_function_half() -> Result<()> { + let db = Connection::open_in_memory()?; + db.create_scalar_function( + "half", + 1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + half, + )?; + let result: f64 = db.one_column("SELECT half(6)")?; + + assert!((3f64 - result).abs() < f64::EPSILON); + Ok(()) + } + + #[test] + fn test_remove_function() -> Result<()> { + let db = Connection::open_in_memory()?; + db.create_scalar_function( + "half", + 1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + half, + )?; + let result: f64 = db.one_column("SELECT half(6)")?; + assert!((3f64 - result).abs() < f64::EPSILON); + + db.remove_function("half", 1)?; + let result: Result<f64> = db.one_column("SELECT half(6)"); + result.unwrap_err(); + Ok(()) + } + + // This implementation of a regexp scalar function uses SQLite's auxiliary data + // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular + // expression multiple times within one query. + fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result<bool> { + assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); + type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>; + let regexp: std::sync::Arc<Regex> = ctx + .get_or_create_aux(0, |vr| -> Result<_, BoxError> { + Ok(Regex::new(vr.as_str()?)?) + })?; + + let is_match = { + let text = ctx + .get_raw(1) + .as_str() + .map_err(|e| Error::UserFunctionError(e.into()))?; + + regexp.is_match(text) + }; + + Ok(is_match) + } + + #[test] + fn test_function_regexp_with_auxiliary() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch( + "BEGIN; + CREATE TABLE foo (x string); + INSERT INTO foo VALUES ('lisa'); + INSERT INTO foo VALUES ('lXsi'); + INSERT INTO foo VALUES ('lisX'); + END;", + )?; + db.create_scalar_function( + "regexp", + 2, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + regexp_with_auxiliary, + )?; + + let result: bool = db.one_column("SELECT regexp('l.s[aeiouy]', 'lisa')")?; + + assert!(result); + + let result: i64 = + db.one_column("SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1")?; + + assert_eq!(2, result); + Ok(()) + } + + #[test] + fn test_varargs_function() -> Result<()> { + let db = Connection::open_in_memory()?; + db.create_scalar_function( + "my_concat", + -1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + |ctx| { + let mut ret = String::new(); + + for idx in 0..ctx.len() { + let s = ctx.get::<String>(idx)?; + ret.push_str(&s); + } + + Ok(ret) + }, + )?; + + for &(expected, query) in &[ + ("", "SELECT my_concat()"), + ("onetwo", "SELECT my_concat('one', 'two')"), + ("abc", "SELECT my_concat('a', 'b', 'c')"), + ] { + let result: String = db.one_column(query)?; + assert_eq!(expected, result); + } + Ok(()) + } + + #[test] + fn test_get_aux_type_checking() -> Result<()> { + let db = Connection::open_in_memory()?; + db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| { + if !ctx.get::<bool>(1)? { + ctx.set_aux::<i64>(0, 100)?; + } else { + assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType)); + assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100); + } + Ok(true) + })?; + + let res: bool = + db.one_column("SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)")?; + // Doesn't actually matter, we'll assert in the function if there's a problem. + assert!(res); + Ok(()) + } + + struct Sum; + struct Count; + + impl Aggregate<i64, Option<i64>> for Sum { + fn init(&self, _: &mut Context<'_>) -> Result<i64> { + Ok(0) + } + + fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> { + *sum += ctx.get::<i64>(0)?; + Ok(()) + } + + fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>> { + Ok(sum) + } + } + + impl Aggregate<i64, i64> for Count { + fn init(&self, _: &mut Context<'_>) -> Result<i64> { + Ok(0) + } + + fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> { + *sum += 1; + Ok(()) + } + + fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64> { + Ok(sum.unwrap_or(0)) + } + } + + #[test] + fn test_sum() -> Result<()> { + let db = Connection::open_in_memory()?; + db.create_aggregate_function( + "my_sum", + 1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + Sum, + )?; + + // sum should return NULL when given no columns (contrast with count below) + let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; + let result: Option<i64> = db.one_column(no_result)?; + assert!(result.is_none()); + + let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; + let result: i64 = db.one_column(single_sum)?; + assert_eq!(4, result); + + let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \ + 2, 1)"; + let result: (i64, i64) = db.query_row(dual_sum, [], |r| Ok((r.get(0)?, r.get(1)?)))?; + assert_eq!((4, 2), result); + Ok(()) + } + + #[test] + fn test_count() -> Result<()> { + let db = Connection::open_in_memory()?; + db.create_aggregate_function( + "my_count", + -1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + Count, + )?; + + // count should return 0 when given no columns (contrast with sum above) + let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; + let result: i64 = db.one_column(no_result)?; + assert_eq!(result, 0); + + let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; + let result: i64 = db.one_column(single_sum)?; + assert_eq!(2, result); + Ok(()) + } + + #[cfg(feature = "window")] + impl WindowAggregate<i64, Option<i64>> for Sum { + fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> { + *sum -= ctx.get::<i64>(0)?; + Ok(()) + } + + fn value(&self, sum: Option<&mut i64>) -> Result<Option<i64>> { + Ok(sum.copied()) + } + } + + #[test] + #[cfg(feature = "window")] + fn test_window() -> Result<()> { + use fallible_iterator::FallibleIterator; + + let db = Connection::open_in_memory()?; + db.create_window_function( + "sumint", + 1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + Sum, + )?; + db.execute_batch( + "CREATE TABLE t3(x, y); + INSERT INTO t3 VALUES('a', 4), + ('b', 5), + ('c', 3), + ('d', 8), + ('e', 1);", + )?; + + let mut stmt = db.prepare( + "SELECT x, sumint(y) OVER ( + ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING + ) AS sum_y + FROM t3 ORDER BY x;", + )?; + + let results: Vec<(String, i64)> = stmt + .query([])? + .map(|row| Ok((row.get("x")?, row.get("sum_y")?))) + .collect()?; + let expected = vec![ + ("a".to_owned(), 9), + ("b".to_owned(), 12), + ("c".to_owned(), 16), + ("d".to_owned(), 12), + ("e".to_owned(), 9), + ]; + assert_eq!(expected, results); + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/hooks.rs b/third_party/rust/rusqlite/src/hooks.rs new file mode 100644 index 0000000000..52a53a3886 --- /dev/null +++ b/third_party/rust/rusqlite/src/hooks.rs @@ -0,0 +1,806 @@ +//! Commit, Data Change and Rollback Notification Callbacks +#![allow(non_camel_case_types)] + +use std::os::raw::{c_char, c_int, c_void}; +use std::panic::{catch_unwind, RefUnwindSafe}; +use std::ptr; + +use crate::ffi; + +use crate::{Connection, InnerConnection}; + +/// Action Codes +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[repr(i32)] +#[non_exhaustive] +#[allow(clippy::upper_case_acronyms)] +pub enum Action { + /// Unsupported / unexpected action + UNKNOWN = -1, + /// DELETE command + SQLITE_DELETE = ffi::SQLITE_DELETE, + /// INSERT command + SQLITE_INSERT = ffi::SQLITE_INSERT, + /// UPDATE command + SQLITE_UPDATE = ffi::SQLITE_UPDATE, +} + +impl From<i32> for Action { + #[inline] + fn from(code: i32) -> Action { + match code { + ffi::SQLITE_DELETE => Action::SQLITE_DELETE, + ffi::SQLITE_INSERT => Action::SQLITE_INSERT, + ffi::SQLITE_UPDATE => Action::SQLITE_UPDATE, + _ => Action::UNKNOWN, + } + } +} + +/// The context received by an authorizer hook. +/// +/// See <https://sqlite.org/c3ref/set_authorizer.html> for more info. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct AuthContext<'c> { + /// The action to be authorized. + pub action: AuthAction<'c>, + + /// The database name, if applicable. + pub database_name: Option<&'c str>, + + /// The inner-most trigger or view responsible for the access attempt. + /// `None` if the access attempt was made by top-level SQL code. + pub accessor: Option<&'c str>, +} + +/// Actions and arguments found within a statement during +/// preparation. +/// +/// See <https://sqlite.org/c3ref/c_alter_table.html> for more info. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[non_exhaustive] +#[allow(missing_docs)] +pub enum AuthAction<'c> { + /// This variant is not normally produced by SQLite. You may encounter it + // if you're using a different version than what's supported by this library. + Unknown { + /// The unknown authorization action code. + code: i32, + /// The third arg to the authorizer callback. + arg1: Option<&'c str>, + /// The fourth arg to the authorizer callback. + arg2: Option<&'c str>, + }, + CreateIndex { + index_name: &'c str, + table_name: &'c str, + }, + CreateTable { + table_name: &'c str, + }, + CreateTempIndex { + index_name: &'c str, + table_name: &'c str, + }, + CreateTempTable { + table_name: &'c str, + }, + CreateTempTrigger { + trigger_name: &'c str, + table_name: &'c str, + }, + CreateTempView { + view_name: &'c str, + }, + CreateTrigger { + trigger_name: &'c str, + table_name: &'c str, + }, + CreateView { + view_name: &'c str, + }, + Delete { + table_name: &'c str, + }, + DropIndex { + index_name: &'c str, + table_name: &'c str, + }, + DropTable { + table_name: &'c str, + }, + DropTempIndex { + index_name: &'c str, + table_name: &'c str, + }, + DropTempTable { + table_name: &'c str, + }, + DropTempTrigger { + trigger_name: &'c str, + table_name: &'c str, + }, + DropTempView { + view_name: &'c str, + }, + DropTrigger { + trigger_name: &'c str, + table_name: &'c str, + }, + DropView { + view_name: &'c str, + }, + Insert { + table_name: &'c str, + }, + Pragma { + pragma_name: &'c str, + /// The pragma value, if present (e.g., `PRAGMA name = value;`). + pragma_value: Option<&'c str>, + }, + Read { + table_name: &'c str, + column_name: &'c str, + }, + Select, + Transaction { + operation: TransactionOperation, + }, + Update { + table_name: &'c str, + column_name: &'c str, + }, + Attach { + filename: &'c str, + }, + Detach { + database_name: &'c str, + }, + AlterTable { + database_name: &'c str, + table_name: &'c str, + }, + Reindex { + index_name: &'c str, + }, + Analyze { + table_name: &'c str, + }, + CreateVtable { + table_name: &'c str, + module_name: &'c str, + }, + DropVtable { + table_name: &'c str, + module_name: &'c str, + }, + Function { + function_name: &'c str, + }, + Savepoint { + operation: TransactionOperation, + savepoint_name: &'c str, + }, + Recursive, +} + +impl<'c> AuthAction<'c> { + fn from_raw(code: i32, arg1: Option<&'c str>, arg2: Option<&'c str>) -> Self { + match (code, arg1, arg2) { + (ffi::SQLITE_CREATE_INDEX, Some(index_name), Some(table_name)) => Self::CreateIndex { + index_name, + table_name, + }, + (ffi::SQLITE_CREATE_TABLE, Some(table_name), _) => Self::CreateTable { table_name }, + (ffi::SQLITE_CREATE_TEMP_INDEX, Some(index_name), Some(table_name)) => { + Self::CreateTempIndex { + index_name, + table_name, + } + } + (ffi::SQLITE_CREATE_TEMP_TABLE, Some(table_name), _) => { + Self::CreateTempTable { table_name } + } + (ffi::SQLITE_CREATE_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => { + Self::CreateTempTrigger { + trigger_name, + table_name, + } + } + (ffi::SQLITE_CREATE_TEMP_VIEW, Some(view_name), _) => { + Self::CreateTempView { view_name } + } + (ffi::SQLITE_CREATE_TRIGGER, Some(trigger_name), Some(table_name)) => { + Self::CreateTrigger { + trigger_name, + table_name, + } + } + (ffi::SQLITE_CREATE_VIEW, Some(view_name), _) => Self::CreateView { view_name }, + (ffi::SQLITE_DELETE, Some(table_name), None) => Self::Delete { table_name }, + (ffi::SQLITE_DROP_INDEX, Some(index_name), Some(table_name)) => Self::DropIndex { + index_name, + table_name, + }, + (ffi::SQLITE_DROP_TABLE, Some(table_name), _) => Self::DropTable { table_name }, + (ffi::SQLITE_DROP_TEMP_INDEX, Some(index_name), Some(table_name)) => { + Self::DropTempIndex { + index_name, + table_name, + } + } + (ffi::SQLITE_DROP_TEMP_TABLE, Some(table_name), _) => { + Self::DropTempTable { table_name } + } + (ffi::SQLITE_DROP_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => { + Self::DropTempTrigger { + trigger_name, + table_name, + } + } + (ffi::SQLITE_DROP_TEMP_VIEW, Some(view_name), _) => Self::DropTempView { view_name }, + (ffi::SQLITE_DROP_TRIGGER, Some(trigger_name), Some(table_name)) => Self::DropTrigger { + trigger_name, + table_name, + }, + (ffi::SQLITE_DROP_VIEW, Some(view_name), _) => Self::DropView { view_name }, + (ffi::SQLITE_INSERT, Some(table_name), _) => Self::Insert { table_name }, + (ffi::SQLITE_PRAGMA, Some(pragma_name), pragma_value) => Self::Pragma { + pragma_name, + pragma_value, + }, + (ffi::SQLITE_READ, Some(table_name), Some(column_name)) => Self::Read { + table_name, + column_name, + }, + (ffi::SQLITE_SELECT, ..) => Self::Select, + (ffi::SQLITE_TRANSACTION, Some(operation_str), _) => Self::Transaction { + operation: TransactionOperation::from_str(operation_str), + }, + (ffi::SQLITE_UPDATE, Some(table_name), Some(column_name)) => Self::Update { + table_name, + column_name, + }, + (ffi::SQLITE_ATTACH, Some(filename), _) => Self::Attach { filename }, + (ffi::SQLITE_DETACH, Some(database_name), _) => Self::Detach { database_name }, + (ffi::SQLITE_ALTER_TABLE, Some(database_name), Some(table_name)) => Self::AlterTable { + database_name, + table_name, + }, + (ffi::SQLITE_REINDEX, Some(index_name), _) => Self::Reindex { index_name }, + (ffi::SQLITE_ANALYZE, Some(table_name), _) => Self::Analyze { table_name }, + (ffi::SQLITE_CREATE_VTABLE, Some(table_name), Some(module_name)) => { + Self::CreateVtable { + table_name, + module_name, + } + } + (ffi::SQLITE_DROP_VTABLE, Some(table_name), Some(module_name)) => Self::DropVtable { + table_name, + module_name, + }, + (ffi::SQLITE_FUNCTION, _, Some(function_name)) => Self::Function { function_name }, + (ffi::SQLITE_SAVEPOINT, Some(operation_str), Some(savepoint_name)) => Self::Savepoint { + operation: TransactionOperation::from_str(operation_str), + savepoint_name, + }, + (ffi::SQLITE_RECURSIVE, ..) => Self::Recursive, + (code, arg1, arg2) => Self::Unknown { code, arg1, arg2 }, + } + } +} + +pub(crate) type BoxedAuthorizer = + Box<dyn for<'c> FnMut(AuthContext<'c>) -> Authorization + Send + 'static>; + +/// A transaction operation. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[non_exhaustive] +#[allow(missing_docs)] +pub enum TransactionOperation { + Unknown, + Begin, + Release, + Rollback, +} + +impl TransactionOperation { + fn from_str(op_str: &str) -> Self { + match op_str { + "BEGIN" => Self::Begin, + "RELEASE" => Self::Release, + "ROLLBACK" => Self::Rollback, + _ => Self::Unknown, + } + } +} + +/// [`authorizer`](Connection::authorizer) return code +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[non_exhaustive] +pub enum Authorization { + /// Authorize the action. + Allow, + /// Don't allow access, but don't trigger an error either. + Ignore, + /// Trigger an error. + Deny, +} + +impl Authorization { + fn into_raw(self) -> c_int { + match self { + Self::Allow => ffi::SQLITE_OK, + Self::Ignore => ffi::SQLITE_IGNORE, + Self::Deny => ffi::SQLITE_DENY, + } + } +} + +impl Connection { + /// Register a callback function to be invoked whenever + /// a transaction is committed. + /// + /// The callback returns `true` to rollback. + #[inline] + pub fn commit_hook<F>(&self, hook: Option<F>) + where + F: FnMut() -> bool + Send + 'static, + { + self.db.borrow_mut().commit_hook(hook); + } + + /// Register a callback function to be invoked whenever + /// a transaction is committed. + #[inline] + pub fn rollback_hook<F>(&self, hook: Option<F>) + where + F: FnMut() + Send + 'static, + { + self.db.borrow_mut().rollback_hook(hook); + } + + /// Register a callback function to be invoked whenever + /// a row is updated, inserted or deleted in a rowid table. + /// + /// The callback parameters are: + /// + /// - the type of database update (`SQLITE_INSERT`, `SQLITE_UPDATE` or + /// `SQLITE_DELETE`), + /// - the name of the database ("main", "temp", ...), + /// - the name of the table that is updated, + /// - the ROWID of the row that is updated. + #[inline] + pub fn update_hook<F>(&self, hook: Option<F>) + where + F: FnMut(Action, &str, &str, i64) + Send + 'static, + { + self.db.borrow_mut().update_hook(hook); + } + + /// Register a query progress callback. + /// + /// The parameter `num_ops` is the approximate number of virtual machine + /// instructions that are evaluated between successive invocations of the + /// `handler`. If `num_ops` is less than one then the progress handler + /// is disabled. + /// + /// If the progress callback returns `true`, the operation is interrupted. + pub fn progress_handler<F>(&self, num_ops: c_int, handler: Option<F>) + where + F: FnMut() -> bool + Send + RefUnwindSafe + 'static, + { + self.db.borrow_mut().progress_handler(num_ops, handler); + } + + /// Register an authorizer callback that's invoked + /// as a statement is being prepared. + #[inline] + pub fn authorizer<'c, F>(&self, hook: Option<F>) + where + F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + RefUnwindSafe + 'static, + { + self.db.borrow_mut().authorizer(hook); + } +} + +impl InnerConnection { + #[inline] + pub fn remove_hooks(&mut self) { + self.update_hook(None::<fn(Action, &str, &str, i64)>); + self.commit_hook(None::<fn() -> bool>); + self.rollback_hook(None::<fn()>); + self.progress_handler(0, None::<fn() -> bool>); + self.authorizer(None::<fn(AuthContext<'_>) -> Authorization>); + } + + fn commit_hook<F>(&mut self, hook: Option<F>) + where + F: FnMut() -> bool + Send + 'static, + { + unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int + where + F: FnMut() -> bool, + { + let r = catch_unwind(|| { + let boxed_hook: *mut F = p_arg.cast::<F>(); + (*boxed_hook)() + }); + c_int::from(r.unwrap_or_default()) + } + + // unlike `sqlite3_create_function_v2`, we cannot specify a `xDestroy` with + // `sqlite3_commit_hook`. so we keep the `xDestroy` function in + // `InnerConnection.free_boxed_hook`. + let free_commit_hook = if hook.is_some() { + Some(free_boxed_hook::<F> as unsafe fn(*mut c_void)) + } else { + None + }; + + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_commit_hook( + self.db(), + Some(call_boxed_closure::<F>), + boxed_hook.cast(), + ) + } + } + _ => unsafe { ffi::sqlite3_commit_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_commit_hook { + unsafe { free_boxed_hook(previous_hook) }; + } + } + self.free_commit_hook = free_commit_hook; + } + + fn rollback_hook<F>(&mut self, hook: Option<F>) + where + F: FnMut() + Send + 'static, + { + unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) + where + F: FnMut(), + { + drop(catch_unwind(|| { + let boxed_hook: *mut F = p_arg.cast::<F>(); + (*boxed_hook)(); + })); + } + + let free_rollback_hook = if hook.is_some() { + Some(free_boxed_hook::<F> as unsafe fn(*mut c_void)) + } else { + None + }; + + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_rollback_hook( + self.db(), + Some(call_boxed_closure::<F>), + boxed_hook.cast(), + ) + } + } + _ => unsafe { ffi::sqlite3_rollback_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_rollback_hook { + unsafe { free_boxed_hook(previous_hook) }; + } + } + self.free_rollback_hook = free_rollback_hook; + } + + fn update_hook<F>(&mut self, hook: Option<F>) + where + F: FnMut(Action, &str, &str, i64) + Send + 'static, + { + unsafe extern "C" fn call_boxed_closure<F>( + p_arg: *mut c_void, + action_code: c_int, + p_db_name: *const c_char, + p_table_name: *const c_char, + row_id: i64, + ) where + F: FnMut(Action, &str, &str, i64), + { + let action = Action::from(action_code); + drop(catch_unwind(|| { + let boxed_hook: *mut F = p_arg.cast::<F>(); + (*boxed_hook)( + action, + expect_utf8(p_db_name, "database name"), + expect_utf8(p_table_name, "table name"), + row_id, + ); + })); + } + + let free_update_hook = if hook.is_some() { + Some(free_boxed_hook::<F> as unsafe fn(*mut c_void)) + } else { + None + }; + + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_update_hook( + self.db(), + Some(call_boxed_closure::<F>), + boxed_hook.cast(), + ) + } + } + _ => unsafe { ffi::sqlite3_update_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_update_hook { + unsafe { free_boxed_hook(previous_hook) }; + } + } + self.free_update_hook = free_update_hook; + } + + fn progress_handler<F>(&mut self, num_ops: c_int, handler: Option<F>) + where + F: FnMut() -> bool + Send + RefUnwindSafe + 'static, + { + unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int + where + F: FnMut() -> bool, + { + let r = catch_unwind(|| { + let boxed_handler: *mut F = p_arg.cast::<F>(); + (*boxed_handler)() + }); + c_int::from(r.unwrap_or_default()) + } + + if let Some(handler) = handler { + let boxed_handler = Box::new(handler); + unsafe { + ffi::sqlite3_progress_handler( + self.db(), + num_ops, + Some(call_boxed_closure::<F>), + &*boxed_handler as *const F as *mut _, + ); + } + self.progress_handler = Some(boxed_handler); + } else { + unsafe { ffi::sqlite3_progress_handler(self.db(), num_ops, None, ptr::null_mut()) } + self.progress_handler = None; + }; + } + + fn authorizer<'c, F>(&'c mut self, authorizer: Option<F>) + where + F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + RefUnwindSafe + 'static, + { + unsafe extern "C" fn call_boxed_closure<'c, F>( + p_arg: *mut c_void, + action_code: c_int, + param1: *const c_char, + param2: *const c_char, + db_name: *const c_char, + trigger_or_view_name: *const c_char, + ) -> c_int + where + F: FnMut(AuthContext<'c>) -> Authorization + Send + 'static, + { + catch_unwind(|| { + let action = AuthAction::from_raw( + action_code, + expect_optional_utf8(param1, "authorizer param 1"), + expect_optional_utf8(param2, "authorizer param 2"), + ); + let auth_ctx = AuthContext { + action, + database_name: expect_optional_utf8(db_name, "database name"), + accessor: expect_optional_utf8( + trigger_or_view_name, + "accessor (inner-most trigger or view)", + ), + }; + let boxed_hook: *mut F = p_arg.cast::<F>(); + (*boxed_hook)(auth_ctx) + }) + .map_or_else(|_| ffi::SQLITE_ERROR, Authorization::into_raw) + } + + let callback_fn = authorizer + .as_ref() + .map(|_| call_boxed_closure::<'c, F> as unsafe extern "C" fn(_, _, _, _, _, _) -> _); + let boxed_authorizer = authorizer.map(Box::new); + + match unsafe { + ffi::sqlite3_set_authorizer( + self.db(), + callback_fn, + boxed_authorizer + .as_ref() + .map_or_else(ptr::null_mut, |f| &**f as *const F as *mut _), + ) + } { + ffi::SQLITE_OK => { + self.authorizer = boxed_authorizer.map(|ba| ba as _); + } + err_code => { + // The only error that `sqlite3_set_authorizer` returns is `SQLITE_MISUSE` + // when compiled with `ENABLE_API_ARMOR` and the db pointer is invalid. + // This library does not allow constructing a null db ptr, so if this branch + // is hit, something very bad has happened. Panicking instead of returning + // `Result` keeps this hook's API consistent with the others. + panic!("unexpectedly failed to set_authorizer: {}", unsafe { + crate::error::error_from_handle(self.db(), err_code) + }); + } + } + } +} + +unsafe fn free_boxed_hook<F>(p: *mut c_void) { + drop(Box::from_raw(p.cast::<F>())); +} + +unsafe fn expect_utf8<'a>(p_str: *const c_char, description: &'static str) -> &'a str { + expect_optional_utf8(p_str, description) + .unwrap_or_else(|| panic!("received empty {description}")) +} + +unsafe fn expect_optional_utf8<'a>( + p_str: *const c_char, + description: &'static str, +) -> Option<&'a str> { + if p_str.is_null() { + return None; + } + std::str::from_utf8(std::ffi::CStr::from_ptr(p_str).to_bytes()) + .unwrap_or_else(|_| panic!("received non-utf8 string as {description}")) + .into() +} + +#[cfg(test)] +mod test { + use super::Action; + use crate::{Connection, Result}; + use std::sync::atomic::{AtomicBool, Ordering}; + + #[test] + fn test_commit_hook() -> Result<()> { + let db = Connection::open_in_memory()?; + + static CALLED: AtomicBool = AtomicBool::new(false); + db.commit_hook(Some(|| { + CALLED.store(true, Ordering::Relaxed); + false + })); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?; + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) + } + + #[test] + fn test_fn_commit_hook() -> Result<()> { + let db = Connection::open_in_memory()?; + + fn hook() -> bool { + true + } + + db.commit_hook(Some(hook)); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;") + .unwrap_err(); + Ok(()) + } + + #[test] + fn test_rollback_hook() -> Result<()> { + let db = Connection::open_in_memory()?; + + static CALLED: AtomicBool = AtomicBool::new(false); + db.rollback_hook(Some(|| { + CALLED.store(true, Ordering::Relaxed); + })); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); ROLLBACK;")?; + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) + } + + #[test] + fn test_update_hook() -> Result<()> { + let db = Connection::open_in_memory()?; + + static CALLED: AtomicBool = AtomicBool::new(false); + db.update_hook(Some(|action, db: &str, tbl: &str, row_id| { + assert_eq!(Action::SQLITE_INSERT, action); + assert_eq!("main", db); + assert_eq!("foo", tbl); + assert_eq!(1, row_id); + CALLED.store(true, Ordering::Relaxed); + })); + db.execute_batch("CREATE TABLE foo (t TEXT)")?; + db.execute_batch("INSERT INTO foo VALUES ('lisa')")?; + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) + } + + #[test] + fn test_progress_handler() -> Result<()> { + let db = Connection::open_in_memory()?; + + static CALLED: AtomicBool = AtomicBool::new(false); + db.progress_handler( + 1, + Some(|| { + CALLED.store(true, Ordering::Relaxed); + false + }), + ); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?; + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) + } + + #[test] + fn test_progress_handler_interrupt() -> Result<()> { + let db = Connection::open_in_memory()?; + + fn handler() -> bool { + true + } + + db.progress_handler(1, Some(handler)); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;") + .unwrap_err(); + Ok(()) + } + + #[test] + fn test_authorizer() -> Result<()> { + use super::{AuthAction, AuthContext, Authorization}; + + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (public TEXT, private TEXT)") + .unwrap(); + + let authorizer = move |ctx: AuthContext<'_>| match ctx.action { + AuthAction::Read { + column_name: "private", + .. + } => Authorization::Ignore, + AuthAction::DropTable { .. } => Authorization::Deny, + AuthAction::Pragma { .. } => panic!("shouldn't be called"), + _ => Authorization::Allow, + }; + + db.authorizer(Some(authorizer)); + db.execute_batch( + "BEGIN TRANSACTION; INSERT INTO foo VALUES ('pub txt', 'priv txt'); COMMIT;", + ) + .unwrap(); + db.query_row_and_then("SELECT * FROM foo", [], |row| -> Result<()> { + assert_eq!(row.get::<_, String>("public")?, "pub txt"); + assert!(row.get::<_, Option<String>>("private")?.is_none()); + Ok(()) + }) + .unwrap(); + db.execute_batch("DROP TABLE foo").unwrap_err(); + + db.authorizer(None::<fn(AuthContext<'_>) -> Authorization>); + db.execute_batch("PRAGMA user_version=1").unwrap(); // Disallowed by first authorizer, but it's now removed. + + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/inner_connection.rs b/third_party/rust/rusqlite/src/inner_connection.rs new file mode 100644 index 0000000000..b7c930fdea --- /dev/null +++ b/third_party/rust/rusqlite/src/inner_connection.rs @@ -0,0 +1,462 @@ +use std::ffi::CStr; +use std::os::raw::{c_char, c_int}; +#[cfg(feature = "load_extension")] +use std::path::Path; +use std::ptr; +use std::str; +use std::sync::atomic::AtomicBool; +use std::sync::{Arc, Mutex}; + +use super::ffi; +use super::str_for_sqlite; +use super::{Connection, InterruptHandle, OpenFlags, PrepFlags, Result}; +use crate::error::{error_from_handle, error_from_sqlite_code, error_with_offset, Error}; +use crate::raw_statement::RawStatement; +use crate::statement::Statement; +use crate::version::version_number; + +pub struct InnerConnection { + pub db: *mut ffi::sqlite3, + // It's unsafe to call `sqlite3_close` while another thread is performing + // a `sqlite3_interrupt`, and vice versa, so we take this mutex during + // those functions. This protects a copy of the `db` pointer (which is + // cleared on closing), however the main copy, `db`, is unprotected. + // Otherwise, a long running query would prevent calling interrupt, as + // interrupt would only acquire the lock after the query's completion. + interrupt_lock: Arc<Mutex<*mut ffi::sqlite3>>, + #[cfg(feature = "hooks")] + pub free_commit_hook: Option<unsafe fn(*mut std::os::raw::c_void)>, + #[cfg(feature = "hooks")] + pub free_rollback_hook: Option<unsafe fn(*mut std::os::raw::c_void)>, + #[cfg(feature = "hooks")] + pub free_update_hook: Option<unsafe fn(*mut std::os::raw::c_void)>, + #[cfg(feature = "hooks")] + pub progress_handler: Option<Box<dyn FnMut() -> bool + Send>>, + #[cfg(feature = "hooks")] + pub authorizer: Option<crate::hooks::BoxedAuthorizer>, + owned: bool, +} + +unsafe impl Send for InnerConnection {} + +impl InnerConnection { + #[allow(clippy::mutex_atomic, clippy::arc_with_non_send_sync)] // See unsafe impl Send / Sync for InterruptHandle + #[inline] + pub unsafe fn new(db: *mut ffi::sqlite3, owned: bool) -> InnerConnection { + InnerConnection { + db, + interrupt_lock: Arc::new(Mutex::new(db)), + #[cfg(feature = "hooks")] + free_commit_hook: None, + #[cfg(feature = "hooks")] + free_rollback_hook: None, + #[cfg(feature = "hooks")] + free_update_hook: None, + #[cfg(feature = "hooks")] + progress_handler: None, + #[cfg(feature = "hooks")] + authorizer: None, + owned, + } + } + + pub fn open_with_flags( + c_path: &CStr, + flags: OpenFlags, + vfs: Option<&CStr>, + ) -> Result<InnerConnection> { + ensure_safe_sqlite_threading_mode()?; + + // Replicate the check for sane open flags from SQLite, because the check in + // SQLite itself wasn't added until version 3.7.3. + debug_assert_eq!(1 << OpenFlags::SQLITE_OPEN_READ_ONLY.bits(), 0x02); + debug_assert_eq!(1 << OpenFlags::SQLITE_OPEN_READ_WRITE.bits(), 0x04); + debug_assert_eq!( + 1 << (OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE).bits(), + 0x40 + ); + if (1 << (flags.bits() & 0x7)) & 0x46 == 0 { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + None, + )); + } + + let z_vfs = match vfs { + Some(c_vfs) => c_vfs.as_ptr(), + None => ptr::null(), + }; + + unsafe { + let mut db: *mut ffi::sqlite3 = ptr::null_mut(); + let r = ffi::sqlite3_open_v2(c_path.as_ptr(), &mut db, flags.bits(), z_vfs); + if r != ffi::SQLITE_OK { + let e = if db.is_null() { + error_from_sqlite_code(r, Some(c_path.to_string_lossy().to_string())) + } else { + let mut e = error_from_handle(db, r); + if let Error::SqliteFailure( + ffi::Error { + code: ffi::ErrorCode::CannotOpen, + .. + }, + Some(msg), + ) = e + { + e = Error::SqliteFailure( + ffi::Error::new(r), + Some(format!("{msg}: {}", c_path.to_string_lossy())), + ); + } + ffi::sqlite3_close(db); + e + }; + + return Err(e); + } + + // attempt to turn on extended results code; don't fail if we can't. + ffi::sqlite3_extended_result_codes(db, 1); + + let r = ffi::sqlite3_busy_timeout(db, 5000); + if r != ffi::SQLITE_OK { + let e = error_from_handle(db, r); + ffi::sqlite3_close(db); + return Err(e); + } + + Ok(InnerConnection::new(db, true)) + } + } + + #[inline] + pub fn db(&self) -> *mut ffi::sqlite3 { + self.db + } + + #[inline] + pub fn decode_result(&self, code: c_int) -> Result<()> { + unsafe { InnerConnection::decode_result_raw(self.db(), code) } + } + + #[inline] + unsafe fn decode_result_raw(db: *mut ffi::sqlite3, code: c_int) -> Result<()> { + if code == ffi::SQLITE_OK { + Ok(()) + } else { + Err(error_from_handle(db, code)) + } + } + + #[allow(clippy::mutex_atomic)] + pub fn close(&mut self) -> Result<()> { + if self.db.is_null() { + return Ok(()); + } + self.remove_hooks(); + let mut shared_handle = self.interrupt_lock.lock().unwrap(); + assert!( + !shared_handle.is_null(), + "Bug: Somehow interrupt_lock was cleared before the DB was closed" + ); + if !self.owned { + self.db = ptr::null_mut(); + return Ok(()); + } + unsafe { + let r = ffi::sqlite3_close(self.db); + // Need to use _raw because _guard has a reference out, and + // decode_result takes &mut self. + let r = InnerConnection::decode_result_raw(self.db, r); + if r.is_ok() { + *shared_handle = ptr::null_mut(); + self.db = ptr::null_mut(); + } + r + } + } + + #[inline] + pub fn get_interrupt_handle(&self) -> InterruptHandle { + InterruptHandle { + db_lock: Arc::clone(&self.interrupt_lock), + } + } + + #[inline] + #[cfg(feature = "load_extension")] + pub unsafe fn enable_load_extension(&mut self, onoff: c_int) -> Result<()> { + let r = ffi::sqlite3_enable_load_extension(self.db, onoff); + self.decode_result(r) + } + + #[cfg(feature = "load_extension")] + pub unsafe fn load_extension( + &self, + dylib_path: &Path, + entry_point: Option<&str>, + ) -> Result<()> { + let dylib_str = super::path_to_cstring(dylib_path)?; + let mut errmsg: *mut c_char = ptr::null_mut(); + let r = if let Some(entry_point) = entry_point { + let c_entry = crate::str_to_cstring(entry_point)?; + ffi::sqlite3_load_extension(self.db, dylib_str.as_ptr(), c_entry.as_ptr(), &mut errmsg) + } else { + ffi::sqlite3_load_extension(self.db, dylib_str.as_ptr(), ptr::null(), &mut errmsg) + }; + if r == ffi::SQLITE_OK { + Ok(()) + } else { + let message = super::errmsg_to_string(errmsg); + ffi::sqlite3_free(errmsg.cast::<std::os::raw::c_void>()); + Err(error_from_sqlite_code(r, Some(message))) + } + } + + #[inline] + pub fn last_insert_rowid(&self) -> i64 { + unsafe { ffi::sqlite3_last_insert_rowid(self.db()) } + } + + pub fn prepare<'a>( + &mut self, + conn: &'a Connection, + sql: &str, + flags: PrepFlags, + ) -> Result<Statement<'a>> { + let mut c_stmt: *mut ffi::sqlite3_stmt = ptr::null_mut(); + let (c_sql, len, _) = str_for_sqlite(sql.as_bytes())?; + let mut c_tail: *const c_char = ptr::null(); + // TODO sqlite3_prepare_v3 (https://sqlite.org/c3ref/c_prepare_normalize.html) // 3.20.0, #728 + #[cfg(not(feature = "unlock_notify"))] + let r = unsafe { self.prepare_(c_sql, len, flags, &mut c_stmt, &mut c_tail) }; + #[cfg(feature = "unlock_notify")] + let r = unsafe { + use crate::unlock_notify; + let mut rc; + loop { + rc = self.prepare_(c_sql, len, flags, &mut c_stmt, &mut c_tail); + if !unlock_notify::is_locked(self.db, rc) { + break; + } + rc = unlock_notify::wait_for_unlock_notify(self.db); + if rc != ffi::SQLITE_OK { + break; + } + } + rc + }; + // If there is an error, *ppStmt is set to NULL. + if r != ffi::SQLITE_OK { + return Err(unsafe { error_with_offset(self.db, r, sql) }); + } + // If the input text contains no SQL (if the input is an empty string or a + // comment) then *ppStmt is set to NULL. + let tail = if c_tail.is_null() { + 0 + } else { + let n = (c_tail as isize) - (c_sql as isize); + if n <= 0 || n >= len as isize { + 0 + } else { + n as usize + } + }; + Ok(Statement::new(conn, unsafe { + RawStatement::new(c_stmt, tail) + })) + } + + #[inline] + #[cfg(not(feature = "modern_sqlite"))] + unsafe fn prepare_( + &self, + z_sql: *const c_char, + n_byte: c_int, + _: PrepFlags, + pp_stmt: *mut *mut ffi::sqlite3_stmt, + pz_tail: *mut *const c_char, + ) -> c_int { + ffi::sqlite3_prepare_v2(self.db(), z_sql, n_byte, pp_stmt, pz_tail) + } + + #[inline] + #[cfg(feature = "modern_sqlite")] + unsafe fn prepare_( + &self, + z_sql: *const c_char, + n_byte: c_int, + flags: PrepFlags, + pp_stmt: *mut *mut ffi::sqlite3_stmt, + pz_tail: *mut *const c_char, + ) -> c_int { + ffi::sqlite3_prepare_v3(self.db(), z_sql, n_byte, flags.bits(), pp_stmt, pz_tail) + } + + #[inline] + pub fn changes(&self) -> u64 { + #[cfg(not(feature = "modern_sqlite"))] + unsafe { + ffi::sqlite3_changes(self.db()) as u64 + } + #[cfg(feature = "modern_sqlite")] // 3.37.0 + unsafe { + ffi::sqlite3_changes64(self.db()) as u64 + } + } + + #[inline] + pub fn is_autocommit(&self) -> bool { + unsafe { ffi::sqlite3_get_autocommit(self.db()) != 0 } + } + + pub fn is_busy(&self) -> bool { + let db = self.db(); + unsafe { + let mut stmt = ffi::sqlite3_next_stmt(db, ptr::null_mut()); + while !stmt.is_null() { + if ffi::sqlite3_stmt_busy(stmt) != 0 { + return true; + } + stmt = ffi::sqlite3_next_stmt(db, stmt); + } + } + false + } + + pub fn cache_flush(&mut self) -> Result<()> { + crate::error::check(unsafe { ffi::sqlite3_db_cacheflush(self.db()) }) + } + + #[cfg(not(feature = "hooks"))] + #[inline] + fn remove_hooks(&mut self) {} + + pub fn db_readonly(&self, db_name: super::DatabaseName<'_>) -> Result<bool> { + let name = db_name.as_cstring()?; + let r = unsafe { ffi::sqlite3_db_readonly(self.db, name.as_ptr()) }; + match r { + 0 => Ok(false), + 1 => Ok(true), + -1 => Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("{db_name:?} is not the name of a database")), + )), + _ => Err(error_from_sqlite_code( + r, + Some("Unexpected result".to_owned()), + )), + } + } + + #[cfg(feature = "modern_sqlite")] // 3.37.0 + pub fn txn_state( + &self, + db_name: Option<super::DatabaseName<'_>>, + ) -> Result<super::transaction::TransactionState> { + let r = if let Some(ref name) = db_name { + let name = name.as_cstring()?; + unsafe { ffi::sqlite3_txn_state(self.db, name.as_ptr()) } + } else { + unsafe { ffi::sqlite3_txn_state(self.db, ptr::null()) } + }; + match r { + 0 => Ok(super::transaction::TransactionState::None), + 1 => Ok(super::transaction::TransactionState::Read), + 2 => Ok(super::transaction::TransactionState::Write), + -1 => Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("{db_name:?} is not the name of a valid schema")), + )), + _ => Err(error_from_sqlite_code( + r, + Some("Unexpected result".to_owned()), + )), + } + } + + #[inline] + #[cfg(feature = "release_memory")] + pub fn release_memory(&self) -> Result<()> { + self.decode_result(unsafe { ffi::sqlite3_db_release_memory(self.db) }) + } +} + +impl Drop for InnerConnection { + #[allow(unused_must_use)] + #[inline] + fn drop(&mut self) { + self.close(); + } +} + +#[cfg(not(any(target_arch = "wasm32", feature = "loadable_extension")))] +static SQLITE_INIT: std::sync::Once = std::sync::Once::new(); + +pub static BYPASS_SQLITE_INIT: AtomicBool = AtomicBool::new(false); + +// threading mode checks are not necessary (and do not work) on target +// platforms that do not have threading (such as webassembly) +#[cfg(target_arch = "wasm32")] +fn ensure_safe_sqlite_threading_mode() -> Result<()> { + Ok(()) +} + +#[cfg(not(any(target_arch = "wasm32")))] +fn ensure_safe_sqlite_threading_mode() -> Result<()> { + // Ensure SQLite was compiled in threadsafe mode. + if unsafe { ffi::sqlite3_threadsafe() == 0 } { + return Err(Error::SqliteSingleThreadedMode); + } + + // Now we know SQLite is _capable_ of being in Multi-thread of Serialized mode, + // but it's possible someone configured it to be in Single-thread mode + // before calling into us. That would mean we're exposing an unsafe API via + // a safe one (in Rust terminology), which is no good. We have two options + // to protect against this, depending on the version of SQLite we're linked + // with: + // + // 1. If we're on 3.7.0 or later, we can ask SQLite for a mutex and check for + // the magic value 8. This isn't documented, but it's what SQLite + // returns for its mutex allocation function in Single-thread mode. + // 2. If we're prior to SQLite 3.7.0, AFAIK there's no way to check the + // threading mode. The check we perform for >= 3.7.0 will segfault. + // Instead, we insist on being able to call sqlite3_config and + // sqlite3_initialize ourself, ensuring we know the threading + // mode. This will fail if someone else has already initialized SQLite + // even if they initialized it safely. That's not ideal either, which is + // why we expose bypass_sqlite_initialization above. + if version_number() >= 3_007_000 { + const SQLITE_SINGLETHREADED_MUTEX_MAGIC: usize = 8; + let is_singlethreaded = unsafe { + let mutex_ptr = ffi::sqlite3_mutex_alloc(0); + let is_singlethreaded = mutex_ptr as usize == SQLITE_SINGLETHREADED_MUTEX_MAGIC; + ffi::sqlite3_mutex_free(mutex_ptr); + is_singlethreaded + }; + if is_singlethreaded { + Err(Error::SqliteSingleThreadedMode) + } else { + Ok(()) + } + } else { + #[cfg(not(feature = "loadable_extension"))] + SQLITE_INIT.call_once(|| { + use std::sync::atomic::Ordering; + if BYPASS_SQLITE_INIT.load(Ordering::Relaxed) { + return; + } + + unsafe { + assert!(ffi::sqlite3_config(ffi::SQLITE_CONFIG_MULTITHREAD) == ffi::SQLITE_OK && ffi::sqlite3_initialize() == ffi::SQLITE_OK, + "Could not ensure safe initialization of SQLite.\n\ + To fix this, either:\n\ + * Upgrade SQLite to at least version 3.7.0\n\ + * Ensure that SQLite has been initialized in Multi-thread or Serialized mode and call\n\ + rusqlite::bypass_sqlite_initialization() prior to your first connection attempt." + ); + } + }); + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/lib.rs b/third_party/rust/rusqlite/src/lib.rs new file mode 100644 index 0000000000..cb4b14b9dc --- /dev/null +++ b/third_party/rust/rusqlite/src/lib.rs @@ -0,0 +1,2200 @@ +//! Rusqlite is an ergonomic wrapper for using SQLite from Rust. +//! +//! Historically, the API was based on the one from +//! [`rust-postgres`](https://github.com/sfackler/rust-postgres). However, the +//! two have diverged in many ways, and no compatibility between the two is +//! intended. +//! +//! ```rust +//! use rusqlite::{params, Connection, Result}; +//! +//! #[derive(Debug)] +//! struct Person { +//! id: i32, +//! name: String, +//! data: Option<Vec<u8>>, +//! } +//! +//! fn main() -> Result<()> { +//! let conn = Connection::open_in_memory()?; +//! +//! conn.execute( +//! "CREATE TABLE person ( +//! id INTEGER PRIMARY KEY, +//! name TEXT NOT NULL, +//! data BLOB +//! )", +//! (), // empty list of parameters. +//! )?; +//! let me = Person { +//! id: 0, +//! name: "Steven".to_string(), +//! data: None, +//! }; +//! conn.execute( +//! "INSERT INTO person (name, data) VALUES (?1, ?2)", +//! (&me.name, &me.data), +//! )?; +//! +//! let mut stmt = conn.prepare("SELECT id, name, data FROM person")?; +//! let person_iter = stmt.query_map([], |row| { +//! Ok(Person { +//! id: row.get(0)?, +//! name: row.get(1)?, +//! data: row.get(2)?, +//! }) +//! })?; +//! +//! for person in person_iter { +//! println!("Found person {:?}", person.unwrap()); +//! } +//! Ok(()) +//! } +//! ``` +#![warn(missing_docs)] +#![cfg_attr(docsrs, feature(doc_cfg))] + +pub use libsqlite3_sys as ffi; + +use std::cell::RefCell; +use std::default::Default; +use std::ffi::{CStr, CString}; +use std::fmt; +use std::os::raw::{c_char, c_int}; + +use std::path::Path; +use std::result; +use std::str; +use std::sync::atomic::Ordering; +use std::sync::{Arc, Mutex}; + +use crate::cache::StatementCache; +use crate::inner_connection::{InnerConnection, BYPASS_SQLITE_INIT}; +use crate::raw_statement::RawStatement; +use crate::types::ValueRef; + +pub use crate::cache::CachedStatement; +#[cfg(feature = "column_decltype")] +pub use crate::column::Column; +pub use crate::error::{to_sqlite_error, Error}; +pub use crate::ffi::ErrorCode; +#[cfg(feature = "load_extension")] +pub use crate::load_extension_guard::LoadExtensionGuard; +pub use crate::params::{params_from_iter, Params, ParamsFromIter}; +pub use crate::row::{AndThenRows, Map, MappedRows, Row, RowIndex, Rows}; +pub use crate::statement::{Statement, StatementStatus}; +#[cfg(feature = "modern_sqlite")] +pub use crate::transaction::TransactionState; +pub use crate::transaction::{DropBehavior, Savepoint, Transaction, TransactionBehavior}; +pub use crate::types::ToSql; +pub use crate::version::*; +#[cfg(feature = "rusqlite-macros")] +#[doc(hidden)] +pub use rusqlite_macros::__bind; + +mod error; + +#[cfg(feature = "backup")] +#[cfg_attr(docsrs, doc(cfg(feature = "backup")))] +pub mod backup; +#[cfg(feature = "blob")] +#[cfg_attr(docsrs, doc(cfg(feature = "blob")))] +pub mod blob; +mod busy; +mod cache; +#[cfg(feature = "collation")] +#[cfg_attr(docsrs, doc(cfg(feature = "collation")))] +mod collation; +mod column; +pub mod config; +#[cfg(any(feature = "functions", feature = "vtab"))] +mod context; +#[cfg(feature = "functions")] +#[cfg_attr(docsrs, doc(cfg(feature = "functions")))] +pub mod functions; +#[cfg(feature = "hooks")] +#[cfg_attr(docsrs, doc(cfg(feature = "hooks")))] +pub mod hooks; +mod inner_connection; +#[cfg(feature = "limits")] +#[cfg_attr(docsrs, doc(cfg(feature = "limits")))] +pub mod limits; +#[cfg(feature = "load_extension")] +mod load_extension_guard; +mod params; +mod pragma; +mod raw_statement; +mod row; +#[cfg(feature = "serialize")] +#[cfg_attr(docsrs, doc(cfg(feature = "serialize")))] +pub mod serialize; +#[cfg(feature = "session")] +#[cfg_attr(docsrs, doc(cfg(feature = "session")))] +pub mod session; +mod statement; +#[cfg(feature = "trace")] +#[cfg_attr(docsrs, doc(cfg(feature = "trace")))] +pub mod trace; +mod transaction; +pub mod types; +#[cfg(feature = "unlock_notify")] +mod unlock_notify; +mod version; +#[cfg(feature = "vtab")] +#[cfg_attr(docsrs, doc(cfg(feature = "vtab")))] +pub mod vtab; + +pub(crate) mod util; +pub(crate) use util::SmallCString; + +// Number of cached prepared statements we'll hold on to. +const STATEMENT_CACHE_DEFAULT_CAPACITY: usize = 16; + +/// A macro making it more convenient to longer lists of +/// parameters as a `&[&dyn ToSql]`. +/// +/// # Example +/// +/// ```rust,no_run +/// # use rusqlite::{Result, Connection, params}; +/// +/// struct Person { +/// name: String, +/// age_in_years: u8, +/// data: Option<Vec<u8>>, +/// } +/// +/// fn add_person(conn: &Connection, person: &Person) -> Result<()> { +/// conn.execute( +/// "INSERT INTO person(name, age_in_years, data) VALUES (?1, ?2, ?3)", +/// params![person.name, person.age_in_years, person.data], +/// )?; +/// Ok(()) +/// } +/// ``` +#[macro_export] +macro_rules! params { + () => { + &[] as &[&dyn $crate::ToSql] + }; + ($($param:expr),+ $(,)?) => { + &[$(&$param as &dyn $crate::ToSql),+] as &[&dyn $crate::ToSql] + }; +} + +/// A macro making it more convenient to pass lists of named parameters +/// as a `&[(&str, &dyn ToSql)]`. +/// +/// # Example +/// +/// ```rust,no_run +/// # use rusqlite::{Result, Connection, named_params}; +/// +/// struct Person { +/// name: String, +/// age_in_years: u8, +/// data: Option<Vec<u8>>, +/// } +/// +/// fn add_person(conn: &Connection, person: &Person) -> Result<()> { +/// conn.execute( +/// "INSERT INTO person (name, age_in_years, data) +/// VALUES (:name, :age, :data)", +/// named_params! { +/// ":name": person.name, +/// ":age": person.age_in_years, +/// ":data": person.data, +/// }, +/// )?; +/// Ok(()) +/// } +/// ``` +#[macro_export] +macro_rules! named_params { + () => { + &[] as &[(&str, &dyn $crate::ToSql)] + }; + // Note: It's a lot more work to support this as part of the same macro as + // `params!`, unfortunately. + ($($param_name:literal: $param_val:expr),+ $(,)?) => { + &[$(($param_name, &$param_val as &dyn $crate::ToSql)),+] as &[(&str, &dyn $crate::ToSql)] + }; +} + +/// Captured identifiers in SQL +/// +/// * only SQLite `$x` / `@x` / `:x` syntax works (Rust `&x` syntax does not +/// work). +/// * `$x.y` expression does not work. +/// +/// # Example +/// +/// ```rust, no_run +/// # use rusqlite::{prepare_and_bind, Connection, Result, Statement}; +/// +/// fn misc(db: &Connection) -> Result<Statement> { +/// let name = "Lisa"; +/// let age = 8; +/// let smart = true; +/// Ok(prepare_and_bind!(db, "SELECT $name, @age, :smart;")) +/// } +/// ``` +#[cfg(feature = "rusqlite-macros")] +#[cfg_attr(docsrs, doc(cfg(feature = "rusqlite-macros")))] +#[macro_export] +macro_rules! prepare_and_bind { + ($conn:expr, $sql:literal) => {{ + let mut stmt = $conn.prepare($sql)?; + $crate::__bind!(stmt $sql); + stmt + }}; +} + +/// Captured identifiers in SQL +/// +/// * only SQLite `$x` / `@x` / `:x` syntax works (Rust `&x` syntax does not +/// work). +/// * `$x.y` expression does not work. +#[cfg(feature = "rusqlite-macros")] +#[cfg_attr(docsrs, doc(cfg(feature = "rusqlite-macros")))] +#[macro_export] +macro_rules! prepare_cached_and_bind { + ($conn:expr, $sql:literal) => {{ + let mut stmt = $conn.prepare_cached($sql)?; + $crate::__bind!(stmt $sql); + stmt + }}; +} + +/// A typedef of the result returned by many methods. +pub type Result<T, E = Error> = result::Result<T, E>; + +/// See the [method documentation](#tymethod.optional). +pub trait OptionalExtension<T> { + /// Converts a `Result<T>` into a `Result<Option<T>>`. + /// + /// By default, Rusqlite treats 0 rows being returned from a query that is + /// expected to return 1 row as an error. This method will + /// handle that error, and give you back an `Option<T>` instead. + fn optional(self) -> Result<Option<T>>; +} + +impl<T> OptionalExtension<T> for Result<T> { + fn optional(self) -> Result<Option<T>> { + match self { + Ok(value) => Ok(Some(value)), + Err(Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(e), + } + } +} + +unsafe fn errmsg_to_string(errmsg: *const c_char) -> String { + let c_slice = CStr::from_ptr(errmsg).to_bytes(); + String::from_utf8_lossy(c_slice).into_owned() +} + +fn str_to_cstring(s: &str) -> Result<SmallCString> { + Ok(SmallCString::new(s)?) +} + +/// Returns `Ok((string ptr, len as c_int, SQLITE_STATIC | SQLITE_TRANSIENT))` +/// normally. +/// Returns error if the string is too large for sqlite. +/// The `sqlite3_destructor_type` item is always `SQLITE_TRANSIENT` unless +/// the string was empty (in which case it's `SQLITE_STATIC`, and the ptr is +/// static). +fn str_for_sqlite(s: &[u8]) -> Result<(*const c_char, c_int, ffi::sqlite3_destructor_type)> { + let len = len_as_c_int(s.len())?; + let (ptr, dtor_info) = if len != 0 { + (s.as_ptr().cast::<c_char>(), ffi::SQLITE_TRANSIENT()) + } else { + // Return a pointer guaranteed to live forever + ("".as_ptr().cast::<c_char>(), ffi::SQLITE_STATIC()) + }; + Ok((ptr, len, dtor_info)) +} + +// Helper to cast to c_int safely, returning the correct error type if the cast +// failed. +fn len_as_c_int(len: usize) -> Result<c_int> { + if len >= (c_int::MAX as usize) { + Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_TOOBIG), + None, + )) + } else { + Ok(len as c_int) + } +} + +#[cfg(unix)] +fn path_to_cstring(p: &Path) -> Result<CString> { + use std::os::unix::ffi::OsStrExt; + Ok(CString::new(p.as_os_str().as_bytes())?) +} + +#[cfg(not(unix))] +fn path_to_cstring(p: &Path) -> Result<CString> { + let s = p.to_str().ok_or_else(|| Error::InvalidPath(p.to_owned()))?; + Ok(CString::new(s)?) +} + +/// Name for a database within a SQLite connection. +#[derive(Copy, Clone, Debug)] +pub enum DatabaseName<'a> { + /// The main database. + Main, + + /// The temporary database (e.g., any "CREATE TEMPORARY TABLE" tables). + Temp, + + /// A database that has been attached via "ATTACH DATABASE ...". + Attached(&'a str), +} + +/// Shorthand for [`DatabaseName::Main`]. +pub const MAIN_DB: DatabaseName<'static> = DatabaseName::Main; + +/// Shorthand for [`DatabaseName::Temp`]. +pub const TEMP_DB: DatabaseName<'static> = DatabaseName::Temp; + +// Currently DatabaseName is only used by the backup and blob mods, so hide +// this (private) impl to avoid dead code warnings. +impl DatabaseName<'_> { + #[inline] + fn as_cstring(&self) -> Result<SmallCString> { + use self::DatabaseName::{Attached, Main, Temp}; + match *self { + Main => str_to_cstring("main"), + Temp => str_to_cstring("temp"), + Attached(s) => str_to_cstring(s), + } + } +} + +/// A connection to a SQLite database. +pub struct Connection { + db: RefCell<InnerConnection>, + cache: StatementCache, +} + +unsafe impl Send for Connection {} + +impl Drop for Connection { + #[inline] + fn drop(&mut self) { + self.flush_prepared_statement_cache(); + } +} + +impl Connection { + /// Open a new connection to a SQLite database. If a database does not exist + /// at the path, one is created. + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn open_my_db() -> Result<()> { + /// let path = "./my_db.db3"; + /// let db = Connection::open(path)?; + /// // Use the database somehow... + /// println!("{}", db.is_autocommit()); + /// Ok(()) + /// } + /// ``` + /// + /// # Flags + /// + /// `Connection::open(path)` is equivalent to using + /// [`Connection::open_with_flags`] with the default [`OpenFlags`]. That is, + /// it's equivalent to: + /// + /// ```ignore + /// Connection::open_with_flags( + /// path, + /// OpenFlags::SQLITE_OPEN_READ_WRITE + /// | OpenFlags::SQLITE_OPEN_CREATE + /// | OpenFlags::SQLITE_OPEN_URI + /// | OpenFlags::SQLITE_OPEN_NO_MUTEX, + /// ) + /// ``` + /// + /// These flags have the following effects: + /// + /// - Open the database for both reading or writing. + /// - Create the database if one does not exist at the path. + /// - Allow the filename to be interpreted as a URI (see <https://www.sqlite.org/uri.html#uri_filenames_in_sqlite> + /// for details). + /// - Disables the use of a per-connection mutex. + /// + /// Rusqlite enforces thread-safety at compile time, so additional + /// locking is not needed and provides no benefit. (See the + /// documentation on [`OpenFlags::SQLITE_OPEN_FULL_MUTEX`] for some + /// additional discussion about this). + /// + /// Most of these are also the default settings for the C API, although + /// technically the default locking behavior is controlled by the flags used + /// when compiling SQLite -- rather than let it vary, we choose `NO_MUTEX` + /// because it's a fairly clearly the best choice for users of this library. + /// + /// # Failure + /// + /// Will return `Err` if `path` cannot be converted to a C-compatible string + /// or if the underlying SQLite open call fails. + #[inline] + pub fn open<P: AsRef<Path>>(path: P) -> Result<Connection> { + let flags = OpenFlags::default(); + Connection::open_with_flags(path, flags) + } + + /// Open a new connection to an in-memory SQLite database. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite open call fails. + #[inline] + pub fn open_in_memory() -> Result<Connection> { + let flags = OpenFlags::default(); + Connection::open_in_memory_with_flags(flags) + } + + /// Open a new connection to a SQLite database. + /// + /// [Database Connection](http://www.sqlite.org/c3ref/open.html) for a description of valid + /// flag combinations. + /// + /// # Failure + /// + /// Will return `Err` if `path` cannot be converted to a C-compatible + /// string or if the underlying SQLite open call fails. + #[inline] + pub fn open_with_flags<P: AsRef<Path>>(path: P, flags: OpenFlags) -> Result<Connection> { + let c_path = path_to_cstring(path.as_ref())?; + InnerConnection::open_with_flags(&c_path, flags, None).map(|db| Connection { + db: RefCell::new(db), + cache: StatementCache::with_capacity(STATEMENT_CACHE_DEFAULT_CAPACITY), + }) + } + + /// Open a new connection to a SQLite database using the specific flags and + /// vfs name. + /// + /// [Database Connection](http://www.sqlite.org/c3ref/open.html) for a description of valid + /// flag combinations. + /// + /// # Failure + /// + /// Will return `Err` if either `path` or `vfs` cannot be converted to a + /// C-compatible string or if the underlying SQLite open call fails. + #[inline] + pub fn open_with_flags_and_vfs<P: AsRef<Path>>( + path: P, + flags: OpenFlags, + vfs: &str, + ) -> Result<Connection> { + let c_path = path_to_cstring(path.as_ref())?; + let c_vfs = str_to_cstring(vfs)?; + InnerConnection::open_with_flags(&c_path, flags, Some(&c_vfs)).map(|db| Connection { + db: RefCell::new(db), + cache: StatementCache::with_capacity(STATEMENT_CACHE_DEFAULT_CAPACITY), + }) + } + + /// Open a new connection to an in-memory SQLite database. + /// + /// [Database Connection](http://www.sqlite.org/c3ref/open.html) for a description of valid + /// flag combinations. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite open call fails. + #[inline] + pub fn open_in_memory_with_flags(flags: OpenFlags) -> Result<Connection> { + Connection::open_with_flags(":memory:", flags) + } + + /// Open a new connection to an in-memory SQLite database using the specific + /// flags and vfs name. + /// + /// [Database Connection](http://www.sqlite.org/c3ref/open.html) for a description of valid + /// flag combinations. + /// + /// # Failure + /// + /// Will return `Err` if `vfs` cannot be converted to a C-compatible + /// string or if the underlying SQLite open call fails. + #[inline] + pub fn open_in_memory_with_flags_and_vfs(flags: OpenFlags, vfs: &str) -> Result<Connection> { + Connection::open_with_flags_and_vfs(":memory:", flags, vfs) + } + + /// Convenience method to run multiple SQL statements (that cannot take any + /// parameters). + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn create_tables(conn: &Connection) -> Result<()> { + /// conn.execute_batch( + /// "BEGIN; + /// CREATE TABLE foo(x INTEGER); + /// CREATE TABLE bar(y TEXT); + /// COMMIT;", + /// ) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + pub fn execute_batch(&self, sql: &str) -> Result<()> { + let mut sql = sql; + while !sql.is_empty() { + let stmt = self.prepare(sql)?; + if !stmt.stmt.is_null() && stmt.step()? && cfg!(feature = "extra_check") { + // Some PRAGMA may return rows + return Err(Error::ExecuteReturnedResults); + } + let tail = stmt.stmt.tail(); + if tail == 0 || tail >= sql.len() { + break; + } + sql = &sql[tail..]; + } + Ok(()) + } + + /// Convenience method to prepare and execute a single SQL statement. + /// + /// On success, returns the number of rows that were changed or inserted or + /// deleted (via `sqlite3_changes`). + /// + /// ## Example + /// + /// ### With positional params + /// + /// ```rust,no_run + /// # use rusqlite::{Connection}; + /// fn update_rows(conn: &Connection) { + /// match conn.execute("UPDATE foo SET bar = 'baz' WHERE qux = ?1", [1i32]) { + /// Ok(updated) => println!("{} rows were updated", updated), + /// Err(err) => println!("update failed: {}", err), + /// } + /// } + /// ``` + /// + /// ### With positional params of varying types + /// + /// ```rust,no_run + /// # use rusqlite::{params, Connection}; + /// fn update_rows(conn: &Connection) { + /// match conn.execute( + /// "UPDATE foo SET bar = 'baz' WHERE qux = ?1 AND quux = ?2", + /// params![1i32, 1.5f64], + /// ) { + /// Ok(updated) => println!("{} rows were updated", updated), + /// Err(err) => println!("update failed: {}", err), + /// } + /// } + /// ``` + /// + /// ### With named params + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn insert(conn: &Connection) -> Result<usize> { + /// conn.execute( + /// "INSERT INTO test (name) VALUES (:name)", + /// &[(":name", "one")], + /// ) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + #[inline] + pub fn execute<P: Params>(&self, sql: &str, params: P) -> Result<usize> { + self.prepare(sql) + .and_then(|mut stmt| stmt.check_no_tail().and_then(|()| stmt.execute(params))) + } + + /// Returns the path to the database file, if one exists and is known. + /// + /// Returns `Some("")` for a temporary or in-memory database. + /// + /// Note that in some cases [PRAGMA + /// database_list](https://sqlite.org/pragma.html#pragma_database_list) is + /// likely to be more robust. + #[inline] + pub fn path(&self) -> Option<&str> { + unsafe { + let db = self.handle(); + let db_name = DatabaseName::Main.as_cstring().unwrap(); + let db_filename = ffi::sqlite3_db_filename(db, db_name.as_ptr()); + if db_filename.is_null() { + None + } else { + CStr::from_ptr(db_filename).to_str().ok() + } + } + } + + /// Attempts to free as much heap memory as possible from the database + /// connection. + /// + /// This calls [`sqlite3_db_release_memory`](https://www.sqlite.org/c3ref/db_release_memory.html). + #[inline] + #[cfg(feature = "release_memory")] + pub fn release_memory(&self) -> Result<()> { + self.db.borrow_mut().release_memory() + } + + /// Get the SQLite rowid of the most recent successful INSERT. + /// + /// Uses [sqlite3_last_insert_rowid](https://www.sqlite.org/c3ref/last_insert_rowid.html) under + /// the hood. + #[inline] + pub fn last_insert_rowid(&self) -> i64 { + self.db.borrow_mut().last_insert_rowid() + } + + /// Convenience method to execute a query that is expected to return a + /// single row. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Result, Connection}; + /// fn preferred_locale(conn: &Connection) -> Result<String> { + /// conn.query_row( + /// "SELECT value FROM preferences WHERE name='locale'", + /// [], + /// |row| row.get(0), + /// ) + /// } + /// ``` + /// + /// If the query returns more than one row, all rows except the first are + /// ignored. + /// + /// Returns `Err(QueryReturnedNoRows)` if no results are returned. If the + /// query truly is optional, you can call `.optional()` on the result of + /// this to get a `Result<Option<T>>`. + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + #[inline] + pub fn query_row<T, P, F>(&self, sql: &str, params: P, f: F) -> Result<T> + where + P: Params, + F: FnOnce(&Row<'_>) -> Result<T>, + { + let mut stmt = self.prepare(sql)?; + stmt.check_no_tail()?; + stmt.query_row(params, f) + } + + // https://sqlite.org/tclsqlite.html#onecolumn + #[cfg(test)] + pub(crate) fn one_column<T: types::FromSql>(&self, sql: &str) -> Result<T> { + self.query_row(sql, [], |r| r.get(0)) + } + + /// Convenience method to execute a query that is expected to return a + /// single row, and execute a mapping via `f` on that returned row with + /// the possibility of failure. The `Result` type of `f` must implement + /// `std::convert::From<Error>`. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Result, Connection}; + /// fn preferred_locale(conn: &Connection) -> Result<String> { + /// conn.query_row_and_then( + /// "SELECT value FROM preferences WHERE name='locale'", + /// [], + /// |row| row.get(0), + /// ) + /// } + /// ``` + /// + /// If the query returns more than one row, all rows except the first are + /// ignored. + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + #[inline] + pub fn query_row_and_then<T, E, P, F>(&self, sql: &str, params: P, f: F) -> Result<T, E> + where + P: Params, + F: FnOnce(&Row<'_>) -> Result<T, E>, + E: From<Error>, + { + let mut stmt = self.prepare(sql)?; + stmt.check_no_tail()?; + let mut rows = stmt.query(params)?; + + rows.get_expected_row().map_err(E::from).and_then(f) + } + + /// Prepare a SQL statement for execution. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn insert_new_people(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("INSERT INTO People (name) VALUES (?1)")?; + /// stmt.execute(["Joe Smith"])?; + /// stmt.execute(["Bob Jones"])?; + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + #[inline] + pub fn prepare(&self, sql: &str) -> Result<Statement<'_>> { + self.prepare_with_flags(sql, PrepFlags::default()) + } + + /// Prepare a SQL statement for execution. + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + #[inline] + pub fn prepare_with_flags(&self, sql: &str, flags: PrepFlags) -> Result<Statement<'_>> { + self.db.borrow_mut().prepare(self, sql, flags) + } + + /// Close the SQLite connection. + /// + /// This is functionally equivalent to the `Drop` implementation for + /// `Connection` except that on failure, it returns an error and the + /// connection itself (presumably so closing can be attempted again). + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[inline] + pub fn close(self) -> Result<(), (Connection, Error)> { + self.flush_prepared_statement_cache(); + let r = self.db.borrow_mut().close(); + r.map_err(move |err| (self, err)) + } + + /// Enable loading of SQLite extensions from both SQL queries and Rust. + /// + /// You must call [`Connection::load_extension_disable`] when you're + /// finished loading extensions (failure to call it can lead to bad things, + /// see "Safety"), so you should strongly consider using + /// [`LoadExtensionGuard`] instead of this function, automatically disables + /// extension loading when it goes out of scope. + /// + /// # Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn load_my_extension(conn: &Connection) -> Result<()> { + /// // Safety: We fully trust the loaded extension and execute no untrusted SQL + /// // while extension loading is enabled. + /// unsafe { + /// conn.load_extension_enable()?; + /// let r = conn.load_extension("my/trusted/extension", None); + /// conn.load_extension_disable()?; + /// r + /// } + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + /// + /// # Safety + /// + /// TLDR: Don't execute any untrusted queries between this call and + /// [`Connection::load_extension_disable`]. + /// + /// Perhaps surprisingly, this function does not only allow the use of + /// [`Connection::load_extension`] from Rust, but it also allows SQL queries + /// to perform [the same operation][loadext]. For example, in the period + /// between `load_extension_enable` and `load_extension_disable`, the + /// following operation will load and call some function in some dynamic + /// library: + /// + /// ```sql + /// SELECT load_extension('why_is_this_possible.dll', 'dubious_func'); + /// ``` + /// + /// This means that while this is enabled a carefully crafted SQL query can + /// be used to escalate a SQL injection attack into code execution. + /// + /// Safely using this function requires that you trust all SQL queries run + /// between when it is called, and when loading is disabled (by + /// [`Connection::load_extension_disable`]). + /// + /// [loadext]: https://www.sqlite.org/lang_corefunc.html#load_extension + #[cfg(feature = "load_extension")] + #[cfg_attr(docsrs, doc(cfg(feature = "load_extension")))] + #[inline] + pub unsafe fn load_extension_enable(&self) -> Result<()> { + self.db.borrow_mut().enable_load_extension(1) + } + + /// Disable loading of SQLite extensions. + /// + /// See [`Connection::load_extension_enable`] for an example. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[cfg(feature = "load_extension")] + #[cfg_attr(docsrs, doc(cfg(feature = "load_extension")))] + #[inline] + pub fn load_extension_disable(&self) -> Result<()> { + // It's always safe to turn off extension loading. + unsafe { self.db.borrow_mut().enable_load_extension(0) } + } + + /// Load the SQLite extension at `dylib_path`. `dylib_path` is passed + /// through to `sqlite3_load_extension`, which may attempt OS-specific + /// modifications if the file cannot be loaded directly (for example + /// converting `"some/ext"` to `"some/ext.so"`, `"some\\ext.dll"`, ...). + /// + /// If `entry_point` is `None`, SQLite will attempt to find the entry point. + /// If it is not `None`, the entry point will be passed through to + /// `sqlite3_load_extension`. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result, LoadExtensionGuard}; + /// fn load_my_extension(conn: &Connection) -> Result<()> { + /// // Safety: we don't execute any SQL statements while + /// // extension loading is enabled. + /// let _guard = unsafe { LoadExtensionGuard::new(conn)? }; + /// // Safety: `my_sqlite_extension` is highly trustworthy. + /// unsafe { conn.load_extension("my_sqlite_extension", None) } + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + /// + /// # Safety + /// + /// This is equivalent to performing a `dlopen`/`LoadLibrary` on a shared + /// library, and calling a function inside, and thus requires that you trust + /// the library that you're loading. + /// + /// That is to say: to safely use this, the code in the extension must be + /// sound, trusted, correctly use the SQLite APIs, and not contain any + /// memory or thread safety errors. + #[cfg(feature = "load_extension")] + #[cfg_attr(docsrs, doc(cfg(feature = "load_extension")))] + #[inline] + pub unsafe fn load_extension<P: AsRef<Path>>( + &self, + dylib_path: P, + entry_point: Option<&str>, + ) -> Result<()> { + self.db + .borrow_mut() + .load_extension(dylib_path.as_ref(), entry_point) + } + + /// Get access to the underlying SQLite database connection handle. + /// + /// # Warning + /// + /// You should not need to use this function. If you do need to, please + /// [open an issue on the rusqlite repository](https://github.com/rusqlite/rusqlite/issues) and describe + /// your use case. + /// + /// # Safety + /// + /// This function is unsafe because it gives you raw access + /// to the SQLite connection, and what you do with it could impact the + /// safety of this `Connection`. + #[inline] + pub unsafe fn handle(&self) -> *mut ffi::sqlite3 { + self.db.borrow().db() + } + + /// Create a `Connection` from a raw handle. + /// + /// The underlying SQLite database connection handle will not be closed when + /// the returned connection is dropped/closed. + /// + /// # Safety + /// + /// This function is unsafe because improper use may impact the Connection. + #[inline] + pub unsafe fn from_handle(db: *mut ffi::sqlite3) -> Result<Connection> { + let db = InnerConnection::new(db, false); + Ok(Connection { + db: RefCell::new(db), + cache: StatementCache::with_capacity(STATEMENT_CACHE_DEFAULT_CAPACITY), + }) + } + + /// Like SQLITE_EXTENSION_INIT2 macro + #[cfg(feature = "loadable_extension")] + #[cfg_attr(docsrs, doc(cfg(feature = "loadable_extension")))] + pub unsafe fn extension_init2( + db: *mut ffi::sqlite3, + p_api: *mut ffi::sqlite3_api_routines, + ) -> Result<Connection> { + ffi::rusqlite_extension_init2(p_api)?; + Connection::from_handle(db) + } + + /// Create a `Connection` from a raw owned handle. + /// + /// The returned connection will attempt to close the inner connection + /// when dropped/closed. This function should only be called on connections + /// owned by the caller. + /// + /// # Safety + /// + /// This function is unsafe because improper use may impact the Connection. + /// In particular, it should only be called on connections created + /// and owned by the caller, e.g. as a result of calling + /// `ffi::sqlite3_open`(). + #[inline] + pub unsafe fn from_handle_owned(db: *mut ffi::sqlite3) -> Result<Connection> { + let db = InnerConnection::new(db, true); + Ok(Connection { + db: RefCell::new(db), + cache: StatementCache::with_capacity(STATEMENT_CACHE_DEFAULT_CAPACITY), + }) + } + + /// Get access to a handle that can be used to interrupt long running + /// queries from another thread. + #[inline] + pub fn get_interrupt_handle(&self) -> InterruptHandle { + self.db.borrow().get_interrupt_handle() + } + + #[inline] + fn decode_result(&self, code: c_int) -> Result<()> { + self.db.borrow().decode_result(code) + } + + /// Return the number of rows modified, inserted or deleted by the most + /// recently completed INSERT, UPDATE or DELETE statement on the database + /// connection. + /// + /// See <https://www.sqlite.org/c3ref/changes.html> + #[inline] + pub fn changes(&self) -> u64 { + self.db.borrow().changes() + } + + /// Test for auto-commit mode. + /// Autocommit mode is on by default. + #[inline] + pub fn is_autocommit(&self) -> bool { + self.db.borrow().is_autocommit() + } + + /// Determine if all associated prepared statements have been reset. + #[inline] + pub fn is_busy(&self) -> bool { + self.db.borrow().is_busy() + } + + /// Flush caches to disk mid-transaction + pub fn cache_flush(&self) -> Result<()> { + self.db.borrow_mut().cache_flush() + } + + /// Determine if a database is read-only + pub fn is_readonly(&self, db_name: DatabaseName<'_>) -> Result<bool> { + self.db.borrow().db_readonly(db_name) + } +} + +impl fmt::Debug for Connection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Connection") + .field("path", &self.path()) + .finish() + } +} + +/// Batch iterator +/// ```rust +/// use rusqlite::{Batch, Connection, Result}; +/// +/// fn main() -> Result<()> { +/// let conn = Connection::open_in_memory()?; +/// let sql = r" +/// CREATE TABLE tbl1 (col); +/// CREATE TABLE tbl2 (col); +/// "; +/// let mut batch = Batch::new(&conn, sql); +/// while let Some(mut stmt) = batch.next()? { +/// stmt.execute([])?; +/// } +/// Ok(()) +/// } +/// ``` +#[derive(Debug)] +pub struct Batch<'conn, 'sql> { + conn: &'conn Connection, + sql: &'sql str, + tail: usize, +} + +impl<'conn, 'sql> Batch<'conn, 'sql> { + /// Constructor + pub fn new(conn: &'conn Connection, sql: &'sql str) -> Batch<'conn, 'sql> { + Batch { conn, sql, tail: 0 } + } + + /// Iterates on each batch statements. + /// + /// Returns `Ok(None)` when batch is completed. + #[allow(clippy::should_implement_trait)] // fallible iterator + pub fn next(&mut self) -> Result<Option<Statement<'conn>>> { + while self.tail < self.sql.len() { + let sql = &self.sql[self.tail..]; + let next = self.conn.prepare(sql)?; + let tail = next.stmt.tail(); + if tail == 0 { + self.tail = self.sql.len(); + } else { + self.tail += tail; + } + if next.stmt.is_null() { + continue; + } + return Ok(Some(next)); + } + Ok(None) + } +} + +impl<'conn> Iterator for Batch<'conn, '_> { + type Item = Result<Statement<'conn>>; + + fn next(&mut self) -> Option<Result<Statement<'conn>>> { + self.next().transpose() + } +} + +bitflags::bitflags! { + /// Flags for opening SQLite database connections. See + /// [sqlite3_open_v2](https://www.sqlite.org/c3ref/open.html) for details. + /// + /// The default open flags are `SQLITE_OPEN_READ_WRITE | SQLITE_OPEN_CREATE + /// | SQLITE_OPEN_URI | SQLITE_OPEN_NO_MUTEX`. See [`Connection::open`] for + /// some discussion about these flags. + #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] + #[repr(C)] + pub struct OpenFlags: ::std::os::raw::c_int { + /// The database is opened in read-only mode. + /// If the database does not already exist, an error is returned. + const SQLITE_OPEN_READ_ONLY = ffi::SQLITE_OPEN_READONLY; + /// The database is opened for reading and writing if possible, + /// or reading only if the file is write protected by the operating system. + /// In either case the database must already exist, otherwise an error is returned. + const SQLITE_OPEN_READ_WRITE = ffi::SQLITE_OPEN_READWRITE; + /// The database is created if it does not already exist + const SQLITE_OPEN_CREATE = ffi::SQLITE_OPEN_CREATE; + /// The filename can be interpreted as a URI if this flag is set. + const SQLITE_OPEN_URI = ffi::SQLITE_OPEN_URI; + /// The database will be opened as an in-memory database. + const SQLITE_OPEN_MEMORY = ffi::SQLITE_OPEN_MEMORY; + /// The new database connection will not use a per-connection mutex (the + /// connection will use the "multi-thread" threading mode, in SQLite + /// parlance). + /// + /// This is used by default, as proper `Send`/`Sync` usage (in + /// particular, the fact that [`Connection`] does not implement `Sync`) + /// ensures thread-safety without the need to perform locking around all + /// calls. + const SQLITE_OPEN_NO_MUTEX = ffi::SQLITE_OPEN_NOMUTEX; + /// The new database connection will use a per-connection mutex -- the + /// "serialized" threading mode, in SQLite parlance. + /// + /// # Caveats + /// + /// This flag should probably never be used with `rusqlite`, as we + /// ensure thread-safety statically (we implement [`Send`] and not + /// [`Sync`]). That said + /// + /// Critically, even if this flag is used, the [`Connection`] is not + /// safe to use across multiple threads simultaneously. To access a + /// database from multiple threads, you should either create multiple + /// connections, one for each thread (if you have very many threads, + /// wrapping the `rusqlite::Connection` in a mutex is also reasonable). + /// + /// This is both because of the additional per-connection state stored + /// by `rusqlite` (for example, the prepared statement cache), and + /// because not all of SQLites functions are fully thread safe, even in + /// serialized/`SQLITE_OPEN_FULLMUTEX` mode. + /// + /// All that said, it's fairly harmless to enable this flag with + /// `rusqlite`, it will just slow things down while providing no + /// benefit. + const SQLITE_OPEN_FULL_MUTEX = ffi::SQLITE_OPEN_FULLMUTEX; + /// The database is opened with shared cache enabled. + /// + /// This is frequently useful for in-memory connections, but note that + /// broadly speaking it's discouraged by SQLite itself, which states + /// "Any use of shared cache is discouraged" in the official + /// [documentation](https://www.sqlite.org/c3ref/enable_shared_cache.html). + const SQLITE_OPEN_SHARED_CACHE = 0x0002_0000; + /// The database is opened shared cache disabled. + const SQLITE_OPEN_PRIVATE_CACHE = 0x0004_0000; + /// The database filename is not allowed to be a symbolic link. (3.31.0) + const SQLITE_OPEN_NOFOLLOW = 0x0100_0000; + /// Extended result codes. (3.37.0) + const SQLITE_OPEN_EXRESCODE = 0x0200_0000; + } +} + +impl Default for OpenFlags { + #[inline] + fn default() -> OpenFlags { + // Note: update the `Connection::open` and top-level `OpenFlags` docs if + // you change these. + OpenFlags::SQLITE_OPEN_READ_WRITE + | OpenFlags::SQLITE_OPEN_CREATE + | OpenFlags::SQLITE_OPEN_NO_MUTEX + | OpenFlags::SQLITE_OPEN_URI + } +} + +bitflags::bitflags! { + /// Prepare flags. See + /// [sqlite3_prepare_v3](https://sqlite.org/c3ref/c_prepare_normalize.html) for details. + #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] + #[repr(C)] + pub struct PrepFlags: ::std::os::raw::c_uint { + /// A hint to the query planner that the prepared statement will be retained for a long time and probably reused many times. + const SQLITE_PREPARE_PERSISTENT = 0x01; + /// Causes the SQL compiler to return an error (error code SQLITE_ERROR) if the statement uses any virtual tables. + const SQLITE_PREPARE_NO_VTAB = 0x04; + } +} + +/// rusqlite's check for a safe SQLite threading mode requires SQLite 3.7.0 or +/// later. If you are running against a SQLite older than that, rusqlite +/// attempts to ensure safety by performing configuration and initialization of +/// SQLite itself the first time you +/// attempt to open a connection. By default, rusqlite panics if that +/// initialization fails, since that could mean SQLite has been initialized in +/// single-thread mode. +/// +/// If you are encountering that panic _and_ can ensure that SQLite has been +/// initialized in either multi-thread or serialized mode, call this function +/// prior to attempting to open a connection and rusqlite's initialization +/// process will by skipped. +/// +/// # Safety +/// +/// This function is unsafe because if you call it and SQLite has actually been +/// configured to run in single-thread mode, +/// you may encounter memory errors or data corruption or any number of terrible +/// things that should not be possible when you're using Rust. +pub unsafe fn bypass_sqlite_initialization() { + BYPASS_SQLITE_INIT.store(true, Ordering::Relaxed); +} + +/// Allows interrupting a long-running computation. +pub struct InterruptHandle { + db_lock: Arc<Mutex<*mut ffi::sqlite3>>, +} + +unsafe impl Send for InterruptHandle {} +unsafe impl Sync for InterruptHandle {} + +impl InterruptHandle { + /// Interrupt the query currently executing on another thread. This will + /// cause that query to fail with a `SQLITE3_INTERRUPT` error. + pub fn interrupt(&self) { + let db_handle = self.db_lock.lock().unwrap(); + if !db_handle.is_null() { + unsafe { ffi::sqlite3_interrupt(*db_handle) } + } + } +} + +#[cfg(doctest)] +doc_comment::doctest!("../README.md"); + +#[cfg(test)] +mod test { + use super::*; + use crate::ffi; + use fallible_iterator::FallibleIterator; + use std::error::Error as StdError; + use std::fmt; + + // this function is never called, but is still type checked; in + // particular, calls with specific instantiations will require + // that those types are `Send`. + #[allow( + dead_code, + unconditional_recursion, + clippy::extra_unused_type_parameters + )] + fn ensure_send<T: Send>() { + ensure_send::<Connection>(); + ensure_send::<InterruptHandle>(); + } + + #[allow( + dead_code, + unconditional_recursion, + clippy::extra_unused_type_parameters + )] + fn ensure_sync<T: Sync>() { + ensure_sync::<InterruptHandle>(); + } + + fn checked_memory_handle() -> Connection { + Connection::open_in_memory().unwrap() + } + + #[test] + fn test_concurrent_transactions_busy_commit() -> Result<()> { + use std::time::Duration; + let tmp = tempfile::tempdir().unwrap(); + let path = tmp.path().join("transactions.db3"); + + Connection::open(&path)?.execute_batch( + " + BEGIN; CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(42); END;", + )?; + + let mut db1 = Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_WRITE)?; + let mut db2 = Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_ONLY)?; + + db1.busy_timeout(Duration::from_millis(0))?; + db2.busy_timeout(Duration::from_millis(0))?; + + { + let tx1 = db1.transaction()?; + let tx2 = db2.transaction()?; + + // SELECT first makes sqlite lock with a shared lock + tx1.query_row("SELECT x FROM foo LIMIT 1", [], |_| Ok(()))?; + tx2.query_row("SELECT x FROM foo LIMIT 1", [], |_| Ok(()))?; + + tx1.execute("INSERT INTO foo VALUES(?1)", [1])?; + let _ = tx2.execute("INSERT INTO foo VALUES(?1)", [2]); + + let _ = tx1.commit(); + let _ = tx2.commit(); + } + + let _ = db1 + .transaction() + .expect("commit should have closed transaction"); + let _ = db2 + .transaction() + .expect("commit should have closed transaction"); + Ok(()) + } + + #[test] + fn test_persistence() -> Result<()> { + let temp_dir = tempfile::tempdir().unwrap(); + let path = temp_dir.path().join("test.db3"); + + { + let db = Connection::open(&path)?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(42); + END;"; + db.execute_batch(sql)?; + } + + let path_string = path.to_str().unwrap(); + let db = Connection::open(path_string)?; + let the_answer: i64 = db.one_column("SELECT x FROM foo")?; + + assert_eq!(42i64, the_answer); + Ok(()) + } + + #[test] + fn test_open() { + Connection::open_in_memory().unwrap(); + + let db = checked_memory_handle(); + db.close().unwrap(); + } + + #[test] + fn test_path() -> Result<()> { + let tmp = tempfile::tempdir().unwrap(); + let db = Connection::open("")?; + assert_eq!(Some(""), db.path()); + let db = Connection::open_in_memory()?; + assert_eq!(Some(""), db.path()); + let db = Connection::open("file:dummy.db?mode=memory&cache=shared")?; + assert_eq!(Some(""), db.path()); + let path = tmp.path().join("file.db"); + let db = Connection::open(path)?; + assert!(db.path().map(|p| p.ends_with("file.db")).unwrap_or(false)); + Ok(()) + } + + #[test] + fn test_open_failure() { + let filename = "no_such_file.db"; + let result = Connection::open_with_flags(filename, OpenFlags::SQLITE_OPEN_READ_ONLY); + let err = result.unwrap_err(); + if let Error::SqliteFailure(e, Some(msg)) = err { + assert_eq!(ErrorCode::CannotOpen, e.code); + assert_eq!(ffi::SQLITE_CANTOPEN, e.extended_code); + assert!( + msg.contains(filename), + "error message '{msg}' does not contain '{filename}'" + ); + } else { + panic!("SqliteFailure expected"); + } + } + + #[cfg(unix)] + #[test] + fn test_invalid_unicode_file_names() -> Result<()> { + use std::ffi::OsStr; + use std::fs::File; + use std::os::unix::ffi::OsStrExt; + let temp_dir = tempfile::tempdir().unwrap(); + + let path = temp_dir.path(); + if File::create(path.join(OsStr::from_bytes(&[0xFE]))).is_err() { + // Skip test, filesystem doesn't support invalid Unicode + return Ok(()); + } + let db_path = path.join(OsStr::from_bytes(&[0xFF])); + { + let db = Connection::open(&db_path)?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(42); + END;"; + db.execute_batch(sql)?; + } + + let db = Connection::open(&db_path)?; + let the_answer: i64 = db.one_column("SELECT x FROM foo")?; + + assert_eq!(42i64, the_answer); + Ok(()) + } + + #[test] + fn test_close_retry() -> Result<()> { + let db = Connection::open_in_memory()?; + + // force the DB to be busy by preparing a statement; this must be done at the + // FFI level to allow us to call .close() without dropping the prepared + // statement first. + let raw_stmt = { + use super::str_to_cstring; + use std::os::raw::c_int; + use std::ptr; + + let raw_db = db.db.borrow_mut().db; + let sql = "SELECT 1"; + let mut raw_stmt: *mut ffi::sqlite3_stmt = ptr::null_mut(); + let cstring = str_to_cstring(sql)?; + let rc = unsafe { + ffi::sqlite3_prepare_v2( + raw_db, + cstring.as_ptr(), + (sql.len() + 1) as c_int, + &mut raw_stmt, + ptr::null_mut(), + ) + }; + assert_eq!(rc, ffi::SQLITE_OK); + raw_stmt + }; + + // now that we have an open statement, trying (and retrying) to close should + // fail. + let (db, _) = db.close().unwrap_err(); + let (db, _) = db.close().unwrap_err(); + let (db, _) = db.close().unwrap_err(); + + // finalize the open statement so a final close will succeed + assert_eq!(ffi::SQLITE_OK, unsafe { ffi::sqlite3_finalize(raw_stmt) }); + + db.close().unwrap(); + Ok(()) + } + + #[test] + fn test_open_with_flags() { + for bad_flags in &[ + OpenFlags::empty(), + OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_READ_WRITE, + OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_CREATE, + ] { + Connection::open_in_memory_with_flags(*bad_flags).unwrap_err(); + } + } + + #[test] + fn test_execute_batch() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(1); + INSERT INTO foo VALUES(2); + INSERT INTO foo VALUES(3); + INSERT INTO foo VALUES(4); + END;"; + db.execute_batch(sql)?; + + db.execute_batch("UPDATE foo SET x = 3 WHERE x < 3")?; + + db.execute_batch("INVALID SQL").unwrap_err(); + Ok(()) + } + + #[test] + fn test_execute() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER)")?; + + assert_eq!(1, db.execute("INSERT INTO foo(x) VALUES (?1)", [1i32])?); + assert_eq!(1, db.execute("INSERT INTO foo(x) VALUES (?1)", [2i32])?); + + assert_eq!(3i32, db.one_column::<i32>("SELECT SUM(x) FROM foo")?); + Ok(()) + } + + #[test] + #[cfg(feature = "extra_check")] + fn test_execute_select() { + let db = checked_memory_handle(); + let err = db.execute("SELECT 1 WHERE 1 < ?1", [1i32]).unwrap_err(); + assert_eq!( + err, + Error::ExecuteReturnedResults, + "Unexpected error: {err}" + ); + } + + #[test] + #[cfg(feature = "extra_check")] + fn test_execute_multiple() { + let db = checked_memory_handle(); + let err = db + .execute( + "CREATE TABLE foo(x INTEGER); CREATE TABLE foo(x INTEGER)", + [], + ) + .unwrap_err(); + match err { + Error::MultipleStatement => (), + _ => panic!("Unexpected error: {err}"), + } + } + + #[test] + fn test_prepare_column_names() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER);")?; + + let stmt = db.prepare("SELECT * FROM foo")?; + assert_eq!(stmt.column_count(), 1); + assert_eq!(stmt.column_names(), vec!["x"]); + + let stmt = db.prepare("SELECT x AS a, x AS b FROM foo")?; + assert_eq!(stmt.column_count(), 2); + assert_eq!(stmt.column_names(), vec!["a", "b"]); + Ok(()) + } + + #[test] + fn test_prepare_execute() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER);")?; + + let mut insert_stmt = db.prepare("INSERT INTO foo(x) VALUES(?1)")?; + assert_eq!(insert_stmt.execute([1i32])?, 1); + assert_eq!(insert_stmt.execute([2i32])?, 1); + assert_eq!(insert_stmt.execute([3i32])?, 1); + + assert_eq!(insert_stmt.execute(["hello"])?, 1); + assert_eq!(insert_stmt.execute(["goodbye"])?, 1); + assert_eq!(insert_stmt.execute([types::Null])?, 1); + + let mut update_stmt = db.prepare("UPDATE foo SET x=?1 WHERE x<?2")?; + assert_eq!(update_stmt.execute([3i32, 3i32])?, 2); + assert_eq!(update_stmt.execute([3i32, 3i32])?, 0); + assert_eq!(update_stmt.execute([8i32, 8i32])?, 3); + Ok(()) + } + + #[test] + fn test_prepare_query() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER);")?; + + let mut insert_stmt = db.prepare("INSERT INTO foo(x) VALUES(?1)")?; + assert_eq!(insert_stmt.execute([1i32])?, 1); + assert_eq!(insert_stmt.execute([2i32])?, 1); + assert_eq!(insert_stmt.execute([3i32])?, 1); + + let mut query = db.prepare("SELECT x FROM foo WHERE x < ?1 ORDER BY x DESC")?; + { + let mut rows = query.query([4i32])?; + let mut v = Vec::<i32>::new(); + + while let Some(row) = rows.next()? { + v.push(row.get(0)?); + } + + assert_eq!(v, [3i32, 2, 1]); + } + + { + let mut rows = query.query([3i32])?; + let mut v = Vec::<i32>::new(); + + while let Some(row) = rows.next()? { + v.push(row.get(0)?); + } + + assert_eq!(v, [2i32, 1]); + } + Ok(()) + } + + #[test] + fn test_query_map() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + INSERT INTO foo VALUES(3, \", \"); + INSERT INTO foo VALUES(2, \"world\"); + INSERT INTO foo VALUES(1, \"!\"); + END;"; + db.execute_batch(sql)?; + + let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC")?; + let results: Result<Vec<String>> = query.query([])?.map(|row| row.get(1)).collect(); + + assert_eq!(results?.concat(), "hello, world!"); + Ok(()) + } + + #[test] + fn test_query_row() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(1); + INSERT INTO foo VALUES(2); + INSERT INTO foo VALUES(3); + INSERT INTO foo VALUES(4); + END;"; + db.execute_batch(sql)?; + + assert_eq!(10i64, db.one_column::<i64>("SELECT SUM(x) FROM foo")?); + + let result: Result<i64> = db.one_column("SELECT x FROM foo WHERE x > 5"); + match result.unwrap_err() { + Error::QueryReturnedNoRows => (), + err => panic!("Unexpected error {err}"), + } + + let bad_query_result = db.query_row("NOT A PROPER QUERY; test123", [], |_| Ok(())); + + bad_query_result.unwrap_err(); + Ok(()) + } + + #[test] + fn test_optional() -> Result<()> { + let db = Connection::open_in_memory()?; + + let result: Result<i64> = db.one_column("SELECT 1 WHERE 0 <> 0"); + let result = result.optional(); + match result? { + None => (), + _ => panic!("Unexpected result"), + } + + let result: Result<i64> = db.one_column("SELECT 1 WHERE 0 == 0"); + let result = result.optional(); + match result? { + Some(1) => (), + _ => panic!("Unexpected result"), + } + + let bad_query_result: Result<i64> = db.one_column("NOT A PROPER QUERY"); + let bad_query_result = bad_query_result.optional(); + bad_query_result.unwrap_err(); + Ok(()) + } + + #[test] + fn test_pragma_query_row() -> Result<()> { + let db = Connection::open_in_memory()?; + assert_eq!("memory", db.one_column::<String>("PRAGMA journal_mode")?); + let mode = db.one_column::<String>("PRAGMA journal_mode=off")?; + if cfg!(features = "bundled") { + assert_eq!(mode, "off"); + } else { + // Note: system SQLite on macOS defaults to "off" rather than + // "memory" for the journal mode (which cannot be changed for + // in-memory connections). This seems like it's *probably* legal + // according to the docs below, so we relax this test when not + // bundling: + // + // From https://www.sqlite.org/pragma.html#pragma_journal_mode + // > Note that the journal_mode for an in-memory database is either + // > MEMORY or OFF and can not be changed to a different value. An + // > attempt to change the journal_mode of an in-memory database to + // > any setting other than MEMORY or OFF is ignored. + assert!(mode == "memory" || mode == "off", "Got mode {mode:?}"); + } + + Ok(()) + } + + #[test] + fn test_prepare_failures() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER);")?; + + let err = db.prepare("SELECT * FROM does_not_exist").unwrap_err(); + assert!(format!("{err}").contains("does_not_exist")); + Ok(()) + } + + #[test] + fn test_last_insert_rowid() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER PRIMARY KEY)")?; + db.execute_batch("INSERT INTO foo DEFAULT VALUES")?; + + assert_eq!(db.last_insert_rowid(), 1); + + let mut stmt = db.prepare("INSERT INTO foo DEFAULT VALUES")?; + for _ in 0i32..9 { + stmt.execute([])?; + } + assert_eq!(db.last_insert_rowid(), 10); + Ok(()) + } + + #[test] + fn test_is_autocommit() -> Result<()> { + let db = Connection::open_in_memory()?; + assert!( + db.is_autocommit(), + "autocommit expected to be active by default" + ); + Ok(()) + } + + #[test] + fn test_is_busy() -> Result<()> { + let db = Connection::open_in_memory()?; + assert!(!db.is_busy()); + let mut stmt = db.prepare("PRAGMA schema_version")?; + assert!(!db.is_busy()); + { + let mut rows = stmt.query([])?; + assert!(!db.is_busy()); + let row = rows.next()?; + assert!(db.is_busy()); + assert!(row.is_some()); + } + assert!(!db.is_busy()); + Ok(()) + } + + #[test] + fn test_statement_debugging() -> Result<()> { + let db = Connection::open_in_memory()?; + let query = "SELECT 12345"; + let stmt = db.prepare(query)?; + + assert!(format!("{stmt:?}").contains(query)); + Ok(()) + } + + #[test] + fn test_notnull_constraint_error() -> Result<()> { + // extended error codes for constraints were added in SQLite 3.7.16; if we're + // running on our bundled version, we know the extended error code exists. + fn check_extended_code(extended_code: c_int) { + assert_eq!(extended_code, ffi::SQLITE_CONSTRAINT_NOTNULL); + } + + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x NOT NULL)")?; + + let result = db.execute("INSERT INTO foo (x) VALUES (NULL)", []); + + match result.unwrap_err() { + Error::SqliteFailure(err, _) => { + assert_eq!(err.code, ErrorCode::ConstraintViolation); + check_extended_code(err.extended_code); + } + err => panic!("Unexpected error {err}"), + } + Ok(()) + } + + #[test] + fn test_version_string() { + let n = version_number(); + let major = n / 1_000_000; + let minor = (n % 1_000_000) / 1_000; + let patch = n % 1_000; + + assert!(version().contains(&format!("{major}.{minor}.{patch}"))); + } + + #[test] + #[cfg(feature = "functions")] + fn test_interrupt() -> Result<()> { + let db = Connection::open_in_memory()?; + + let interrupt_handle = db.get_interrupt_handle(); + + db.create_scalar_function( + "interrupt", + 0, + functions::FunctionFlags::default(), + move |_| { + interrupt_handle.interrupt(); + Ok(0) + }, + )?; + + let mut stmt = + db.prepare("SELECT interrupt() FROM (SELECT 1 UNION SELECT 2 UNION SELECT 3)")?; + + let result: Result<Vec<i32>> = stmt.query([])?.map(|r| r.get(0)).collect(); + + assert_eq!( + result.unwrap_err().sqlite_error_code(), + Some(ErrorCode::OperationInterrupted) + ); + Ok(()) + } + + #[test] + fn test_interrupt_close() { + let db = checked_memory_handle(); + let handle = db.get_interrupt_handle(); + handle.interrupt(); + db.close().unwrap(); + handle.interrupt(); + + // Look at it's internals to see if we cleared it out properly. + let db_guard = handle.db_lock.lock().unwrap(); + assert!(db_guard.is_null()); + // It would be nice to test that we properly handle close/interrupt + // running at the same time, but it seems impossible to do with any + // degree of reliability. + } + + #[test] + fn test_get_raw() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(i, x);")?; + let vals = ["foobar", "1234", "qwerty"]; + let mut insert_stmt = db.prepare("INSERT INTO foo(i, x) VALUES(?1, ?2)")?; + for (i, v) in vals.iter().enumerate() { + let i_to_insert = i as i64; + assert_eq!(insert_stmt.execute(params![i_to_insert, v])?, 1); + } + + let mut query = db.prepare("SELECT i, x FROM foo")?; + let mut rows = query.query([])?; + + while let Some(row) = rows.next()? { + let i = row.get_ref(0)?.as_i64()?; + let expect = vals[i as usize]; + let x = row.get_ref("x")?.as_str()?; + assert_eq!(x, expect); + } + + let mut query = db.prepare("SELECT x FROM foo")?; + let rows = query.query_map([], |row| { + let x = row.get_ref(0)?.as_str()?; // check From<FromSqlError> for Error + Ok(x[..].to_owned()) + })?; + + for (i, row) in rows.enumerate() { + assert_eq!(row?, vals[i]); + } + Ok(()) + } + + #[test] + fn test_from_handle() -> Result<()> { + let db = Connection::open_in_memory()?; + let handle = unsafe { db.handle() }; + { + let db = unsafe { Connection::from_handle(handle) }?; + db.execute_batch("PRAGMA VACUUM")?; + } + db.close().unwrap(); + Ok(()) + } + + #[test] + fn test_from_handle_owned() -> Result<()> { + let mut handle: *mut ffi::sqlite3 = std::ptr::null_mut(); + let r = unsafe { ffi::sqlite3_open(":memory:\0".as_ptr() as *const i8, &mut handle) }; + assert_eq!(r, ffi::SQLITE_OK); + let db = unsafe { Connection::from_handle_owned(handle) }?; + db.execute_batch("PRAGMA VACUUM")?; + Ok(()) + } + + mod query_and_then_tests { + + use super::*; + + #[derive(Debug)] + enum CustomError { + SomeError, + Sqlite(Error), + } + + impl fmt::Display for CustomError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match *self { + CustomError::SomeError => write!(f, "my custom error"), + CustomError::Sqlite(ref se) => write!(f, "my custom error: {se}"), + } + } + } + + impl StdError for CustomError { + fn description(&self) -> &str { + "my custom error" + } + + fn cause(&self) -> Option<&dyn StdError> { + match *self { + CustomError::SomeError => None, + CustomError::Sqlite(ref se) => Some(se), + } + } + } + + impl From<Error> for CustomError { + fn from(se: Error) -> CustomError { + CustomError::Sqlite(se) + } + } + + type CustomResult<T> = Result<T, CustomError>; + + #[test] + fn test_query_and_then() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + INSERT INTO foo VALUES(3, \", \"); + INSERT INTO foo VALUES(2, \"world\"); + INSERT INTO foo VALUES(1, \"!\"); + END;"; + db.execute_batch(sql)?; + + let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC")?; + let results: Result<Vec<String>> = + query.query_and_then([], |row| row.get(1))?.collect(); + + assert_eq!(results?.concat(), "hello, world!"); + Ok(()) + } + + #[test] + fn test_query_and_then_fails() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + INSERT INTO foo VALUES(3, \", \"); + INSERT INTO foo VALUES(2, \"world\"); + INSERT INTO foo VALUES(1, \"!\"); + END;"; + db.execute_batch(sql)?; + + let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC")?; + let bad_type: Result<Vec<f64>> = query.query_and_then([], |row| row.get(1))?.collect(); + + match bad_type.unwrap_err() { + Error::InvalidColumnType(..) => (), + err => panic!("Unexpected error {err}"), + } + + let bad_idx: Result<Vec<String>> = + query.query_and_then([], |row| row.get(3))?.collect(); + + match bad_idx.unwrap_err() { + Error::InvalidColumnIndex(_) => (), + err => panic!("Unexpected error {err}"), + } + Ok(()) + } + + #[test] + fn test_query_and_then_custom_error() -> CustomResult<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + INSERT INTO foo VALUES(3, \", \"); + INSERT INTO foo VALUES(2, \"world\"); + INSERT INTO foo VALUES(1, \"!\"); + END;"; + db.execute_batch(sql)?; + + let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC")?; + let results: CustomResult<Vec<String>> = query + .query_and_then([], |row| row.get(1).map_err(CustomError::Sqlite))? + .collect(); + + assert_eq!(results?.concat(), "hello, world!"); + Ok(()) + } + + #[test] + fn test_query_and_then_custom_error_fails() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + INSERT INTO foo VALUES(3, \", \"); + INSERT INTO foo VALUES(2, \"world\"); + INSERT INTO foo VALUES(1, \"!\"); + END;"; + db.execute_batch(sql)?; + + let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC")?; + let bad_type: CustomResult<Vec<f64>> = query + .query_and_then([], |row| row.get(1).map_err(CustomError::Sqlite))? + .collect(); + + match bad_type.unwrap_err() { + CustomError::Sqlite(Error::InvalidColumnType(..)) => (), + err => panic!("Unexpected error {err}"), + } + + let bad_idx: CustomResult<Vec<String>> = query + .query_and_then([], |row| row.get(3).map_err(CustomError::Sqlite))? + .collect(); + + match bad_idx.unwrap_err() { + CustomError::Sqlite(Error::InvalidColumnIndex(_)) => (), + err => panic!("Unexpected error {err}"), + } + + let non_sqlite_err: CustomResult<Vec<String>> = query + .query_and_then([], |_| Err(CustomError::SomeError))? + .collect(); + + match non_sqlite_err.unwrap_err() { + CustomError::SomeError => (), + err => panic!("Unexpected error {err}"), + } + Ok(()) + } + + #[test] + fn test_query_row_and_then_custom_error() -> CustomResult<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + END;"; + db.execute_batch(sql)?; + + let query = "SELECT x, y FROM foo ORDER BY x DESC"; + let results: CustomResult<String> = + db.query_row_and_then(query, [], |row| row.get(1).map_err(CustomError::Sqlite)); + + assert_eq!(results?, "hello"); + Ok(()) + } + + #[test] + fn test_query_row_and_then_custom_error_fails() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + END;"; + db.execute_batch(sql)?; + + let query = "SELECT x, y FROM foo ORDER BY x DESC"; + let bad_type: CustomResult<f64> = + db.query_row_and_then(query, [], |row| row.get(1).map_err(CustomError::Sqlite)); + + match bad_type.unwrap_err() { + CustomError::Sqlite(Error::InvalidColumnType(..)) => (), + err => panic!("Unexpected error {err}"), + } + + let bad_idx: CustomResult<String> = + db.query_row_and_then(query, [], |row| row.get(3).map_err(CustomError::Sqlite)); + + match bad_idx.unwrap_err() { + CustomError::Sqlite(Error::InvalidColumnIndex(_)) => (), + err => panic!("Unexpected error {err}"), + } + + let non_sqlite_err: CustomResult<String> = + db.query_row_and_then(query, [], |_| Err(CustomError::SomeError)); + + match non_sqlite_err.unwrap_err() { + CustomError::SomeError => (), + err => panic!("Unexpected error {err}"), + } + Ok(()) + } + } + + #[test] + fn test_dynamic() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + END;"; + db.execute_batch(sql)?; + + db.query_row("SELECT * FROM foo", [], |r| { + assert_eq!(2, r.as_ref().column_count()); + Ok(()) + }) + } + #[test] + fn test_dyn_box() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER);")?; + let b: Box<dyn ToSql> = Box::new(5); + db.execute("INSERT INTO foo VALUES(?1)", [b])?; + db.query_row("SELECT x FROM foo", [], |r| { + assert_eq!(5, r.get_unwrap::<_, i32>(0)); + Ok(()) + }) + } + + #[test] + fn test_params() -> Result<()> { + let db = Connection::open_in_memory()?; + db.query_row( + "SELECT + ?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, + ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19, ?20, + ?21, ?22, ?23, ?24, ?25, ?26, ?27, ?28, ?29, ?30, + ?31, ?32, ?33, ?34;", + params![ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, + ], + |r| { + assert_eq!(1, r.get_unwrap::<_, i32>(0)); + Ok(()) + }, + ) + } + + #[test] + #[cfg(not(feature = "extra_check"))] + fn test_alter_table() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE x(t);")?; + // `execute_batch` should be used but `execute` should also work + db.execute("ALTER TABLE x RENAME TO y;", [])?; + Ok(()) + } + + #[test] + fn test_batch() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = r" + CREATE TABLE tbl1 (col); + CREATE TABLE tbl2 (col); + "; + let batch = Batch::new(&db, sql); + for stmt in batch { + let mut stmt = stmt?; + stmt.execute([])?; + } + Ok(()) + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn test_returning() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER PRIMARY KEY)")?; + let row_id = db.one_column::<i64>("INSERT INTO foo DEFAULT VALUES RETURNING ROWID")?; + assert_eq!(row_id, 1); + Ok(()) + } + + #[test] + fn test_cache_flush() -> Result<()> { + let db = Connection::open_in_memory()?; + db.cache_flush() + } + + #[test] + fn db_readonly() -> Result<()> { + let db = Connection::open_in_memory()?; + assert!(!db.is_readonly(MAIN_DB)?); + Ok(()) + } + + #[test] + #[cfg(feature = "rusqlite-macros")] + fn prepare_and_bind() -> Result<()> { + let db = Connection::open_in_memory()?; + let name = "Lisa"; + let age = 8; + let mut stmt = prepare_and_bind!(db, "SELECT $name, $age;"); + let (v1, v2) = stmt + .raw_query() + .next() + .and_then(|o| o.ok_or(Error::QueryReturnedNoRows)) + .and_then(|r| Ok((r.get::<_, String>(0)?, r.get::<_, i64>(1)?)))?; + assert_eq!((v1.as_str(), v2), (name, age)); + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/limits.rs b/third_party/rust/rusqlite/src/limits.rs new file mode 100644 index 0000000000..d0694e304c --- /dev/null +++ b/third_party/rust/rusqlite/src/limits.rs @@ -0,0 +1,163 @@ +//! Run-Time Limits + +use crate::{ffi, Connection}; +use std::os::raw::c_int; + +/// Run-Time limit categories, for use with [`Connection::limit`] and +/// [`Connection::set_limit`]. +/// +/// See the official documentation for more information: +/// - <https://www.sqlite.org/c3ref/c_limit_attached.html> +/// - <https://www.sqlite.org/limits.html> +#[repr(i32)] +#[non_exhaustive] +#[allow(clippy::upper_case_acronyms, non_camel_case_types)] +#[cfg_attr(docsrs, doc(cfg(feature = "limits")))] +pub enum Limit { + /// The maximum size of any string or BLOB or table row, in bytes. + SQLITE_LIMIT_LENGTH = ffi::SQLITE_LIMIT_LENGTH, + /// The maximum length of an SQL statement, in bytes. + SQLITE_LIMIT_SQL_LENGTH = ffi::SQLITE_LIMIT_SQL_LENGTH, + /// The maximum number of columns in a table definition or in the result set + /// of a SELECT or the maximum number of columns in an index or in an + /// ORDER BY or GROUP BY clause. + SQLITE_LIMIT_COLUMN = ffi::SQLITE_LIMIT_COLUMN, + /// The maximum depth of the parse tree on any expression. + SQLITE_LIMIT_EXPR_DEPTH = ffi::SQLITE_LIMIT_EXPR_DEPTH, + /// The maximum number of terms in a compound SELECT statement. + SQLITE_LIMIT_COMPOUND_SELECT = ffi::SQLITE_LIMIT_COMPOUND_SELECT, + /// The maximum number of instructions in a virtual machine program used to + /// implement an SQL statement. + SQLITE_LIMIT_VDBE_OP = ffi::SQLITE_LIMIT_VDBE_OP, + /// The maximum number of arguments on a function. + SQLITE_LIMIT_FUNCTION_ARG = ffi::SQLITE_LIMIT_FUNCTION_ARG, + /// The maximum number of attached databases. + SQLITE_LIMIT_ATTACHED = ffi::SQLITE_LIMIT_ATTACHED, + /// The maximum length of the pattern argument to the LIKE or GLOB + /// operators. + SQLITE_LIMIT_LIKE_PATTERN_LENGTH = ffi::SQLITE_LIMIT_LIKE_PATTERN_LENGTH, + /// The maximum index number of any parameter in an SQL statement. + SQLITE_LIMIT_VARIABLE_NUMBER = ffi::SQLITE_LIMIT_VARIABLE_NUMBER, + /// The maximum depth of recursion for triggers. + SQLITE_LIMIT_TRIGGER_DEPTH = ffi::SQLITE_LIMIT_TRIGGER_DEPTH, + /// The maximum number of auxiliary worker threads that a single prepared + /// statement may start. + SQLITE_LIMIT_WORKER_THREADS = ffi::SQLITE_LIMIT_WORKER_THREADS, +} + +impl Connection { + /// Returns the current value of a [`Limit`]. + #[inline] + #[cfg_attr(docsrs, doc(cfg(feature = "limits")))] + pub fn limit(&self, limit: Limit) -> i32 { + let c = self.db.borrow(); + unsafe { ffi::sqlite3_limit(c.db(), limit as c_int, -1) } + } + + /// Changes the [`Limit`] to `new_val`, returning the prior + /// value of the limit. + #[inline] + #[cfg_attr(docsrs, doc(cfg(feature = "limits")))] + pub fn set_limit(&self, limit: Limit, new_val: i32) -> i32 { + let c = self.db.borrow_mut(); + unsafe { ffi::sqlite3_limit(c.db(), limit as c_int, new_val) } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{Connection, Result}; + + #[test] + fn test_limit_values() { + assert_eq!(Limit::SQLITE_LIMIT_LENGTH as i32, ffi::SQLITE_LIMIT_LENGTH,); + assert_eq!( + Limit::SQLITE_LIMIT_SQL_LENGTH as i32, + ffi::SQLITE_LIMIT_SQL_LENGTH, + ); + assert_eq!(Limit::SQLITE_LIMIT_COLUMN as i32, ffi::SQLITE_LIMIT_COLUMN,); + assert_eq!( + Limit::SQLITE_LIMIT_EXPR_DEPTH as i32, + ffi::SQLITE_LIMIT_EXPR_DEPTH, + ); + assert_eq!( + Limit::SQLITE_LIMIT_COMPOUND_SELECT as i32, + ffi::SQLITE_LIMIT_COMPOUND_SELECT, + ); + assert_eq!( + Limit::SQLITE_LIMIT_VDBE_OP as i32, + ffi::SQLITE_LIMIT_VDBE_OP, + ); + assert_eq!( + Limit::SQLITE_LIMIT_FUNCTION_ARG as i32, + ffi::SQLITE_LIMIT_FUNCTION_ARG, + ); + assert_eq!( + Limit::SQLITE_LIMIT_ATTACHED as i32, + ffi::SQLITE_LIMIT_ATTACHED, + ); + assert_eq!( + Limit::SQLITE_LIMIT_LIKE_PATTERN_LENGTH as i32, + ffi::SQLITE_LIMIT_LIKE_PATTERN_LENGTH, + ); + assert_eq!( + Limit::SQLITE_LIMIT_VARIABLE_NUMBER as i32, + ffi::SQLITE_LIMIT_VARIABLE_NUMBER, + ); + #[cfg(feature = "bundled")] + assert_eq!( + Limit::SQLITE_LIMIT_TRIGGER_DEPTH as i32, + ffi::SQLITE_LIMIT_TRIGGER_DEPTH, + ); + #[cfg(feature = "bundled")] + assert_eq!( + Limit::SQLITE_LIMIT_WORKER_THREADS as i32, + ffi::SQLITE_LIMIT_WORKER_THREADS, + ); + } + + #[test] + fn test_limit() -> Result<()> { + let db = Connection::open_in_memory()?; + db.set_limit(Limit::SQLITE_LIMIT_LENGTH, 1024); + assert_eq!(1024, db.limit(Limit::SQLITE_LIMIT_LENGTH)); + + db.set_limit(Limit::SQLITE_LIMIT_SQL_LENGTH, 1024); + assert_eq!(1024, db.limit(Limit::SQLITE_LIMIT_SQL_LENGTH)); + + db.set_limit(Limit::SQLITE_LIMIT_COLUMN, 64); + assert_eq!(64, db.limit(Limit::SQLITE_LIMIT_COLUMN)); + + db.set_limit(Limit::SQLITE_LIMIT_EXPR_DEPTH, 256); + assert_eq!(256, db.limit(Limit::SQLITE_LIMIT_EXPR_DEPTH)); + + db.set_limit(Limit::SQLITE_LIMIT_COMPOUND_SELECT, 32); + assert_eq!(32, db.limit(Limit::SQLITE_LIMIT_COMPOUND_SELECT)); + + db.set_limit(Limit::SQLITE_LIMIT_FUNCTION_ARG, 32); + assert_eq!(32, db.limit(Limit::SQLITE_LIMIT_FUNCTION_ARG)); + + db.set_limit(Limit::SQLITE_LIMIT_ATTACHED, 2); + assert_eq!(2, db.limit(Limit::SQLITE_LIMIT_ATTACHED)); + + db.set_limit(Limit::SQLITE_LIMIT_LIKE_PATTERN_LENGTH, 128); + assert_eq!(128, db.limit(Limit::SQLITE_LIMIT_LIKE_PATTERN_LENGTH)); + + db.set_limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER, 99); + assert_eq!(99, db.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER)); + + // SQLITE_LIMIT_TRIGGER_DEPTH was added in SQLite 3.6.18. + if crate::version_number() >= 3_006_018 { + db.set_limit(Limit::SQLITE_LIMIT_TRIGGER_DEPTH, 32); + assert_eq!(32, db.limit(Limit::SQLITE_LIMIT_TRIGGER_DEPTH)); + } + + // SQLITE_LIMIT_WORKER_THREADS was added in SQLite 3.8.7. + if crate::version_number() >= 3_008_007 { + db.set_limit(Limit::SQLITE_LIMIT_WORKER_THREADS, 2); + assert_eq!(2, db.limit(Limit::SQLITE_LIMIT_WORKER_THREADS)); + } + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/load_extension_guard.rs b/third_party/rust/rusqlite/src/load_extension_guard.rs new file mode 100644 index 0000000000..deed3b4bdd --- /dev/null +++ b/third_party/rust/rusqlite/src/load_extension_guard.rs @@ -0,0 +1,46 @@ +use crate::{Connection, Result}; + +/// RAII guard temporarily enabling SQLite extensions to be loaded. +/// +/// ## Example +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result, LoadExtensionGuard}; +/// # use std::path::{Path}; +/// fn load_my_extension(conn: &Connection) -> Result<()> { +/// unsafe { +/// let _guard = LoadExtensionGuard::new(conn)?; +/// conn.load_extension("trusted/sqlite/extension", None) +/// } +/// } +/// ``` +#[cfg_attr(docsrs, doc(cfg(feature = "load_extension")))] +pub struct LoadExtensionGuard<'conn> { + conn: &'conn Connection, +} + +impl LoadExtensionGuard<'_> { + /// Attempt to enable loading extensions. Loading extensions will be + /// disabled when this guard goes out of scope. Cannot be meaningfully + /// nested. + /// + /// # Safety + /// + /// You must not run untrusted queries while extension loading is enabled. + /// + /// See the safety comment on [`Connection::load_extension_enable`] for more + /// details. + #[inline] + pub unsafe fn new(conn: &Connection) -> Result<LoadExtensionGuard<'_>> { + conn.load_extension_enable() + .map(|_| LoadExtensionGuard { conn }) + } +} + +#[allow(unused_must_use)] +impl Drop for LoadExtensionGuard<'_> { + #[inline] + fn drop(&mut self) { + self.conn.load_extension_disable(); + } +} diff --git a/third_party/rust/rusqlite/src/params.rs b/third_party/rust/rusqlite/src/params.rs new file mode 100644 index 0000000000..a4c506667c --- /dev/null +++ b/third_party/rust/rusqlite/src/params.rs @@ -0,0 +1,455 @@ +use crate::{Result, Statement, ToSql}; + +mod sealed { + /// This trait exists just to ensure that the only impls of `trait Params` + /// that are allowed are ones in this crate. + pub trait Sealed {} +} +use sealed::Sealed; + +/// Trait used for [sets of parameter][params] passed into SQL +/// statements/queries. +/// +/// [params]: https://www.sqlite.org/c3ref/bind_blob.html +/// +/// Note: Currently, this trait can only be implemented inside this crate. +/// Additionally, it's methods (which are `doc(hidden)`) should currently not be +/// considered part of the stable API, although it's possible they will +/// stabilize in the future. +/// +/// # Passing parameters to SQLite +/// +/// Many functions in this library let you pass parameters to SQLite. Doing this +/// lets you avoid any risk of SQL injection, and is simpler than escaping +/// things manually. Aside from deprecated functions and a few helpers, this is +/// indicated by the function taking a generic argument that implements `Params` +/// (this trait). +/// +/// ## Positional parameters +/// +/// For cases where you want to pass a list of parameters where the number of +/// parameters is known at compile time, this can be done in one of the +/// following ways: +/// +/// - For small lists of parameters up to 16 items, they may alternatively be +/// passed as a tuple, as in `thing.query((1, "foo"))`. +/// +/// This is somewhat inconvenient for a single item, since you need a +/// weird-looking trailing comma: `thing.query(("example",))`. That case is +/// perhaps more cleanly expressed as `thing.query(["example"])`. +/// +/// - Using the [`rusqlite::params!`](crate::params!) macro, e.g. +/// `thing.query(rusqlite::params![1, "foo", bar])`. This is mostly useful for +/// heterogeneous lists where the number of parameters greater than 16, or +/// homogenous lists of parameters where the number of parameters exceeds 32. +/// +/// - For small homogeneous lists of parameters, they can either be passed as: +/// +/// - an array, as in `thing.query([1i32, 2, 3, 4])` or `thing.query(["foo", +/// "bar", "baz"])`. +/// +/// - a reference to an array of references, as in `thing.query(&["foo", +/// "bar", "baz"])` or `thing.query(&[&1i32, &2, &3])`. +/// +/// (Note: in this case we don't implement this for slices for coherence +/// reasons, so it really is only for the "reference to array" types — +/// hence why the number of parameters must be <= 32 or you need to +/// reach for `rusqlite::params!`) +/// +/// Unfortunately, in the current design it's not possible to allow this for +/// references to arrays of non-references (e.g. `&[1i32, 2, 3]`). Code like +/// this should instead either use `params!`, an array literal, a `&[&dyn +/// ToSql]` or if none of those work, [`ParamsFromIter`]. +/// +/// - As a slice of `ToSql` trait object references, e.g. `&[&dyn ToSql]`. This +/// is mostly useful for passing parameter lists around as arguments without +/// having every function take a generic `P: Params`. +/// +/// ### Example (positional) +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result, params}; +/// fn update_rows(conn: &Connection) -> Result<()> { +/// let mut stmt = conn.prepare("INSERT INTO test (a, b) VALUES (?1, ?2)")?; +/// +/// // Using a tuple: +/// stmt.execute((0, "foobar"))?; +/// +/// // Using `rusqlite::params!`: +/// stmt.execute(params![1i32, "blah"])?; +/// +/// // array literal — non-references +/// stmt.execute([2i32, 3i32])?; +/// +/// // array literal — references +/// stmt.execute(["foo", "bar"])?; +/// +/// // Slice literal, references: +/// stmt.execute(&[&2i32, &3i32])?; +/// +/// // Note: The types behind the references don't have to be `Sized` +/// stmt.execute(&["foo", "bar"])?; +/// +/// // However, this doesn't work (see above): +/// // stmt.execute(&[1i32, 2i32])?; +/// Ok(()) +/// } +/// ``` +/// +/// ## Named parameters +/// +/// SQLite lets you name parameters using a number of conventions (":foo", +/// "@foo", "$foo"). You can pass named parameters in to SQLite using rusqlite +/// in a few ways: +/// +/// - Using the [`rusqlite::named_params!`](crate::named_params!) macro, as in +/// `stmt.execute(named_params!{ ":name": "foo", ":age": 99 })`. Similar to +/// the `params` macro, this is most useful for heterogeneous lists of +/// parameters, or lists where the number of parameters exceeds 32. +/// +/// - As a slice of `&[(&str, &dyn ToSql)]`. This is what essentially all of +/// these boil down to in the end, conceptually at least. In theory you can +/// pass this as `stmt`. +/// +/// - As array references, similar to the positional params. This looks like +/// `thing.query(&[(":foo", &1i32), (":bar", &2i32)])` or +/// `thing.query(&[(":foo", "abc"), (":bar", "def")])`. +/// +/// Note: Unbound named parameters will be left to the value they previously +/// were bound with, falling back to `NULL` for parameters which have never been +/// bound. +/// +/// ### Example (named) +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result, named_params}; +/// fn insert(conn: &Connection) -> Result<()> { +/// let mut stmt = conn.prepare("INSERT INTO test (key, value) VALUES (:key, :value)")?; +/// // Using `rusqlite::params!`: +/// stmt.execute(named_params! { ":key": "one", ":val": 2 })?; +/// // Alternatively: +/// stmt.execute(&[(":key", "three"), (":val", "four")])?; +/// // Or: +/// stmt.execute(&[(":key", &100), (":val", &200)])?; +/// Ok(()) +/// } +/// ``` +/// +/// ## No parameters +/// +/// You can just use an empty tuple or the empty array literal to run a query +/// that accepts no parameters. +/// +/// ### Example (no parameters) +/// +/// The empty tuple: +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result, params}; +/// fn delete_all_users(conn: &Connection) -> Result<()> { +/// // You may also use `()`. +/// conn.execute("DELETE FROM users", ())?; +/// Ok(()) +/// } +/// ``` +/// +/// The empty array: +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result, params}; +/// fn delete_all_users(conn: &Connection) -> Result<()> { +/// // Just use an empty array (e.g. `[]`) for no params. +/// conn.execute("DELETE FROM users", [])?; +/// Ok(()) +/// } +/// ``` +/// +/// ## Dynamic parameter list +/// +/// If you have a number of parameters which is unknown at compile time (for +/// example, building a dynamic query at runtime), you have two choices: +/// +/// - Use a `&[&dyn ToSql]`. This is often annoying to construct if you don't +/// already have this type on-hand. +/// - Use the [`ParamsFromIter`] type. This essentially lets you wrap an +/// iterator some `T: ToSql` with something that implements `Params`. The +/// usage of this looks like `rusqlite::params_from_iter(something)`. +/// +/// A lot of the considerations here are similar either way, so you should see +/// the [`ParamsFromIter`] documentation for more info / examples. +pub trait Params: Sealed { + // XXX not public api, might not need to expose. + // + // Binds the parameters to the statement. It is unlikely calling this + // explicitly will do what you want. Please use `Statement::query` or + // similar directly. + // + // For now, just hide the function in the docs... + #[doc(hidden)] + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()>; +} + +// Explicitly impl for empty array. Critically, for `conn.execute([])` to be +// unambiguous, this must be the *only* implementation for an empty array. +// +// This sadly prevents `impl<T: ToSql, const N: usize> Params for [T; N]`, which +// forces people to use `params![...]` or `rusqlite::params_from_iter` for long +// homogenous lists of parameters. This is not that big of a deal, but is +// unfortunate, especially because I mostly did it because I wanted a simple +// syntax for no-params that didnt require importing -- the empty tuple fits +// that nicely, but I didn't think of it until much later. +// +// Admittedly, if we did have the generic impl, then we *wouldn't* support the +// empty array literal as a parameter, since the `T` there would fail to be +// inferred. The error message here would probably be quite bad, and so on +// further thought, probably would end up causing *more* surprises, not less. +impl Sealed for [&(dyn ToSql + Send + Sync); 0] {} +impl Params for [&(dyn ToSql + Send + Sync); 0] { + #[inline] + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.ensure_parameter_count(0) + } +} + +impl Sealed for &[&dyn ToSql] {} +impl Params for &[&dyn ToSql] { + #[inline] + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.bind_parameters(self) + } +} + +impl Sealed for &[(&str, &dyn ToSql)] {} +impl Params for &[(&str, &dyn ToSql)] { + #[inline] + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.bind_parameters_named(self) + } +} + +// Manual impls for the empty and singleton tuple, although the rest are covered +// by macros. +impl Sealed for () {} +impl Params for () { + #[inline] + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.ensure_parameter_count(0) + } +} + +// I'm pretty sure you could tweak the `single_tuple_impl` to accept this. +impl<T: ToSql> Sealed for (T,) {} +impl<T: ToSql> Params for (T,) { + #[inline] + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.ensure_parameter_count(1)?; + stmt.raw_bind_parameter(1, self.0)?; + Ok(()) + } +} + +macro_rules! single_tuple_impl { + ($count:literal : $(($field:tt $ftype:ident)),* $(,)?) => { + impl<$($ftype,)*> Sealed for ($($ftype,)*) where $($ftype: ToSql,)* {} + impl<$($ftype,)*> Params for ($($ftype,)*) where $($ftype: ToSql,)* { + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.ensure_parameter_count($count)?; + $({ + debug_assert!($field < $count); + stmt.raw_bind_parameter($field + 1, self.$field)?; + })+ + Ok(()) + } + } + } +} + +// We use a the macro for the rest, but don't bother with trying to implement it +// in a single invocation (it's possible to do, but my attempts were almost the +// same amount of code as just writing it out this way, and much more dense -- +// it is a more complicated case than the TryFrom macro we have for row->tuple). +// +// Note that going up to 16 (rather than the 12 that the impls in the stdlib +// usually support) is just because we did the same in the `TryFrom<Row>` impl. +// I didn't catch that then, but there's no reason to remove it, and it seems +// nice to be consistent here; this way putting data in the database and getting +// data out of the database are more symmetric in a (mostly superficial) sense. +single_tuple_impl!(2: (0 A), (1 B)); +single_tuple_impl!(3: (0 A), (1 B), (2 C)); +single_tuple_impl!(4: (0 A), (1 B), (2 C), (3 D)); +single_tuple_impl!(5: (0 A), (1 B), (2 C), (3 D), (4 E)); +single_tuple_impl!(6: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F)); +single_tuple_impl!(7: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G)); +single_tuple_impl!(8: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H)); +single_tuple_impl!(9: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I)); +single_tuple_impl!(10: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I), (9 J)); +single_tuple_impl!(11: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I), (9 J), (10 K)); +single_tuple_impl!(12: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I), (9 J), (10 K), (11 L)); +single_tuple_impl!(13: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I), (9 J), (10 K), (11 L), (12 M)); +single_tuple_impl!(14: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I), (9 J), (10 K), (11 L), (12 M), (13 N)); +single_tuple_impl!(15: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I), (9 J), (10 K), (11 L), (12 M), (13 N), (14 O)); +single_tuple_impl!(16: (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I), (9 J), (10 K), (11 L), (12 M), (13 N), (14 O), (15 P)); + +macro_rules! impl_for_array_ref { + ($($N:literal)+) => {$( + // These are already generic, and there's a shedload of them, so lets + // avoid the compile time hit from making them all inline for now. + impl<T: ToSql + ?Sized> Sealed for &[&T; $N] {} + impl<T: ToSql + ?Sized> Params for &[&T; $N] { + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.bind_parameters(self) + } + } + impl<T: ToSql + ?Sized> Sealed for &[(&str, &T); $N] {} + impl<T: ToSql + ?Sized> Params for &[(&str, &T); $N] { + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.bind_parameters_named(self) + } + } + impl<T: ToSql> Sealed for [T; $N] {} + impl<T: ToSql> Params for [T; $N] { + #[inline] + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.bind_parameters(&self) + } + } + )+}; +} + +// Following libstd/libcore's (old) lead, implement this for arrays up to `[_; +// 32]`. Note `[_; 0]` is intentionally omitted for coherence reasons, see the +// note above the impl of `[&dyn ToSql; 0]` for more information. +// +// Note that this unfortunately means we can't use const generics here, but I +// don't really think it matters -- users who hit that can use `params!` anyway. +impl_for_array_ref!( + 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 + 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 +); + +/// Adapter type which allows any iterator over [`ToSql`] values to implement +/// [`Params`]. +/// +/// This struct is created by the [`params_from_iter`] function. +/// +/// This can be useful if you have something like an `&[String]` (of unknown +/// length), and you want to use them with an API that wants something +/// implementing `Params`. This way, you can avoid having to allocate storage +/// for something like a `&[&dyn ToSql]`. +/// +/// This essentially is only ever actually needed when dynamically generating +/// SQL — static SQL (by definition) has the number of parameters known +/// statically. As dynamically generating SQL is itself pretty advanced, this +/// API is itself for advanced use cases (See "Realistic use case" in the +/// examples). +/// +/// # Example +/// +/// ## Basic usage +/// +/// ```rust,no_run +/// use rusqlite::{params_from_iter, Connection, Result}; +/// use std::collections::BTreeSet; +/// +/// fn query(conn: &Connection, ids: &BTreeSet<String>) -> Result<()> { +/// assert_eq!(ids.len(), 3, "Unrealistic sample code"); +/// +/// let mut stmt = conn.prepare("SELECT * FROM users WHERE id IN (?1, ?2, ?3)")?; +/// let _rows = stmt.query(params_from_iter(ids.iter()))?; +/// +/// // use _rows... +/// Ok(()) +/// } +/// ``` +/// +/// ## Realistic use case +/// +/// Here's how you'd use `ParamsFromIter` to call [`Statement::exists`] with a +/// dynamic number of parameters. +/// +/// ```rust,no_run +/// use rusqlite::{Connection, Result}; +/// +/// pub fn any_active_users(conn: &Connection, usernames: &[String]) -> Result<bool> { +/// if usernames.is_empty() { +/// return Ok(false); +/// } +/// +/// // Note: `repeat_vars` never returns anything attacker-controlled, so +/// // it's fine to use it in a dynamically-built SQL string. +/// let vars = repeat_vars(usernames.len()); +/// +/// let sql = format!( +/// // In practice this would probably be better as an `EXISTS` query. +/// "SELECT 1 FROM user WHERE is_active AND name IN ({}) LIMIT 1", +/// vars, +/// ); +/// let mut stmt = conn.prepare(&sql)?; +/// stmt.exists(rusqlite::params_from_iter(usernames)) +/// } +/// +/// // Helper function to return a comma-separated sequence of `?`. +/// // - `repeat_vars(0) => panic!(...)` +/// // - `repeat_vars(1) => "?"` +/// // - `repeat_vars(2) => "?,?"` +/// // - `repeat_vars(3) => "?,?,?"` +/// // - ... +/// fn repeat_vars(count: usize) -> String { +/// assert_ne!(count, 0); +/// let mut s = "?,".repeat(count); +/// // Remove trailing comma +/// s.pop(); +/// s +/// } +/// ``` +/// +/// That is fairly complex, and even so would need even more work to be fully +/// production-ready: +/// +/// - production code should ensure `usernames` isn't so large that it will +/// surpass [`conn.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER)`][limits]), +/// chunking if too large. (Note that the limits api requires rusqlite to have +/// the "limits" feature). +/// +/// - `repeat_vars` can be implemented in a way that avoids needing to allocate +/// a String. +/// +/// - Etc... +/// +/// [limits]: crate::Connection::limit +/// +/// This complexity reflects the fact that `ParamsFromIter` is mainly intended +/// for advanced use cases — most of the time you should know how many +/// parameters you have statically (and if you don't, you're either doing +/// something tricky, or should take a moment to think about the design). +#[derive(Clone, Debug)] +pub struct ParamsFromIter<I>(I); + +/// Constructor function for a [`ParamsFromIter`]. See its documentation for +/// more. +#[inline] +pub fn params_from_iter<I>(iter: I) -> ParamsFromIter<I> +where + I: IntoIterator, + I::Item: ToSql, +{ + ParamsFromIter(iter) +} + +impl<I> Sealed for ParamsFromIter<I> +where + I: IntoIterator, + I::Item: ToSql, +{ +} + +impl<I> Params for ParamsFromIter<I> +where + I: IntoIterator, + I::Item: ToSql, +{ + #[inline] + fn __bind_in(self, stmt: &mut Statement<'_>) -> Result<()> { + stmt.bind_parameters(self.0) + } +} diff --git a/third_party/rust/rusqlite/src/pragma.rs b/third_party/rust/rusqlite/src/pragma.rs new file mode 100644 index 0000000000..46bbde14a2 --- /dev/null +++ b/third_party/rust/rusqlite/src/pragma.rs @@ -0,0 +1,454 @@ +//! Pragma helpers + +use std::ops::Deref; + +use crate::error::Error; +use crate::ffi; +use crate::types::{ToSql, ToSqlOutput, ValueRef}; +use crate::{Connection, DatabaseName, Result, Row}; + +pub struct Sql { + buf: String, +} + +impl Sql { + pub fn new() -> Sql { + Sql { buf: String::new() } + } + + pub fn push_pragma( + &mut self, + schema_name: Option<DatabaseName<'_>>, + pragma_name: &str, + ) -> Result<()> { + self.push_keyword("PRAGMA")?; + self.push_space(); + if let Some(schema_name) = schema_name { + self.push_schema_name(schema_name); + self.push_dot(); + } + self.push_keyword(pragma_name) + } + + pub fn push_keyword(&mut self, keyword: &str) -> Result<()> { + if !keyword.is_empty() && is_identifier(keyword) { + self.buf.push_str(keyword); + Ok(()) + } else { + Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Invalid keyword \"{keyword}\"")), + )) + } + } + + pub fn push_schema_name(&mut self, schema_name: DatabaseName<'_>) { + match schema_name { + DatabaseName::Main => self.buf.push_str("main"), + DatabaseName::Temp => self.buf.push_str("temp"), + DatabaseName::Attached(s) => self.push_identifier(s), + }; + } + + pub fn push_identifier(&mut self, s: &str) { + if is_identifier(s) { + self.buf.push_str(s); + } else { + self.wrap_and_escape(s, '"'); + } + } + + pub fn push_value(&mut self, value: &dyn ToSql) -> Result<()> { + let value = value.to_sql()?; + let value = match value { + ToSqlOutput::Borrowed(v) => v, + ToSqlOutput::Owned(ref v) => ValueRef::from(v), + #[cfg(feature = "blob")] + ToSqlOutput::ZeroBlob(_) => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{value:?}\"")), + )); + } + #[cfg(feature = "array")] + ToSqlOutput::Array(_) => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{value:?}\"")), + )); + } + }; + match value { + ValueRef::Integer(i) => { + self.push_int(i); + } + ValueRef::Real(r) => { + self.push_real(r); + } + ValueRef::Text(s) => { + let s = std::str::from_utf8(s)?; + self.push_string_literal(s); + } + _ => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{value:?}\"")), + )); + } + }; + Ok(()) + } + + pub fn push_string_literal(&mut self, s: &str) { + self.wrap_and_escape(s, '\''); + } + + pub fn push_int(&mut self, i: i64) { + self.buf.push_str(&i.to_string()); + } + + pub fn push_real(&mut self, f: f64) { + self.buf.push_str(&f.to_string()); + } + + pub fn push_space(&mut self) { + self.buf.push(' '); + } + + pub fn push_dot(&mut self) { + self.buf.push('.'); + } + + pub fn push_equal_sign(&mut self) { + self.buf.push('='); + } + + pub fn open_brace(&mut self) { + self.buf.push('('); + } + + pub fn close_brace(&mut self) { + self.buf.push(')'); + } + + pub fn as_str(&self) -> &str { + &self.buf + } + + fn wrap_and_escape(&mut self, s: &str, quote: char) { + self.buf.push(quote); + let chars = s.chars(); + for ch in chars { + // escape `quote` by doubling it + if ch == quote { + self.buf.push(ch); + } + self.buf.push(ch); + } + self.buf.push(quote); + } +} + +impl Deref for Sql { + type Target = str; + + fn deref(&self) -> &str { + self.as_str() + } +} + +impl Connection { + /// Query the current value of `pragma_name`. + /// + /// Some pragmas will return multiple rows/values which cannot be retrieved + /// with this method. + /// + /// Prefer [PRAGMA function](https://sqlite.org/pragma.html#pragfunc) introduced in SQLite 3.20: + /// `SELECT user_version FROM pragma_user_version;` + pub fn pragma_query_value<T, F>( + &self, + schema_name: Option<DatabaseName<'_>>, + pragma_name: &str, + f: F, + ) -> Result<T> + where + F: FnOnce(&Row<'_>) -> Result<T>, + { + let mut query = Sql::new(); + query.push_pragma(schema_name, pragma_name)?; + self.query_row(&query, [], f) + } + + /// Query the current rows/values of `pragma_name`. + /// + /// Prefer [PRAGMA function](https://sqlite.org/pragma.html#pragfunc) introduced in SQLite 3.20: + /// `SELECT * FROM pragma_collation_list;` + pub fn pragma_query<F>( + &self, + schema_name: Option<DatabaseName<'_>>, + pragma_name: &str, + mut f: F, + ) -> Result<()> + where + F: FnMut(&Row<'_>) -> Result<()>, + { + let mut query = Sql::new(); + query.push_pragma(schema_name, pragma_name)?; + let mut stmt = self.prepare(&query)?; + let mut rows = stmt.query([])?; + while let Some(result_row) = rows.next()? { + let row = result_row; + f(row)?; + } + Ok(()) + } + + /// Query the current value(s) of `pragma_name` associated to + /// `pragma_value`. + /// + /// This method can be used with query-only pragmas which need an argument + /// (e.g. `table_info('one_tbl')`) or pragmas which returns value(s) + /// (e.g. `integrity_check`). + /// + /// Prefer [PRAGMA function](https://sqlite.org/pragma.html#pragfunc) introduced in SQLite 3.20: + /// `SELECT * FROM pragma_table_info(?1);` + pub fn pragma<F, V>( + &self, + schema_name: Option<DatabaseName<'_>>, + pragma_name: &str, + pragma_value: V, + mut f: F, + ) -> Result<()> + where + F: FnMut(&Row<'_>) -> Result<()>, + V: ToSql, + { + let mut sql = Sql::new(); + sql.push_pragma(schema_name, pragma_name)?; + // The argument may be either in parentheses + // or it may be separated from the pragma name by an equal sign. + // The two syntaxes yield identical results. + sql.open_brace(); + sql.push_value(&pragma_value)?; + sql.close_brace(); + let mut stmt = self.prepare(&sql)?; + let mut rows = stmt.query([])?; + while let Some(result_row) = rows.next()? { + let row = result_row; + f(row)?; + } + Ok(()) + } + + /// Set a new value to `pragma_name`. + /// + /// Some pragmas will return the updated value which cannot be retrieved + /// with this method. + pub fn pragma_update<V>( + &self, + schema_name: Option<DatabaseName<'_>>, + pragma_name: &str, + pragma_value: V, + ) -> Result<()> + where + V: ToSql, + { + let mut sql = Sql::new(); + sql.push_pragma(schema_name, pragma_name)?; + // The argument may be either in parentheses + // or it may be separated from the pragma name by an equal sign. + // The two syntaxes yield identical results. + sql.push_equal_sign(); + sql.push_value(&pragma_value)?; + self.execute_batch(&sql) + } + + /// Set a new value to `pragma_name` and return the updated value. + /// + /// Only few pragmas automatically return the updated value. + pub fn pragma_update_and_check<F, T, V>( + &self, + schema_name: Option<DatabaseName<'_>>, + pragma_name: &str, + pragma_value: V, + f: F, + ) -> Result<T> + where + F: FnOnce(&Row<'_>) -> Result<T>, + V: ToSql, + { + let mut sql = Sql::new(); + sql.push_pragma(schema_name, pragma_name)?; + // The argument may be either in parentheses + // or it may be separated from the pragma name by an equal sign. + // The two syntaxes yield identical results. + sql.push_equal_sign(); + sql.push_value(&pragma_value)?; + self.query_row(&sql, [], f) + } +} + +fn is_identifier(s: &str) -> bool { + let chars = s.char_indices(); + for (i, ch) in chars { + if i == 0 { + if !is_identifier_start(ch) { + return false; + } + } else if !is_identifier_continue(ch) { + return false; + } + } + true +} + +fn is_identifier_start(c: char) -> bool { + c.is_ascii_uppercase() || c == '_' || c.is_ascii_lowercase() || c > '\x7F' +} + +fn is_identifier_continue(c: char) -> bool { + c == '$' + || c.is_ascii_digit() + || c.is_ascii_uppercase() + || c == '_' + || c.is_ascii_lowercase() + || c > '\x7F' +} + +#[cfg(test)] +mod test { + use super::Sql; + use crate::pragma; + use crate::{Connection, DatabaseName, Result}; + + #[test] + fn pragma_query_value() -> Result<()> { + let db = Connection::open_in_memory()?; + let user_version: i32 = db.pragma_query_value(None, "user_version", |row| row.get(0))?; + assert_eq!(0, user_version); + Ok(()) + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn pragma_func_query_value() -> Result<()> { + let db = Connection::open_in_memory()?; + let user_version: i32 = db.one_column("SELECT user_version FROM pragma_user_version")?; + assert_eq!(0, user_version); + Ok(()) + } + + #[test] + fn pragma_query_no_schema() -> Result<()> { + let db = Connection::open_in_memory()?; + let mut user_version = -1; + db.pragma_query(None, "user_version", |row| { + user_version = row.get(0)?; + Ok(()) + })?; + assert_eq!(0, user_version); + Ok(()) + } + + #[test] + fn pragma_query_with_schema() -> Result<()> { + let db = Connection::open_in_memory()?; + let mut user_version = -1; + db.pragma_query(Some(DatabaseName::Main), "user_version", |row| { + user_version = row.get(0)?; + Ok(()) + })?; + assert_eq!(0, user_version); + Ok(()) + } + + #[test] + fn pragma() -> Result<()> { + let db = Connection::open_in_memory()?; + let mut columns = Vec::new(); + db.pragma(None, "table_info", "sqlite_master", |row| { + let column: String = row.get(1)?; + columns.push(column); + Ok(()) + })?; + assert_eq!(5, columns.len()); + Ok(()) + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn pragma_func() -> Result<()> { + let db = Connection::open_in_memory()?; + let mut table_info = db.prepare("SELECT * FROM pragma_table_info(?1)")?; + let mut columns = Vec::new(); + let mut rows = table_info.query(["sqlite_master"])?; + + while let Some(row) = rows.next()? { + let column: String = row.get(1)?; + columns.push(column); + } + assert_eq!(5, columns.len()); + Ok(()) + } + + #[test] + fn pragma_update() -> Result<()> { + let db = Connection::open_in_memory()?; + db.pragma_update(None, "user_version", 1) + } + + #[test] + fn pragma_update_and_check() -> Result<()> { + let db = Connection::open_in_memory()?; + let journal_mode: String = + db.pragma_update_and_check(None, "journal_mode", "OFF", |row| row.get(0))?; + assert!( + journal_mode == "off" || journal_mode == "memory", + "mode: {journal_mode:?}" + ); + // Sanity checks to ensure the move to a generic `ToSql` wasn't breaking + let mode = + db.pragma_update_and_check(None, "journal_mode", "OFF", |row| row.get::<_, String>(0))?; + assert!(mode == "off" || mode == "memory", "mode: {mode:?}"); + + let param: &dyn crate::ToSql = &"OFF"; + let mode = + db.pragma_update_and_check(None, "journal_mode", param, |row| row.get::<_, String>(0))?; + assert!(mode == "off" || mode == "memory", "mode: {mode:?}"); + Ok(()) + } + + #[test] + fn is_identifier() { + assert!(pragma::is_identifier("full")); + assert!(pragma::is_identifier("r2d2")); + assert!(!pragma::is_identifier("sp ce")); + assert!(!pragma::is_identifier("semi;colon")); + } + + #[test] + fn double_quote() { + let mut sql = Sql::new(); + sql.push_schema_name(DatabaseName::Attached(r#"schema";--"#)); + assert_eq!(r#""schema"";--""#, sql.as_str()); + } + + #[test] + fn wrap_and_escape() { + let mut sql = Sql::new(); + sql.push_string_literal("value'; --"); + assert_eq!("'value''; --'", sql.as_str()); + } + + #[test] + fn locking_mode() -> Result<()> { + let db = Connection::open_in_memory()?; + let r = db.pragma_update(None, "locking_mode", "exclusive"); + if cfg!(feature = "extra_check") { + r.unwrap_err(); + } else { + r?; + } + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/raw_statement.rs b/third_party/rust/rusqlite/src/raw_statement.rs new file mode 100644 index 0000000000..1683c7b857 --- /dev/null +++ b/third_party/rust/rusqlite/src/raw_statement.rs @@ -0,0 +1,240 @@ +use super::ffi; +use super::StatementStatus; +use crate::util::ParamIndexCache; +use crate::util::SqliteMallocString; +use std::ffi::CStr; +use std::os::raw::c_int; +use std::ptr; +use std::sync::Arc; + +// Private newtype for raw sqlite3_stmts that finalize themselves when dropped. +#[derive(Debug)] +pub struct RawStatement { + ptr: *mut ffi::sqlite3_stmt, + tail: usize, + // Cached indices of named parameters, computed on the fly. + cache: ParamIndexCache, + // Cached SQL (trimmed) that we use as the key when we're in the statement + // cache. This is None for statements which didn't come from the statement + // cache. + // + // This is probably the same as `self.sql()` in most cases, but we don't + // care either way -- It's a better cache key as it is anyway since it's the + // actual source we got from rust. + // + // One example of a case where the result of `sqlite_sql` and the value in + // `statement_cache_key` might differ is if the statement has a `tail`. + statement_cache_key: Option<Arc<str>>, +} + +impl RawStatement { + #[inline] + pub unsafe fn new(stmt: *mut ffi::sqlite3_stmt, tail: usize) -> RawStatement { + RawStatement { + ptr: stmt, + tail, + cache: ParamIndexCache::default(), + statement_cache_key: None, + } + } + + #[inline] + pub fn is_null(&self) -> bool { + self.ptr.is_null() + } + + #[inline] + pub(crate) fn set_statement_cache_key(&mut self, p: impl Into<Arc<str>>) { + self.statement_cache_key = Some(p.into()); + } + + #[inline] + pub(crate) fn statement_cache_key(&self) -> Option<Arc<str>> { + self.statement_cache_key.clone() + } + + #[inline] + pub unsafe fn ptr(&self) -> *mut ffi::sqlite3_stmt { + self.ptr + } + + #[inline] + pub fn column_count(&self) -> usize { + // Note: Can't cache this as it changes if the schema is altered. + unsafe { ffi::sqlite3_column_count(self.ptr) as usize } + } + + #[inline] + pub fn column_type(&self, idx: usize) -> c_int { + unsafe { ffi::sqlite3_column_type(self.ptr, idx as c_int) } + } + + #[inline] + #[cfg(feature = "column_decltype")] + pub fn column_decltype(&self, idx: usize) -> Option<&CStr> { + unsafe { + let decltype = ffi::sqlite3_column_decltype(self.ptr, idx as c_int); + if decltype.is_null() { + None + } else { + Some(CStr::from_ptr(decltype)) + } + } + } + + #[inline] + pub fn column_name(&self, idx: usize) -> Option<&CStr> { + let idx = idx as c_int; + if idx < 0 || idx >= self.column_count() as c_int { + return None; + } + unsafe { + let ptr = ffi::sqlite3_column_name(self.ptr, idx); + // If ptr is null here, it's an OOM, so there's probably nothing + // meaningful we can do. Just assert instead of returning None. + assert!( + !ptr.is_null(), + "Null pointer from sqlite3_column_name: Out of memory?" + ); + Some(CStr::from_ptr(ptr)) + } + } + + #[inline] + #[cfg(not(feature = "unlock_notify"))] + pub fn step(&self) -> c_int { + unsafe { ffi::sqlite3_step(self.ptr) } + } + + #[cfg(feature = "unlock_notify")] + pub fn step(&self) -> c_int { + use crate::unlock_notify; + let mut db = ptr::null_mut::<ffi::sqlite3>(); + loop { + unsafe { + let mut rc = ffi::sqlite3_step(self.ptr); + // Bail out early for success and errors unrelated to locking. We + // still need check `is_locked` after this, but checking now lets us + // avoid one or two (admittedly cheap) calls into SQLite that we + // don't need to make. + if (rc & 0xff) != ffi::SQLITE_LOCKED { + break rc; + } + if db.is_null() { + db = ffi::sqlite3_db_handle(self.ptr); + } + if !unlock_notify::is_locked(db, rc) { + break rc; + } + rc = unlock_notify::wait_for_unlock_notify(db); + if rc != ffi::SQLITE_OK { + break rc; + } + self.reset(); + } + } + } + + #[inline] + pub fn reset(&self) -> c_int { + unsafe { ffi::sqlite3_reset(self.ptr) } + } + + #[inline] + pub fn bind_parameter_count(&self) -> usize { + unsafe { ffi::sqlite3_bind_parameter_count(self.ptr) as usize } + } + + #[inline] + pub fn bind_parameter_index(&self, name: &str) -> Option<usize> { + self.cache.get_or_insert_with(name, |param_cstr| { + let r = unsafe { ffi::sqlite3_bind_parameter_index(self.ptr, param_cstr.as_ptr()) }; + match r { + 0 => None, + i => Some(i as usize), + } + }) + } + + #[inline] + pub fn bind_parameter_name(&self, index: i32) -> Option<&CStr> { + unsafe { + let name = ffi::sqlite3_bind_parameter_name(self.ptr, index); + if name.is_null() { + None + } else { + Some(CStr::from_ptr(name)) + } + } + } + + #[inline] + pub fn clear_bindings(&self) { + unsafe { + ffi::sqlite3_clear_bindings(self.ptr); + } // rc is always SQLITE_OK + } + + #[inline] + pub fn sql(&self) -> Option<&CStr> { + if self.ptr.is_null() { + None + } else { + Some(unsafe { CStr::from_ptr(ffi::sqlite3_sql(self.ptr)) }) + } + } + + #[inline] + pub fn finalize(mut self) -> c_int { + self.finalize_() + } + + #[inline] + fn finalize_(&mut self) -> c_int { + let r = unsafe { ffi::sqlite3_finalize(self.ptr) }; + self.ptr = ptr::null_mut(); + r + } + + // does not work for PRAGMA + #[inline] + pub fn readonly(&self) -> bool { + unsafe { ffi::sqlite3_stmt_readonly(self.ptr) != 0 } + } + + #[inline] + pub(crate) fn expanded_sql(&self) -> Option<SqliteMallocString> { + unsafe { SqliteMallocString::from_raw(ffi::sqlite3_expanded_sql(self.ptr)) } + } + + #[inline] + pub fn get_status(&self, status: StatementStatus, reset: bool) -> i32 { + assert!(!self.ptr.is_null()); + unsafe { ffi::sqlite3_stmt_status(self.ptr, status as i32, reset as i32) } + } + + #[inline] + #[cfg(feature = "extra_check")] + pub fn has_tail(&self) -> bool { + self.tail != 0 + } + + #[inline] + pub fn tail(&self) -> usize { + self.tail + } + + #[inline] + #[cfg(feature = "modern_sqlite")] // 3.28.0 + pub fn is_explain(&self) -> i32 { + unsafe { ffi::sqlite3_stmt_isexplain(self.ptr) } + } + + // TODO sqlite3_normalized_sql (https://sqlite.org/c3ref/expanded_sql.html) // 3.27.0 + SQLITE_ENABLE_NORMALIZE +} + +impl Drop for RawStatement { + fn drop(&mut self) { + self.finalize_(); + } +} diff --git a/third_party/rust/rusqlite/src/row.rs b/third_party/rust/rusqlite/src/row.rs new file mode 100644 index 0000000000..2d2590040a --- /dev/null +++ b/third_party/rust/rusqlite/src/row.rs @@ -0,0 +1,588 @@ +use fallible_iterator::FallibleIterator; +use fallible_streaming_iterator::FallibleStreamingIterator; +use std::convert; + +use super::{Error, Result, Statement}; +use crate::types::{FromSql, FromSqlError, ValueRef}; + +/// An handle for the resulting rows of a query. +#[must_use = "Rows is lazy and will do nothing unless consumed"] +pub struct Rows<'stmt> { + pub(crate) stmt: Option<&'stmt Statement<'stmt>>, + row: Option<Row<'stmt>>, +} + +impl<'stmt> Rows<'stmt> { + #[inline] + fn reset(&mut self) { + if let Some(stmt) = self.stmt.take() { + stmt.reset(); + } + } + + /// Attempt to get the next row from the query. Returns `Ok(Some(Row))` if + /// there is another row, `Err(...)` if there was an error + /// getting the next row, and `Ok(None)` if all rows have been retrieved. + /// + /// ## Note + /// + /// This interface is not compatible with Rust's `Iterator` trait, because + /// the lifetime of the returned row is tied to the lifetime of `self`. + /// This is a fallible "streaming iterator". For a more natural interface, + /// consider using [`query_map`](Statement::query_map) or + /// [`query_and_then`](Statement::query_and_then) instead, which + /// return types that implement `Iterator`. + #[allow(clippy::should_implement_trait)] // cannot implement Iterator + #[inline] + pub fn next(&mut self) -> Result<Option<&Row<'stmt>>> { + self.advance()?; + Ok((*self).get()) + } + + /// Map over this `Rows`, converting it to a [`Map`], which + /// implements `FallibleIterator`. + /// ```rust,no_run + /// use fallible_iterator::FallibleIterator; + /// # use rusqlite::{Result, Statement}; + /// fn query(stmt: &mut Statement) -> Result<Vec<i64>> { + /// let rows = stmt.query([])?; + /// rows.map(|r| r.get(0)).collect() + /// } + /// ``` + // FIXME Hide FallibleStreamingIterator::map + #[inline] + pub fn map<F, B>(self, f: F) -> Map<'stmt, F> + where + F: FnMut(&Row<'_>) -> Result<B>, + { + Map { rows: self, f } + } + + /// Map over this `Rows`, converting it to a [`MappedRows`], which + /// implements `Iterator`. + #[inline] + pub fn mapped<F, B>(self, f: F) -> MappedRows<'stmt, F> + where + F: FnMut(&Row<'_>) -> Result<B>, + { + MappedRows { rows: self, map: f } + } + + /// Map over this `Rows` with a fallible function, converting it to a + /// [`AndThenRows`], which implements `Iterator` (instead of + /// `FallibleStreamingIterator`). + #[inline] + pub fn and_then<F, T, E>(self, f: F) -> AndThenRows<'stmt, F> + where + F: FnMut(&Row<'_>) -> Result<T, E>, + { + AndThenRows { rows: self, map: f } + } + + /// Give access to the underlying statement + #[must_use] + pub fn as_ref(&self) -> Option<&Statement<'stmt>> { + self.stmt + } +} + +impl<'stmt> Rows<'stmt> { + #[inline] + pub(crate) fn new(stmt: &'stmt Statement<'stmt>) -> Rows<'stmt> { + Rows { + stmt: Some(stmt), + row: None, + } + } + + #[inline] + pub(crate) fn get_expected_row(&mut self) -> Result<&Row<'stmt>> { + match self.next()? { + Some(row) => Ok(row), + None => Err(Error::QueryReturnedNoRows), + } + } +} + +impl Drop for Rows<'_> { + #[inline] + fn drop(&mut self) { + self.reset(); + } +} + +/// `F` is used to transform the _streaming_ iterator into a _fallible_ +/// iterator. +#[must_use = "iterators are lazy and do nothing unless consumed"] +pub struct Map<'stmt, F> { + rows: Rows<'stmt>, + f: F, +} + +impl<F, B> FallibleIterator for Map<'_, F> +where + F: FnMut(&Row<'_>) -> Result<B>, +{ + type Error = Error; + type Item = B; + + #[inline] + fn next(&mut self) -> Result<Option<B>> { + match self.rows.next()? { + Some(v) => Ok(Some((self.f)(v)?)), + None => Ok(None), + } + } +} + +/// An iterator over the mapped resulting rows of a query. +/// +/// `F` is used to transform the _streaming_ iterator into a _standard_ +/// iterator. +#[must_use = "iterators are lazy and do nothing unless consumed"] +pub struct MappedRows<'stmt, F> { + rows: Rows<'stmt>, + map: F, +} + +impl<T, F> Iterator for MappedRows<'_, F> +where + F: FnMut(&Row<'_>) -> Result<T>, +{ + type Item = Result<T>; + + #[inline] + fn next(&mut self) -> Option<Result<T>> { + let map = &mut self.map; + self.rows + .next() + .transpose() + .map(|row_result| row_result.and_then(map)) + } +} + +/// An iterator over the mapped resulting rows of a query, with an Error type +/// unifying with Error. +#[must_use = "iterators are lazy and do nothing unless consumed"] +pub struct AndThenRows<'stmt, F> { + rows: Rows<'stmt>, + map: F, +} + +impl<T, E, F> Iterator for AndThenRows<'_, F> +where + E: From<Error>, + F: FnMut(&Row<'_>) -> Result<T, E>, +{ + type Item = Result<T, E>; + + #[inline] + fn next(&mut self) -> Option<Self::Item> { + let map = &mut self.map; + self.rows + .next() + .transpose() + .map(|row_result| row_result.map_err(E::from).and_then(map)) + } +} + +/// `FallibleStreamingIterator` differs from the standard library's `Iterator` +/// in two ways: +/// * each call to `next` (`sqlite3_step`) can fail. +/// * returned `Row` is valid until `next` is called again or `Statement` is +/// reset or finalized. +/// +/// While these iterators cannot be used with Rust `for` loops, `while let` +/// loops offer a similar level of ergonomics: +/// ```rust,no_run +/// # use rusqlite::{Result, Statement}; +/// fn query(stmt: &mut Statement) -> Result<()> { +/// let mut rows = stmt.query([])?; +/// while let Some(row) = rows.next()? { +/// // scan columns value +/// } +/// Ok(()) +/// } +/// ``` +impl<'stmt> FallibleStreamingIterator for Rows<'stmt> { + type Error = Error; + type Item = Row<'stmt>; + + #[inline] + fn advance(&mut self) -> Result<()> { + if let Some(stmt) = self.stmt { + match stmt.step() { + Ok(true) => { + self.row = Some(Row { stmt }); + Ok(()) + } + Ok(false) => { + self.reset(); + self.row = None; + Ok(()) + } + Err(e) => { + self.reset(); + self.row = None; + Err(e) + } + } + } else { + self.row = None; + Ok(()) + } + } + + #[inline] + fn get(&self) -> Option<&Row<'stmt>> { + self.row.as_ref() + } +} + +/// A single result row of a query. +pub struct Row<'stmt> { + pub(crate) stmt: &'stmt Statement<'stmt>, +} + +impl<'stmt> Row<'stmt> { + /// Get the value of a particular column of the result row. + /// + /// # Panics + /// + /// Panics if calling [`row.get(idx)`](Row::get) would return an error, + /// including: + /// + /// * If the underlying SQLite column type is not a valid type as a source + /// for `T` + /// * If the underlying SQLite integral value is outside the range + /// representable by `T` + /// * If `idx` is outside the range of columns in the returned query + #[track_caller] + pub fn get_unwrap<I: RowIndex, T: FromSql>(&self, idx: I) -> T { + self.get(idx).unwrap() + } + + /// Get the value of a particular column of the result row. + /// + /// ## Failure + /// + /// Returns an `Error::InvalidColumnType` if the underlying SQLite column + /// type is not a valid type as a source for `T`. + /// + /// Returns an `Error::InvalidColumnIndex` if `idx` is outside the valid + /// column range for this row. + /// + /// Returns an `Error::InvalidColumnName` if `idx` is not a valid column + /// name for this row. + /// + /// If the result type is i128 (which requires the `i128_blob` feature to be + /// enabled), and the underlying SQLite column is a blob whose size is not + /// 16 bytes, `Error::InvalidColumnType` will also be returned. + #[track_caller] + pub fn get<I: RowIndex, T: FromSql>(&self, idx: I) -> Result<T> { + let idx = idx.idx(self.stmt)?; + let value = self.stmt.value_ref(idx); + FromSql::column_result(value).map_err(|err| match err { + FromSqlError::InvalidType => Error::InvalidColumnType( + idx, + self.stmt.column_name_unwrap(idx).into(), + value.data_type(), + ), + FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i), + FromSqlError::Other(err) => { + Error::FromSqlConversionFailure(idx, value.data_type(), err) + } + FromSqlError::InvalidBlobSize { .. } => { + Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err)) + } + }) + } + + /// Get the value of a particular column of the result row as a `ValueRef`, + /// allowing data to be read out of a row without copying. + /// + /// This `ValueRef` is valid only as long as this Row, which is enforced by + /// it's lifetime. This means that while this method is completely safe, + /// it can be somewhat difficult to use, and most callers will be better + /// served by [`get`](Row::get) or [`get_unwrap`](Row::get_unwrap). + /// + /// ## Failure + /// + /// Returns an `Error::InvalidColumnIndex` if `idx` is outside the valid + /// column range for this row. + /// + /// Returns an `Error::InvalidColumnName` if `idx` is not a valid column + /// name for this row. + pub fn get_ref<I: RowIndex>(&self, idx: I) -> Result<ValueRef<'_>> { + let idx = idx.idx(self.stmt)?; + // Narrowing from `ValueRef<'stmt>` (which `self.stmt.value_ref(idx)` + // returns) to `ValueRef<'a>` is needed because it's only valid until + // the next call to sqlite3_step. + let val_ref = self.stmt.value_ref(idx); + Ok(val_ref) + } + + /// Get the value of a particular column of the result row as a `ValueRef`, + /// allowing data to be read out of a row without copying. + /// + /// This `ValueRef` is valid only as long as this Row, which is enforced by + /// it's lifetime. This means that while this method is completely safe, + /// it can be difficult to use, and most callers will be better served by + /// [`get`](Row::get) or [`get_unwrap`](Row::get_unwrap). + /// + /// # Panics + /// + /// Panics if calling [`row.get_ref(idx)`](Row::get_ref) would return an + /// error, including: + /// + /// * If `idx` is outside the range of columns in the returned query. + /// * If `idx` is not a valid column name for this row. + #[track_caller] + pub fn get_ref_unwrap<I: RowIndex>(&self, idx: I) -> ValueRef<'_> { + self.get_ref(idx).unwrap() + } +} + +impl<'stmt> AsRef<Statement<'stmt>> for Row<'stmt> { + fn as_ref(&self) -> &Statement<'stmt> { + self.stmt + } +} + +/// Debug `Row` like an ordered `Map<Result<&str>, Result<(Type, ValueRef)>>` +/// with column name as key except that for `Type::Blob` only its size is +/// printed (not its content). +impl<'stmt> std::fmt::Debug for Row<'stmt> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut dm = f.debug_map(); + for c in 0..self.stmt.column_count() { + let name = self.stmt.column_name(c); + dm.key(&name); + let value = self.get_ref(c); + match value { + Ok(value) => { + let dt = value.data_type(); + match value { + ValueRef::Null => { + dm.value(&(dt, ())); + } + ValueRef::Integer(i) => { + dm.value(&(dt, i)); + } + ValueRef::Real(f) => { + dm.value(&(dt, f)); + } + ValueRef::Text(s) => { + dm.value(&(dt, String::from_utf8_lossy(s))); + } + ValueRef::Blob(b) => { + dm.value(&(dt, b.len())); + } + } + } + Err(ref _err) => { + dm.value(&value); + } + } + } + dm.finish() + } +} + +mod sealed { + /// This trait exists just to ensure that the only impls of `trait Params` + /// that are allowed are ones in this crate. + pub trait Sealed {} + impl Sealed for usize {} + impl Sealed for &str {} +} + +/// A trait implemented by types that can index into columns of a row. +/// +/// It is only implemented for `usize` and `&str`. +pub trait RowIndex: sealed::Sealed { + /// Returns the index of the appropriate column, or `None` if no such + /// column exists. + fn idx(&self, stmt: &Statement<'_>) -> Result<usize>; +} + +impl RowIndex for usize { + #[inline] + fn idx(&self, stmt: &Statement<'_>) -> Result<usize> { + if *self >= stmt.column_count() { + Err(Error::InvalidColumnIndex(*self)) + } else { + Ok(*self) + } + } +} + +impl RowIndex for &'_ str { + #[inline] + fn idx(&self, stmt: &Statement<'_>) -> Result<usize> { + stmt.column_index(self) + } +} + +macro_rules! tuple_try_from_row { + ($($field:ident),*) => { + impl<'a, $($field,)*> convert::TryFrom<&'a Row<'a>> for ($($field,)*) where $($field: FromSql,)* { + type Error = crate::Error; + + // we end with index += 1, which rustc warns about + // unused_variables and unused_mut are allowed for () + #[allow(unused_assignments, unused_variables, unused_mut)] + fn try_from(row: &'a Row<'a>) -> Result<Self> { + let mut index = 0; + $( + #[allow(non_snake_case)] + let $field = row.get::<_, $field>(index)?; + index += 1; + )* + Ok(($($field,)*)) + } + } + } +} + +macro_rules! tuples_try_from_row { + () => { + // not very useful, but maybe some other macro users will find this helpful + tuple_try_from_row!(); + }; + ($first:ident $(, $remaining:ident)*) => { + tuple_try_from_row!($first $(, $remaining)*); + tuples_try_from_row!($($remaining),*); + }; +} + +tuples_try_from_row!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P); + +#[cfg(test)] +mod tests { + #![allow(clippy::redundant_closure)] // false positives due to lifetime issues; clippy issue #5594 + use crate::{Connection, Result}; + + #[test] + fn test_try_from_row_for_tuple_1() -> Result<()> { + use crate::ToSql; + use std::convert::TryFrom; + + let conn = Connection::open_in_memory()?; + conn.execute( + "CREATE TABLE test (a INTEGER)", + crate::params_from_iter(std::iter::empty::<&dyn ToSql>()), + )?; + conn.execute("INSERT INTO test VALUES (42)", [])?; + let val = conn.query_row("SELECT a FROM test", [], |row| <(u32,)>::try_from(row))?; + assert_eq!(val, (42,)); + let fail = conn.query_row("SELECT a FROM test", [], |row| <(u32, u32)>::try_from(row)); + fail.unwrap_err(); + Ok(()) + } + + #[test] + fn test_try_from_row_for_tuple_2() -> Result<()> { + use std::convert::TryFrom; + + let conn = Connection::open_in_memory()?; + conn.execute("CREATE TABLE test (a INTEGER, b INTEGER)", [])?; + conn.execute("INSERT INTO test VALUES (42, 47)", [])?; + let val = conn.query_row("SELECT a, b FROM test", [], |row| { + <(u32, u32)>::try_from(row) + })?; + assert_eq!(val, (42, 47)); + let fail = conn.query_row("SELECT a, b FROM test", [], |row| { + <(u32, u32, u32)>::try_from(row) + }); + fail.unwrap_err(); + Ok(()) + } + + #[test] + fn test_try_from_row_for_tuple_16() -> Result<()> { + use std::convert::TryFrom; + + let create_table = "CREATE TABLE test ( + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER, + e INTEGER, + f INTEGER, + g INTEGER, + h INTEGER, + i INTEGER, + j INTEGER, + k INTEGER, + l INTEGER, + m INTEGER, + n INTEGER, + o INTEGER, + p INTEGER + )"; + + let insert_values = "INSERT INTO test VALUES ( + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15 + )"; + + type BigTuple = ( + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + ); + + let conn = Connection::open_in_memory()?; + conn.execute(create_table, [])?; + conn.execute(insert_values, [])?; + let val = conn.query_row("SELECT * FROM test", [], |row| BigTuple::try_from(row))?; + // Debug is not implemented for tuples of 16 + assert_eq!(val.0, 0); + assert_eq!(val.1, 1); + assert_eq!(val.2, 2); + assert_eq!(val.3, 3); + assert_eq!(val.4, 4); + assert_eq!(val.5, 5); + assert_eq!(val.6, 6); + assert_eq!(val.7, 7); + assert_eq!(val.8, 8); + assert_eq!(val.9, 9); + assert_eq!(val.10, 10); + assert_eq!(val.11, 11); + assert_eq!(val.12, 12); + assert_eq!(val.13, 13); + assert_eq!(val.14, 14); + assert_eq!(val.15, 15); + + // We don't test one bigger because it's unimplemented + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/serialize.rs b/third_party/rust/rusqlite/src/serialize.rs new file mode 100644 index 0000000000..6852761b7c --- /dev/null +++ b/third_party/rust/rusqlite/src/serialize.rs @@ -0,0 +1,167 @@ +//! Serialize a database. +use std::convert::TryInto; +use std::marker::PhantomData; +use std::ops::Deref; +use std::ptr::NonNull; + +use crate::error::error_from_handle; +use crate::ffi; +use crate::{Connection, DatabaseName, Result}; + +/// Shared (SQLITE_SERIALIZE_NOCOPY) serialized database +pub struct SharedData<'conn> { + phantom: PhantomData<&'conn Connection>, + ptr: NonNull<u8>, + sz: usize, +} + +/// Owned serialized database +pub struct OwnedData { + ptr: NonNull<u8>, + sz: usize, +} + +impl OwnedData { + /// # Safety + /// + /// Caller must be certain that `ptr` is allocated by `sqlite3_malloc`. + pub unsafe fn from_raw_nonnull(ptr: NonNull<u8>, sz: usize) -> Self { + Self { ptr, sz } + } + + fn into_raw(self) -> (*mut u8, usize) { + let raw = (self.ptr.as_ptr(), self.sz); + std::mem::forget(self); + raw + } +} + +impl Drop for OwnedData { + fn drop(&mut self) { + unsafe { + ffi::sqlite3_free(self.ptr.as_ptr().cast()); + } + } +} + +/// Serialized database +pub enum Data<'conn> { + /// Shared (SQLITE_SERIALIZE_NOCOPY) serialized database + Shared(SharedData<'conn>), + /// Owned serialized database + Owned(OwnedData), +} + +impl<'conn> Deref for Data<'conn> { + type Target = [u8]; + + fn deref(&self) -> &[u8] { + let (ptr, sz) = match self { + Data::Owned(OwnedData { ptr, sz }) => (ptr.as_ptr(), *sz), + Data::Shared(SharedData { ptr, sz, .. }) => (ptr.as_ptr(), *sz), + }; + unsafe { std::slice::from_raw_parts(ptr, sz) } + } +} + +impl Connection { + /// Serialize a database. + pub fn serialize(&self, schema: DatabaseName) -> Result<Data> { + let schema = schema.as_cstring()?; + let mut sz = 0; + let mut ptr: *mut u8 = unsafe { + ffi::sqlite3_serialize( + self.handle(), + schema.as_ptr(), + &mut sz, + ffi::SQLITE_SERIALIZE_NOCOPY, + ) + }; + Ok(if ptr.is_null() { + ptr = unsafe { ffi::sqlite3_serialize(self.handle(), schema.as_ptr(), &mut sz, 0) }; + if ptr.is_null() { + return Err(unsafe { error_from_handle(self.handle(), ffi::SQLITE_NOMEM) }); + } + Data::Owned(OwnedData { + ptr: NonNull::new(ptr).unwrap(), + sz: sz.try_into().unwrap(), + }) + } else { + // shared buffer + Data::Shared(SharedData { + ptr: NonNull::new(ptr).unwrap(), + sz: sz.try_into().unwrap(), + phantom: PhantomData, + }) + }) + } + + /// Deserialize a database. + pub fn deserialize( + &mut self, + schema: DatabaseName<'_>, + data: OwnedData, + read_only: bool, + ) -> Result<()> { + let schema = schema.as_cstring()?; + let (data, sz) = data.into_raw(); + let sz = sz.try_into().unwrap(); + let flags = if read_only { + ffi::SQLITE_DESERIALIZE_FREEONCLOSE | ffi::SQLITE_DESERIALIZE_READONLY + } else { + ffi::SQLITE_DESERIALIZE_FREEONCLOSE | ffi::SQLITE_DESERIALIZE_RESIZEABLE + }; + let rc = unsafe { + ffi::sqlite3_deserialize(self.handle(), schema.as_ptr(), data, sz, sz, flags) + }; + if rc != ffi::SQLITE_OK { + // TODO sqlite3_free(data) ? + return Err(unsafe { error_from_handle(self.handle(), rc) }); + } + /* TODO + if let Some(mxSize) = mxSize { + unsafe { + ffi::sqlite3_file_control( + self.handle(), + schema.as_ptr(), + ffi::SQLITE_FCNTL_SIZE_LIMIT, + &mut mxSize, + ) + }; + }*/ + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{Connection, DatabaseName, Result}; + + #[test] + fn serialize() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE x AS SELECT 'data'")?; + let data = db.serialize(DatabaseName::Main)?; + let Data::Owned(data) = data else { + panic!("expected OwnedData") + }; + assert!(data.sz > 0); + Ok(()) + } + + #[test] + fn deserialize() -> Result<()> { + let src = Connection::open_in_memory()?; + src.execute_batch("CREATE TABLE x AS SELECT 'data'")?; + let data = src.serialize(DatabaseName::Main)?; + let Data::Owned(data) = data else { + panic!("expected OwnedData") + }; + + let mut dst = Connection::open_in_memory()?; + dst.deserialize(DatabaseName::Main, data, false)?; + dst.execute("DELETE FROM x", [])?; + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/session.rs b/third_party/rust/rusqlite/src/session.rs new file mode 100644 index 0000000000..0169a1cc9d --- /dev/null +++ b/third_party/rust/rusqlite/src/session.rs @@ -0,0 +1,933 @@ +//! [Session Extension](https://sqlite.org/sessionintro.html) +#![allow(non_camel_case_types)] + +use std::ffi::CStr; +use std::io::{Read, Write}; +use std::marker::PhantomData; +use std::os::raw::{c_char, c_int, c_uchar, c_void}; +use std::panic::{catch_unwind, RefUnwindSafe}; +use std::ptr; +use std::slice::{from_raw_parts, from_raw_parts_mut}; + +use fallible_streaming_iterator::FallibleStreamingIterator; + +use crate::error::{check, error_from_sqlite_code}; +use crate::ffi; +use crate::hooks::Action; +use crate::types::ValueRef; +use crate::{errmsg_to_string, str_to_cstring, Connection, DatabaseName, Result}; + +// https://sqlite.org/session.html + +type Filter = Option<Box<dyn Fn(&str) -> bool>>; + +/// An instance of this object is a session that can be +/// used to record changes to a database. +pub struct Session<'conn> { + phantom: PhantomData<&'conn Connection>, + s: *mut ffi::sqlite3_session, + filter: Filter, +} + +impl Session<'_> { + /// Create a new session object + #[inline] + pub fn new(db: &Connection) -> Result<Session<'_>> { + Session::new_with_name(db, DatabaseName::Main) + } + + /// Create a new session object + #[inline] + pub fn new_with_name<'conn>( + db: &'conn Connection, + name: DatabaseName<'_>, + ) -> Result<Session<'conn>> { + let name = name.as_cstring()?; + + let db = db.db.borrow_mut().db; + + let mut s: *mut ffi::sqlite3_session = ptr::null_mut(); + check(unsafe { ffi::sqlite3session_create(db, name.as_ptr(), &mut s) })?; + + Ok(Session { + phantom: PhantomData, + s, + filter: None, + }) + } + + /// Set a table filter + pub fn table_filter<F>(&mut self, filter: Option<F>) + where + F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, + { + unsafe extern "C" fn call_boxed_closure<F>( + p_arg: *mut c_void, + tbl_str: *const c_char, + ) -> c_int + where + F: Fn(&str) -> bool + RefUnwindSafe, + { + use std::str; + + let boxed_filter: *mut F = p_arg as *mut F; + let tbl_name = { + let c_slice = CStr::from_ptr(tbl_str).to_bytes(); + str::from_utf8(c_slice) + }; + c_int::from( + catch_unwind(|| (*boxed_filter)(tbl_name.expect("non-utf8 table name"))) + .unwrap_or_default(), + ) + } + + match filter { + Some(filter) => { + let boxed_filter = Box::new(filter); + unsafe { + ffi::sqlite3session_table_filter( + self.s, + Some(call_boxed_closure::<F>), + &*boxed_filter as *const F as *mut _, + ); + } + self.filter = Some(boxed_filter); + } + _ => { + unsafe { ffi::sqlite3session_table_filter(self.s, None, ptr::null_mut()) } + self.filter = None; + } + }; + } + + /// Attach a table. `None` means all tables. + pub fn attach(&mut self, table: Option<&str>) -> Result<()> { + let table = if let Some(table) = table { + Some(str_to_cstring(table)?) + } else { + None + }; + let table = table.as_ref().map(|s| s.as_ptr()).unwrap_or(ptr::null()); + check(unsafe { ffi::sqlite3session_attach(self.s, table) }) + } + + /// Generate a Changeset + pub fn changeset(&mut self) -> Result<Changeset> { + let mut n = 0; + let mut cs: *mut c_void = ptr::null_mut(); + check(unsafe { ffi::sqlite3session_changeset(self.s, &mut n, &mut cs) })?; + Ok(Changeset { cs, n }) + } + + /// Write the set of changes represented by this session to `output`. + #[inline] + pub fn changeset_strm(&mut self, output: &mut dyn Write) -> Result<()> { + let output_ref = &output; + check(unsafe { + ffi::sqlite3session_changeset_strm( + self.s, + Some(x_output), + output_ref as *const &mut dyn Write as *mut c_void, + ) + }) + } + + /// Generate a Patchset + #[inline] + pub fn patchset(&mut self) -> Result<Changeset> { + let mut n = 0; + let mut ps: *mut c_void = ptr::null_mut(); + check(unsafe { ffi::sqlite3session_patchset(self.s, &mut n, &mut ps) })?; + // TODO Validate: same struct + Ok(Changeset { cs: ps, n }) + } + + /// Write the set of patches represented by this session to `output`. + #[inline] + pub fn patchset_strm(&mut self, output: &mut dyn Write) -> Result<()> { + let output_ref = &output; + check(unsafe { + ffi::sqlite3session_patchset_strm( + self.s, + Some(x_output), + output_ref as *const &mut dyn Write as *mut c_void, + ) + }) + } + + /// Load the difference between tables. + pub fn diff(&mut self, from: DatabaseName<'_>, table: &str) -> Result<()> { + let from = from.as_cstring()?; + let table = str_to_cstring(table)?; + let table = table.as_ptr(); + unsafe { + let mut errmsg = ptr::null_mut(); + let r = + ffi::sqlite3session_diff(self.s, from.as_ptr(), table, &mut errmsg as *mut *mut _); + if r != ffi::SQLITE_OK { + let errmsg: *mut c_char = errmsg; + let message = errmsg_to_string(&*errmsg); + ffi::sqlite3_free(errmsg as *mut c_void); + return Err(error_from_sqlite_code(r, Some(message))); + } + } + Ok(()) + } + + /// Test if a changeset has recorded any changes + #[inline] + pub fn is_empty(&self) -> bool { + unsafe { ffi::sqlite3session_isempty(self.s) != 0 } + } + + /// Query the current state of the session + #[inline] + pub fn is_enabled(&self) -> bool { + unsafe { ffi::sqlite3session_enable(self.s, -1) != 0 } + } + + /// Enable or disable the recording of changes + #[inline] + pub fn set_enabled(&mut self, enabled: bool) { + unsafe { + ffi::sqlite3session_enable(self.s, c_int::from(enabled)); + } + } + + /// Query the current state of the indirect flag + #[inline] + pub fn is_indirect(&self) -> bool { + unsafe { ffi::sqlite3session_indirect(self.s, -1) != 0 } + } + + /// Set or clear the indirect change flag + #[inline] + pub fn set_indirect(&mut self, indirect: bool) { + unsafe { + ffi::sqlite3session_indirect(self.s, c_int::from(indirect)); + } + } +} + +impl Drop for Session<'_> { + #[inline] + fn drop(&mut self) { + if self.filter.is_some() { + self.table_filter(None::<fn(&str) -> bool>); + } + unsafe { ffi::sqlite3session_delete(self.s) }; + } +} + +/// Invert a changeset +#[inline] +pub fn invert_strm(input: &mut dyn Read, output: &mut dyn Write) -> Result<()> { + let input_ref = &input; + let output_ref = &output; + check(unsafe { + ffi::sqlite3changeset_invert_strm( + Some(x_input), + input_ref as *const &mut dyn Read as *mut c_void, + Some(x_output), + output_ref as *const &mut dyn Write as *mut c_void, + ) + }) +} + +/// Combine two changesets +#[inline] +pub fn concat_strm( + input_a: &mut dyn Read, + input_b: &mut dyn Read, + output: &mut dyn Write, +) -> Result<()> { + let input_a_ref = &input_a; + let input_b_ref = &input_b; + let output_ref = &output; + check(unsafe { + ffi::sqlite3changeset_concat_strm( + Some(x_input), + input_a_ref as *const &mut dyn Read as *mut c_void, + Some(x_input), + input_b_ref as *const &mut dyn Read as *mut c_void, + Some(x_output), + output_ref as *const &mut dyn Write as *mut c_void, + ) + }) +} + +/// Changeset or Patchset +pub struct Changeset { + cs: *mut c_void, + n: c_int, +} + +impl Changeset { + /// Invert a changeset + #[inline] + pub fn invert(&self) -> Result<Changeset> { + let mut n = 0; + let mut cs = ptr::null_mut(); + check(unsafe { + ffi::sqlite3changeset_invert(self.n, self.cs, &mut n, &mut cs as *mut *mut _) + })?; + Ok(Changeset { cs, n }) + } + + /// Create an iterator to traverse a changeset + #[inline] + pub fn iter(&self) -> Result<ChangesetIter<'_>> { + let mut it = ptr::null_mut(); + check(unsafe { ffi::sqlite3changeset_start(&mut it as *mut *mut _, self.n, self.cs) })?; + Ok(ChangesetIter { + phantom: PhantomData, + it, + item: None, + }) + } + + /// Concatenate two changeset objects + #[inline] + pub fn concat(a: &Changeset, b: &Changeset) -> Result<Changeset> { + let mut n = 0; + let mut cs = ptr::null_mut(); + check(unsafe { + ffi::sqlite3changeset_concat(a.n, a.cs, b.n, b.cs, &mut n, &mut cs as *mut *mut _) + })?; + Ok(Changeset { cs, n }) + } +} + +impl Drop for Changeset { + #[inline] + fn drop(&mut self) { + unsafe { + ffi::sqlite3_free(self.cs); + } + } +} + +/// Cursor for iterating over the elements of a changeset +/// or patchset. +pub struct ChangesetIter<'changeset> { + phantom: PhantomData<&'changeset Changeset>, + it: *mut ffi::sqlite3_changeset_iter, + item: Option<ChangesetItem>, +} + +impl ChangesetIter<'_> { + /// Create an iterator on `input` + #[inline] + pub fn start_strm<'input>(input: &&'input mut dyn Read) -> Result<ChangesetIter<'input>> { + let mut it = ptr::null_mut(); + check(unsafe { + ffi::sqlite3changeset_start_strm( + &mut it as *mut *mut _, + Some(x_input), + input as *const &mut dyn Read as *mut c_void, + ) + })?; + Ok(ChangesetIter { + phantom: PhantomData, + it, + item: None, + }) + } +} + +impl FallibleStreamingIterator for ChangesetIter<'_> { + type Error = crate::error::Error; + type Item = ChangesetItem; + + #[inline] + fn advance(&mut self) -> Result<()> { + let rc = unsafe { ffi::sqlite3changeset_next(self.it) }; + match rc { + ffi::SQLITE_ROW => { + self.item = Some(ChangesetItem { it: self.it }); + Ok(()) + } + ffi::SQLITE_DONE => { + self.item = None; + Ok(()) + } + code => Err(error_from_sqlite_code(code, None)), + } + } + + #[inline] + fn get(&self) -> Option<&ChangesetItem> { + self.item.as_ref() + } +} + +/// Operation +pub struct Operation<'item> { + table_name: &'item str, + number_of_columns: i32, + code: Action, + indirect: bool, +} + +impl Operation<'_> { + /// Returns the table name. + #[inline] + pub fn table_name(&self) -> &str { + self.table_name + } + + /// Returns the number of columns in table + #[inline] + pub fn number_of_columns(&self) -> i32 { + self.number_of_columns + } + + /// Returns the action code. + #[inline] + pub fn code(&self) -> Action { + self.code + } + + /// Returns `true` for an 'indirect' change. + #[inline] + pub fn indirect(&self) -> bool { + self.indirect + } +} + +impl Drop for ChangesetIter<'_> { + #[inline] + fn drop(&mut self) { + unsafe { + ffi::sqlite3changeset_finalize(self.it); + } + } +} + +/// An item passed to a conflict-handler by +/// [`Connection::apply`](Connection::apply), or an item generated by +/// [`ChangesetIter::next`](ChangesetIter::next). +// TODO enum ? Delete, Insert, Update, ... +pub struct ChangesetItem { + it: *mut ffi::sqlite3_changeset_iter, +} + +impl ChangesetItem { + /// Obtain conflicting row values + /// + /// May only be called with an `SQLITE_CHANGESET_DATA` or + /// `SQLITE_CHANGESET_CONFLICT` conflict handler callback. + #[inline] + pub fn conflict(&self, col: usize) -> Result<ValueRef<'_>> { + unsafe { + let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut(); + check(ffi::sqlite3changeset_conflict( + self.it, + col as i32, + &mut p_value, + ))?; + Ok(ValueRef::from_value(p_value)) + } + } + + /// Determine the number of foreign key constraint violations + /// + /// May only be called with an `SQLITE_CHANGESET_FOREIGN_KEY` conflict + /// handler callback. + #[inline] + pub fn fk_conflicts(&self) -> Result<i32> { + unsafe { + let mut p_out = 0; + check(ffi::sqlite3changeset_fk_conflicts(self.it, &mut p_out))?; + Ok(p_out) + } + } + + /// Obtain new.* Values + /// + /// May only be called if the type of change is either `SQLITE_UPDATE` or + /// `SQLITE_INSERT`. + #[inline] + pub fn new_value(&self, col: usize) -> Result<ValueRef<'_>> { + unsafe { + let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut(); + check(ffi::sqlite3changeset_new(self.it, col as i32, &mut p_value))?; + Ok(ValueRef::from_value(p_value)) + } + } + + /// Obtain old.* Values + /// + /// May only be called if the type of change is either `SQLITE_DELETE` or + /// `SQLITE_UPDATE`. + #[inline] + pub fn old_value(&self, col: usize) -> Result<ValueRef<'_>> { + unsafe { + let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut(); + check(ffi::sqlite3changeset_old(self.it, col as i32, &mut p_value))?; + Ok(ValueRef::from_value(p_value)) + } + } + + /// Obtain the current operation + #[inline] + pub fn op(&self) -> Result<Operation<'_>> { + let mut number_of_columns = 0; + let mut code = 0; + let mut indirect = 0; + let tab = unsafe { + let mut pz_tab: *const c_char = ptr::null(); + check(ffi::sqlite3changeset_op( + self.it, + &mut pz_tab, + &mut number_of_columns, + &mut code, + &mut indirect, + ))?; + CStr::from_ptr(pz_tab) + }; + let table_name = tab.to_str()?; + Ok(Operation { + table_name, + number_of_columns, + code: Action::from(code), + indirect: indirect != 0, + }) + } + + /// Obtain the primary key definition of a table + #[inline] + pub fn pk(&self) -> Result<&[u8]> { + let mut number_of_columns = 0; + unsafe { + let mut pks: *mut c_uchar = ptr::null_mut(); + check(ffi::sqlite3changeset_pk( + self.it, + &mut pks, + &mut number_of_columns, + ))?; + Ok(from_raw_parts(pks, number_of_columns as usize)) + } + } +} + +/// Used to combine two or more changesets or +/// patchsets +pub struct Changegroup { + cg: *mut ffi::sqlite3_changegroup, +} + +impl Changegroup { + /// Create a new change group. + #[inline] + pub fn new() -> Result<Self> { + let mut cg = ptr::null_mut(); + check(unsafe { ffi::sqlite3changegroup_new(&mut cg) })?; + Ok(Changegroup { cg }) + } + + /// Add a changeset + #[inline] + pub fn add(&mut self, cs: &Changeset) -> Result<()> { + check(unsafe { ffi::sqlite3changegroup_add(self.cg, cs.n, cs.cs) }) + } + + /// Add a changeset read from `input` to this change group. + #[inline] + pub fn add_stream(&mut self, input: &mut dyn Read) -> Result<()> { + let input_ref = &input; + check(unsafe { + ffi::sqlite3changegroup_add_strm( + self.cg, + Some(x_input), + input_ref as *const &mut dyn Read as *mut c_void, + ) + }) + } + + /// Obtain a composite Changeset + #[inline] + pub fn output(&mut self) -> Result<Changeset> { + let mut n = 0; + let mut output: *mut c_void = ptr::null_mut(); + check(unsafe { ffi::sqlite3changegroup_output(self.cg, &mut n, &mut output) })?; + Ok(Changeset { cs: output, n }) + } + + /// Write the combined set of changes to `output`. + #[inline] + pub fn output_strm(&mut self, output: &mut dyn Write) -> Result<()> { + let output_ref = &output; + check(unsafe { + ffi::sqlite3changegroup_output_strm( + self.cg, + Some(x_output), + output_ref as *const &mut dyn Write as *mut c_void, + ) + }) + } +} + +impl Drop for Changegroup { + #[inline] + fn drop(&mut self) { + unsafe { + ffi::sqlite3changegroup_delete(self.cg); + } + } +} + +impl Connection { + /// Apply a changeset to a database + pub fn apply<F, C>(&self, cs: &Changeset, filter: Option<F>, conflict: C) -> Result<()> + where + F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, + C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, + { + let db = self.db.borrow_mut().db; + + let filtered = filter.is_some(); + let tuple = &mut (filter, conflict); + check(unsafe { + if filtered { + ffi::sqlite3changeset_apply( + db, + cs.n, + cs.cs, + Some(call_filter::<F, C>), + Some(call_conflict::<F, C>), + tuple as *mut (Option<F>, C) as *mut c_void, + ) + } else { + ffi::sqlite3changeset_apply( + db, + cs.n, + cs.cs, + None, + Some(call_conflict::<F, C>), + tuple as *mut (Option<F>, C) as *mut c_void, + ) + } + }) + } + + /// Apply a changeset to a database + pub fn apply_strm<F, C>( + &self, + input: &mut dyn Read, + filter: Option<F>, + conflict: C, + ) -> Result<()> + where + F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, + C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, + { + let input_ref = &input; + let db = self.db.borrow_mut().db; + + let filtered = filter.is_some(); + let tuple = &mut (filter, conflict); + check(unsafe { + if filtered { + ffi::sqlite3changeset_apply_strm( + db, + Some(x_input), + input_ref as *const &mut dyn Read as *mut c_void, + Some(call_filter::<F, C>), + Some(call_conflict::<F, C>), + tuple as *mut (Option<F>, C) as *mut c_void, + ) + } else { + ffi::sqlite3changeset_apply_strm( + db, + Some(x_input), + input_ref as *const &mut dyn Read as *mut c_void, + None, + Some(call_conflict::<F, C>), + tuple as *mut (Option<F>, C) as *mut c_void, + ) + } + }) + } +} + +/// Constants passed to the conflict handler +/// See [here](https://sqlite.org/session.html#SQLITE_CHANGESET_CONFLICT) for details. +#[allow(missing_docs)] +#[repr(i32)] +#[derive(Debug, PartialEq, Eq)] +#[non_exhaustive] +#[allow(clippy::upper_case_acronyms)] +pub enum ConflictType { + UNKNOWN = -1, + SQLITE_CHANGESET_DATA = ffi::SQLITE_CHANGESET_DATA, + SQLITE_CHANGESET_NOTFOUND = ffi::SQLITE_CHANGESET_NOTFOUND, + SQLITE_CHANGESET_CONFLICT = ffi::SQLITE_CHANGESET_CONFLICT, + SQLITE_CHANGESET_CONSTRAINT = ffi::SQLITE_CHANGESET_CONSTRAINT, + SQLITE_CHANGESET_FOREIGN_KEY = ffi::SQLITE_CHANGESET_FOREIGN_KEY, +} +impl From<i32> for ConflictType { + fn from(code: i32) -> ConflictType { + match code { + ffi::SQLITE_CHANGESET_DATA => ConflictType::SQLITE_CHANGESET_DATA, + ffi::SQLITE_CHANGESET_NOTFOUND => ConflictType::SQLITE_CHANGESET_NOTFOUND, + ffi::SQLITE_CHANGESET_CONFLICT => ConflictType::SQLITE_CHANGESET_CONFLICT, + ffi::SQLITE_CHANGESET_CONSTRAINT => ConflictType::SQLITE_CHANGESET_CONSTRAINT, + ffi::SQLITE_CHANGESET_FOREIGN_KEY => ConflictType::SQLITE_CHANGESET_FOREIGN_KEY, + _ => ConflictType::UNKNOWN, + } + } +} + +/// Constants returned by the conflict handler +/// See [here](https://sqlite.org/session.html#SQLITE_CHANGESET_ABORT) for details. +#[allow(missing_docs)] +#[repr(i32)] +#[derive(Debug, PartialEq, Eq)] +#[non_exhaustive] +#[allow(clippy::upper_case_acronyms)] +pub enum ConflictAction { + SQLITE_CHANGESET_OMIT = ffi::SQLITE_CHANGESET_OMIT, + SQLITE_CHANGESET_REPLACE = ffi::SQLITE_CHANGESET_REPLACE, + SQLITE_CHANGESET_ABORT = ffi::SQLITE_CHANGESET_ABORT, +} + +unsafe extern "C" fn call_filter<F, C>(p_ctx: *mut c_void, tbl_str: *const c_char) -> c_int +where + F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, + C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, +{ + use std::str; + + let tuple: *mut (Option<F>, C) = p_ctx as *mut (Option<F>, C); + let tbl_name = { + let c_slice = CStr::from_ptr(tbl_str).to_bytes(); + str::from_utf8(c_slice) + }; + match *tuple { + (Some(ref filter), _) => c_int::from( + catch_unwind(|| filter(tbl_name.expect("illegal table name"))).unwrap_or_default(), + ), + _ => unimplemented!(), + } +} + +unsafe extern "C" fn call_conflict<F, C>( + p_ctx: *mut c_void, + e_conflict: c_int, + p: *mut ffi::sqlite3_changeset_iter, +) -> c_int +where + F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, + C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, +{ + let tuple: *mut (Option<F>, C) = p_ctx as *mut (Option<F>, C); + let conflict_type = ConflictType::from(e_conflict); + let item = ChangesetItem { it: p }; + if let Ok(action) = catch_unwind(|| (*tuple).1(conflict_type, item)) { + action as c_int + } else { + ffi::SQLITE_CHANGESET_ABORT + } +} + +unsafe extern "C" fn x_input(p_in: *mut c_void, data: *mut c_void, len: *mut c_int) -> c_int { + if p_in.is_null() { + return ffi::SQLITE_MISUSE; + } + let bytes: &mut [u8] = from_raw_parts_mut(data as *mut u8, *len as usize); + let input = p_in as *mut &mut dyn Read; + match (*input).read(bytes) { + Ok(n) => { + *len = n as i32; // TODO Validate: n = 0 may not mean the reader will always no longer be able to + // produce bytes. + ffi::SQLITE_OK + } + Err(_) => ffi::SQLITE_IOERR_READ, // TODO check if err is a (ru)sqlite Error => propagate + } +} + +unsafe extern "C" fn x_output(p_out: *mut c_void, data: *const c_void, len: c_int) -> c_int { + if p_out.is_null() { + return ffi::SQLITE_MISUSE; + } + // The sessions module never invokes an xOutput callback with the third + // parameter set to a value less than or equal to zero. + let bytes: &[u8] = from_raw_parts(data as *const u8, len as usize); + let output = p_out as *mut &mut dyn Write; + match (*output).write_all(bytes) { + Ok(_) => ffi::SQLITE_OK, + Err(_) => ffi::SQLITE_IOERR_WRITE, // TODO check if err is a (ru)sqlite Error => propagate + } +} + +#[cfg(test)] +mod test { + use fallible_streaming_iterator::FallibleStreamingIterator; + use std::io::Read; + use std::sync::atomic::{AtomicBool, Ordering}; + + use super::{Changeset, ChangesetIter, ConflictAction, ConflictType, Session}; + use crate::hooks::Action; + use crate::{Connection, Result}; + + fn one_changeset() -> Result<Changeset> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);")?; + + let mut session = Session::new(&db)?; + assert!(session.is_empty()); + + session.attach(None)?; + db.execute("INSERT INTO foo (t) VALUES (?1);", ["bar"])?; + + session.changeset() + } + + fn one_changeset_strm() -> Result<Vec<u8>> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);")?; + + let mut session = Session::new(&db)?; + assert!(session.is_empty()); + + session.attach(None)?; + db.execute("INSERT INTO foo (t) VALUES (?1);", ["bar"])?; + + let mut output = Vec::new(); + session.changeset_strm(&mut output)?; + Ok(output) + } + + #[test] + fn test_changeset() -> Result<()> { + let changeset = one_changeset()?; + let mut iter = changeset.iter()?; + let item = iter.next()?; + assert!(item.is_some()); + + let item = item.unwrap(); + let op = item.op()?; + assert_eq!("foo", op.table_name()); + assert_eq!(1, op.number_of_columns()); + assert_eq!(Action::SQLITE_INSERT, op.code()); + assert!(!op.indirect()); + + let pk = item.pk()?; + assert_eq!(&[1], pk); + + let new_value = item.new_value(0)?; + assert_eq!(Ok("bar"), new_value.as_str()); + Ok(()) + } + + #[test] + fn test_changeset_strm() -> Result<()> { + let output = one_changeset_strm()?; + assert!(!output.is_empty()); + assert_eq!(14, output.len()); + + let input: &mut dyn Read = &mut output.as_slice(); + let mut iter = ChangesetIter::start_strm(&input)?; + let item = iter.next()?; + assert!(item.is_some()); + Ok(()) + } + + #[test] + fn test_changeset_apply() -> Result<()> { + let changeset = one_changeset()?; + + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);")?; + + static CALLED: AtomicBool = AtomicBool::new(false); + db.apply( + &changeset, + None::<fn(&str) -> bool>, + |_conflict_type, _item| { + CALLED.store(true, Ordering::Relaxed); + ConflictAction::SQLITE_CHANGESET_OMIT + }, + )?; + + assert!(!CALLED.load(Ordering::Relaxed)); + let check = db.query_row("SELECT 1 FROM foo WHERE t = ?1", ["bar"], |row| { + row.get::<_, i32>(0) + })?; + assert_eq!(1, check); + + // conflict expected when same changeset applied again on the same db + db.apply( + &changeset, + None::<fn(&str) -> bool>, + |conflict_type, item| { + CALLED.store(true, Ordering::Relaxed); + assert_eq!(ConflictType::SQLITE_CHANGESET_CONFLICT, conflict_type); + let conflict = item.conflict(0).unwrap(); + assert_eq!(Ok("bar"), conflict.as_str()); + ConflictAction::SQLITE_CHANGESET_OMIT + }, + )?; + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) + } + + #[test] + fn test_changeset_apply_strm() -> Result<()> { + let output = one_changeset_strm()?; + + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);")?; + + let mut input = output.as_slice(); + db.apply_strm( + &mut input, + None::<fn(&str) -> bool>, + |_conflict_type, _item| ConflictAction::SQLITE_CHANGESET_OMIT, + )?; + + let check = db.query_row("SELECT 1 FROM foo WHERE t = ?1", ["bar"], |row| { + row.get::<_, i32>(0) + })?; + assert_eq!(1, check); + Ok(()) + } + + #[test] + fn test_session_empty() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);")?; + + let mut session = Session::new(&db)?; + assert!(session.is_empty()); + + session.attach(None)?; + db.execute("INSERT INTO foo (t) VALUES (?1);", ["bar"])?; + + assert!(!session.is_empty()); + Ok(()) + } + + #[test] + fn test_session_set_enabled() -> Result<()> { + let db = Connection::open_in_memory()?; + + let mut session = Session::new(&db)?; + assert!(session.is_enabled()); + session.set_enabled(false); + assert!(!session.is_enabled()); + Ok(()) + } + + #[test] + fn test_session_set_indirect() -> Result<()> { + let db = Connection::open_in_memory()?; + + let mut session = Session::new(&db)?; + assert!(!session.is_indirect()); + session.set_indirect(true); + assert!(session.is_indirect()); + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/statement.rs b/third_party/rust/rusqlite/src/statement.rs new file mode 100644 index 0000000000..d39cc1fe87 --- /dev/null +++ b/third_party/rust/rusqlite/src/statement.rs @@ -0,0 +1,1360 @@ +use std::iter::IntoIterator; +use std::os::raw::{c_int, c_void}; +#[cfg(feature = "array")] +use std::rc::Rc; +use std::slice::from_raw_parts; +use std::{fmt, mem, ptr, str}; + +use super::ffi; +use super::{len_as_c_int, str_for_sqlite}; +use super::{ + AndThenRows, Connection, Error, MappedRows, Params, RawStatement, Result, Row, Rows, ValueRef, +}; +use crate::types::{ToSql, ToSqlOutput}; +#[cfg(feature = "array")] +use crate::vtab::array::{free_array, ARRAY_TYPE}; + +/// A prepared statement. +pub struct Statement<'conn> { + conn: &'conn Connection, + pub(crate) stmt: RawStatement, +} + +impl Statement<'_> { + /// Execute the prepared statement. + /// + /// On success, returns the number of rows that were changed or inserted or + /// deleted (via `sqlite3_changes`). + /// + /// ## Example + /// + /// ### Use with positional parameters + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result, params}; + /// fn update_rows(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("UPDATE foo SET bar = ?1 WHERE qux = ?2")?; + /// // For a single parameter, or a parameter where all the values have + /// // the same type, just passing an array is simplest. + /// stmt.execute([2i32])?; + /// // The `rusqlite::params!` macro is mostly useful when the parameters do not + /// // all have the same type, or if there are more than 32 parameters + /// // at once, but it can be used in other cases. + /// stmt.execute(params![1i32])?; + /// // However, it's not required, many cases are fine as: + /// stmt.execute(&[&2i32])?; + /// // Or even: + /// stmt.execute([2i32])?; + /// // If you really want to, this is an option as well. + /// stmt.execute((2i32,))?; + /// Ok(()) + /// } + /// ``` + /// + /// #### Heterogeneous positional parameters + /// + /// ``` + /// use rusqlite::{Connection, Result}; + /// fn store_file(conn: &Connection, path: &str, data: &[u8]) -> Result<()> { + /// # // no need to do it for real. + /// # fn sha256(_: &[u8]) -> [u8; 32] { [0; 32] } + /// let query = "INSERT OR REPLACE INTO files(path, hash, data) VALUES (?1, ?2, ?3)"; + /// let mut stmt = conn.prepare_cached(query)?; + /// let hash: [u8; 32] = sha256(data); + /// // The easiest way to pass positional parameters of have several + /// // different types is by using a tuple. + /// stmt.execute((path, hash, data))?; + /// // Using the `params!` macro also works, and supports longer parameter lists: + /// stmt.execute(rusqlite::params![path, hash, data])?; + /// Ok(()) + /// } + /// # let c = Connection::open_in_memory().unwrap(); + /// # c.execute_batch("CREATE TABLE files(path TEXT PRIMARY KEY, hash BLOB, data BLOB)").unwrap(); + /// # store_file(&c, "foo/bar.txt", b"bibble").unwrap(); + /// # store_file(&c, "foo/baz.txt", b"bobble").unwrap(); + /// ``` + /// + /// ### Use with named parameters + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result, named_params}; + /// fn insert(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("INSERT INTO test (key, value) VALUES (:key, :value)")?; + /// // The `rusqlite::named_params!` macro (like `params!`) is useful for heterogeneous + /// // sets of parameters (where all parameters are not the same type), or for queries + /// // with many (more than 32) statically known parameters. + /// stmt.execute(named_params! { ":key": "one", ":val": 2 })?; + /// // However, named parameters can also be passed like: + /// stmt.execute(&[(":key", "three"), (":val", "four")])?; + /// // Or even: (note that a &T is required for the value type, currently) + /// stmt.execute(&[(":key", &100), (":val", &200)])?; + /// Ok(()) + /// } + /// ``` + /// + /// ### Use without parameters + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result, params}; + /// fn delete_all(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("DELETE FROM users")?; + /// stmt.execute([])?; + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if binding parameters fails, the executed statement + /// returns rows (in which case `query` should be used instead), or the + /// underlying SQLite call fails. + #[inline] + pub fn execute<P: Params>(&mut self, params: P) -> Result<usize> { + params.__bind_in(self)?; + self.execute_with_bound_parameters() + } + + /// Execute an INSERT and return the ROWID. + /// + /// # Note + /// + /// This function is a convenience wrapper around + /// [`execute()`](Statement::execute) intended for queries that insert a + /// single item. It is possible to misuse this function in a way that it + /// cannot detect, such as by calling it on a statement which _updates_ + /// a single item rather than inserting one. Please don't do that. + /// + /// # Failure + /// + /// Will return `Err` if no row is inserted or many rows are inserted. + #[inline] + pub fn insert<P: Params>(&mut self, params: P) -> Result<i64> { + let changes = self.execute(params)?; + match changes { + 1 => Ok(self.conn.last_insert_rowid()), + _ => Err(Error::StatementChangedRows(changes)), + } + } + + /// Execute the prepared statement, returning a handle to the resulting + /// rows. + /// + /// Due to lifetime restrictions, the rows handle returned by `query` does + /// not implement the `Iterator` trait. Consider using + /// [`query_map`](Statement::query_map) or + /// [`query_and_then`](Statement::query_and_then) instead, which do. + /// + /// ## Example + /// + /// ### Use without parameters + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn get_names(conn: &Connection) -> Result<Vec<String>> { + /// let mut stmt = conn.prepare("SELECT name FROM people")?; + /// let mut rows = stmt.query([])?; + /// + /// let mut names = Vec::new(); + /// while let Some(row) = rows.next()? { + /// names.push(row.get(0)?); + /// } + /// + /// Ok(names) + /// } + /// ``` + /// + /// ### Use with positional parameters + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn query(conn: &Connection, name: &str) -> Result<()> { + /// let mut stmt = conn.prepare("SELECT * FROM test where name = ?1")?; + /// let mut rows = stmt.query(rusqlite::params![name])?; + /// while let Some(row) = rows.next()? { + /// // ... + /// } + /// Ok(()) + /// } + /// ``` + /// + /// Or, equivalently (but without the [`crate::params!`] macro). + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn query(conn: &Connection, name: &str) -> Result<()> { + /// let mut stmt = conn.prepare("SELECT * FROM test where name = ?1")?; + /// let mut rows = stmt.query([name])?; + /// while let Some(row) = rows.next()? { + /// // ... + /// } + /// Ok(()) + /// } + /// ``` + /// + /// ### Use with named parameters + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn query(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("SELECT * FROM test where name = :name")?; + /// let mut rows = stmt.query(&[(":name", "one")])?; + /// while let Some(row) = rows.next()? { + /// // ... + /// } + /// Ok(()) + /// } + /// ``` + /// + /// Note, the `named_params!` macro is provided for syntactic convenience, + /// and so the above example could also be written as: + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result, named_params}; + /// fn query(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("SELECT * FROM test where name = :name")?; + /// let mut rows = stmt.query(named_params! { ":name": "one" })?; + /// while let Some(row) = rows.next()? { + /// // ... + /// } + /// Ok(()) + /// } + /// ``` + /// + /// ## Failure + /// + /// Will return `Err` if binding parameters fails. + #[inline] + pub fn query<P: Params>(&mut self, params: P) -> Result<Rows<'_>> { + params.__bind_in(self)?; + Ok(Rows::new(self)) + } + + /// Executes the prepared statement and maps a function over the resulting + /// rows, returning an iterator over the mapped function results. + /// + /// `f` is used to transform the _streaming_ iterator into a _standard_ + /// iterator. + /// + /// This is equivalent to `stmt.query(params)?.mapped(f)`. + /// + /// ## Example + /// + /// ### Use with positional params + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn get_names(conn: &Connection) -> Result<Vec<String>> { + /// let mut stmt = conn.prepare("SELECT name FROM people")?; + /// let rows = stmt.query_map([], |row| row.get(0))?; + /// + /// let mut names = Vec::new(); + /// for name_result in rows { + /// names.push(name_result?); + /// } + /// + /// Ok(names) + /// } + /// ``` + /// + /// ### Use with named params + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn get_names(conn: &Connection) -> Result<Vec<String>> { + /// let mut stmt = conn.prepare("SELECT name FROM people WHERE id = :id")?; + /// let rows = stmt.query_map(&[(":id", &"one")], |row| row.get(0))?; + /// + /// let mut names = Vec::new(); + /// for name_result in rows { + /// names.push(name_result?); + /// } + /// + /// Ok(names) + /// } + /// ``` + /// ## Failure + /// + /// Will return `Err` if binding parameters fails. + pub fn query_map<T, P, F>(&mut self, params: P, f: F) -> Result<MappedRows<'_, F>> + where + P: Params, + F: FnMut(&Row<'_>) -> Result<T>, + { + self.query(params).map(|rows| rows.mapped(f)) + } + + /// Executes the prepared statement and maps a function over the resulting + /// rows, where the function returns a `Result` with `Error` type + /// implementing `std::convert::From<Error>` (so errors can be unified). + /// + /// This is equivalent to `stmt.query(params)?.and_then(f)`. + /// + /// ## Example + /// + /// ### Use with named params + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// struct Person { + /// name: String, + /// }; + /// + /// fn name_to_person(name: String) -> Result<Person> { + /// // ... check for valid name + /// Ok(Person { name }) + /// } + /// + /// fn get_names(conn: &Connection) -> Result<Vec<Person>> { + /// let mut stmt = conn.prepare("SELECT name FROM people WHERE id = :id")?; + /// let rows = stmt.query_and_then(&[(":id", "one")], |row| name_to_person(row.get(0)?))?; + /// + /// let mut persons = Vec::new(); + /// for person_result in rows { + /// persons.push(person_result?); + /// } + /// + /// Ok(persons) + /// } + /// ``` + /// + /// ### Use with positional params + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn get_names(conn: &Connection) -> Result<Vec<String>> { + /// let mut stmt = conn.prepare("SELECT name FROM people WHERE id = ?1")?; + /// let rows = stmt.query_and_then(["one"], |row| row.get::<_, String>(0))?; + /// + /// let mut persons = Vec::new(); + /// for person_result in rows { + /// persons.push(person_result?); + /// } + /// + /// Ok(persons) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if binding parameters fails. + #[inline] + pub fn query_and_then<T, E, P, F>(&mut self, params: P, f: F) -> Result<AndThenRows<'_, F>> + where + P: Params, + E: From<Error>, + F: FnMut(&Row<'_>) -> Result<T, E>, + { + self.query(params).map(|rows| rows.and_then(f)) + } + + /// Return `true` if a query in the SQL statement it executes returns one + /// or more rows and `false` if the SQL returns an empty set. + #[inline] + pub fn exists<P: Params>(&mut self, params: P) -> Result<bool> { + let mut rows = self.query(params)?; + let exists = rows.next()?.is_some(); + Ok(exists) + } + + /// Convenience method to execute a query that is expected to return a + /// single row. + /// + /// If the query returns more than one row, all rows except the first are + /// ignored. + /// + /// Returns `Err(QueryReturnedNoRows)` if no results are returned. If the + /// query truly is optional, you can call + /// [`.optional()`](crate::OptionalExtension::optional) on the result of + /// this to get a `Result<Option<T>>` (requires that the trait + /// `rusqlite::OptionalExtension` is imported). + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + pub fn query_row<T, P, F>(&mut self, params: P, f: F) -> Result<T> + where + P: Params, + F: FnOnce(&Row<'_>) -> Result<T>, + { + let mut rows = self.query(params)?; + + rows.get_expected_row().and_then(f) + } + + /// Consumes the statement. + /// + /// Functionally equivalent to the `Drop` implementation, but allows + /// callers to see any errors that occur. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[inline] + pub fn finalize(mut self) -> Result<()> { + self.finalize_() + } + + /// Return the (one-based) index of an SQL parameter given its name. + /// + /// Note that the initial ":" or "$" or "@" or "?" used to specify the + /// parameter is included as part of the name. + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn example(conn: &Connection) -> Result<()> { + /// let stmt = conn.prepare("SELECT * FROM test WHERE name = :example")?; + /// let index = stmt.parameter_index(":example")?; + /// assert_eq!(index, Some(1)); + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return Err if `name` is invalid. Will return Ok(None) if the name + /// is valid but not a bound parameter of this statement. + #[inline] + pub fn parameter_index(&self, name: &str) -> Result<Option<usize>> { + Ok(self.stmt.bind_parameter_index(name)) + } + + /// Return the SQL parameter name given its (one-based) index (the inverse + /// of [`Statement::parameter_index`]). + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn example(conn: &Connection) -> Result<()> { + /// let stmt = conn.prepare("SELECT * FROM test WHERE name = :example")?; + /// let index = stmt.parameter_name(1); + /// assert_eq!(index, Some(":example")); + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `None` if the column index is out of bounds or if the + /// parameter is positional. + /// + /// # Panics + /// + /// Panics when parameter name is not valid UTF-8. + #[inline] + pub fn parameter_name(&self, index: usize) -> Option<&'_ str> { + self.stmt.bind_parameter_name(index as i32).map(|name| { + str::from_utf8(name.to_bytes()).expect("Invalid UTF-8 sequence in parameter name") + }) + } + + #[inline] + pub(crate) fn bind_parameters<P>(&mut self, params: P) -> Result<()> + where + P: IntoIterator, + P::Item: ToSql, + { + let expected = self.stmt.bind_parameter_count(); + let mut index = 0; + for p in params { + index += 1; // The leftmost SQL parameter has an index of 1. + if index > expected { + break; + } + self.bind_parameter(&p, index)?; + } + if index != expected { + Err(Error::InvalidParameterCount(index, expected)) + } else { + Ok(()) + } + } + + #[inline] + pub(crate) fn ensure_parameter_count(&self, n: usize) -> Result<()> { + let count = self.parameter_count(); + if count != n { + Err(Error::InvalidParameterCount(n, count)) + } else { + Ok(()) + } + } + + #[inline] + pub(crate) fn bind_parameters_named<T: ?Sized + ToSql>( + &mut self, + params: &[(&str, &T)], + ) -> Result<()> { + for &(name, value) in params { + if let Some(i) = self.parameter_index(name)? { + let ts: &dyn ToSql = &value; + self.bind_parameter(ts, i)?; + } else { + return Err(Error::InvalidParameterName(name.into())); + } + } + Ok(()) + } + + /// Return the number of parameters that can be bound to this statement. + #[inline] + pub fn parameter_count(&self) -> usize { + self.stmt.bind_parameter_count() + } + + /// Low level API to directly bind a parameter to a given index. + /// + /// Note that the index is one-based, that is, the first parameter index is + /// 1 and not 0. This is consistent with the SQLite API and the values given + /// to parameters bound as `?NNN`. + /// + /// The valid values for `one_based_col_index` begin at `1`, and end at + /// [`Statement::parameter_count`], inclusive. + /// + /// # Caveats + /// + /// This should not generally be used, but is available for special cases + /// such as: + /// + /// - binding parameters where a gap exists. + /// - binding named and positional parameters in the same query. + /// - separating parameter binding from query execution. + /// + /// In general, statements that have had *any* parameters bound this way + /// should have *all* parameters bound this way, and be queried or executed + /// by [`Statement::raw_query`] or [`Statement::raw_execute`], other usage + /// is unsupported and will likely, probably in surprising ways. + /// + /// That is: Do not mix the "raw" statement functions with the rest of the + /// API, or the results may be surprising, and may even change in future + /// versions without comment. + /// + /// # Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn query(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("SELECT * FROM test WHERE name = :name AND value > ?2")?; + /// let name_index = stmt.parameter_index(":name")?.expect("No such parameter"); + /// stmt.raw_bind_parameter(name_index, "foo")?; + /// stmt.raw_bind_parameter(2, 100)?; + /// let mut rows = stmt.raw_query(); + /// while let Some(row) = rows.next()? { + /// // ... + /// } + /// Ok(()) + /// } + /// ``` + #[inline] + pub fn raw_bind_parameter<T: ToSql>( + &mut self, + one_based_col_index: usize, + param: T, + ) -> Result<()> { + // This is the same as `bind_parameter` but slightly more ergonomic and + // correctly takes `&mut self`. + self.bind_parameter(¶m, one_based_col_index) + } + + /// Low level API to execute a statement given that all parameters were + /// bound explicitly with the [`Statement::raw_bind_parameter`] API. + /// + /// # Caveats + /// + /// Any unbound parameters will have `NULL` as their value. + /// + /// This should not generally be used outside of special cases, and + /// functions in the [`Statement::execute`] family should be preferred. + /// + /// # Failure + /// + /// Will return `Err` if the executed statement returns rows (in which case + /// `query` should be used instead), or the underlying SQLite call fails. + #[inline] + pub fn raw_execute(&mut self) -> Result<usize> { + self.execute_with_bound_parameters() + } + + /// Low level API to get `Rows` for this query given that all parameters + /// were bound explicitly with the [`Statement::raw_bind_parameter`] API. + /// + /// # Caveats + /// + /// Any unbound parameters will have `NULL` as their value. + /// + /// This should not generally be used outside of special cases, and + /// functions in the [`Statement::query`] family should be preferred. + /// + /// Note that if the SQL does not return results, [`Statement::raw_execute`] + /// should be used instead. + #[inline] + pub fn raw_query(&mut self) -> Rows<'_> { + Rows::new(self) + } + + // generic because many of these branches can constant fold away. + fn bind_parameter<P: ?Sized + ToSql>(&self, param: &P, col: usize) -> Result<()> { + let value = param.to_sql()?; + + let ptr = unsafe { self.stmt.ptr() }; + let value = match value { + ToSqlOutput::Borrowed(v) => v, + ToSqlOutput::Owned(ref v) => ValueRef::from(v), + + #[cfg(feature = "blob")] + ToSqlOutput::ZeroBlob(len) => { + // TODO sqlite3_bind_zeroblob64 // 3.8.11 + return self + .conn + .decode_result(unsafe { ffi::sqlite3_bind_zeroblob(ptr, col as c_int, len) }); + } + #[cfg(feature = "array")] + ToSqlOutput::Array(a) => { + return self.conn.decode_result(unsafe { + ffi::sqlite3_bind_pointer( + ptr, + col as c_int, + Rc::into_raw(a) as *mut c_void, + ARRAY_TYPE, + Some(free_array), + ) + }); + } + }; + self.conn.decode_result(match value { + ValueRef::Null => unsafe { ffi::sqlite3_bind_null(ptr, col as c_int) }, + ValueRef::Integer(i) => unsafe { ffi::sqlite3_bind_int64(ptr, col as c_int, i) }, + ValueRef::Real(r) => unsafe { ffi::sqlite3_bind_double(ptr, col as c_int, r) }, + ValueRef::Text(s) => unsafe { + let (c_str, len, destructor) = str_for_sqlite(s)?; + // TODO sqlite3_bind_text64 // 3.8.7 + ffi::sqlite3_bind_text(ptr, col as c_int, c_str, len, destructor) + }, + ValueRef::Blob(b) => unsafe { + let length = len_as_c_int(b.len())?; + if length == 0 { + ffi::sqlite3_bind_zeroblob(ptr, col as c_int, 0) + } else { + // TODO sqlite3_bind_blob64 // 3.8.7 + ffi::sqlite3_bind_blob( + ptr, + col as c_int, + b.as_ptr().cast::<c_void>(), + length, + ffi::SQLITE_TRANSIENT(), + ) + } + }, + }) + } + + #[inline] + fn execute_with_bound_parameters(&mut self) -> Result<usize> { + self.check_update()?; + let r = self.stmt.step(); + self.stmt.reset(); + match r { + ffi::SQLITE_DONE => Ok(self.conn.changes() as usize), + ffi::SQLITE_ROW => Err(Error::ExecuteReturnedResults), + _ => Err(self.conn.decode_result(r).unwrap_err()), + } + } + + #[inline] + fn finalize_(&mut self) -> Result<()> { + let mut stmt = unsafe { RawStatement::new(ptr::null_mut(), 0) }; + mem::swap(&mut stmt, &mut self.stmt); + self.conn.decode_result(stmt.finalize()) + } + + #[cfg(feature = "extra_check")] + #[inline] + fn check_update(&self) -> Result<()> { + // sqlite3_column_count works for DML but not for DDL (ie ALTER) + if self.column_count() > 0 && self.stmt.readonly() { + return Err(Error::ExecuteReturnedResults); + } + Ok(()) + } + + #[cfg(not(feature = "extra_check"))] + #[inline] + #[allow(clippy::unnecessary_wraps)] + fn check_update(&self) -> Result<()> { + Ok(()) + } + + /// Returns a string containing the SQL text of prepared statement with + /// bound parameters expanded. + pub fn expanded_sql(&self) -> Option<String> { + self.stmt + .expanded_sql() + .map(|s| s.to_string_lossy().to_string()) + } + + /// Get the value for one of the status counters for this statement. + #[inline] + pub fn get_status(&self, status: StatementStatus) -> i32 { + self.stmt.get_status(status, false) + } + + /// Reset the value of one of the status counters for this statement, + #[inline] + /// returning the value it had before resetting. + pub fn reset_status(&self, status: StatementStatus) -> i32 { + self.stmt.get_status(status, true) + } + + /// Returns 1 if the prepared statement is an EXPLAIN statement, + /// or 2 if the statement is an EXPLAIN QUERY PLAN, + /// or 0 if it is an ordinary statement or a NULL pointer. + #[inline] + #[cfg(feature = "modern_sqlite")] // 3.28.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn is_explain(&self) -> i32 { + self.stmt.is_explain() + } + + /// Returns true if the statement is read only. + #[inline] + pub fn readonly(&self) -> bool { + self.stmt.readonly() + } + + #[cfg(feature = "extra_check")] + #[inline] + pub(crate) fn check_no_tail(&self) -> Result<()> { + if self.stmt.has_tail() { + Err(Error::MultipleStatement) + } else { + Ok(()) + } + } + + #[cfg(not(feature = "extra_check"))] + #[inline] + #[allow(clippy::unnecessary_wraps)] + pub(crate) fn check_no_tail(&self) -> Result<()> { + Ok(()) + } + + /// Safety: This is unsafe, because using `sqlite3_stmt` after the + /// connection has closed is illegal, but `RawStatement` does not enforce + /// this, as it loses our protective `'conn` lifetime bound. + #[inline] + pub(crate) unsafe fn into_raw(mut self) -> RawStatement { + let mut stmt = RawStatement::new(ptr::null_mut(), 0); + mem::swap(&mut stmt, &mut self.stmt); + stmt + } + + /// Reset all bindings + pub fn clear_bindings(&mut self) { + self.stmt.clear_bindings(); + } +} + +impl fmt::Debug for Statement<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let sql = if self.stmt.is_null() { + Ok("") + } else { + str::from_utf8(self.stmt.sql().unwrap().to_bytes()) + }; + f.debug_struct("Statement") + .field("conn", self.conn) + .field("stmt", &self.stmt) + .field("sql", &sql) + .finish() + } +} + +impl Drop for Statement<'_> { + #[allow(unused_must_use)] + #[inline] + fn drop(&mut self) { + self.finalize_(); + } +} + +impl Statement<'_> { + #[inline] + pub(super) fn new(conn: &Connection, stmt: RawStatement) -> Statement<'_> { + Statement { conn, stmt } + } + + pub(super) fn value_ref(&self, col: usize) -> ValueRef<'_> { + let raw = unsafe { self.stmt.ptr() }; + + match self.stmt.column_type(col) { + ffi::SQLITE_NULL => ValueRef::Null, + ffi::SQLITE_INTEGER => { + ValueRef::Integer(unsafe { ffi::sqlite3_column_int64(raw, col as c_int) }) + } + ffi::SQLITE_FLOAT => { + ValueRef::Real(unsafe { ffi::sqlite3_column_double(raw, col as c_int) }) + } + ffi::SQLITE_TEXT => { + let s = unsafe { + // Quoting from "Using SQLite" book: + // To avoid problems, an application should first extract the desired type using + // a sqlite3_column_xxx() function, and then call the + // appropriate sqlite3_column_bytes() function. + let text = ffi::sqlite3_column_text(raw, col as c_int); + let len = ffi::sqlite3_column_bytes(raw, col as c_int); + assert!( + !text.is_null(), + "unexpected SQLITE_TEXT column type with NULL data" + ); + from_raw_parts(text.cast::<u8>(), len as usize) + }; + + ValueRef::Text(s) + } + ffi::SQLITE_BLOB => { + let (blob, len) = unsafe { + ( + ffi::sqlite3_column_blob(raw, col as c_int), + ffi::sqlite3_column_bytes(raw, col as c_int), + ) + }; + + assert!( + len >= 0, + "unexpected negative return from sqlite3_column_bytes" + ); + if len > 0 { + assert!( + !blob.is_null(), + "unexpected SQLITE_BLOB column type with NULL data" + ); + ValueRef::Blob(unsafe { from_raw_parts(blob.cast::<u8>(), len as usize) }) + } else { + // The return value from sqlite3_column_blob() for a zero-length BLOB + // is a NULL pointer. + ValueRef::Blob(&[]) + } + } + _ => unreachable!("sqlite3_column_type returned invalid value"), + } + } + + #[inline] + pub(super) fn step(&self) -> Result<bool> { + match self.stmt.step() { + ffi::SQLITE_ROW => Ok(true), + ffi::SQLITE_DONE => Ok(false), + code => Err(self.conn.decode_result(code).unwrap_err()), + } + } + + #[inline] + pub(super) fn reset(&self) -> c_int { + self.stmt.reset() + } +} + +/// Prepared statement status counters. +/// +/// See `https://www.sqlite.org/c3ref/c_stmtstatus_counter.html` +/// for explanations of each. +/// +/// Note that depending on your version of SQLite, all of these +/// may not be available. +#[repr(i32)] +#[derive(Clone, Copy, PartialEq, Eq)] +#[non_exhaustive] +pub enum StatementStatus { + /// Equivalent to SQLITE_STMTSTATUS_FULLSCAN_STEP + FullscanStep = 1, + /// Equivalent to SQLITE_STMTSTATUS_SORT + Sort = 2, + /// Equivalent to SQLITE_STMTSTATUS_AUTOINDEX + AutoIndex = 3, + /// Equivalent to SQLITE_STMTSTATUS_VM_STEP + VmStep = 4, + /// Equivalent to SQLITE_STMTSTATUS_REPREPARE (3.20.0) + RePrepare = 5, + /// Equivalent to SQLITE_STMTSTATUS_RUN (3.20.0) + Run = 6, + /// Equivalent to SQLITE_STMTSTATUS_FILTER_MISS + FilterMiss = 7, + /// Equivalent to SQLITE_STMTSTATUS_FILTER_HIT + FilterHit = 8, + /// Equivalent to SQLITE_STMTSTATUS_MEMUSED (3.20.0) + MemUsed = 99, +} + +#[cfg(test)] +mod test { + use crate::types::ToSql; + use crate::{params_from_iter, Connection, Error, Result}; + + #[test] + fn test_execute_named() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER)")?; + + assert_eq!( + db.execute("INSERT INTO foo(x) VALUES (:x)", &[(":x", &1i32)])?, + 1 + ); + assert_eq!( + db.execute("INSERT INTO foo(x) VALUES (:x)", &[(":x", &2i32)])?, + 1 + ); + assert_eq!( + db.execute( + "INSERT INTO foo(x) VALUES (:x)", + crate::named_params! {":x": 3i32} + )?, + 1 + ); + + assert_eq!( + 6i32, + db.query_row::<i32, _, _>( + "SELECT SUM(x) FROM foo WHERE x > :x", + &[(":x", &0i32)], + |r| r.get(0) + )? + ); + assert_eq!( + 5i32, + db.query_row::<i32, _, _>( + "SELECT SUM(x) FROM foo WHERE x > :x", + &[(":x", &1i32)], + |r| r.get(0) + )? + ); + Ok(()) + } + + #[test] + fn test_stmt_execute_named() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag \ + INTEGER)"; + db.execute_batch(sql)?; + + let mut stmt = db.prepare("INSERT INTO test (name) VALUES (:name)")?; + stmt.execute(&[(":name", &"one")])?; + + let mut stmt = db.prepare("SELECT COUNT(*) FROM test WHERE name = :name")?; + assert_eq!( + 1i32, + stmt.query_row::<i32, _, _>(&[(":name", "one")], |r| r.get(0))? + ); + Ok(()) + } + + #[test] + fn test_query_named() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = r#" + CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag INTEGER); + INSERT INTO test(id, name) VALUES (1, "one"); + "#; + db.execute_batch(sql)?; + + let mut stmt = db.prepare("SELECT id FROM test where name = :name")?; + let mut rows = stmt.query(&[(":name", "one")])?; + let id: Result<i32> = rows.next()?.unwrap().get(0); + assert_eq!(Ok(1), id); + Ok(()) + } + + #[test] + fn test_query_map_named() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = r#" + CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag INTEGER); + INSERT INTO test(id, name) VALUES (1, "one"); + "#; + db.execute_batch(sql)?; + + let mut stmt = db.prepare("SELECT id FROM test where name = :name")?; + let mut rows = stmt.query_map(&[(":name", "one")], |row| { + let id: Result<i32> = row.get(0); + id.map(|i| 2 * i) + })?; + + let doubled_id: i32 = rows.next().unwrap()?; + assert_eq!(2, doubled_id); + Ok(()) + } + + #[test] + fn test_query_and_then_by_name() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = r#" + CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag INTEGER); + INSERT INTO test(id, name) VALUES (1, "one"); + INSERT INTO test(id, name) VALUES (2, "one"); + "#; + db.execute_batch(sql)?; + + let mut stmt = db.prepare("SELECT id FROM test where name = :name ORDER BY id ASC")?; + let mut rows = stmt.query_and_then(&[(":name", "one")], |row| { + let id: i32 = row.get(0)?; + if id == 1 { + Ok(id) + } else { + Err(Error::SqliteSingleThreadedMode) + } + })?; + + // first row should be Ok + let doubled_id: i32 = rows.next().unwrap()?; + assert_eq!(1, doubled_id); + + // second row should be Err + #[allow(clippy::match_wild_err_arm)] + match rows.next().unwrap() { + Ok(_) => panic!("invalid Ok"), + Err(Error::SqliteSingleThreadedMode) => (), + Err(_) => panic!("invalid Err"), + } + Ok(()) + } + + #[test] + fn test_unbound_parameters_are_null() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "CREATE TABLE test (x TEXT, y TEXT)"; + db.execute_batch(sql)?; + + let mut stmt = db.prepare("INSERT INTO test (x, y) VALUES (:x, :y)")?; + stmt.execute(&[(":x", &"one")])?; + + let result: Option<String> = db.one_column("SELECT y FROM test WHERE x = 'one'")?; + assert!(result.is_none()); + Ok(()) + } + + #[test] + fn test_raw_binding() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE test (name TEXT, value INTEGER)")?; + { + let mut stmt = db.prepare("INSERT INTO test (name, value) VALUES (:name, ?3)")?; + + let name_idx = stmt.parameter_index(":name")?.unwrap(); + stmt.raw_bind_parameter(name_idx, "example")?; + stmt.raw_bind_parameter(3, 50i32)?; + let n = stmt.raw_execute()?; + assert_eq!(n, 1); + } + + { + let mut stmt = db.prepare("SELECT name, value FROM test WHERE value = ?2")?; + stmt.raw_bind_parameter(2, 50)?; + let mut rows = stmt.raw_query(); + { + let row = rows.next()?.unwrap(); + let name: String = row.get(0)?; + assert_eq!(name, "example"); + let value: i32 = row.get(1)?; + assert_eq!(value, 50); + } + assert!(rows.next()?.is_none()); + } + + Ok(()) + } + + #[test] + fn test_unbound_parameters_are_reused() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "CREATE TABLE test (x TEXT, y TEXT)"; + db.execute_batch(sql)?; + + let mut stmt = db.prepare("INSERT INTO test (x, y) VALUES (:x, :y)")?; + stmt.execute(&[(":x", "one")])?; + stmt.execute(&[(":y", "two")])?; + + let result: String = db.one_column("SELECT x FROM test WHERE y = 'two'")?; + assert_eq!(result, "one"); + Ok(()) + } + + #[test] + fn test_insert() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo(x INTEGER UNIQUE)")?; + let mut stmt = db.prepare("INSERT OR IGNORE INTO foo (x) VALUES (?1)")?; + assert_eq!(stmt.insert([1i32])?, 1); + assert_eq!(stmt.insert([2i32])?, 2); + match stmt.insert([1i32]).unwrap_err() { + Error::StatementChangedRows(0) => (), + err => panic!("Unexpected error {err}"), + } + let mut multi = db.prepare("INSERT INTO foo (x) SELECT 3 UNION ALL SELECT 4")?; + match multi.insert([]).unwrap_err() { + Error::StatementChangedRows(2) => (), + err => panic!("Unexpected error {err}"), + } + Ok(()) + } + + #[test] + fn test_insert_different_tables() -> Result<()> { + // Test for https://github.com/rusqlite/rusqlite/issues/171 + let db = Connection::open_in_memory()?; + db.execute_batch( + r" + CREATE TABLE foo(x INTEGER); + CREATE TABLE bar(x INTEGER); + ", + )?; + + assert_eq!(db.prepare("INSERT INTO foo VALUES (10)")?.insert([])?, 1); + assert_eq!(db.prepare("INSERT INTO bar VALUES (10)")?.insert([])?, 1); + Ok(()) + } + + #[test] + fn test_exists() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(1); + INSERT INTO foo VALUES(2); + END;"; + db.execute_batch(sql)?; + let mut stmt = db.prepare("SELECT 1 FROM foo WHERE x = ?1")?; + assert!(stmt.exists([1i32])?); + assert!(stmt.exists([2i32])?); + assert!(!stmt.exists([0i32])?); + Ok(()) + } + #[test] + fn test_tuple_params() -> Result<()> { + let db = Connection::open_in_memory()?; + let s = db.query_row("SELECT printf('[%s]', ?1)", ("abc",), |r| { + r.get::<_, String>(0) + })?; + assert_eq!(s, "[abc]"); + let s = db.query_row( + "SELECT printf('%d %s %d', ?1, ?2, ?3)", + (1i32, "abc", 2i32), + |r| r.get::<_, String>(0), + )?; + assert_eq!(s, "1 abc 2"); + let s = db.query_row( + "SELECT printf('%d %s %d %d', ?1, ?2, ?3, ?4)", + (1, "abc", 2i32, 4i64), + |r| r.get::<_, String>(0), + )?; + assert_eq!(s, "1 abc 2 4"); + #[rustfmt::skip] + let bigtup = ( + 0, "a", 1, "b", 2, "c", 3, "d", + 4, "e", 5, "f", 6, "g", 7, "h", + ); + let query = "SELECT printf( + '%d %s | %d %s | %d %s | %d %s || %d %s | %d %s | %d %s | %d %s', + ?1, ?2, ?3, ?4, + ?5, ?6, ?7, ?8, + ?9, ?10, ?11, ?12, + ?13, ?14, ?15, ?16 + )"; + let s = db.query_row(query, bigtup, |r| r.get::<_, String>(0))?; + assert_eq!(s, "0 a | 1 b | 2 c | 3 d || 4 e | 5 f | 6 g | 7 h"); + Ok(()) + } + + #[test] + fn test_query_row() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y INTEGER); + INSERT INTO foo VALUES(1, 3); + INSERT INTO foo VALUES(2, 4); + END;"; + db.execute_batch(sql)?; + let mut stmt = db.prepare("SELECT y FROM foo WHERE x = ?1")?; + let y: Result<i64> = stmt.query_row([1i32], |r| r.get(0)); + assert_eq!(3i64, y?); + Ok(()) + } + + #[test] + fn test_query_by_column_name() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y INTEGER); + INSERT INTO foo VALUES(1, 3); + END;"; + db.execute_batch(sql)?; + let mut stmt = db.prepare("SELECT y FROM foo")?; + let y: Result<i64> = stmt.query_row([], |r| r.get("y")); + assert_eq!(3i64, y?); + Ok(()) + } + + #[test] + fn test_query_by_column_name_ignore_case() -> Result<()> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y INTEGER); + INSERT INTO foo VALUES(1, 3); + END;"; + db.execute_batch(sql)?; + let mut stmt = db.prepare("SELECT y as Y FROM foo")?; + let y: Result<i64> = stmt.query_row([], |r| r.get("y")); + assert_eq!(3i64, y?); + Ok(()) + } + + #[test] + fn test_expanded_sql() -> Result<()> { + let db = Connection::open_in_memory()?; + let stmt = db.prepare("SELECT ?1")?; + stmt.bind_parameter(&1, 1)?; + assert_eq!(Some("SELECT 1".to_owned()), stmt.expanded_sql()); + Ok(()) + } + + #[test] + fn test_bind_parameters() -> Result<()> { + let db = Connection::open_in_memory()?; + // dynamic slice: + db.query_row( + "SELECT ?1, ?2, ?3", + [&1u8 as &dyn ToSql, &"one", &Some("one")], + |row| row.get::<_, u8>(0), + )?; + // existing collection: + let data = vec![1, 2, 3]; + db.query_row("SELECT ?1, ?2, ?3", params_from_iter(&data), |row| { + row.get::<_, u8>(0) + })?; + db.query_row( + "SELECT ?1, ?2, ?3", + params_from_iter(data.as_slice()), + |row| row.get::<_, u8>(0), + )?; + db.query_row("SELECT ?1, ?2, ?3", params_from_iter(data), |row| { + row.get::<_, u8>(0) + })?; + + use std::collections::BTreeSet; + let data: BTreeSet<String> = ["one", "two", "three"] + .iter() + .map(|s| (*s).to_string()) + .collect(); + db.query_row("SELECT ?1, ?2, ?3", params_from_iter(&data), |row| { + row.get::<_, String>(0) + })?; + + let data = [0; 3]; + db.query_row("SELECT ?1, ?2, ?3", params_from_iter(&data), |row| { + row.get::<_, u8>(0) + })?; + db.query_row("SELECT ?1, ?2, ?3", params_from_iter(data.iter()), |row| { + row.get::<_, u8>(0) + })?; + Ok(()) + } + + #[test] + fn test_parameter_name() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE test (name TEXT, value INTEGER)")?; + let stmt = db.prepare("INSERT INTO test (name, value) VALUES (:name, ?3)")?; + assert_eq!(stmt.parameter_name(0), None); + assert_eq!(stmt.parameter_name(1), Some(":name")); + assert_eq!(stmt.parameter_name(2), None); + Ok(()) + } + + #[test] + fn test_empty_stmt() -> Result<()> { + let conn = Connection::open_in_memory()?; + let mut stmt = conn.prepare("")?; + assert_eq!(0, stmt.column_count()); + stmt.parameter_index("test").unwrap(); + stmt.step().unwrap_err(); + stmt.reset(); + stmt.execute([]).unwrap_err(); + Ok(()) + } + + #[test] + fn test_comment_stmt() -> Result<()> { + let conn = Connection::open_in_memory()?; + conn.prepare("/*SELECT 1;*/")?; + Ok(()) + } + + #[test] + fn test_comment_and_sql_stmt() -> Result<()> { + let conn = Connection::open_in_memory()?; + let stmt = conn.prepare("/*...*/ SELECT 1;")?; + assert_eq!(1, stmt.column_count()); + Ok(()) + } + + #[test] + fn test_semi_colon_stmt() -> Result<()> { + let conn = Connection::open_in_memory()?; + let stmt = conn.prepare(";")?; + assert_eq!(0, stmt.column_count()); + Ok(()) + } + + #[test] + fn test_utf16_conversion() -> Result<()> { + let db = Connection::open_in_memory()?; + db.pragma_update(None, "encoding", "UTF-16le")?; + let encoding: String = db.pragma_query_value(None, "encoding", |row| row.get(0))?; + assert_eq!("UTF-16le", encoding); + db.execute_batch("CREATE TABLE foo(x TEXT)")?; + let expected = "テスト"; + db.execute("INSERT INTO foo(x) VALUES (?1)", [&expected])?; + let actual: String = db.one_column("SELECT x FROM foo")?; + assert_eq!(expected, actual); + Ok(()) + } + + #[test] + fn test_nul_byte() -> Result<()> { + let db = Connection::open_in_memory()?; + let expected = "a\x00b"; + let actual: String = db.query_row("SELECT ?1", [expected], |row| row.get(0))?; + assert_eq!(expected, actual); + Ok(()) + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn is_explain() -> Result<()> { + let db = Connection::open_in_memory()?; + let stmt = db.prepare("SELECT 1;")?; + assert_eq!(0, stmt.is_explain()); + Ok(()) + } + + #[test] + fn readonly() -> Result<()> { + let db = Connection::open_in_memory()?; + let stmt = db.prepare("SELECT 1;")?; + assert!(stmt.readonly()); + Ok(()) + } + + #[test] + #[cfg(feature = "modern_sqlite")] // SQLite >= 3.38.0 + fn test_error_offset() -> Result<()> { + use crate::ffi::ErrorCode; + let db = Connection::open_in_memory()?; + let r = db.execute_batch("SELECT CURRENT_TIMESTANP;"); + match r.unwrap_err() { + Error::SqlInputError { error, offset, .. } => { + assert_eq!(error.code, ErrorCode::Unknown); + assert_eq!(offset, 7); + } + err => panic!("Unexpected error {err}"), + } + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/trace.rs b/third_party/rust/rusqlite/src/trace.rs new file mode 100644 index 0000000000..7317a0ca8e --- /dev/null +++ b/third_party/rust/rusqlite/src/trace.rs @@ -0,0 +1,184 @@ +//! Tracing and profiling functions. Error and warning log. + +use std::ffi::{CStr, CString}; +use std::mem; +use std::os::raw::{c_char, c_int, c_void}; +use std::panic::catch_unwind; +use std::ptr; +use std::time::Duration; + +use super::ffi; +use crate::Connection; + +/// Set up the process-wide SQLite error logging callback. +/// +/// # Safety +/// +/// This function is marked unsafe for two reasons: +/// +/// * The function is not threadsafe. No other SQLite calls may be made while +/// `config_log` is running, and multiple threads may not call `config_log` +/// simultaneously. +/// * The provided `callback` itself function has two requirements: +/// * It must not invoke any SQLite calls. +/// * It must be threadsafe if SQLite is used in a multithreaded way. +/// +/// cf [The Error And Warning Log](http://sqlite.org/errlog.html). +#[cfg(not(feature = "loadable_extension"))] +pub unsafe fn config_log(callback: Option<fn(c_int, &str)>) -> crate::Result<()> { + extern "C" fn log_callback(p_arg: *mut c_void, err: c_int, msg: *const c_char) { + let c_slice = unsafe { CStr::from_ptr(msg).to_bytes() }; + let callback: fn(c_int, &str) = unsafe { mem::transmute(p_arg) }; + + let s = String::from_utf8_lossy(c_slice); + drop(catch_unwind(|| callback(err, &s))); + } + + let rc = if let Some(f) = callback { + ffi::sqlite3_config( + ffi::SQLITE_CONFIG_LOG, + log_callback as extern "C" fn(_, _, _), + f as *mut c_void, + ) + } else { + let nullptr: *mut c_void = ptr::null_mut(); + ffi::sqlite3_config(ffi::SQLITE_CONFIG_LOG, nullptr, nullptr) + }; + + if rc == ffi::SQLITE_OK { + Ok(()) + } else { + Err(crate::error::error_from_sqlite_code(rc, None)) + } +} + +/// Write a message into the error log established by +/// `config_log`. +#[inline] +pub fn log(err_code: c_int, msg: &str) { + let msg = CString::new(msg).expect("SQLite log messages cannot contain embedded zeroes"); + unsafe { + ffi::sqlite3_log(err_code, b"%s\0" as *const _ as *const c_char, msg.as_ptr()); + } +} + +impl Connection { + /// Register or clear a callback function that can be + /// used for tracing the execution of SQL statements. + /// + /// Prepared statement placeholders are replaced/logged with their assigned + /// values. There can only be a single tracer defined for each database + /// connection. Setting a new tracer clears the old one. + pub fn trace(&mut self, trace_fn: Option<fn(&str)>) { + unsafe extern "C" fn trace_callback(p_arg: *mut c_void, z_sql: *const c_char) { + let trace_fn: fn(&str) = mem::transmute(p_arg); + let c_slice = CStr::from_ptr(z_sql).to_bytes(); + let s = String::from_utf8_lossy(c_slice); + drop(catch_unwind(|| trace_fn(&s))); + } + + let c = self.db.borrow_mut(); + match trace_fn { + Some(f) => unsafe { + ffi::sqlite3_trace(c.db(), Some(trace_callback), f as *mut c_void); + }, + None => unsafe { + ffi::sqlite3_trace(c.db(), None, ptr::null_mut()); + }, + } + } + + /// Register or clear a callback function that can be + /// used for profiling the execution of SQL statements. + /// + /// There can only be a single profiler defined for each database + /// connection. Setting a new profiler clears the old one. + pub fn profile(&mut self, profile_fn: Option<fn(&str, Duration)>) { + unsafe extern "C" fn profile_callback( + p_arg: *mut c_void, + z_sql: *const c_char, + nanoseconds: u64, + ) { + let profile_fn: fn(&str, Duration) = mem::transmute(p_arg); + let c_slice = CStr::from_ptr(z_sql).to_bytes(); + let s = String::from_utf8_lossy(c_slice); + const NANOS_PER_SEC: u64 = 1_000_000_000; + + let duration = Duration::new( + nanoseconds / NANOS_PER_SEC, + (nanoseconds % NANOS_PER_SEC) as u32, + ); + drop(catch_unwind(|| profile_fn(&s, duration))); + } + + let c = self.db.borrow_mut(); + match profile_fn { + Some(f) => unsafe { + ffi::sqlite3_profile(c.db(), Some(profile_callback), f as *mut c_void) + }, + None => unsafe { ffi::sqlite3_profile(c.db(), None, ptr::null_mut()) }, + }; + } + + // TODO sqlite3_trace_v2 (https://sqlite.org/c3ref/trace_v2.html) // 3.14.0, #977 +} + +#[cfg(test)] +mod test { + use lazy_static::lazy_static; + use std::sync::Mutex; + use std::time::Duration; + + use crate::{Connection, Result}; + + #[test] + fn test_trace() -> Result<()> { + lazy_static! { + static ref TRACED_STMTS: Mutex<Vec<String>> = Mutex::new(Vec::new()); + } + fn tracer(s: &str) { + let mut traced_stmts = TRACED_STMTS.lock().unwrap(); + traced_stmts.push(s.to_owned()); + } + + let mut db = Connection::open_in_memory()?; + db.trace(Some(tracer)); + { + let _ = db.query_row("SELECT ?1", [1i32], |_| Ok(())); + let _ = db.query_row("SELECT ?1", ["hello"], |_| Ok(())); + } + db.trace(None); + { + let _ = db.query_row("SELECT ?1", [2i32], |_| Ok(())); + let _ = db.query_row("SELECT ?1", ["goodbye"], |_| Ok(())); + } + + let traced_stmts = TRACED_STMTS.lock().unwrap(); + assert_eq!(traced_stmts.len(), 2); + assert_eq!(traced_stmts[0], "SELECT 1"); + assert_eq!(traced_stmts[1], "SELECT 'hello'"); + Ok(()) + } + + #[test] + fn test_profile() -> Result<()> { + lazy_static! { + static ref PROFILED: Mutex<Vec<(String, Duration)>> = Mutex::new(Vec::new()); + } + fn profiler(s: &str, d: Duration) { + let mut profiled = PROFILED.lock().unwrap(); + profiled.push((s.to_owned(), d)); + } + + let mut db = Connection::open_in_memory()?; + db.profile(Some(profiler)); + db.execute_batch("PRAGMA application_id = 1")?; + db.profile(None); + db.execute_batch("PRAGMA application_id = 2")?; + + let profiled = PROFILED.lock().unwrap(); + assert_eq!(profiled.len(), 1); + assert_eq!(profiled[0].0, "PRAGMA application_id = 1"); + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/transaction.rs b/third_party/rust/rusqlite/src/transaction.rs new file mode 100644 index 0000000000..2e905286b9 --- /dev/null +++ b/third_party/rust/rusqlite/src/transaction.rs @@ -0,0 +1,778 @@ +use crate::{Connection, Result}; +use std::ops::Deref; + +/// Options for transaction behavior. See [BEGIN +/// TRANSACTION](http://www.sqlite.org/lang_transaction.html) for details. +#[derive(Copy, Clone)] +#[non_exhaustive] +pub enum TransactionBehavior { + /// DEFERRED means that the transaction does not actually start until the + /// database is first accessed. + Deferred, + /// IMMEDIATE cause the database connection to start a new write + /// immediately, without waiting for a writes statement. + Immediate, + /// EXCLUSIVE prevents other database connections from reading the database + /// while the transaction is underway. + Exclusive, +} + +/// Options for how a Transaction or Savepoint should behave when it is dropped. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum DropBehavior { + /// Roll back the changes. This is the default. + Rollback, + + /// Commit the changes. + Commit, + + /// Do not commit or roll back changes - this will leave the transaction or + /// savepoint open, so should be used with care. + Ignore, + + /// Panic. Used to enforce intentional behavior during development. + Panic, +} + +/// Represents a transaction on a database connection. +/// +/// ## Note +/// +/// Transactions will roll back by default. Use `commit` method to explicitly +/// commit the transaction, or use `set_drop_behavior` to change what happens +/// when the transaction is dropped. +/// +/// ## Example +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result}; +/// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) } +/// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) } +/// fn perform_queries(conn: &mut Connection) -> Result<()> { +/// let tx = conn.transaction()?; +/// +/// do_queries_part_1(&tx)?; // tx causes rollback if this fails +/// do_queries_part_2(&tx)?; // tx causes rollback if this fails +/// +/// tx.commit() +/// } +/// ``` +#[derive(Debug)] +pub struct Transaction<'conn> { + conn: &'conn Connection, + drop_behavior: DropBehavior, +} + +/// Represents a savepoint on a database connection. +/// +/// ## Note +/// +/// Savepoints will roll back by default. Use `commit` method to explicitly +/// commit the savepoint, or use `set_drop_behavior` to change what happens +/// when the savepoint is dropped. +/// +/// ## Example +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result}; +/// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) } +/// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) } +/// fn perform_queries(conn: &mut Connection) -> Result<()> { +/// let sp = conn.savepoint()?; +/// +/// do_queries_part_1(&sp)?; // sp causes rollback if this fails +/// do_queries_part_2(&sp)?; // sp causes rollback if this fails +/// +/// sp.commit() +/// } +/// ``` +#[derive(Debug)] +pub struct Savepoint<'conn> { + conn: &'conn Connection, + name: String, + drop_behavior: DropBehavior, + committed: bool, +} + +impl Transaction<'_> { + /// Begin a new transaction. Cannot be nested; see `savepoint` for nested + /// transactions. + /// + /// Even though we don't mutate the connection, we take a `&mut Connection` + /// so as to prevent nested transactions on the same connection. For cases + /// where this is unacceptable, [`Transaction::new_unchecked`] is available. + #[inline] + pub fn new(conn: &mut Connection, behavior: TransactionBehavior) -> Result<Transaction<'_>> { + Self::new_unchecked(conn, behavior) + } + + /// Begin a new transaction, failing if a transaction is open. + /// + /// If a transaction is already open, this will return an error. Where + /// possible, [`Transaction::new`] should be preferred, as it provides a + /// compile-time guarantee that transactions are not nested. + #[inline] + pub fn new_unchecked( + conn: &Connection, + behavior: TransactionBehavior, + ) -> Result<Transaction<'_>> { + let query = match behavior { + TransactionBehavior::Deferred => "BEGIN DEFERRED", + TransactionBehavior::Immediate => "BEGIN IMMEDIATE", + TransactionBehavior::Exclusive => "BEGIN EXCLUSIVE", + }; + conn.execute_batch(query).map(move |()| Transaction { + conn, + drop_behavior: DropBehavior::Rollback, + }) + } + + /// Starts a new [savepoint](http://www.sqlite.org/lang_savepoint.html), allowing nested + /// transactions. + /// + /// ## Note + /// + /// Just like outer level transactions, savepoint transactions rollback by + /// default. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// # fn perform_queries_part_1_succeeds(_conn: &Connection) -> bool { true } + /// fn perform_queries(conn: &mut Connection) -> Result<()> { + /// let mut tx = conn.transaction()?; + /// + /// { + /// let sp = tx.savepoint()?; + /// if perform_queries_part_1_succeeds(&sp) { + /// sp.commit()?; + /// } + /// // otherwise, sp will rollback + /// } + /// + /// tx.commit() + /// } + /// ``` + #[inline] + pub fn savepoint(&mut self) -> Result<Savepoint<'_>> { + Savepoint::new_(self.conn) + } + + /// Create a new savepoint with a custom savepoint name. See `savepoint()`. + #[inline] + pub fn savepoint_with_name<T: Into<String>>(&mut self, name: T) -> Result<Savepoint<'_>> { + Savepoint::with_name_(self.conn, name) + } + + /// Get the current setting for what happens to the transaction when it is + /// dropped. + #[inline] + #[must_use] + pub fn drop_behavior(&self) -> DropBehavior { + self.drop_behavior + } + + /// Configure the transaction to perform the specified action when it is + /// dropped. + #[inline] + pub fn set_drop_behavior(&mut self, drop_behavior: DropBehavior) { + self.drop_behavior = drop_behavior; + } + + /// A convenience method which consumes and commits a transaction. + #[inline] + pub fn commit(mut self) -> Result<()> { + self.commit_() + } + + #[inline] + fn commit_(&mut self) -> Result<()> { + self.conn.execute_batch("COMMIT")?; + Ok(()) + } + + /// A convenience method which consumes and rolls back a transaction. + #[inline] + pub fn rollback(mut self) -> Result<()> { + self.rollback_() + } + + #[inline] + fn rollback_(&mut self) -> Result<()> { + self.conn.execute_batch("ROLLBACK")?; + Ok(()) + } + + /// Consumes the transaction, committing or rolling back according to the + /// current setting (see `drop_behavior`). + /// + /// Functionally equivalent to the `Drop` implementation, but allows + /// callers to see any errors that occur. + #[inline] + pub fn finish(mut self) -> Result<()> { + self.finish_() + } + + #[inline] + fn finish_(&mut self) -> Result<()> { + if self.conn.is_autocommit() { + return Ok(()); + } + match self.drop_behavior() { + DropBehavior::Commit => self.commit_().or_else(|_| self.rollback_()), + DropBehavior::Rollback => self.rollback_(), + DropBehavior::Ignore => Ok(()), + DropBehavior::Panic => panic!("Transaction dropped unexpectedly."), + } + } +} + +impl Deref for Transaction<'_> { + type Target = Connection; + + #[inline] + fn deref(&self) -> &Connection { + self.conn + } +} + +#[allow(unused_must_use)] +impl Drop for Transaction<'_> { + #[inline] + fn drop(&mut self) { + self.finish_(); + } +} + +impl Savepoint<'_> { + #[inline] + fn with_name_<T: Into<String>>(conn: &Connection, name: T) -> Result<Savepoint<'_>> { + let name = name.into(); + conn.execute_batch(&format!("SAVEPOINT {name}")) + .map(|()| Savepoint { + conn, + name, + drop_behavior: DropBehavior::Rollback, + committed: false, + }) + } + + #[inline] + fn new_(conn: &Connection) -> Result<Savepoint<'_>> { + Savepoint::with_name_(conn, "_rusqlite_sp") + } + + /// Begin a new savepoint. Can be nested. + #[inline] + pub fn new(conn: &mut Connection) -> Result<Savepoint<'_>> { + Savepoint::new_(conn) + } + + /// Begin a new savepoint with a user-provided savepoint name. + #[inline] + pub fn with_name<T: Into<String>>(conn: &mut Connection, name: T) -> Result<Savepoint<'_>> { + Savepoint::with_name_(conn, name) + } + + /// Begin a nested savepoint. + #[inline] + pub fn savepoint(&mut self) -> Result<Savepoint<'_>> { + Savepoint::new_(self.conn) + } + + /// Begin a nested savepoint with a user-provided savepoint name. + #[inline] + pub fn savepoint_with_name<T: Into<String>>(&mut self, name: T) -> Result<Savepoint<'_>> { + Savepoint::with_name_(self.conn, name) + } + + /// Get the current setting for what happens to the savepoint when it is + /// dropped. + #[inline] + #[must_use] + pub fn drop_behavior(&self) -> DropBehavior { + self.drop_behavior + } + + /// Configure the savepoint to perform the specified action when it is + /// dropped. + #[inline] + pub fn set_drop_behavior(&mut self, drop_behavior: DropBehavior) { + self.drop_behavior = drop_behavior; + } + + /// A convenience method which consumes and commits a savepoint. + #[inline] + pub fn commit(mut self) -> Result<()> { + self.commit_() + } + + #[inline] + fn commit_(&mut self) -> Result<()> { + self.conn.execute_batch(&format!("RELEASE {}", self.name))?; + self.committed = true; + Ok(()) + } + + /// A convenience method which rolls back a savepoint. + /// + /// ## Note + /// + /// Unlike `Transaction`s, savepoints remain active after they have been + /// rolled back, and can be rolled back again or committed. + #[inline] + pub fn rollback(&mut self) -> Result<()> { + self.conn + .execute_batch(&format!("ROLLBACK TO {}", self.name)) + } + + /// Consumes the savepoint, committing or rolling back according to the + /// current setting (see `drop_behavior`). + /// + /// Functionally equivalent to the `Drop` implementation, but allows + /// callers to see any errors that occur. + #[inline] + pub fn finish(mut self) -> Result<()> { + self.finish_() + } + + #[inline] + fn finish_(&mut self) -> Result<()> { + if self.committed { + return Ok(()); + } + match self.drop_behavior() { + DropBehavior::Commit => self + .commit_() + .or_else(|_| self.rollback().and_then(|()| self.commit_())), + DropBehavior::Rollback => self.rollback().and_then(|()| self.commit_()), + DropBehavior::Ignore => Ok(()), + DropBehavior::Panic => panic!("Savepoint dropped unexpectedly."), + } + } +} + +impl Deref for Savepoint<'_> { + type Target = Connection; + + #[inline] + fn deref(&self) -> &Connection { + self.conn + } +} + +#[allow(unused_must_use)] +impl Drop for Savepoint<'_> { + #[inline] + fn drop(&mut self) { + self.finish_(); + } +} + +/// Transaction state of a database +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[non_exhaustive] +#[cfg(feature = "modern_sqlite")] // 3.37.0 +#[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] +pub enum TransactionState { + /// Equivalent to SQLITE_TXN_NONE + None, + /// Equivalent to SQLITE_TXN_READ + Read, + /// Equivalent to SQLITE_TXN_WRITE + Write, +} + +impl Connection { + /// Begin a new transaction with the default behavior (DEFERRED). + /// + /// The transaction defaults to rolling back when it is dropped. If you + /// want the transaction to commit, you must call + /// [`commit`](Transaction::commit) or + /// [`set_drop_behavior(DropBehavior::Commit)`](Transaction::set_drop_behavior). + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) } + /// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) } + /// fn perform_queries(conn: &mut Connection) -> Result<()> { + /// let tx = conn.transaction()?; + /// + /// do_queries_part_1(&tx)?; // tx causes rollback if this fails + /// do_queries_part_2(&tx)?; // tx causes rollback if this fails + /// + /// tx.commit() + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[inline] + pub fn transaction(&mut self) -> Result<Transaction<'_>> { + Transaction::new(self, TransactionBehavior::Deferred) + } + + /// Begin a new transaction with a specified behavior. + /// + /// See [`transaction`](Connection::transaction). + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[inline] + pub fn transaction_with_behavior( + &mut self, + behavior: TransactionBehavior, + ) -> Result<Transaction<'_>> { + Transaction::new(self, behavior) + } + + /// Begin a new transaction with the default behavior (DEFERRED). + /// + /// Attempt to open a nested transaction will result in a SQLite error. + /// `Connection::transaction` prevents this at compile time by taking `&mut + /// self`, but `Connection::unchecked_transaction()` may be used to defer + /// the checking until runtime. + /// + /// See [`Connection::transaction`] and [`Transaction::new_unchecked`] + /// (which can be used if the default transaction behavior is undesirable). + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// # use std::rc::Rc; + /// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) } + /// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) } + /// fn perform_queries(conn: Rc<Connection>) -> Result<()> { + /// let tx = conn.unchecked_transaction()?; + /// + /// do_queries_part_1(&tx)?; // tx causes rollback if this fails + /// do_queries_part_2(&tx)?; // tx causes rollback if this fails + /// + /// tx.commit() + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. The specific + /// error returned if transactions are nested is currently unspecified. + pub fn unchecked_transaction(&self) -> Result<Transaction<'_>> { + Transaction::new_unchecked(self, TransactionBehavior::Deferred) + } + + /// Begin a new savepoint with the default behavior (DEFERRED). + /// + /// The savepoint defaults to rolling back when it is dropped. If you want + /// the savepoint to commit, you must call [`commit`](Savepoint::commit) or + /// [`set_drop_behavior(DropBehavior::Commit)`](Savepoint::set_drop_behavior). + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) } + /// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) } + /// fn perform_queries(conn: &mut Connection) -> Result<()> { + /// let sp = conn.savepoint()?; + /// + /// do_queries_part_1(&sp)?; // sp causes rollback if this fails + /// do_queries_part_2(&sp)?; // sp causes rollback if this fails + /// + /// sp.commit() + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[inline] + pub fn savepoint(&mut self) -> Result<Savepoint<'_>> { + Savepoint::new(self) + } + + /// Begin a new savepoint with a specified name. + /// + /// See [`savepoint`](Connection::savepoint). + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[inline] + pub fn savepoint_with_name<T: Into<String>>(&mut self, name: T) -> Result<Savepoint<'_>> { + Savepoint::with_name(self, name) + } + + /// Determine the transaction state of a database + #[cfg(feature = "modern_sqlite")] // 3.37.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn transaction_state( + &self, + db_name: Option<crate::DatabaseName<'_>>, + ) -> Result<TransactionState> { + self.db.borrow().txn_state(db_name) + } +} + +#[cfg(test)] +mod test { + use super::DropBehavior; + use crate::{Connection, Error, Result}; + + fn checked_memory_handle() -> Result<Connection> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (x INTEGER)")?; + Ok(db) + } + + #[test] + fn test_drop() -> Result<()> { + let mut db = checked_memory_handle()?; + { + let tx = db.transaction()?; + tx.execute_batch("INSERT INTO foo VALUES(1)")?; + // default: rollback + } + { + let mut tx = db.transaction()?; + tx.execute_batch("INSERT INTO foo VALUES(2)")?; + tx.set_drop_behavior(DropBehavior::Commit) + } + { + let tx = db.transaction()?; + assert_eq!(2i32, tx.one_column::<i32>("SELECT SUM(x) FROM foo")?); + } + Ok(()) + } + fn assert_nested_tx_error(e: Error) { + if let Error::SqliteFailure(e, Some(m)) = &e { + assert_eq!(e.extended_code, crate::ffi::SQLITE_ERROR); + // FIXME: Not ideal... + assert_eq!(e.code, crate::ErrorCode::Unknown); + assert!(m.contains("transaction")); + } else { + panic!("Unexpected error type: {e:?}"); + } + } + + #[test] + fn test_unchecked_nesting() -> Result<()> { + let db = checked_memory_handle()?; + + { + let tx = db.unchecked_transaction()?; + let e = tx.unchecked_transaction().unwrap_err(); + assert_nested_tx_error(e); + // default: rollback + } + { + let tx = db.unchecked_transaction()?; + tx.execute_batch("INSERT INTO foo VALUES(1)")?; + // Ensure this doesn't interfere with ongoing transaction + let e = tx.unchecked_transaction().unwrap_err(); + assert_nested_tx_error(e); + + tx.execute_batch("INSERT INTO foo VALUES(1)")?; + tx.commit()?; + } + + assert_eq!(2i32, db.one_column::<i32>("SELECT SUM(x) FROM foo")?); + Ok(()) + } + + #[test] + fn test_explicit_rollback_commit() -> Result<()> { + let mut db = checked_memory_handle()?; + { + let mut tx = db.transaction()?; + { + let mut sp = tx.savepoint()?; + sp.execute_batch("INSERT INTO foo VALUES(1)")?; + sp.rollback()?; + sp.execute_batch("INSERT INTO foo VALUES(2)")?; + sp.commit()?; + } + tx.commit()?; + } + { + let tx = db.transaction()?; + tx.execute_batch("INSERT INTO foo VALUES(4)")?; + tx.commit()?; + } + { + let tx = db.transaction()?; + assert_eq!(6i32, tx.one_column::<i32>("SELECT SUM(x) FROM foo")?); + } + Ok(()) + } + + #[test] + fn test_savepoint() -> Result<()> { + let mut db = checked_memory_handle()?; + { + let mut tx = db.transaction()?; + tx.execute_batch("INSERT INTO foo VALUES(1)")?; + assert_current_sum(1, &tx)?; + tx.set_drop_behavior(DropBehavior::Commit); + { + let mut sp1 = tx.savepoint()?; + sp1.execute_batch("INSERT INTO foo VALUES(2)")?; + assert_current_sum(3, &sp1)?; + // will rollback sp1 + { + let mut sp2 = sp1.savepoint()?; + sp2.execute_batch("INSERT INTO foo VALUES(4)")?; + assert_current_sum(7, &sp2)?; + // will rollback sp2 + { + let sp3 = sp2.savepoint()?; + sp3.execute_batch("INSERT INTO foo VALUES(8)")?; + assert_current_sum(15, &sp3)?; + sp3.commit()?; + // committed sp3, but will be erased by sp2 rollback + } + assert_current_sum(15, &sp2)?; + } + assert_current_sum(3, &sp1)?; + } + assert_current_sum(1, &tx)?; + } + assert_current_sum(1, &db)?; + Ok(()) + } + + #[test] + fn test_ignore_drop_behavior() -> Result<()> { + let mut db = checked_memory_handle()?; + + let mut tx = db.transaction()?; + { + let mut sp1 = tx.savepoint()?; + insert(1, &sp1)?; + sp1.rollback()?; + insert(2, &sp1)?; + { + let mut sp2 = sp1.savepoint()?; + sp2.set_drop_behavior(DropBehavior::Ignore); + insert(4, &sp2)?; + } + assert_current_sum(6, &sp1)?; + sp1.commit()?; + } + assert_current_sum(6, &tx)?; + Ok(()) + } + + #[test] + fn test_savepoint_drop_behavior_releases() -> Result<()> { + let mut db = checked_memory_handle()?; + + { + let mut sp = db.savepoint()?; + sp.set_drop_behavior(DropBehavior::Commit); + } + assert!(db.is_autocommit()); + { + let mut sp = db.savepoint()?; + sp.set_drop_behavior(DropBehavior::Rollback); + } + assert!(db.is_autocommit()); + + Ok(()) + } + + #[test] + fn test_savepoint_release_error() -> Result<()> { + let mut db = checked_memory_handle()?; + + db.pragma_update(None, "foreign_keys", true)?; + db.execute_batch("CREATE TABLE r(n INTEGER PRIMARY KEY NOT NULL); CREATE TABLE f(n REFERENCES r(n) DEFERRABLE INITIALLY DEFERRED);")?; + { + let mut sp = db.savepoint()?; + sp.execute("INSERT INTO f VALUES (0)", [])?; + sp.set_drop_behavior(DropBehavior::Commit); + } + assert!(db.is_autocommit()); + + Ok(()) + } + + #[test] + fn test_savepoint_names() -> Result<()> { + let mut db = checked_memory_handle()?; + + { + let mut sp1 = db.savepoint_with_name("my_sp")?; + insert(1, &sp1)?; + assert_current_sum(1, &sp1)?; + { + let mut sp2 = sp1.savepoint_with_name("my_sp")?; + sp2.set_drop_behavior(DropBehavior::Commit); + insert(2, &sp2)?; + assert_current_sum(3, &sp2)?; + sp2.rollback()?; + assert_current_sum(1, &sp2)?; + insert(4, &sp2)?; + } + assert_current_sum(5, &sp1)?; + sp1.rollback()?; + { + let mut sp2 = sp1.savepoint_with_name("my_sp")?; + sp2.set_drop_behavior(DropBehavior::Ignore); + insert(8, &sp2)?; + } + assert_current_sum(8, &sp1)?; + sp1.commit()?; + } + assert_current_sum(8, &db)?; + Ok(()) + } + + #[test] + fn test_rc() -> Result<()> { + use std::rc::Rc; + let mut conn = Connection::open_in_memory()?; + let rc_txn = Rc::new(conn.transaction()?); + + // This will compile only if Transaction is Debug + Rc::try_unwrap(rc_txn).unwrap(); + Ok(()) + } + + fn insert(x: i32, conn: &Connection) -> Result<usize> { + conn.execute("INSERT INTO foo VALUES(?1)", [x]) + } + + fn assert_current_sum(x: i32, conn: &Connection) -> Result<()> { + let i = conn.one_column::<i32>("SELECT SUM(x) FROM foo")?; + assert_eq!(x, i); + Ok(()) + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn txn_state() -> Result<()> { + use super::TransactionState; + use crate::DatabaseName; + let db = Connection::open_in_memory()?; + assert_eq!( + TransactionState::None, + db.transaction_state(Some(DatabaseName::Main))? + ); + assert_eq!(TransactionState::None, db.transaction_state(None)?); + db.execute_batch("BEGIN")?; + assert_eq!(TransactionState::None, db.transaction_state(None)?); + let _: i32 = db.pragma_query_value(None, "user_version", |row| row.get(0))?; + assert_eq!(TransactionState::Read, db.transaction_state(None)?); + db.pragma_update(None, "user_version", 1)?; + assert_eq!(TransactionState::Write, db.transaction_state(None)?); + db.execute_batch("ROLLBACK")?; + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/types/chrono.rs b/third_party/rust/rusqlite/src/types/chrono.rs new file mode 100644 index 0000000000..6b50e01099 --- /dev/null +++ b/third_party/rust/rusqlite/src/types/chrono.rs @@ -0,0 +1,319 @@ +//! Convert most of the [Time Strings](http://sqlite.org/lang_datefunc.html) to chrono types. + +use chrono::{DateTime, FixedOffset, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; + +use crate::types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef}; +use crate::Result; + +/// ISO 8601 calendar date without timezone => "YYYY-MM-DD" +impl ToSql for NaiveDate { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + let date_str = self.format("%F").to_string(); + Ok(ToSqlOutput::from(date_str)) + } +} + +/// "YYYY-MM-DD" => ISO 8601 calendar date without timezone. +impl FromSql for NaiveDate { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value + .as_str() + .and_then(|s| match NaiveDate::parse_from_str(s, "%F") { + Ok(dt) => Ok(dt), + Err(err) => Err(FromSqlError::Other(Box::new(err))), + }) + } +} + +/// ISO 8601 time without timezone => "HH:MM:SS.SSS" +impl ToSql for NaiveTime { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + let date_str = self.format("%T%.f").to_string(); + Ok(ToSqlOutput::from(date_str)) + } +} + +/// "HH:MM"/"HH:MM:SS"/"HH:MM:SS.SSS" => ISO 8601 time without timezone. +impl FromSql for NaiveTime { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().and_then(|s| { + let fmt = match s.len() { + 5 => "%H:%M", + 8 => "%T", + _ => "%T%.f", + }; + match NaiveTime::parse_from_str(s, fmt) { + Ok(dt) => Ok(dt), + Err(err) => Err(FromSqlError::Other(Box::new(err))), + } + }) + } +} + +/// ISO 8601 combined date and time without timezone => +/// "YYYY-MM-DD HH:MM:SS.SSS" +impl ToSql for NaiveDateTime { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + let date_str = self.format("%F %T%.f").to_string(); + Ok(ToSqlOutput::from(date_str)) + } +} + +/// "YYYY-MM-DD HH:MM:SS"/"YYYY-MM-DD HH:MM:SS.SSS" => ISO 8601 combined date +/// and time without timezone. ("YYYY-MM-DDTHH:MM:SS"/"YYYY-MM-DDTHH:MM:SS.SSS" +/// also supported) +impl FromSql for NaiveDateTime { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().and_then(|s| { + let fmt = if s.len() >= 11 && s.as_bytes()[10] == b'T' { + "%FT%T%.f" + } else { + "%F %T%.f" + }; + + match NaiveDateTime::parse_from_str(s, fmt) { + Ok(dt) => Ok(dt), + Err(err) => Err(FromSqlError::Other(Box::new(err))), + } + }) + } +} + +/// UTC time => UTC RFC3339 timestamp +/// ("YYYY-MM-DD HH:MM:SS.SSS+00:00"). +impl ToSql for DateTime<Utc> { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + let date_str = self.format("%F %T%.f%:z").to_string(); + Ok(ToSqlOutput::from(date_str)) + } +} + +/// Local time => UTC RFC3339 timestamp +/// ("YYYY-MM-DD HH:MM:SS.SSS+00:00"). +impl ToSql for DateTime<Local> { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + let date_str = self.with_timezone(&Utc).format("%F %T%.f%:z").to_string(); + Ok(ToSqlOutput::from(date_str)) + } +} + +/// Date and time with time zone => RFC3339 timestamp +/// ("YYYY-MM-DD HH:MM:SS.SSS[+-]HH:MM"). +impl ToSql for DateTime<FixedOffset> { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + let date_str = self.format("%F %T%.f%:z").to_string(); + Ok(ToSqlOutput::from(date_str)) + } +} + +/// RFC3339 ("YYYY-MM-DD HH:MM:SS.SSS[+-]HH:MM") into `DateTime<Utc>`. +impl FromSql for DateTime<Utc> { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + { + // Try to parse value as rfc3339 first. + let s = value.as_str()?; + + let fmt = if s.len() >= 11 && s.as_bytes()[10] == b'T' { + "%FT%T%.f%#z" + } else { + "%F %T%.f%#z" + }; + + if let Ok(dt) = DateTime::parse_from_str(s, fmt) { + return Ok(dt.with_timezone(&Utc)); + } + } + + // Couldn't parse as rfc3339 - fall back to NaiveDateTime. + NaiveDateTime::column_result(value).map(|dt| Utc.from_utc_datetime(&dt)) + } +} + +/// RFC3339 ("YYYY-MM-DD HH:MM:SS.SSS[+-]HH:MM") into `DateTime<Local>`. +impl FromSql for DateTime<Local> { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + let utc_dt = DateTime::<Utc>::column_result(value)?; + Ok(utc_dt.with_timezone(&Local)) + } +} + +/// RFC3339 ("YYYY-MM-DD HH:MM:SS.SSS[+-]HH:MM") into `DateTime<FixedOffset>`. +impl FromSql for DateTime<FixedOffset> { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + let s = String::column_result(value)?; + Self::parse_from_rfc3339(s.as_str()) + .or_else(|_| Self::parse_from_str(s.as_str(), "%F %T%.f%:z")) + .map_err(|e| FromSqlError::Other(Box::new(e))) + } +} + +#[cfg(test)] +mod test { + use crate::{ + types::{FromSql, ValueRef}, + Connection, Result, + }; + use chrono::{ + DateTime, Duration, FixedOffset, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc, + }; + + fn checked_memory_handle() -> Result<Connection> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (t TEXT, i INTEGER, f FLOAT, b BLOB)")?; + Ok(db) + } + + #[test] + fn test_naive_date() -> Result<()> { + let db = checked_memory_handle()?; + let date = NaiveDate::from_ymd_opt(2016, 2, 23).unwrap(); + db.execute("INSERT INTO foo (t) VALUES (?1)", [date])?; + + let s: String = db.one_column("SELECT t FROM foo")?; + assert_eq!("2016-02-23", s); + let t: NaiveDate = db.one_column("SELECT t FROM foo")?; + assert_eq!(date, t); + Ok(()) + } + + #[test] + fn test_naive_time() -> Result<()> { + let db = checked_memory_handle()?; + let time = NaiveTime::from_hms_opt(23, 56, 4).unwrap(); + db.execute("INSERT INTO foo (t) VALUES (?1)", [time])?; + + let s: String = db.one_column("SELECT t FROM foo")?; + assert_eq!("23:56:04", s); + let v: NaiveTime = db.one_column("SELECT t FROM foo")?; + assert_eq!(time, v); + Ok(()) + } + + #[test] + fn test_naive_date_time() -> Result<()> { + let db = checked_memory_handle()?; + let date = NaiveDate::from_ymd_opt(2016, 2, 23).unwrap(); + let time = NaiveTime::from_hms_opt(23, 56, 4).unwrap(); + let dt = NaiveDateTime::new(date, time); + + db.execute("INSERT INTO foo (t) VALUES (?1)", [dt])?; + + let s: String = db.one_column("SELECT t FROM foo")?; + assert_eq!("2016-02-23 23:56:04", s); + let v: NaiveDateTime = db.one_column("SELECT t FROM foo")?; + assert_eq!(dt, v); + + db.execute("UPDATE foo set b = datetime(t)", [])?; // "YYYY-MM-DD HH:MM:SS" + let hms: NaiveDateTime = db.one_column("SELECT b FROM foo")?; + assert_eq!(dt, hms); + Ok(()) + } + + #[test] + fn test_date_time_utc() -> Result<()> { + let db = checked_memory_handle()?; + let date = NaiveDate::from_ymd_opt(2016, 2, 23).unwrap(); + let time = NaiveTime::from_hms_milli_opt(23, 56, 4, 789).unwrap(); + let dt = NaiveDateTime::new(date, time); + let utc = Utc.from_utc_datetime(&dt); + + db.execute("INSERT INTO foo (t) VALUES (?1)", [utc])?; + + let s: String = db.one_column("SELECT t FROM foo")?; + assert_eq!("2016-02-23 23:56:04.789+00:00", s); + + let v1: DateTime<Utc> = db.one_column("SELECT t FROM foo")?; + assert_eq!(utc, v1); + + let v2: DateTime<Utc> = db.one_column("SELECT '2016-02-23 23:56:04.789'")?; + assert_eq!(utc, v2); + + let v3: DateTime<Utc> = db.one_column("SELECT '2016-02-23 23:56:04'")?; + assert_eq!(utc - Duration::milliseconds(789), v3); + + let v4: DateTime<Utc> = db.one_column("SELECT '2016-02-23 23:56:04.789+00:00'")?; + assert_eq!(utc, v4); + Ok(()) + } + + #[test] + fn test_date_time_local() -> Result<()> { + let db = checked_memory_handle()?; + let date = NaiveDate::from_ymd_opt(2016, 2, 23).unwrap(); + let time = NaiveTime::from_hms_milli_opt(23, 56, 4, 789).unwrap(); + let dt = NaiveDateTime::new(date, time); + let local = Local.from_local_datetime(&dt).single().unwrap(); + + db.execute("INSERT INTO foo (t) VALUES (?1)", [local])?; + + // Stored string should be in UTC + let s: String = db.one_column("SELECT t FROM foo")?; + assert!(s.ends_with("+00:00")); + + let v: DateTime<Local> = db.one_column("SELECT t FROM foo")?; + assert_eq!(local, v); + Ok(()) + } + + #[test] + fn test_date_time_fixed() -> Result<()> { + let db = checked_memory_handle()?; + let time = DateTime::parse_from_rfc3339("2020-04-07T11:23:45+04:00").unwrap(); + + db.execute("INSERT INTO foo (t) VALUES (?1)", [time])?; + + // Stored string should preserve timezone offset + let s: String = db.one_column("SELECT t FROM foo")?; + assert!(s.ends_with("+04:00")); + + let v: DateTime<FixedOffset> = db.one_column("SELECT t FROM foo")?; + assert_eq!(time.offset(), v.offset()); + assert_eq!(time, v); + Ok(()) + } + + #[test] + fn test_sqlite_functions() -> Result<()> { + let db = checked_memory_handle()?; + let result: Result<NaiveTime> = db.one_column("SELECT CURRENT_TIME"); + result.unwrap(); + let result: Result<NaiveDate> = db.one_column("SELECT CURRENT_DATE"); + result.unwrap(); + let result: Result<NaiveDateTime> = db.one_column("SELECT CURRENT_TIMESTAMP"); + result.unwrap(); + let result: Result<DateTime<Utc>> = db.one_column("SELECT CURRENT_TIMESTAMP"); + result.unwrap(); + Ok(()) + } + + #[test] + fn test_naive_date_time_param() -> Result<()> { + let db = checked_memory_handle()?; + let result: Result<bool> = db.query_row("SELECT 1 WHERE ?1 BETWEEN datetime('now', '-1 minute') AND datetime('now', '+1 minute')", [Utc::now().naive_utc()], |r| r.get(0)); + result.unwrap(); + Ok(()) + } + + #[test] + fn test_date_time_param() -> Result<()> { + let db = checked_memory_handle()?; + let result: Result<bool> = db.query_row("SELECT 1 WHERE ?1 BETWEEN datetime('now', '-1 minute') AND datetime('now', '+1 minute')", [Utc::now()], |r| r.get(0)); + result.unwrap(); + Ok(()) + } + + #[test] + fn test_lenient_parse_timezone() { + DateTime::<Utc>::column_result(ValueRef::Text(b"1970-01-01T00:00:00Z")).unwrap(); + DateTime::<Utc>::column_result(ValueRef::Text(b"1970-01-01T00:00:00+00")).unwrap(); + } +} diff --git a/third_party/rust/rusqlite/src/types/from_sql.rs b/third_party/rust/rusqlite/src/types/from_sql.rs new file mode 100644 index 0000000000..acbb99ef30 --- /dev/null +++ b/third_party/rust/rusqlite/src/types/from_sql.rs @@ -0,0 +1,366 @@ +use super::{Value, ValueRef}; +use std::convert::TryInto; +use std::error::Error; +use std::fmt; + +/// Enum listing possible errors from [`FromSql`] trait. +#[derive(Debug)] +#[non_exhaustive] +pub enum FromSqlError { + /// Error when an SQLite value is requested, but the type of the result + /// cannot be converted to the requested Rust type. + InvalidType, + + /// Error when the i64 value returned by SQLite cannot be stored into the + /// requested type. + OutOfRange(i64), + + /// Error when the blob result returned by SQLite cannot be stored into the + /// requested type due to a size mismatch. + InvalidBlobSize { + /// The expected size of the blob. + expected_size: usize, + /// The actual size of the blob that was returned. + blob_size: usize, + }, + + /// An error case available for implementors of the [`FromSql`] trait. + Other(Box<dyn Error + Send + Sync + 'static>), +} + +impl PartialEq for FromSqlError { + fn eq(&self, other: &FromSqlError) -> bool { + match (self, other) { + (FromSqlError::InvalidType, FromSqlError::InvalidType) => true, + (FromSqlError::OutOfRange(n1), FromSqlError::OutOfRange(n2)) => n1 == n2, + ( + FromSqlError::InvalidBlobSize { + expected_size: es1, + blob_size: bs1, + }, + FromSqlError::InvalidBlobSize { + expected_size: es2, + blob_size: bs2, + }, + ) => es1 == es2 && bs1 == bs2, + (..) => false, + } + } +} + +impl fmt::Display for FromSqlError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + FromSqlError::InvalidType => write!(f, "Invalid type"), + FromSqlError::OutOfRange(i) => write!(f, "Value {i} out of range"), + FromSqlError::InvalidBlobSize { + expected_size, + blob_size, + } => { + write!( + f, + "Cannot read {expected_size} byte value out of {blob_size} byte blob" + ) + } + FromSqlError::Other(ref err) => err.fmt(f), + } + } +} + +impl Error for FromSqlError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + if let FromSqlError::Other(ref err) = self { + Some(&**err) + } else { + None + } + } +} + +/// Result type for implementors of the [`FromSql`] trait. +pub type FromSqlResult<T> = Result<T, FromSqlError>; + +/// A trait for types that can be created from a SQLite value. +pub trait FromSql: Sized { + /// Converts SQLite value into Rust value. + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self>; +} + +macro_rules! from_sql_integral( + ($t:ident) => ( + impl FromSql for $t { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + let i = i64::column_result(value)?; + i.try_into().map_err(|_| FromSqlError::OutOfRange(i)) + } + } + ); + (non_zero $nz:ty, $z:ty) => ( + impl FromSql for $nz { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + let i = <$z>::column_result(value)?; + <$nz>::new(i).ok_or(FromSqlError::OutOfRange(0)) + } + } + ) +); + +from_sql_integral!(i8); +from_sql_integral!(i16); +from_sql_integral!(i32); +// from_sql_integral!(i64); // Not needed because the native type is i64. +from_sql_integral!(isize); +from_sql_integral!(u8); +from_sql_integral!(u16); +from_sql_integral!(u32); +from_sql_integral!(u64); +from_sql_integral!(usize); + +from_sql_integral!(non_zero std::num::NonZeroIsize, isize); +from_sql_integral!(non_zero std::num::NonZeroI8, i8); +from_sql_integral!(non_zero std::num::NonZeroI16, i16); +from_sql_integral!(non_zero std::num::NonZeroI32, i32); +from_sql_integral!(non_zero std::num::NonZeroI64, i64); +#[cfg(feature = "i128_blob")] +#[cfg_attr(docsrs, doc(cfg(feature = "i128_blob")))] +from_sql_integral!(non_zero std::num::NonZeroI128, i128); + +from_sql_integral!(non_zero std::num::NonZeroUsize, usize); +from_sql_integral!(non_zero std::num::NonZeroU8, u8); +from_sql_integral!(non_zero std::num::NonZeroU16, u16); +from_sql_integral!(non_zero std::num::NonZeroU32, u32); +from_sql_integral!(non_zero std::num::NonZeroU64, u64); +// std::num::NonZeroU128 is not supported since u128 isn't either + +impl FromSql for i64 { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_i64() + } +} + +impl FromSql for f32 { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + match value { + ValueRef::Integer(i) => Ok(i as f32), + ValueRef::Real(f) => Ok(f as f32), + _ => Err(FromSqlError::InvalidType), + } + } +} + +impl FromSql for f64 { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + match value { + ValueRef::Integer(i) => Ok(i as f64), + ValueRef::Real(f) => Ok(f), + _ => Err(FromSqlError::InvalidType), + } + } +} + +impl FromSql for bool { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + i64::column_result(value).map(|i| i != 0) + } +} + +impl FromSql for String { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().map(ToString::to_string) + } +} + +impl FromSql for Box<str> { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().map(Into::into) + } +} + +impl FromSql for std::rc::Rc<str> { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().map(Into::into) + } +} + +impl FromSql for std::sync::Arc<str> { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().map(Into::into) + } +} + +impl FromSql for Vec<u8> { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_blob().map(<[u8]>::to_vec) + } +} + +impl<const N: usize> FromSql for [u8; N] { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + let slice = value.as_blob()?; + slice.try_into().map_err(|_| FromSqlError::InvalidBlobSize { + expected_size: N, + blob_size: slice.len(), + }) + } +} + +#[cfg(feature = "i128_blob")] +#[cfg_attr(docsrs, doc(cfg(feature = "i128_blob")))] +impl FromSql for i128 { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + let bytes = <[u8; 16]>::column_result(value)?; + Ok(i128::from_be_bytes(bytes) ^ (1_i128 << 127)) + } +} + +#[cfg(feature = "uuid")] +#[cfg_attr(docsrs, doc(cfg(feature = "uuid")))] +impl FromSql for uuid::Uuid { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + let bytes = <[u8; 16]>::column_result(value)?; + Ok(uuid::Uuid::from_u128(u128::from_be_bytes(bytes))) + } +} + +impl<T: FromSql> FromSql for Option<T> { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + match value { + ValueRef::Null => Ok(None), + _ => FromSql::column_result(value).map(Some), + } + } +} + +impl FromSql for Value { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + Ok(value.into()) + } +} + +#[cfg(test)] +mod test { + use super::FromSql; + use crate::{Connection, Error, Result}; + + #[test] + fn test_integral_ranges() -> Result<()> { + let db = Connection::open_in_memory()?; + + fn check_ranges<T>(db: &Connection, out_of_range: &[i64], in_range: &[i64]) + where + T: Into<i64> + FromSql + std::fmt::Debug, + { + for n in out_of_range { + let err = db + .query_row("SELECT ?1", [n], |r| r.get::<_, T>(0)) + .unwrap_err(); + match err { + Error::IntegralValueOutOfRange(_, value) => assert_eq!(*n, value), + _ => panic!("unexpected error: {err}"), + } + } + for n in in_range { + assert_eq!( + *n, + db.query_row("SELECT ?1", [n], |r| r.get::<_, T>(0)) + .unwrap() + .into() + ); + } + } + + check_ranges::<i8>(&db, &[-129, 128], &[-128, 0, 1, 127]); + check_ranges::<i16>(&db, &[-32769, 32768], &[-32768, -1, 0, 1, 32767]); + check_ranges::<i32>( + &db, + &[-2_147_483_649, 2_147_483_648], + &[-2_147_483_648, -1, 0, 1, 2_147_483_647], + ); + check_ranges::<u8>(&db, &[-2, -1, 256], &[0, 1, 255]); + check_ranges::<u16>(&db, &[-2, -1, 65536], &[0, 1, 65535]); + check_ranges::<u32>(&db, &[-2, -1, 4_294_967_296], &[0, 1, 4_294_967_295]); + Ok(()) + } + + #[test] + fn test_nonzero_ranges() -> Result<()> { + let db = Connection::open_in_memory()?; + + macro_rules! check_ranges { + ($nz:ty, $out_of_range:expr, $in_range:expr) => { + for &n in $out_of_range { + assert_eq!( + db.query_row("SELECT ?1", [n], |r| r.get::<_, $nz>(0)), + Err(Error::IntegralValueOutOfRange(0, n)), + "{}", + std::any::type_name::<$nz>() + ); + } + for &n in $in_range { + let non_zero = <$nz>::new(n).unwrap(); + assert_eq!( + Ok(non_zero), + db.query_row("SELECT ?1", [non_zero], |r| r.get::<_, $nz>(0)) + ); + } + }; + } + + check_ranges!(std::num::NonZeroI8, &[0, -129, 128], &[-128, 1, 127]); + check_ranges!( + std::num::NonZeroI16, + &[0, -32769, 32768], + &[-32768, -1, 1, 32767] + ); + check_ranges!( + std::num::NonZeroI32, + &[0, -2_147_483_649, 2_147_483_648], + &[-2_147_483_648, -1, 1, 2_147_483_647] + ); + check_ranges!( + std::num::NonZeroI64, + &[0], + &[-2_147_483_648, -1, 1, 2_147_483_647, i64::MAX, i64::MIN] + ); + check_ranges!( + std::num::NonZeroIsize, + &[0], + &[-2_147_483_648, -1, 1, 2_147_483_647] + ); + check_ranges!(std::num::NonZeroU8, &[0, -2, -1, 256], &[1, 255]); + check_ranges!(std::num::NonZeroU16, &[0, -2, -1, 65536], &[1, 65535]); + check_ranges!( + std::num::NonZeroU32, + &[0, -2, -1, 4_294_967_296], + &[1, 4_294_967_295] + ); + check_ranges!( + std::num::NonZeroU64, + &[0, -2, -1, -4_294_967_296], + &[1, 4_294_967_295, i64::MAX as u64] + ); + check_ranges!( + std::num::NonZeroUsize, + &[0, -2, -1, -4_294_967_296], + &[1, 4_294_967_295] + ); + + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/types/mod.rs b/third_party/rust/rusqlite/src/types/mod.rs new file mode 100644 index 0000000000..4dbc19dd95 --- /dev/null +++ b/third_party/rust/rusqlite/src/types/mod.rs @@ -0,0 +1,447 @@ +//! Traits dealing with SQLite data types. +//! +//! SQLite uses a [dynamic type system](https://www.sqlite.org/datatype3.html). Implementations of +//! the [`ToSql`] and [`FromSql`] traits are provided for the basic types that +//! SQLite provides methods for: +//! +//! * Strings (`String` and `&str`) +//! * Blobs (`Vec<u8>` and `&[u8]`) +//! * Numbers +//! +//! The number situation is a little complicated due to the fact that all +//! numbers in SQLite are stored as `INTEGER` (`i64`) or `REAL` (`f64`). +//! +//! [`ToSql`] and [`FromSql`] are implemented for all primitive number types. +//! [`FromSql`] has different behaviour depending on the SQL and Rust types, and +//! the value. +//! +//! * `INTEGER` to integer: returns an +//! [`Error::IntegralValueOutOfRange`](crate::Error::IntegralValueOutOfRange) +//! error if the value does not fit in the Rust type. +//! * `REAL` to integer: always returns an +//! [`Error::InvalidColumnType`](crate::Error::InvalidColumnType) error. +//! * `INTEGER` to float: casts using `as` operator. Never fails. +//! * `REAL` to float: casts using `as` operator. Never fails. +//! +//! [`ToSql`] always succeeds except when storing a `u64` or `usize` value that +//! cannot fit in an `INTEGER` (`i64`). Also note that SQLite ignores column +//! types, so if you store an `i64` in a column with type `REAL` it will be +//! stored as an `INTEGER`, not a `REAL` (unless the column is part of a +//! [STRICT table](https://www.sqlite.org/stricttables.html)). +//! +//! If the `time` feature is enabled, implementations are +//! provided for `time::OffsetDateTime` that use the RFC 3339 date/time format, +//! `"%Y-%m-%dT%H:%M:%S.%fZ"`, to store time values as strings. These values +//! can be parsed by SQLite's builtin +//! [datetime](https://www.sqlite.org/lang_datefunc.html) functions. If you +//! want different storage for datetimes, you can use a newtype. +#![cfg_attr( + feature = "time", + doc = r##" +For example, to store datetimes as `i64`s counting the number of seconds since +the Unix epoch: + +``` +use rusqlite::types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef}; +use rusqlite::Result; + +pub struct DateTimeSql(pub time::OffsetDateTime); + +impl FromSql for DateTimeSql { + fn column_result(value: ValueRef) -> FromSqlResult<Self> { + i64::column_result(value).and_then(|as_i64| { + time::OffsetDateTime::from_unix_timestamp(as_i64) + .map(|odt| DateTimeSql(odt)) + .map_err(|err| FromSqlError::Other(Box::new(err))) + }) + } +} + +impl ToSql for DateTimeSql { + fn to_sql(&self) -> Result<ToSqlOutput> { + Ok(self.0.unix_timestamp().into()) + } +} +``` + +"## +)] +//! [`ToSql`] and [`FromSql`] are also implemented for `Option<T>` where `T` +//! implements [`ToSql`] or [`FromSql`] for the cases where you want to know if +//! a value was NULL (which gets translated to `None`). + +pub use self::from_sql::{FromSql, FromSqlError, FromSqlResult}; +pub use self::to_sql::{ToSql, ToSqlOutput}; +pub use self::value::Value; +pub use self::value_ref::ValueRef; + +use std::fmt; + +#[cfg(feature = "chrono")] +#[cfg_attr(docsrs, doc(cfg(feature = "chrono")))] +mod chrono; +mod from_sql; +#[cfg(feature = "serde_json")] +#[cfg_attr(docsrs, doc(cfg(feature = "serde_json")))] +mod serde_json; +#[cfg(feature = "time")] +#[cfg_attr(docsrs, doc(cfg(feature = "time")))] +mod time; +mod to_sql; +#[cfg(feature = "url")] +#[cfg_attr(docsrs, doc(cfg(feature = "url")))] +mod url; +mod value; +mod value_ref; + +/// Empty struct that can be used to fill in a query parameter as `NULL`. +/// +/// ## Example +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result}; +/// # use rusqlite::types::{Null}; +/// +/// fn insert_null(conn: &Connection) -> Result<usize> { +/// conn.execute("INSERT INTO people (name) VALUES (?1)", [Null]) +/// } +/// ``` +#[derive(Copy, Clone)] +pub struct Null; + +/// SQLite data types. +/// See [Fundamental Datatypes](https://sqlite.org/c3ref/c_blob.html). +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Type { + /// NULL + Null, + /// 64-bit signed integer + Integer, + /// 64-bit IEEE floating point number + Real, + /// String + Text, + /// BLOB + Blob, +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Type::Null => f.pad("Null"), + Type::Integer => f.pad("Integer"), + Type::Real => f.pad("Real"), + Type::Text => f.pad("Text"), + Type::Blob => f.pad("Blob"), + } + } +} + +#[cfg(test)] +mod test { + use super::Value; + use crate::{params, Connection, Error, Result, Statement}; + use std::os::raw::{c_double, c_int}; + + fn checked_memory_handle() -> Result<Connection> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (b BLOB, t TEXT, i INTEGER, f FLOAT, n)")?; + Ok(db) + } + + #[test] + fn test_blob() -> Result<()> { + let db = checked_memory_handle()?; + + let v1234 = vec![1u8, 2, 3, 4]; + db.execute("INSERT INTO foo(b) VALUES (?1)", [&v1234])?; + + let v: Vec<u8> = db.one_column("SELECT b FROM foo")?; + assert_eq!(v, v1234); + Ok(()) + } + + #[test] + fn test_empty_blob() -> Result<()> { + let db = checked_memory_handle()?; + + let empty = vec![]; + db.execute("INSERT INTO foo(b) VALUES (?1)", [&empty])?; + + let v: Vec<u8> = db.one_column("SELECT b FROM foo")?; + assert_eq!(v, empty); + Ok(()) + } + + #[test] + fn test_str() -> Result<()> { + let db = checked_memory_handle()?; + + let s = "hello, world!"; + db.execute("INSERT INTO foo(t) VALUES (?1)", [&s])?; + + let from: String = db.one_column("SELECT t FROM foo")?; + assert_eq!(from, s); + Ok(()) + } + + #[test] + fn test_string() -> Result<()> { + let db = checked_memory_handle()?; + + let s = "hello, world!"; + db.execute("INSERT INTO foo(t) VALUES (?1)", [s.to_owned()])?; + + let from: String = db.one_column("SELECT t FROM foo")?; + assert_eq!(from, s); + Ok(()) + } + + #[test] + fn test_value() -> Result<()> { + let db = checked_memory_handle()?; + + db.execute("INSERT INTO foo(i) VALUES (?1)", [Value::Integer(10)])?; + + assert_eq!(10i64, db.one_column::<i64>("SELECT i FROM foo")?); + Ok(()) + } + + #[test] + fn test_option() -> Result<()> { + let db = checked_memory_handle()?; + + let s = "hello, world!"; + let b = Some(vec![1u8, 2, 3, 4]); + + db.execute("INSERT INTO foo(t) VALUES (?1)", [Some(s)])?; + db.execute("INSERT INTO foo(b) VALUES (?1)", [&b])?; + + let mut stmt = db.prepare("SELECT t, b FROM foo ORDER BY ROWID ASC")?; + let mut rows = stmt.query([])?; + + { + let row1 = rows.next()?.unwrap(); + let s1: Option<String> = row1.get_unwrap(0); + let b1: Option<Vec<u8>> = row1.get_unwrap(1); + assert_eq!(s, s1.unwrap()); + assert!(b1.is_none()); + } + + { + let row2 = rows.next()?.unwrap(); + let s2: Option<String> = row2.get_unwrap(0); + let b2: Option<Vec<u8>> = row2.get_unwrap(1); + assert!(s2.is_none()); + assert_eq!(b, b2); + } + Ok(()) + } + + #[test] + #[allow(clippy::cognitive_complexity)] + fn test_mismatched_types() -> Result<()> { + fn is_invalid_column_type(err: Error) -> bool { + matches!(err, Error::InvalidColumnType(..)) + } + + let db = checked_memory_handle()?; + + db.execute( + "INSERT INTO foo(b, t, i, f) VALUES (X'0102', 'text', 1, 1.5)", + [], + )?; + + let mut stmt = db.prepare("SELECT b, t, i, f, n FROM foo")?; + let mut rows = stmt.query([])?; + + let row = rows.next()?.unwrap(); + + // check the correct types come back as expected + assert_eq!(vec![1, 2], row.get::<_, Vec<u8>>(0)?); + assert_eq!("text", row.get::<_, String>(1)?); + assert_eq!(1, row.get::<_, c_int>(2)?); + assert!((1.5 - row.get::<_, c_double>(3)?).abs() < f64::EPSILON); + assert_eq!(row.get::<_, Option<c_int>>(4)?, None); + assert_eq!(row.get::<_, Option<c_double>>(4)?, None); + assert_eq!(row.get::<_, Option<String>>(4)?, None); + + // check some invalid types + + // 0 is actually a blob (Vec<u8>) + assert!(is_invalid_column_type(row.get::<_, c_int>(0).unwrap_err())); + assert!(is_invalid_column_type(row.get::<_, c_int>(0).unwrap_err())); + assert!(is_invalid_column_type(row.get::<_, i64>(0).err().unwrap())); + assert!(is_invalid_column_type( + row.get::<_, c_double>(0).unwrap_err() + )); + assert!(is_invalid_column_type(row.get::<_, String>(0).unwrap_err())); + #[cfg(feature = "time")] + assert!(is_invalid_column_type( + row.get::<_, time::OffsetDateTime>(0).unwrap_err() + )); + assert!(is_invalid_column_type( + row.get::<_, Option<c_int>>(0).unwrap_err() + )); + + // 1 is actually a text (String) + assert!(is_invalid_column_type(row.get::<_, c_int>(1).unwrap_err())); + assert!(is_invalid_column_type(row.get::<_, i64>(1).err().unwrap())); + assert!(is_invalid_column_type( + row.get::<_, c_double>(1).unwrap_err() + )); + assert!(is_invalid_column_type( + row.get::<_, Vec<u8>>(1).unwrap_err() + )); + assert!(is_invalid_column_type( + row.get::<_, Option<c_int>>(1).unwrap_err() + )); + + // 2 is actually an integer + assert!(is_invalid_column_type(row.get::<_, String>(2).unwrap_err())); + assert!(is_invalid_column_type( + row.get::<_, Vec<u8>>(2).unwrap_err() + )); + assert!(is_invalid_column_type( + row.get::<_, Option<String>>(2).unwrap_err() + )); + + // 3 is actually a float (c_double) + assert!(is_invalid_column_type(row.get::<_, c_int>(3).unwrap_err())); + assert!(is_invalid_column_type(row.get::<_, i64>(3).err().unwrap())); + assert!(is_invalid_column_type(row.get::<_, String>(3).unwrap_err())); + assert!(is_invalid_column_type( + row.get::<_, Vec<u8>>(3).unwrap_err() + )); + assert!(is_invalid_column_type( + row.get::<_, Option<c_int>>(3).unwrap_err() + )); + + // 4 is actually NULL + assert!(is_invalid_column_type(row.get::<_, c_int>(4).unwrap_err())); + assert!(is_invalid_column_type(row.get::<_, i64>(4).err().unwrap())); + assert!(is_invalid_column_type( + row.get::<_, c_double>(4).unwrap_err() + )); + assert!(is_invalid_column_type(row.get::<_, String>(4).unwrap_err())); + assert!(is_invalid_column_type( + row.get::<_, Vec<u8>>(4).unwrap_err() + )); + #[cfg(feature = "time")] + assert!(is_invalid_column_type( + row.get::<_, time::OffsetDateTime>(4).unwrap_err() + )); + Ok(()) + } + + #[test] + fn test_dynamic_type() -> Result<()> { + use super::Value; + let db = checked_memory_handle()?; + + db.execute( + "INSERT INTO foo(b, t, i, f) VALUES (X'0102', 'text', 1, 1.5)", + [], + )?; + + let mut stmt = db.prepare("SELECT b, t, i, f, n FROM foo")?; + let mut rows = stmt.query([])?; + + let row = rows.next()?.unwrap(); + assert_eq!(Value::Blob(vec![1, 2]), row.get::<_, Value>(0)?); + assert_eq!(Value::Text(String::from("text")), row.get::<_, Value>(1)?); + assert_eq!(Value::Integer(1), row.get::<_, Value>(2)?); + match row.get::<_, Value>(3)? { + Value::Real(val) => assert!((1.5 - val).abs() < f64::EPSILON), + x => panic!("Invalid Value {x:?}"), + } + assert_eq!(Value::Null, row.get::<_, Value>(4)?); + Ok(()) + } + + macro_rules! test_conversion { + ($db_etc:ident, $insert_value:expr, $get_type:ty,expect $expected_value:expr) => { + $db_etc.insert_statement.execute(params![$insert_value])?; + let res = $db_etc + .query_statement + .query_row([], |row| row.get::<_, $get_type>(0)); + assert_eq!(res?, $expected_value); + $db_etc.delete_statement.execute([])?; + }; + ($db_etc:ident, $insert_value:expr, $get_type:ty,expect_from_sql_error) => { + $db_etc.insert_statement.execute(params![$insert_value])?; + let res = $db_etc + .query_statement + .query_row([], |row| row.get::<_, $get_type>(0)); + res.unwrap_err(); + $db_etc.delete_statement.execute([])?; + }; + ($db_etc:ident, $insert_value:expr, $get_type:ty,expect_to_sql_error) => { + $db_etc + .insert_statement + .execute(params![$insert_value]) + .unwrap_err(); + }; + } + + #[test] + fn test_numeric_conversions() -> Result<()> { + #![allow(clippy::float_cmp)] + + // Test what happens when we store an f32 and retrieve an i32 etc. + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (x)")?; + + // SQLite actually ignores the column types, so we just need to test + // different numeric values. + + struct DbEtc<'conn> { + insert_statement: Statement<'conn>, + query_statement: Statement<'conn>, + delete_statement: Statement<'conn>, + } + + let mut db_etc = DbEtc { + insert_statement: db.prepare("INSERT INTO foo VALUES (?1)")?, + query_statement: db.prepare("SELECT x FROM foo")?, + delete_statement: db.prepare("DELETE FROM foo")?, + }; + + // Basic non-converting test. + test_conversion!(db_etc, 0u8, u8, expect 0u8); + + // In-range integral conversions. + test_conversion!(db_etc, 100u8, i8, expect 100i8); + test_conversion!(db_etc, 200u8, u8, expect 200u8); + test_conversion!(db_etc, 100u16, i8, expect 100i8); + test_conversion!(db_etc, 200u16, u8, expect 200u8); + test_conversion!(db_etc, u32::MAX, u64, expect u32::MAX as u64); + test_conversion!(db_etc, i64::MIN, i64, expect i64::MIN); + test_conversion!(db_etc, i64::MAX, i64, expect i64::MAX); + test_conversion!(db_etc, i64::MAX, u64, expect i64::MAX as u64); + test_conversion!(db_etc, 100usize, usize, expect 100usize); + test_conversion!(db_etc, 100u64, u64, expect 100u64); + test_conversion!(db_etc, i64::MAX as u64, u64, expect i64::MAX as u64); + + // Out-of-range integral conversions. + test_conversion!(db_etc, 200u8, i8, expect_from_sql_error); + test_conversion!(db_etc, 400u16, i8, expect_from_sql_error); + test_conversion!(db_etc, 400u16, u8, expect_from_sql_error); + test_conversion!(db_etc, -1i8, u8, expect_from_sql_error); + test_conversion!(db_etc, i64::MIN, u64, expect_from_sql_error); + test_conversion!(db_etc, u64::MAX, i64, expect_to_sql_error); + test_conversion!(db_etc, u64::MAX, u64, expect_to_sql_error); + test_conversion!(db_etc, i64::MAX as u64 + 1, u64, expect_to_sql_error); + + // FromSql integer to float, always works. + test_conversion!(db_etc, i64::MIN, f32, expect i64::MIN as f32); + test_conversion!(db_etc, i64::MAX, f32, expect i64::MAX as f32); + test_conversion!(db_etc, i64::MIN, f64, expect i64::MIN as f64); + test_conversion!(db_etc, i64::MAX, f64, expect i64::MAX as f64); + + // FromSql float to int conversion, never works even if the actual value + // is an integer. + test_conversion!(db_etc, 0f64, i64, expect_from_sql_error); + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/types/serde_json.rs b/third_party/rust/rusqlite/src/types/serde_json.rs new file mode 100644 index 0000000000..6e38ba36ba --- /dev/null +++ b/third_party/rust/rusqlite/src/types/serde_json.rs @@ -0,0 +1,135 @@ +//! [`ToSql`] and [`FromSql`] implementation for JSON `Value`. + +use serde_json::{Number, Value}; + +use crate::types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef}; +use crate::{Error, Result}; + +/// Serialize JSON `Value` to text: +/// +/// +/// | JSON | SQLite | +/// |----------|---------| +/// | Null | NULL | +/// | Bool | 'true' / 'false' | +/// | Number | INT or REAL except u64 | +/// | _ | TEXT | +impl ToSql for Value { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + match self { + Value::Null => Ok(ToSqlOutput::Borrowed(ValueRef::Null)), + Value::Number(n) if n.is_i64() => Ok(ToSqlOutput::from(n.as_i64().unwrap())), + Value::Number(n) if n.is_f64() => Ok(ToSqlOutput::from(n.as_f64().unwrap())), + _ => serde_json::to_string(self) + .map(ToSqlOutput::from) + .map_err(|err| Error::ToSqlConversionFailure(err.into())), + } + } +} + +/// Deserialize SQLite value to JSON `Value`: +/// +/// | SQLite | JSON | +/// |----------|---------| +/// | NULL | Null | +/// | 'null' | Null | +/// | 'true' | Bool | +/// | 1 | Number | +/// | 0.1 | Number | +/// | '"text"' | String | +/// | 'text' | _Error_ | +/// | '[0, 1]' | Array | +/// | '{"x": 1}' | Object | +impl FromSql for Value { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + match value { + ValueRef::Text(s) => serde_json::from_slice(s), // KO for b"text" + ValueRef::Blob(b) => serde_json::from_slice(b), + ValueRef::Integer(i) => Ok(Value::Number(Number::from(i))), + ValueRef::Real(f) => { + match Number::from_f64(f) { + Some(n) => Ok(Value::Number(n)), + _ => return Err(FromSqlError::InvalidType), // FIXME + } + } + ValueRef::Null => Ok(Value::Null), + } + .map_err(|err| FromSqlError::Other(Box::new(err))) + } +} + +#[cfg(test)] +mod test { + use crate::types::ToSql; + use crate::{Connection, Result}; + use serde_json::{Number, Value}; + + fn checked_memory_handle() -> Result<Connection> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (t TEXT, b BLOB)")?; + Ok(db) + } + + #[test] + fn test_json_value() -> Result<()> { + let db = checked_memory_handle()?; + + let json = r#"{"foo": 13, "bar": "baz"}"#; + let data: Value = serde_json::from_str(json).unwrap(); + db.execute( + "INSERT INTO foo (t, b) VALUES (?1, ?2)", + [&data as &dyn ToSql, &json.as_bytes()], + )?; + + let t: Value = db.one_column("SELECT t FROM foo")?; + assert_eq!(data, t); + let b: Value = db.one_column("SELECT b FROM foo")?; + assert_eq!(data, b); + Ok(()) + } + + #[test] + fn test_to_sql() -> Result<()> { + let db = Connection::open_in_memory()?; + + let v: Option<String> = db.query_row("SELECT ?", [Value::Null], |r| r.get(0))?; + assert_eq!(None, v); + let v: String = db.query_row("SELECT ?", [Value::Bool(true)], |r| r.get(0))?; + assert_eq!("true", v); + let v: i64 = db.query_row("SELECT ?", [Value::Number(Number::from(1))], |r| r.get(0))?; + assert_eq!(1, v); + let v: f64 = db.query_row( + "SELECT ?", + [Value::Number(Number::from_f64(0.1).unwrap())], + |r| r.get(0), + )?; + assert_eq!(0.1, v); + let v: String = + db.query_row("SELECT ?", [Value::String("text".to_owned())], |r| r.get(0))?; + assert_eq!("\"text\"", v); + Ok(()) + } + + #[test] + fn test_from_sql() -> Result<()> { + let db = Connection::open_in_memory()?; + + let v: Value = db.one_column("SELECT NULL")?; + assert_eq!(Value::Null, v); + let v: Value = db.one_column("SELECT 'null'")?; + assert_eq!(Value::Null, v); + let v: Value = db.one_column("SELECT 'true'")?; + assert_eq!(Value::Bool(true), v); + let v: Value = db.one_column("SELECT 1")?; + assert_eq!(Value::Number(Number::from(1)), v); + let v: Value = db.one_column("SELECT 0.1")?; + assert_eq!(Value::Number(Number::from_f64(0.1).unwrap()), v); + let v: Value = db.one_column("SELECT '\"text\"'")?; + assert_eq!(Value::String("text".to_owned()), v); + let v: Result<Value> = db.one_column("SELECT 'text'"); + assert!(v.is_err()); + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/types/time.rs b/third_party/rust/rusqlite/src/types/time.rs new file mode 100644 index 0000000000..265d6c6bd9 --- /dev/null +++ b/third_party/rust/rusqlite/src/types/time.rs @@ -0,0 +1,426 @@ +//! Convert formats 1-10 in [Time Values](https://sqlite.org/lang_datefunc.html#time_values) to time types. +//! [`ToSql`] and [`FromSql`] implementation for [`time::OffsetDateTime`]. +//! [`ToSql`] and [`FromSql`] implementation for [`time::PrimitiveDateTime`]. +//! [`ToSql`] and [`FromSql`] implementation for [`time::Date`]. +//! [`ToSql`] and [`FromSql`] implementation for [`time::Time`]. +//! Time Strings in: +//! - Format 2: "YYYY-MM-DD HH:MM" +//! - Format 5: "YYYY-MM-DDTHH:MM" +//! - Format 8: "HH:MM" +//! without an explicit second value will assume 0 seconds. +//! Time String that contain an optional timezone without an explicit date are unsupported. +//! All other assumptions described in [Time Values](https://sqlite.org/lang_datefunc.html#time_values) section are unsupported. + +use crate::types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef}; +use crate::{Error, Result}; +use time::format_description::FormatItem; +use time::macros::format_description; +use time::{Date, OffsetDateTime, PrimitiveDateTime, Time}; + +const OFFSET_DATE_TIME_ENCODING: &[FormatItem<'_>] = format_description!( + version = 2, + "[year]-[month]-[day] [hour]:[minute]:[second].[subsecond][offset_hour sign:mandatory]:[offset_minute]" +); +const PRIMITIVE_DATE_TIME_ENCODING: &[FormatItem<'_>] = format_description!( + version = 2, + "[year]-[month]-[day] [hour]:[minute]:[second].[subsecond]" +); +const TIME_ENCODING: &[FormatItem<'_>] = + format_description!(version = 2, "[hour]:[minute]:[second].[subsecond]"); + +const DATE_FORMAT: &[FormatItem<'_>] = format_description!(version = 2, "[year]-[month]-[day]"); +const TIME_FORMAT: &[FormatItem<'_>] = format_description!( + version = 2, + "[hour]:[minute][optional [:[second][optional [.[subsecond]]]]]" +); +const PRIMITIVE_DATE_TIME_FORMAT: &[FormatItem<'_>] = format_description!( + version = 2, + "[year]-[month]-[day][first [ ][T]][hour]:[minute][optional [:[second][optional [.[subsecond]]]]]" +); +const UTC_DATE_TIME_FORMAT: &[FormatItem<'_>] = format_description!( + version = 2, + "[year]-[month]-[day][first [ ][T]][hour]:[minute][optional [:[second][optional [.[subsecond]]]]][optional [Z]]" +); +const OFFSET_DATE_TIME_FORMAT: &[FormatItem<'_>] = format_description!( + version = 2, + "[year]-[month]-[day][first [ ][T]][hour]:[minute][optional [:[second][optional [.[subsecond]]]]][offset_hour sign:mandatory]:[offset_minute]" +); +const LEGACY_DATE_TIME_FORMAT: &[FormatItem<'_>] = format_description!( + version = 2, + "[year]-[month]-[day] [hour]:[minute]:[second]:[subsecond] [offset_hour sign:mandatory]:[offset_minute]" +); + +/// OffsetDatetime => RFC3339 format ("YYYY-MM-DD HH:MM:SS.SSS[+-]HH:MM") +impl ToSql for OffsetDateTime { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + let time_string = self + .format(&OFFSET_DATE_TIME_ENCODING) + .map_err(|err| Error::ToSqlConversionFailure(err.into()))?; + Ok(ToSqlOutput::from(time_string)) + } +} + +// Supports parsing formats 2-7 from https://www.sqlite.org/lang_datefunc.html +// Formats 2-7 without a timezone assumes UTC +impl FromSql for OffsetDateTime { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().and_then(|s| { + if let Some(b' ') = s.as_bytes().get(23) { + // legacy + return OffsetDateTime::parse(s, &LEGACY_DATE_TIME_FORMAT) + .map_err(|err| FromSqlError::Other(Box::new(err))); + } + if s[8..].contains('+') || s[8..].contains('-') { + // Formats 2-7 with timezone + return OffsetDateTime::parse(s, &OFFSET_DATE_TIME_FORMAT) + .map_err(|err| FromSqlError::Other(Box::new(err))); + } + // Formats 2-7 without timezone + PrimitiveDateTime::parse(s, &UTC_DATE_TIME_FORMAT) + .map(|p| p.assume_utc()) + .map_err(|err| FromSqlError::Other(Box::new(err))) + }) + } +} + +/// ISO 8601 calendar date without timezone => "YYYY-MM-DD" +impl ToSql for Date { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + let date_str = self + .format(&DATE_FORMAT) + .map_err(|err| Error::ToSqlConversionFailure(err.into()))?; + Ok(ToSqlOutput::from(date_str)) + } +} + +/// "YYYY-MM-DD" => ISO 8601 calendar date without timezone. +impl FromSql for Date { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().and_then(|s| { + Date::parse(s, &DATE_FORMAT).map_err(|err| FromSqlError::Other(err.into())) + }) + } +} + +/// ISO 8601 time without timezone => "HH:MM:SS.SSS" +impl ToSql for Time { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + let time_str = self + .format(&TIME_ENCODING) + .map_err(|err| Error::ToSqlConversionFailure(err.into()))?; + Ok(ToSqlOutput::from(time_str)) + } +} + +/// "HH:MM"/"HH:MM:SS"/"HH:MM:SS.SSS" => ISO 8601 time without timezone. +impl FromSql for Time { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().and_then(|s| { + Time::parse(s, &TIME_FORMAT).map_err(|err| FromSqlError::Other(err.into())) + }) + } +} + +/// ISO 8601 combined date and time without timezone => "YYYY-MM-DD HH:MM:SS.SSS" +impl ToSql for PrimitiveDateTime { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + let date_time_str = self + .format(&PRIMITIVE_DATE_TIME_ENCODING) + .map_err(|err| Error::ToSqlConversionFailure(err.into()))?; + Ok(ToSqlOutput::from(date_time_str)) + } +} + +/// YYYY-MM-DD HH:MM +/// YYYY-MM-DDTHH:MM +/// YYYY-MM-DD HH:MM:SS +/// YYYY-MM-DDTHH:MM:SS +/// YYYY-MM-DD HH:MM:SS.SSS +/// YYYY-MM-DDTHH:MM:SS.SSS +/// => ISO 8601 combined date and time with timezone +impl FromSql for PrimitiveDateTime { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().and_then(|s| { + PrimitiveDateTime::parse(s, &PRIMITIVE_DATE_TIME_FORMAT) + .map_err(|err| FromSqlError::Other(err.into())) + }) + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, Result}; + use time::macros::{date, datetime, time}; + use time::{Date, OffsetDateTime, PrimitiveDateTime, Time}; + + fn checked_memory_handle() -> Result<Connection> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (t TEXT, i INTEGER, f FLOAT, b BLOB)")?; + Ok(db) + } + + #[test] + fn test_offset_date_time() -> Result<()> { + let db = checked_memory_handle()?; + + let mut ts_vec = vec![]; + + let make_datetime = |secs: i128, nanos: i128| { + OffsetDateTime::from_unix_timestamp_nanos(1_000_000_000 * secs + nanos).unwrap() + }; + + ts_vec.push(make_datetime(10_000, 0)); //January 1, 1970 2:46:40 AM + ts_vec.push(make_datetime(10_000, 1000)); //January 1, 1970 2:46:40 AM (and one microsecond) + ts_vec.push(make_datetime(1_500_391_124, 1_000_000)); //July 18, 2017 + ts_vec.push(make_datetime(2_000_000_000, 2_000_000)); //May 18, 2033 + ts_vec.push(make_datetime(3_000_000_000, 999_999_999)); //January 24, 2065 + ts_vec.push(make_datetime(10_000_000_000, 0)); //November 20, 2286 + + for ts in ts_vec { + db.execute("INSERT INTO foo(t) VALUES (?1)", [ts])?; + + let from: OffsetDateTime = db.one_column("SELECT t FROM foo")?; + + db.execute("DELETE FROM foo", [])?; + + assert_eq!(from, ts); + } + Ok(()) + } + + #[test] + fn test_offset_date_time_parsing() -> Result<()> { + let db = checked_memory_handle()?; + let tests = vec![ + // Rfc3339 + ( + "2013-10-07T08:23:19.123456789Z", + datetime!(2013-10-07 8:23:19.123456789 UTC), + ), + ( + "2013-10-07 08:23:19.123456789Z", + datetime!(2013-10-07 8:23:19.123456789 UTC), + ), + // Format 2 + ("2013-10-07 08:23", datetime!(2013-10-07 8:23 UTC)), + ("2013-10-07 08:23Z", datetime!(2013-10-07 8:23 UTC)), + ("2013-10-07 08:23+04:00", datetime!(2013-10-07 8:23 +4)), + // Format 3 + ("2013-10-07 08:23:19", datetime!(2013-10-07 8:23:19 UTC)), + ("2013-10-07 08:23:19Z", datetime!(2013-10-07 8:23:19 UTC)), + ( + "2013-10-07 08:23:19+04:00", + datetime!(2013-10-07 8:23:19 +4), + ), + // Format 4 + ( + "2013-10-07 08:23:19.123", + datetime!(2013-10-07 8:23:19.123 UTC), + ), + ( + "2013-10-07 08:23:19.123Z", + datetime!(2013-10-07 8:23:19.123 UTC), + ), + ( + "2013-10-07 08:23:19.123+04:00", + datetime!(2013-10-07 8:23:19.123 +4), + ), + // Format 5 + ("2013-10-07T08:23", datetime!(2013-10-07 8:23 UTC)), + ("2013-10-07T08:23Z", datetime!(2013-10-07 8:23 UTC)), + ("2013-10-07T08:23+04:00", datetime!(2013-10-07 8:23 +4)), + // Format 6 + ("2013-10-07T08:23:19", datetime!(2013-10-07 8:23:19 UTC)), + ("2013-10-07T08:23:19Z", datetime!(2013-10-07 8:23:19 UTC)), + ( + "2013-10-07T08:23:19+04:00", + datetime!(2013-10-07 8:23:19 +4), + ), + // Format 7 + ( + "2013-10-07T08:23:19.123", + datetime!(2013-10-07 8:23:19.123 UTC), + ), + ( + "2013-10-07T08:23:19.123Z", + datetime!(2013-10-07 8:23:19.123 UTC), + ), + ( + "2013-10-07T08:23:19.123+04:00", + datetime!(2013-10-07 8:23:19.123 +4), + ), + // Legacy + ( + "2013-10-07 08:23:12:987 -07:00", + datetime!(2013-10-07 8:23:12.987 -7), + ), + ]; + + for (s, t) in tests { + let result: OffsetDateTime = db.query_row("SELECT ?1", [s], |r| r.get(0))?; + assert_eq!(result, t); + } + Ok(()) + } + + #[test] + fn test_date() -> Result<()> { + let db = checked_memory_handle()?; + let date = date!(2016 - 02 - 23); + db.execute("INSERT INTO foo (t) VALUES (?1)", [date])?; + + let s: String = db.one_column("SELECT t FROM foo")?; + assert_eq!("2016-02-23", s); + let t: Date = db.one_column("SELECT t FROM foo")?; + assert_eq!(date, t); + Ok(()) + } + + #[test] + fn test_time() -> Result<()> { + let db = checked_memory_handle()?; + let time = time!(23:56:04.00001); + db.execute("INSERT INTO foo (t) VALUES (?1)", [time])?; + + let s: String = db.one_column("SELECT t FROM foo")?; + assert_eq!("23:56:04.00001", s); + let v: Time = db.one_column("SELECT t FROM foo")?; + assert_eq!(time, v); + Ok(()) + } + + #[test] + fn test_primitive_date_time() -> Result<()> { + let db = checked_memory_handle()?; + let dt = date!(2016 - 02 - 23).with_time(time!(23:56:04)); + + db.execute("INSERT INTO foo (t) VALUES (?1)", [dt])?; + + let s: String = db.one_column("SELECT t FROM foo")?; + assert_eq!("2016-02-23 23:56:04.0", s); + let v: PrimitiveDateTime = db.one_column("SELECT t FROM foo")?; + assert_eq!(dt, v); + + db.execute("UPDATE foo set b = datetime(t)", [])?; // "YYYY-MM-DD HH:MM:SS" + let hms: PrimitiveDateTime = db.one_column("SELECT b FROM foo")?; + assert_eq!(dt, hms); + Ok(()) + } + + #[test] + fn test_date_parsing() -> Result<()> { + let db = checked_memory_handle()?; + let result: Date = db.query_row("SELECT ?1", ["2013-10-07"], |r| r.get(0))?; + assert_eq!(result, date!(2013 - 10 - 07)); + Ok(()) + } + + #[test] + fn test_time_parsing() -> Result<()> { + let db = checked_memory_handle()?; + let tests = vec![ + ("08:23", time!(08:23)), + ("08:23:19", time!(08:23:19)), + ("08:23:19.111", time!(08:23:19.111)), + ]; + + for (s, t) in tests { + let result: Time = db.query_row("SELECT ?1", [s], |r| r.get(0))?; + assert_eq!(result, t); + } + Ok(()) + } + + #[test] + fn test_primitive_date_time_parsing() -> Result<()> { + let db = checked_memory_handle()?; + + let tests = vec![ + ("2013-10-07T08:23", datetime!(2013-10-07 8:23)), + ("2013-10-07T08:23:19", datetime!(2013-10-07 8:23:19)), + ("2013-10-07T08:23:19.111", datetime!(2013-10-07 8:23:19.111)), + ("2013-10-07 08:23", datetime!(2013-10-07 8:23)), + ("2013-10-07 08:23:19", datetime!(2013-10-07 8:23:19)), + ("2013-10-07 08:23:19.111", datetime!(2013-10-07 8:23:19.111)), + ]; + + for (s, t) in tests { + let result: PrimitiveDateTime = db.query_row("SELECT ?1", [s], |r| r.get(0))?; + assert_eq!(result, t); + } + Ok(()) + } + + #[test] + fn test_sqlite_functions() -> Result<()> { + let db = checked_memory_handle()?; + db.one_column::<Time>("SELECT CURRENT_TIME").unwrap(); + db.one_column::<Date>("SELECT CURRENT_DATE").unwrap(); + db.one_column::<PrimitiveDateTime>("SELECT CURRENT_TIMESTAMP") + .unwrap(); + db.one_column::<OffsetDateTime>("SELECT CURRENT_TIMESTAMP") + .unwrap(); + Ok(()) + } + + #[test] + fn test_time_param() -> Result<()> { + let db = checked_memory_handle()?; + let now = OffsetDateTime::now_utc().time(); + let result: Result<bool> = db.query_row( + "SELECT 1 WHERE ?1 BETWEEN time('now', '-1 minute') AND time('now', '+1 minute')", + [now], + |r| r.get(0), + ); + result.unwrap(); + Ok(()) + } + + #[test] + fn test_date_param() -> Result<()> { + let db = checked_memory_handle()?; + let now = OffsetDateTime::now_utc().date(); + let result: Result<bool> = db.query_row( + "SELECT 1 WHERE ?1 BETWEEN date('now', '-1 day') AND date('now', '+1 day')", + [now], + |r| r.get(0), + ); + result.unwrap(); + Ok(()) + } + + #[test] + fn test_primitive_date_time_param() -> Result<()> { + let db = checked_memory_handle()?; + let now = PrimitiveDateTime::new( + OffsetDateTime::now_utc().date(), + OffsetDateTime::now_utc().time(), + ); + let result: Result<bool> = db.query_row( + "SELECT 1 WHERE ?1 BETWEEN datetime('now', '-1 minute') AND datetime('now', '+1 minute')", + [now], + |r| r.get(0), + ); + result.unwrap(); + Ok(()) + } + + #[test] + fn test_offset_date_time_param() -> Result<()> { + let db = checked_memory_handle()?; + let result: Result<bool> = db.query_row( + "SELECT 1 WHERE ?1 BETWEEN datetime('now', '-1 minute') AND datetime('now', '+1 minute')", + [OffsetDateTime::now_utc()], + |r| r.get(0), + ); + result.unwrap(); + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/types/to_sql.rs b/third_party/rust/rusqlite/src/types/to_sql.rs new file mode 100644 index 0000000000..29e63abdcc --- /dev/null +++ b/third_party/rust/rusqlite/src/types/to_sql.rs @@ -0,0 +1,541 @@ +use super::{Null, Value, ValueRef}; +#[cfg(feature = "array")] +use crate::vtab::array::Array; +use crate::{Error, Result}; +use std::borrow::Cow; +use std::convert::TryFrom; + +/// `ToSqlOutput` represents the possible output types for implementers of the +/// [`ToSql`] trait. +#[derive(Clone, Debug, PartialEq)] +#[non_exhaustive] +pub enum ToSqlOutput<'a> { + /// A borrowed SQLite-representable value. + Borrowed(ValueRef<'a>), + + /// An owned SQLite-representable value. + Owned(Value), + + /// A BLOB of the given length that is filled with + /// zeroes. + #[cfg(feature = "blob")] + #[cfg_attr(docsrs, doc(cfg(feature = "blob")))] + ZeroBlob(i32), + + /// `feature = "array"` + #[cfg(feature = "array")] + #[cfg_attr(docsrs, doc(cfg(feature = "array")))] + Array(Array), +} + +// Generically allow any type that can be converted into a ValueRef +// to be converted into a ToSqlOutput as well. +impl<'a, T: ?Sized> From<&'a T> for ToSqlOutput<'a> +where + &'a T: Into<ValueRef<'a>>, +{ + #[inline] + fn from(t: &'a T) -> Self { + ToSqlOutput::Borrowed(t.into()) + } +} + +// We cannot also generically allow any type that can be converted +// into a Value to be converted into a ToSqlOutput because of +// coherence rules (https://github.com/rust-lang/rust/pull/46192), +// so we'll manually implement it for all the types we know can +// be converted into Values. +macro_rules! from_value( + ($t:ty) => ( + impl From<$t> for ToSqlOutput<'_> { + #[inline] + fn from(t: $t) -> Self { ToSqlOutput::Owned(t.into())} + } + ); + (non_zero $t:ty) => ( + impl From<$t> for ToSqlOutput<'_> { + #[inline] + fn from(t: $t) -> Self { ToSqlOutput::Owned(t.get().into())} + } + ) +); +from_value!(String); +from_value!(Null); +from_value!(bool); +from_value!(i8); +from_value!(i16); +from_value!(i32); +from_value!(i64); +from_value!(isize); +from_value!(u8); +from_value!(u16); +from_value!(u32); +from_value!(f32); +from_value!(f64); +from_value!(Vec<u8>); + +from_value!(non_zero std::num::NonZeroI8); +from_value!(non_zero std::num::NonZeroI16); +from_value!(non_zero std::num::NonZeroI32); +from_value!(non_zero std::num::NonZeroI64); +from_value!(non_zero std::num::NonZeroIsize); +from_value!(non_zero std::num::NonZeroU8); +from_value!(non_zero std::num::NonZeroU16); +from_value!(non_zero std::num::NonZeroU32); + +// It would be nice if we could avoid the heap allocation (of the `Vec`) that +// `i128` needs in `Into<Value>`, but it's probably fine for the moment, and not +// worth adding another case to Value. +#[cfg(feature = "i128_blob")] +#[cfg_attr(docsrs, doc(cfg(feature = "i128_blob")))] +from_value!(i128); + +#[cfg(feature = "i128_blob")] +#[cfg_attr(docsrs, doc(cfg(feature = "i128_blob")))] +from_value!(non_zero std::num::NonZeroI128); + +#[cfg(feature = "uuid")] +#[cfg_attr(docsrs, doc(cfg(feature = "uuid")))] +from_value!(uuid::Uuid); + +impl ToSql for ToSqlOutput<'_> { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(match *self { + ToSqlOutput::Borrowed(v) => ToSqlOutput::Borrowed(v), + ToSqlOutput::Owned(ref v) => ToSqlOutput::Borrowed(ValueRef::from(v)), + + #[cfg(feature = "blob")] + ToSqlOutput::ZeroBlob(i) => ToSqlOutput::ZeroBlob(i), + #[cfg(feature = "array")] + ToSqlOutput::Array(ref a) => ToSqlOutput::Array(a.clone()), + }) + } +} + +/// A trait for types that can be converted into SQLite values. Returns +/// [`Error::ToSqlConversionFailure`] if the conversion fails. +pub trait ToSql { + /// Converts Rust value to SQLite value + fn to_sql(&self) -> Result<ToSqlOutput<'_>>; +} + +impl<T: ToSql + ToOwned + ?Sized> ToSql for Cow<'_, T> { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + self.as_ref().to_sql() + } +} + +impl<T: ToSql + ?Sized> ToSql for Box<T> { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + self.as_ref().to_sql() + } +} + +impl<T: ToSql + ?Sized> ToSql for std::rc::Rc<T> { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + self.as_ref().to_sql() + } +} + +impl<T: ToSql + ?Sized> ToSql for std::sync::Arc<T> { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + self.as_ref().to_sql() + } +} + +// We should be able to use a generic impl like this: +// +// impl<T: Copy> ToSql for T where T: Into<Value> { +// fn to_sql(&self) -> Result<ToSqlOutput> { +// Ok(ToSqlOutput::from((*self).into())) +// } +// } +// +// instead of the following macro, but this runs afoul of +// https://github.com/rust-lang/rust/issues/30191 and reports conflicting +// implementations even when there aren't any. + +macro_rules! to_sql_self( + ($t:ty) => ( + impl ToSql for $t { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(*self)) + } + } + ) +); + +to_sql_self!(Null); +to_sql_self!(bool); +to_sql_self!(i8); +to_sql_self!(i16); +to_sql_self!(i32); +to_sql_self!(i64); +to_sql_self!(isize); +to_sql_self!(u8); +to_sql_self!(u16); +to_sql_self!(u32); +to_sql_self!(f32); +to_sql_self!(f64); + +to_sql_self!(std::num::NonZeroI8); +to_sql_self!(std::num::NonZeroI16); +to_sql_self!(std::num::NonZeroI32); +to_sql_self!(std::num::NonZeroI64); +to_sql_self!(std::num::NonZeroIsize); +to_sql_self!(std::num::NonZeroU8); +to_sql_self!(std::num::NonZeroU16); +to_sql_self!(std::num::NonZeroU32); + +#[cfg(feature = "i128_blob")] +#[cfg_attr(docsrs, doc(cfg(feature = "i128_blob")))] +to_sql_self!(i128); + +#[cfg(feature = "i128_blob")] +#[cfg_attr(docsrs, doc(cfg(feature = "i128_blob")))] +to_sql_self!(std::num::NonZeroI128); + +#[cfg(feature = "uuid")] +#[cfg_attr(docsrs, doc(cfg(feature = "uuid")))] +to_sql_self!(uuid::Uuid); + +macro_rules! to_sql_self_fallible( + ($t:ty) => ( + impl ToSql for $t { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::Owned(Value::Integer( + i64::try_from(*self).map_err( + // TODO: Include the values in the error message. + |err| Error::ToSqlConversionFailure(err.into()) + )? + ))) + } + } + ); + (non_zero $t:ty) => ( + impl ToSql for $t { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::Owned(Value::Integer( + i64::try_from(self.get()).map_err( + // TODO: Include the values in the error message. + |err| Error::ToSqlConversionFailure(err.into()) + )? + ))) + } + } + ) +); + +// Special implementations for usize and u64 because these conversions can fail. +to_sql_self_fallible!(u64); +to_sql_self_fallible!(usize); +to_sql_self_fallible!(non_zero std::num::NonZeroU64); +to_sql_self_fallible!(non_zero std::num::NonZeroUsize); + +impl<T: ?Sized> ToSql for &'_ T +where + T: ToSql, +{ + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + (*self).to_sql() + } +} + +impl ToSql for String { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(self.as_str())) + } +} + +impl ToSql for str { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(self)) + } +} + +impl ToSql for Vec<u8> { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(self.as_slice())) + } +} + +impl<const N: usize> ToSql for [u8; N] { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(&self[..])) + } +} + +impl ToSql for [u8] { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(self)) + } +} + +impl ToSql for Value { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(self)) + } +} + +impl<T: ToSql> ToSql for Option<T> { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + match *self { + None => Ok(ToSqlOutput::from(Null)), + Some(ref t) => t.to_sql(), + } + } +} + +#[cfg(test)] +mod test { + use super::ToSql; + + fn is_to_sql<T: ToSql>() {} + + #[test] + fn test_integral_types() { + is_to_sql::<i8>(); + is_to_sql::<i16>(); + is_to_sql::<i32>(); + is_to_sql::<i64>(); + is_to_sql::<isize>(); + is_to_sql::<u8>(); + is_to_sql::<u16>(); + is_to_sql::<u32>(); + is_to_sql::<u64>(); + is_to_sql::<usize>(); + } + + #[test] + fn test_nonzero_types() { + is_to_sql::<std::num::NonZeroI8>(); + is_to_sql::<std::num::NonZeroI16>(); + is_to_sql::<std::num::NonZeroI32>(); + is_to_sql::<std::num::NonZeroI64>(); + is_to_sql::<std::num::NonZeroIsize>(); + is_to_sql::<std::num::NonZeroU8>(); + is_to_sql::<std::num::NonZeroU16>(); + is_to_sql::<std::num::NonZeroU32>(); + is_to_sql::<std::num::NonZeroU64>(); + is_to_sql::<std::num::NonZeroUsize>(); + } + + #[test] + fn test_u8_array() { + let a: [u8; 99] = [0u8; 99]; + let _a: &[&dyn ToSql] = crate::params![a]; + let r = ToSql::to_sql(&a); + + r.unwrap(); + } + + #[test] + fn test_cow_str() { + use std::borrow::Cow; + let s = "str"; + let cow: Cow<str> = Cow::Borrowed(s); + let r = cow.to_sql(); + r.unwrap(); + let cow: Cow<str> = Cow::Owned::<str>(String::from(s)); + let r = cow.to_sql(); + r.unwrap(); + // Ensure this compiles. + let _p: &[&dyn ToSql] = crate::params![cow]; + } + + #[test] + fn test_box_dyn() { + let s: Box<dyn ToSql> = Box::new("Hello world!"); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = ToSql::to_sql(&s); + + r.unwrap(); + } + + #[test] + fn test_box_deref() { + let s: Box<str> = "Hello world!".into(); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + + r.unwrap(); + } + + #[test] + fn test_box_direct() { + let s: Box<str> = "Hello world!".into(); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = ToSql::to_sql(&s); + + r.unwrap(); + } + + #[test] + fn test_cells() { + use std::{rc::Rc, sync::Arc}; + + let source_str: Box<str> = "Hello world!".into(); + + let s: Rc<Box<str>> = Rc::new(source_str.clone()); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + r.unwrap(); + + let s: Arc<Box<str>> = Arc::new(source_str.clone()); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + r.unwrap(); + + let s: Arc<str> = Arc::from(&*source_str); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + r.unwrap(); + + let s: Arc<dyn ToSql> = Arc::new(source_str.clone()); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + r.unwrap(); + + let s: Rc<str> = Rc::from(&*source_str); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + r.unwrap(); + + let s: Rc<dyn ToSql> = Rc::new(source_str); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + r.unwrap(); + } + + #[cfg(feature = "i128_blob")] + #[test] + fn test_i128() -> crate::Result<()> { + use crate::Connection; + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (i128 BLOB, desc TEXT)")?; + db.execute( + " + INSERT INTO foo(i128, desc) VALUES + (?1, 'zero'), + (?2, 'neg one'), (?3, 'neg two'), + (?4, 'pos one'), (?5, 'pos two'), + (?6, 'min'), (?7, 'max')", + [0i128, -1i128, -2i128, 1i128, 2i128, i128::MIN, i128::MAX], + )?; + + let mut stmt = db.prepare("SELECT i128, desc FROM foo ORDER BY i128 ASC")?; + + let res = stmt + .query_map([], |row| { + Ok((row.get::<_, i128>(0)?, row.get::<_, String>(1)?)) + })? + .collect::<Result<Vec<_>, _>>()?; + + assert_eq!( + res, + &[ + (i128::MIN, "min".to_owned()), + (-2, "neg two".to_owned()), + (-1, "neg one".to_owned()), + (0, "zero".to_owned()), + (1, "pos one".to_owned()), + (2, "pos two".to_owned()), + (i128::MAX, "max".to_owned()), + ] + ); + Ok(()) + } + + #[cfg(feature = "i128_blob")] + #[test] + fn test_non_zero_i128() -> crate::Result<()> { + use std::num::NonZeroI128; + macro_rules! nz { + ($x:expr) => { + NonZeroI128::new($x).unwrap() + }; + } + + let db = crate::Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (i128 BLOB, desc TEXT)")?; + db.execute( + "INSERT INTO foo(i128, desc) VALUES + (?1, 'neg one'), (?2, 'neg two'), + (?3, 'pos one'), (?4, 'pos two'), + (?5, 'min'), (?6, 'max')", + [ + nz!(-1), + nz!(-2), + nz!(1), + nz!(2), + nz!(i128::MIN), + nz!(i128::MAX), + ], + )?; + let mut stmt = db.prepare("SELECT i128, desc FROM foo ORDER BY i128 ASC")?; + + let res = stmt + .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))? + .collect::<Result<Vec<(NonZeroI128, String)>, _>>()?; + + assert_eq!( + res, + &[ + (nz!(i128::MIN), "min".to_owned()), + (nz!(-2), "neg two".to_owned()), + (nz!(-1), "neg one".to_owned()), + (nz!(1), "pos one".to_owned()), + (nz!(2), "pos two".to_owned()), + (nz!(i128::MAX), "max".to_owned()), + ] + ); + let err = db.query_row("SELECT ?1", [0i128], |row| row.get::<_, NonZeroI128>(0)); + assert_eq!(err, Err(crate::Error::IntegralValueOutOfRange(0, 0))); + Ok(()) + } + + #[cfg(feature = "uuid")] + #[test] + fn test_uuid() -> crate::Result<()> { + use crate::{params, Connection}; + use uuid::Uuid; + + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE foo (id BLOB CHECK(length(id) = 16), label TEXT);")?; + + let id = Uuid::new_v4(); + + db.execute( + "INSERT INTO foo (id, label) VALUES (?1, ?2)", + params![id, "target"], + )?; + + let mut stmt = db.prepare("SELECT id, label FROM foo WHERE id = ?1")?; + + let mut rows = stmt.query(params![id])?; + let row = rows.next()?.unwrap(); + + let found_id: Uuid = row.get_unwrap(0); + let found_label: String = row.get_unwrap(1); + + assert_eq!(found_id, id); + assert_eq!(found_label, "target"); + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/types/url.rs b/third_party/rust/rusqlite/src/types/url.rs new file mode 100644 index 0000000000..e7fc2039c1 --- /dev/null +++ b/third_party/rust/rusqlite/src/types/url.rs @@ -0,0 +1,82 @@ +//! [`ToSql`] and [`FromSql`] implementation for [`url::Url`]. +use crate::types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef}; +use crate::Result; +use url::Url; + +/// Serialize `Url` to text. +impl ToSql for Url { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(self.as_str())) + } +} + +/// Deserialize text to `Url`. +impl FromSql for Url { + #[inline] + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + match value { + ValueRef::Text(s) => { + let s = std::str::from_utf8(s).map_err(|e| FromSqlError::Other(Box::new(e)))?; + Url::parse(s).map_err(|e| FromSqlError::Other(Box::new(e))) + } + _ => Err(FromSqlError::InvalidType), + } + } +} + +#[cfg(test)] +mod test { + use crate::{params, Connection, Error, Result}; + use url::{ParseError, Url}; + + fn checked_memory_handle() -> Result<Connection> { + let db = Connection::open_in_memory()?; + db.execute_batch("CREATE TABLE urls (i INTEGER, v TEXT)")?; + Ok(db) + } + + fn get_url(db: &Connection, id: i64) -> Result<Url> { + db.query_row("SELECT v FROM urls WHERE i = ?", [id], |r| r.get(0)) + } + + #[test] + fn test_sql_url() -> Result<()> { + let db = &checked_memory_handle()?; + + let url0 = Url::parse("http://www.example1.com").unwrap(); + let url1 = Url::parse("http://www.example1.com/👌").unwrap(); + let url2 = "http://www.example2.com/👌"; + + db.execute( + "INSERT INTO urls (i, v) VALUES (0, ?1), (1, ?2), (2, ?3), (3, ?4)", + // also insert a non-hex encoded url (which might be present if it was + // inserted separately) + params![url0, url1, url2, "illegal"], + )?; + + assert_eq!(get_url(db, 0)?, url0); + + assert_eq!(get_url(db, 1)?, url1); + + // Should successfully read it, even though it wasn't inserted as an + // escaped url. + let out_url2: Url = get_url(db, 2)?; + assert_eq!(out_url2, Url::parse(url2).unwrap()); + + // Make sure the conversion error comes through correctly. + let err = get_url(db, 3).unwrap_err(); + match err { + Error::FromSqlConversionFailure(_, _, e) => { + assert_eq!( + *e.downcast::<ParseError>().unwrap(), + ParseError::RelativeUrlWithoutBase, + ); + } + e => { + panic!("Expected conversion failure, got {e}"); + } + } + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/types/value.rs b/third_party/rust/rusqlite/src/types/value.rs new file mode 100644 index 0000000000..ca3ee9f49d --- /dev/null +++ b/third_party/rust/rusqlite/src/types/value.rs @@ -0,0 +1,142 @@ +use super::{Null, Type}; + +/// Owning [dynamic type value](http://sqlite.org/datatype3.html). Value's type is typically +/// dictated by SQLite (not by the caller). +/// +/// See [`ValueRef`](crate::types::ValueRef) for a non-owning dynamic type +/// value. +#[derive(Clone, Debug, PartialEq)] +pub enum Value { + /// The value is a `NULL` value. + Null, + /// The value is a signed integer. + Integer(i64), + /// The value is a floating point number. + Real(f64), + /// The value is a text string. + Text(String), + /// The value is a blob of data + Blob(Vec<u8>), +} + +impl From<Null> for Value { + #[inline] + fn from(_: Null) -> Value { + Value::Null + } +} + +impl From<bool> for Value { + #[inline] + fn from(i: bool) -> Value { + Value::Integer(i as i64) + } +} + +impl From<isize> for Value { + #[inline] + fn from(i: isize) -> Value { + Value::Integer(i as i64) + } +} + +#[cfg(feature = "i128_blob")] +#[cfg_attr(docsrs, doc(cfg(feature = "i128_blob")))] +impl From<i128> for Value { + #[inline] + fn from(i: i128) -> Value { + // We store these biased (e.g. with the most significant bit flipped) + // so that comparisons with negative numbers work properly. + Value::Blob(i128::to_be_bytes(i ^ (1_i128 << 127)).to_vec()) + } +} + +#[cfg(feature = "uuid")] +#[cfg_attr(docsrs, doc(cfg(feature = "uuid")))] +impl From<uuid::Uuid> for Value { + #[inline] + fn from(id: uuid::Uuid) -> Value { + Value::Blob(id.as_bytes().to_vec()) + } +} + +macro_rules! from_i64( + ($t:ty) => ( + impl From<$t> for Value { + #[inline] + fn from(i: $t) -> Value { + Value::Integer(i64::from(i)) + } + } + ) +); + +from_i64!(i8); +from_i64!(i16); +from_i64!(i32); +from_i64!(u8); +from_i64!(u16); +from_i64!(u32); + +impl From<i64> for Value { + #[inline] + fn from(i: i64) -> Value { + Value::Integer(i) + } +} + +impl From<f32> for Value { + #[inline] + fn from(f: f32) -> Value { + Value::Real(f.into()) + } +} + +impl From<f64> for Value { + #[inline] + fn from(f: f64) -> Value { + Value::Real(f) + } +} + +impl From<String> for Value { + #[inline] + fn from(s: String) -> Value { + Value::Text(s) + } +} + +impl From<Vec<u8>> for Value { + #[inline] + fn from(v: Vec<u8>) -> Value { + Value::Blob(v) + } +} + +impl<T> From<Option<T>> for Value +where + T: Into<Value>, +{ + #[inline] + fn from(v: Option<T>) -> Value { + match v { + Some(x) => x.into(), + None => Value::Null, + } + } +} + +impl Value { + /// Returns SQLite fundamental datatype. + #[inline] + #[must_use] + pub fn data_type(&self) -> Type { + match *self { + Value::Null => Type::Null, + Value::Integer(_) => Type::Integer, + Value::Real(_) => Type::Real, + Value::Text(_) => Type::Text, + Value::Blob(_) => Type::Blob, + } + } +} diff --git a/third_party/rust/rusqlite/src/types/value_ref.rs b/third_party/rust/rusqlite/src/types/value_ref.rs new file mode 100644 index 0000000000..aa062f8ba4 --- /dev/null +++ b/third_party/rust/rusqlite/src/types/value_ref.rs @@ -0,0 +1,258 @@ +use super::{Type, Value}; +use crate::types::{FromSqlError, FromSqlResult}; + +/// A non-owning [dynamic type value](http://sqlite.org/datatype3.html). Typically the +/// memory backing this value is owned by SQLite. +/// +/// See [`Value`](Value) for an owning dynamic type value. +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum ValueRef<'a> { + /// The value is a `NULL` value. + Null, + /// The value is a signed integer. + Integer(i64), + /// The value is a floating point number. + Real(f64), + /// The value is a text string. + Text(&'a [u8]), + /// The value is a blob of data + Blob(&'a [u8]), +} + +impl ValueRef<'_> { + /// Returns SQLite fundamental datatype. + #[inline] + #[must_use] + pub fn data_type(&self) -> Type { + match *self { + ValueRef::Null => Type::Null, + ValueRef::Integer(_) => Type::Integer, + ValueRef::Real(_) => Type::Real, + ValueRef::Text(_) => Type::Text, + ValueRef::Blob(_) => Type::Blob, + } + } +} + +impl<'a> ValueRef<'a> { + /// If `self` is case `Integer`, returns the integral value. Otherwise, + /// returns [`Err(Error::InvalidColumnType)`](crate::Error::InvalidColumnType). + #[inline] + pub fn as_i64(&self) -> FromSqlResult<i64> { + match *self { + ValueRef::Integer(i) => Ok(i), + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Null` returns None. + /// If `self` is case `Integer`, returns the integral value. + /// Otherwise returns [`Err(Error::InvalidColumnType)`](crate::Error::InvalidColumnType). + #[inline] + pub fn as_i64_or_null(&self) -> FromSqlResult<Option<i64>> { + match *self { + ValueRef::Null => Ok(None), + ValueRef::Integer(i) => Ok(Some(i)), + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Real`, returns the floating point value. Otherwise, + /// returns [`Err(Error::InvalidColumnType)`](crate::Error::InvalidColumnType). + #[inline] + pub fn as_f64(&self) -> FromSqlResult<f64> { + match *self { + ValueRef::Real(f) => Ok(f), + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Null` returns None. + /// If `self` is case `Real`, returns the floating point value. + /// Otherwise returns [`Err(Error::InvalidColumnType)`](crate::Error::InvalidColumnType). + #[inline] + pub fn as_f64_or_null(&self) -> FromSqlResult<Option<f64>> { + match *self { + ValueRef::Null => Ok(None), + ValueRef::Real(f) => Ok(Some(f)), + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Text`, returns the string value. Otherwise, returns + /// [`Err(Error::InvalidColumnType)`](crate::Error::InvalidColumnType). + #[inline] + pub fn as_str(&self) -> FromSqlResult<&'a str> { + match *self { + ValueRef::Text(t) => { + std::str::from_utf8(t).map_err(|e| FromSqlError::Other(Box::new(e))) + } + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Null` returns None. + /// If `self` is case `Text`, returns the string value. + /// Otherwise returns [`Err(Error::InvalidColumnType)`](crate::Error::InvalidColumnType). + #[inline] + pub fn as_str_or_null(&self) -> FromSqlResult<Option<&'a str>> { + match *self { + ValueRef::Null => Ok(None), + ValueRef::Text(t) => std::str::from_utf8(t) + .map_err(|e| FromSqlError::Other(Box::new(e))) + .map(Some), + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Blob`, returns the byte slice. Otherwise, returns + /// [`Err(Error::InvalidColumnType)`](crate::Error::InvalidColumnType). + #[inline] + pub fn as_blob(&self) -> FromSqlResult<&'a [u8]> { + match *self { + ValueRef::Blob(b) => Ok(b), + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Null` returns None. + /// If `self` is case `Blob`, returns the byte slice. + /// Otherwise returns [`Err(Error::InvalidColumnType)`](crate::Error::InvalidColumnType). + #[inline] + pub fn as_blob_or_null(&self) -> FromSqlResult<Option<&'a [u8]>> { + match *self { + ValueRef::Null => Ok(None), + ValueRef::Blob(b) => Ok(Some(b)), + _ => Err(FromSqlError::InvalidType), + } + } + + /// Returns the byte slice that makes up this `ValueRef` if it's either + /// [`ValueRef::Blob`] or [`ValueRef::Text`]. + #[inline] + pub fn as_bytes(&self) -> FromSqlResult<&'a [u8]> { + match self { + ValueRef::Text(s) | ValueRef::Blob(s) => Ok(s), + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Null` returns None. + /// If `self` is [`ValueRef::Blob`] or [`ValueRef::Text`] returns the byte + /// slice that makes up this value + #[inline] + pub fn as_bytes_or_null(&self) -> FromSqlResult<Option<&'a [u8]>> { + match *self { + ValueRef::Null => Ok(None), + ValueRef::Text(s) | ValueRef::Blob(s) => Ok(Some(s)), + _ => Err(FromSqlError::InvalidType), + } + } +} + +impl From<ValueRef<'_>> for Value { + #[inline] + #[track_caller] + fn from(borrowed: ValueRef<'_>) -> Value { + match borrowed { + ValueRef::Null => Value::Null, + ValueRef::Integer(i) => Value::Integer(i), + ValueRef::Real(r) => Value::Real(r), + ValueRef::Text(s) => { + let s = std::str::from_utf8(s).expect("invalid UTF-8"); + Value::Text(s.to_string()) + } + ValueRef::Blob(b) => Value::Blob(b.to_vec()), + } + } +} + +impl<'a> From<&'a str> for ValueRef<'a> { + #[inline] + fn from(s: &str) -> ValueRef<'_> { + ValueRef::Text(s.as_bytes()) + } +} + +impl<'a> From<&'a [u8]> for ValueRef<'a> { + #[inline] + fn from(s: &[u8]) -> ValueRef<'_> { + ValueRef::Blob(s) + } +} + +impl<'a> From<&'a Value> for ValueRef<'a> { + #[inline] + fn from(value: &'a Value) -> ValueRef<'a> { + match *value { + Value::Null => ValueRef::Null, + Value::Integer(i) => ValueRef::Integer(i), + Value::Real(r) => ValueRef::Real(r), + Value::Text(ref s) => ValueRef::Text(s.as_bytes()), + Value::Blob(ref b) => ValueRef::Blob(b), + } + } +} + +impl<'a, T> From<Option<T>> for ValueRef<'a> +where + T: Into<ValueRef<'a>>, +{ + #[inline] + fn from(s: Option<T>) -> ValueRef<'a> { + match s { + Some(x) => x.into(), + None => ValueRef::Null, + } + } +} + +#[cfg(any(feature = "functions", feature = "session", feature = "vtab"))] +impl<'a> ValueRef<'a> { + pub(crate) unsafe fn from_value(value: *mut crate::ffi::sqlite3_value) -> ValueRef<'a> { + use crate::ffi; + use std::slice::from_raw_parts; + + match ffi::sqlite3_value_type(value) { + ffi::SQLITE_NULL => ValueRef::Null, + ffi::SQLITE_INTEGER => ValueRef::Integer(ffi::sqlite3_value_int64(value)), + ffi::SQLITE_FLOAT => ValueRef::Real(ffi::sqlite3_value_double(value)), + ffi::SQLITE_TEXT => { + let text = ffi::sqlite3_value_text(value); + let len = ffi::sqlite3_value_bytes(value); + assert!( + !text.is_null(), + "unexpected SQLITE_TEXT value type with NULL data" + ); + let s = from_raw_parts(text.cast::<u8>(), len as usize); + ValueRef::Text(s) + } + ffi::SQLITE_BLOB => { + let (blob, len) = ( + ffi::sqlite3_value_blob(value), + ffi::sqlite3_value_bytes(value), + ); + + assert!( + len >= 0, + "unexpected negative return from sqlite3_value_bytes" + ); + if len > 0 { + assert!( + !blob.is_null(), + "unexpected SQLITE_BLOB value type with NULL data" + ); + ValueRef::Blob(from_raw_parts(blob.cast::<u8>(), len as usize)) + } else { + // The return value from sqlite3_value_blob() for a zero-length BLOB + // is a NULL pointer. + ValueRef::Blob(&[]) + } + } + _ => unreachable!("sqlite3_value_type returned invalid value"), + } + } + + // TODO sqlite3_value_nochange // 3.22.0 & VTab xUpdate + // TODO sqlite3_value_frombind // 3.28.0 +} diff --git a/third_party/rust/rusqlite/src/unlock_notify.rs b/third_party/rust/rusqlite/src/unlock_notify.rs new file mode 100644 index 0000000000..065c52d72a --- /dev/null +++ b/third_party/rust/rusqlite/src/unlock_notify.rs @@ -0,0 +1,117 @@ +//! [Unlock Notification](http://sqlite.org/unlock_notify.html) + +use std::os::raw::c_int; +use std::os::raw::c_void; +use std::panic::catch_unwind; +use std::sync::{Condvar, Mutex}; + +use crate::ffi; + +struct UnlockNotification { + cond: Condvar, // Condition variable to wait on + mutex: Mutex<bool>, // Mutex to protect structure +} + +#[allow(clippy::mutex_atomic)] +impl UnlockNotification { + fn new() -> UnlockNotification { + UnlockNotification { + cond: Condvar::new(), + mutex: Mutex::new(false), + } + } + + fn fired(&self) { + let mut flag = unpoison(self.mutex.lock()); + *flag = true; + self.cond.notify_one(); + } + + fn wait(&self) { + let mut fired = unpoison(self.mutex.lock()); + while !*fired { + fired = unpoison(self.cond.wait(fired)); + } + } +} + +#[inline] +fn unpoison<T>(r: Result<T, std::sync::PoisonError<T>>) -> T { + r.unwrap_or_else(std::sync::PoisonError::into_inner) +} + +/// This function is an unlock-notify callback +unsafe extern "C" fn unlock_notify_cb(ap_arg: *mut *mut c_void, n_arg: c_int) { + use std::slice::from_raw_parts; + let args = from_raw_parts(ap_arg as *const &UnlockNotification, n_arg as usize); + for un in args { + drop(catch_unwind(std::panic::AssertUnwindSafe(|| un.fired()))); + } +} + +pub unsafe fn is_locked(db: *mut ffi::sqlite3, rc: c_int) -> bool { + rc == ffi::SQLITE_LOCKED_SHAREDCACHE + || (rc & 0xFF) == ffi::SQLITE_LOCKED + && ffi::sqlite3_extended_errcode(db) == ffi::SQLITE_LOCKED_SHAREDCACHE +} + +/// This function assumes that an SQLite API call (either `sqlite3_prepare_v2()` +/// or `sqlite3_step()`) has just returned `SQLITE_LOCKED`. The argument is the +/// associated database connection. +/// +/// This function calls `sqlite3_unlock_notify()` to register for an +/// unlock-notify callback, then blocks until that callback is delivered +/// and returns `SQLITE_OK`. The caller should then retry the failed operation. +/// +/// Or, if `sqlite3_unlock_notify()` indicates that to block would deadlock +/// the system, then this function returns `SQLITE_LOCKED` immediately. In +/// this case the caller should not retry the operation and should roll +/// back the current transaction (if any). +#[cfg(feature = "unlock_notify")] +pub unsafe fn wait_for_unlock_notify(db: *mut ffi::sqlite3) -> c_int { + let un = UnlockNotification::new(); + /* Register for an unlock-notify callback. */ + let rc = ffi::sqlite3_unlock_notify( + db, + Some(unlock_notify_cb), + &un as *const UnlockNotification as *mut c_void, + ); + debug_assert!( + rc == ffi::SQLITE_LOCKED || rc == ffi::SQLITE_LOCKED_SHAREDCACHE || rc == ffi::SQLITE_OK + ); + if rc == ffi::SQLITE_OK { + un.wait(); + } + rc +} + +#[cfg(test)] +mod test { + use crate::{Connection, OpenFlags, Result, Transaction, TransactionBehavior}; + use std::sync::mpsc::sync_channel; + use std::thread; + use std::time; + + #[test] + fn test_unlock_notify() -> Result<()> { + let url = "file::memory:?cache=shared"; + let flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_URI; + let db1 = Connection::open_with_flags(url, flags)?; + db1.execute_batch("CREATE TABLE foo (x)")?; + let (rx, tx) = sync_channel(0); + let child = thread::spawn(move || { + let mut db2 = Connection::open_with_flags(url, flags).unwrap(); + let tx2 = Transaction::new(&mut db2, TransactionBehavior::Immediate).unwrap(); + tx2.execute_batch("INSERT INTO foo VALUES (42)").unwrap(); + rx.send(1).unwrap(); + let ten_millis = time::Duration::from_millis(10); + thread::sleep(ten_millis); + tx2.commit().unwrap(); + }); + assert_eq!(tx.recv().unwrap(), 1); + let the_answer: i64 = db1.one_column("SELECT x FROM foo")?; + assert_eq!(42i64, the_answer); + child.join().unwrap(); + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/util/mod.rs b/third_party/rust/rusqlite/src/util/mod.rs new file mode 100644 index 0000000000..a759cb98fc --- /dev/null +++ b/third_party/rust/rusqlite/src/util/mod.rs @@ -0,0 +1,9 @@ +// Internal utilities +pub(crate) mod param_cache; +mod small_cstr; +pub(crate) use param_cache::ParamIndexCache; +pub(crate) use small_cstr::SmallCString; + +// Doesn't use any modern features or vtab stuff, but is only used by them. +mod sqlite_string; +pub(crate) use sqlite_string::{alloc, SqliteMallocString}; diff --git a/third_party/rust/rusqlite/src/util/param_cache.rs b/third_party/rust/rusqlite/src/util/param_cache.rs new file mode 100644 index 0000000000..6faced98af --- /dev/null +++ b/third_party/rust/rusqlite/src/util/param_cache.rs @@ -0,0 +1,60 @@ +use super::SmallCString; +use std::cell::RefCell; +use std::collections::BTreeMap; + +/// Maps parameter names to parameter indices. +#[derive(Default, Clone, Debug)] +// BTreeMap seems to do better here unless we want to pull in a custom hash +// function. +pub(crate) struct ParamIndexCache(RefCell<BTreeMap<SmallCString, usize>>); + +impl ParamIndexCache { + pub fn get_or_insert_with<F>(&self, s: &str, func: F) -> Option<usize> + where + F: FnOnce(&std::ffi::CStr) -> Option<usize>, + { + let mut cache = self.0.borrow_mut(); + // Avoid entry API, needs allocation to test membership. + if let Some(v) = cache.get(s) { + return Some(*v); + } + // If there's an internal nul in the name it couldn't have been a + // parameter, so early return here is ok. + let name = SmallCString::new(s).ok()?; + let val = func(&name)?; + cache.insert(name, val); + Some(val) + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_cache() { + let p = ParamIndexCache::default(); + let v = p.get_or_insert_with("foo", |cstr| { + assert_eq!(cstr.to_str().unwrap(), "foo"); + Some(3) + }); + assert_eq!(v, Some(3)); + let v = p.get_or_insert_with("foo", |_| { + panic!("shouldn't be called this time"); + }); + assert_eq!(v, Some(3)); + let v = p.get_or_insert_with("gar\0bage", |_| { + panic!("shouldn't be called here either"); + }); + assert_eq!(v, None); + let v = p.get_or_insert_with("bar", |cstr| { + assert_eq!(cstr.to_str().unwrap(), "bar"); + None + }); + assert_eq!(v, None); + let v = p.get_or_insert_with("bar", |cstr| { + assert_eq!(cstr.to_str().unwrap(), "bar"); + Some(30) + }); + assert_eq!(v, Some(30)); + } +} diff --git a/third_party/rust/rusqlite/src/util/small_cstr.rs b/third_party/rust/rusqlite/src/util/small_cstr.rs new file mode 100644 index 0000000000..1ec73744e0 --- /dev/null +++ b/third_party/rust/rusqlite/src/util/small_cstr.rs @@ -0,0 +1,170 @@ +use smallvec::{smallvec, SmallVec}; +use std::ffi::{CStr, CString, NulError}; + +/// Similar to `std::ffi::CString`, but avoids heap allocating if the string is +/// small enough. Also guarantees it's input is UTF-8 -- used for cases where we +/// need to pass a NUL-terminated string to SQLite, and we have a `&str`. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct SmallCString(SmallVec<[u8; 16]>); + +impl SmallCString { + #[inline] + pub fn new(s: &str) -> Result<Self, NulError> { + if s.as_bytes().contains(&0_u8) { + return Err(Self::fabricate_nul_error(s)); + } + let mut buf = SmallVec::with_capacity(s.len() + 1); + buf.extend_from_slice(s.as_bytes()); + buf.push(0); + let res = Self(buf); + res.debug_checks(); + Ok(res) + } + + #[inline] + pub fn as_str(&self) -> &str { + self.debug_checks(); + // Constructor takes a &str so this is safe. + unsafe { std::str::from_utf8_unchecked(self.as_bytes_without_nul()) } + } + + /// Get the bytes not including the NUL terminator. E.g. the bytes which + /// make up our `str`: + /// - `SmallCString::new("foo").as_bytes_without_nul() == b"foo"` + /// - `SmallCString::new("foo").as_bytes_with_nul() == b"foo\0"` + #[inline] + pub fn as_bytes_without_nul(&self) -> &[u8] { + self.debug_checks(); + &self.0[..self.len()] + } + + /// Get the bytes behind this str *including* the NUL terminator. This + /// should never return an empty slice. + #[inline] + pub fn as_bytes_with_nul(&self) -> &[u8] { + self.debug_checks(); + &self.0 + } + + #[inline] + #[cfg(debug_assertions)] + fn debug_checks(&self) { + debug_assert_ne!(self.0.len(), 0); + debug_assert_eq!(self.0[self.0.len() - 1], 0); + let strbytes = &self.0[..(self.0.len() - 1)]; + debug_assert!(!strbytes.contains(&0)); + debug_assert!(std::str::from_utf8(strbytes).is_ok()); + } + + #[inline] + #[cfg(not(debug_assertions))] + fn debug_checks(&self) {} + + #[inline] + pub fn len(&self) -> usize { + debug_assert_ne!(self.0.len(), 0); + self.0.len() - 1 + } + + #[inline] + #[allow(unused)] // clippy wants this function. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + #[inline] + pub fn as_cstr(&self) -> &CStr { + let bytes = self.as_bytes_with_nul(); + debug_assert!(CStr::from_bytes_with_nul(bytes).is_ok()); + unsafe { CStr::from_bytes_with_nul_unchecked(bytes) } + } + + #[cold] + fn fabricate_nul_error(b: &str) -> NulError { + CString::new(b).unwrap_err() + } +} + +impl Default for SmallCString { + #[inline] + fn default() -> Self { + Self(smallvec![0]) + } +} + +impl std::fmt::Debug for SmallCString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("SmallCString").field(&self.as_str()).finish() + } +} + +impl std::ops::Deref for SmallCString { + type Target = CStr; + + #[inline] + fn deref(&self) -> &CStr { + self.as_cstr() + } +} + +impl PartialEq<SmallCString> for str { + #[inline] + fn eq(&self, s: &SmallCString) -> bool { + s.as_bytes_without_nul() == self.as_bytes() + } +} + +impl PartialEq<str> for SmallCString { + #[inline] + fn eq(&self, s: &str) -> bool { + self.as_bytes_without_nul() == s.as_bytes() + } +} + +impl std::borrow::Borrow<str> for SmallCString { + #[inline] + fn borrow(&self) -> &str { + self.as_str() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_small_cstring() { + // We don't go through the normal machinery for default, so make sure + // things work. + assert_eq!(SmallCString::default().0, SmallCString::new("").unwrap().0); + assert_eq!(SmallCString::new("foo").unwrap().len(), 3); + assert_eq!( + SmallCString::new("foo").unwrap().as_bytes_with_nul(), + b"foo\0" + ); + assert_eq!( + SmallCString::new("foo").unwrap().as_bytes_without_nul(), + b"foo", + ); + + assert_eq!(SmallCString::new("😀").unwrap().len(), 4); + assert_eq!( + SmallCString::new("😀").unwrap().0.as_slice(), + b"\xf0\x9f\x98\x80\0", + ); + assert_eq!( + SmallCString::new("😀").unwrap().as_bytes_without_nul(), + b"\xf0\x9f\x98\x80", + ); + + assert_eq!(SmallCString::new("").unwrap().len(), 0); + assert!(SmallCString::new("").unwrap().is_empty()); + + assert_eq!(SmallCString::new("").unwrap().0.as_slice(), b"\0"); + assert_eq!(SmallCString::new("").unwrap().as_bytes_without_nul(), b""); + + SmallCString::new("\0").unwrap_err(); + SmallCString::new("\0abc").unwrap_err(); + SmallCString::new("abc\0").unwrap_err(); + } +} diff --git a/third_party/rust/rusqlite/src/util/sqlite_string.rs b/third_party/rust/rusqlite/src/util/sqlite_string.rs new file mode 100644 index 0000000000..1d69552c9d --- /dev/null +++ b/third_party/rust/rusqlite/src/util/sqlite_string.rs @@ -0,0 +1,239 @@ +// This is used when either vtab or modern-sqlite is on. Different methods are +// used in each feature. Avoid having to track this for each function. We will +// still warn for anything that's not used by either, though. +#![cfg_attr(not(feature = "vtab"), allow(dead_code))] +use crate::ffi; +use std::marker::PhantomData; +use std::os::raw::{c_char, c_int}; +use std::ptr::NonNull; + +// Space to hold this string must be obtained +// from an SQLite memory allocation function +pub(crate) fn alloc(s: &str) -> *mut c_char { + SqliteMallocString::from_str(s).into_raw() +} + +/// A string we own that's allocated on the SQLite heap. Automatically calls +/// `sqlite3_free` when dropped, unless `into_raw` (or `into_inner`) is called +/// on it. If constructed from a rust string, `sqlite3_malloc` is used. +/// +/// It has identical representation to a nonnull `*mut c_char`, so you can use +/// it transparently as one. It's nonnull, so Option<SqliteMallocString> can be +/// used for nullable ones (it's still just one pointer). +/// +/// Most strings shouldn't use this! Only places where the string needs to be +/// freed with `sqlite3_free`. This includes `sqlite3_extended_sql` results, +/// some error message pointers... Note that misuse is extremely dangerous! +/// +/// Note that this is *not* a lossless interface. Incoming strings with internal +/// NULs are modified, and outgoing strings which are non-UTF8 are modified. +/// This seems unavoidable -- it tries very hard to not panic. +#[repr(transparent)] +pub(crate) struct SqliteMallocString { + ptr: NonNull<c_char>, + _boo: PhantomData<Box<[c_char]>>, +} +// This is owned data for a primitive type, and thus it's safe to implement +// these. That said, nothing needs them, and they make things easier to misuse. + +// unsafe impl Send for SqliteMallocString {} +// unsafe impl Sync for SqliteMallocString {} + +impl SqliteMallocString { + /// SAFETY: Caller must be certain that `m` a nul-terminated c string + /// allocated by `sqlite3_malloc`, and that SQLite expects us to free it! + #[inline] + pub(crate) unsafe fn from_raw_nonnull(ptr: NonNull<c_char>) -> Self { + Self { + ptr, + _boo: PhantomData, + } + } + + /// SAFETY: Caller must be certain that `m` a nul-terminated c string + /// allocated by `sqlite3_malloc`, and that SQLite expects us to free it! + #[inline] + pub(crate) unsafe fn from_raw(ptr: *mut c_char) -> Option<Self> { + NonNull::new(ptr).map(|p| Self::from_raw_nonnull(p)) + } + + /// Get the pointer behind `self`. After this is called, we no longer manage + /// it. + #[inline] + pub(crate) fn into_inner(self) -> NonNull<c_char> { + let p = self.ptr; + std::mem::forget(self); + p + } + + /// Get the pointer behind `self`. After this is called, we no longer manage + /// it. + #[inline] + pub(crate) fn into_raw(self) -> *mut c_char { + self.into_inner().as_ptr() + } + + /// Borrow the pointer behind `self`. We still manage it when this function + /// returns. If you want to relinquish ownership, use `into_raw`. + #[inline] + pub(crate) fn as_ptr(&self) -> *const c_char { + self.ptr.as_ptr() + } + + #[inline] + pub(crate) fn as_cstr(&self) -> &std::ffi::CStr { + unsafe { std::ffi::CStr::from_ptr(self.as_ptr()) } + } + + #[inline] + pub(crate) fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { + self.as_cstr().to_string_lossy() + } + + /// Convert `s` into a SQLite string. + /// + /// This should almost never be done except for cases like error messages or + /// other strings that SQLite frees. + /// + /// If `s` contains internal NULs, we'll replace them with + /// `NUL_REPLACE_CHAR`. + /// + /// Except for `debug_assert`s which may trigger during testing, this + /// function never panics. If we hit integer overflow or the allocation + /// fails, we call `handle_alloc_error` which aborts the program after + /// calling a global hook. + /// + /// This means it's safe to use in extern "C" functions even outside of + /// `catch_unwind`. + pub(crate) fn from_str(s: &str) -> Self { + let s = if s.as_bytes().contains(&0) { + std::borrow::Cow::Owned(make_nonnull(s)) + } else { + std::borrow::Cow::Borrowed(s) + }; + debug_assert!(!s.as_bytes().contains(&0)); + let bytes: &[u8] = s.as_ref().as_bytes(); + let src_ptr: *const c_char = bytes.as_ptr().cast(); + let src_len = bytes.len(); + let maybe_len_plus_1 = s.len().checked_add(1).and_then(|v| c_int::try_from(v).ok()); + unsafe { + let res_ptr = maybe_len_plus_1 + .and_then(|len_to_alloc| { + // `>` because we added 1. + debug_assert!(len_to_alloc > 0); + debug_assert_eq!((len_to_alloc - 1) as usize, src_len); + NonNull::new(ffi::sqlite3_malloc(len_to_alloc).cast::<c_char>()) + }) + .unwrap_or_else(|| { + use std::alloc::{handle_alloc_error, Layout}; + // Report via handle_alloc_error so that it can be handled with any + // other allocation errors and properly diagnosed. + // + // This is safe: + // - `align` is never 0 + // - `align` is always a power of 2. + // - `size` needs no realignment because it's guaranteed to be aligned + // (everything is aligned to 1) + // - `size` is also never zero, although this function doesn't actually require + // it now. + let len = s.len().saturating_add(1).min(isize::MAX as usize); + let layout = Layout::from_size_align_unchecked(len, 1); + // Note: This call does not return. + handle_alloc_error(layout); + }); + let buf: *mut c_char = res_ptr.as_ptr().cast::<c_char>(); + src_ptr.copy_to_nonoverlapping(buf, src_len); + buf.add(src_len).write(0); + debug_assert_eq!(std::ffi::CStr::from_ptr(res_ptr.as_ptr()).to_bytes(), bytes); + Self::from_raw_nonnull(res_ptr) + } + } +} + +const NUL_REPLACE: &str = "␀"; + +#[cold] +fn make_nonnull(v: &str) -> String { + v.replace('\0', NUL_REPLACE) +} + +impl Drop for SqliteMallocString { + #[inline] + fn drop(&mut self) { + unsafe { ffi::sqlite3_free(self.ptr.as_ptr().cast()) }; + } +} + +impl std::fmt::Debug for SqliteMallocString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.to_string_lossy().fmt(f) + } +} + +impl std::fmt::Display for SqliteMallocString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.to_string_lossy().fmt(f) + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_from_str() { + let to_check = [ + ("", ""), + ("\0", "␀"), + ("␀", "␀"), + ("\0bar", "␀bar"), + ("foo\0bar", "foo␀bar"), + ("foo\0", "foo␀"), + ("a\0b\0c\0\0d", "a␀b␀c␀␀d"), + ("foobar0123", "foobar0123"), + ]; + + for &(input, output) in &to_check { + let s = SqliteMallocString::from_str(input); + assert_eq!(s.to_string_lossy(), output); + assert_eq!(s.as_cstr().to_str().unwrap(), output); + } + } + + // This will trigger an asan error if into_raw still freed the ptr. + #[test] + fn test_lossy() { + let p = SqliteMallocString::from_str("abcd").into_raw(); + // Make invalid + let s = unsafe { + p.cast::<u8>().write(b'\xff'); + SqliteMallocString::from_raw(p).unwrap() + }; + assert_eq!(s.to_string_lossy().as_ref(), "\u{FFFD}bcd"); + } + + // This will trigger an asan error if into_raw still freed the ptr. + #[test] + fn test_into_raw() { + let mut v = vec![]; + for i in 0..1000 { + v.push(SqliteMallocString::from_str(&i.to_string()).into_raw()); + v.push(SqliteMallocString::from_str(&format!("abc {i} 😀")).into_raw()); + } + unsafe { + for (i, s) in v.chunks_mut(2).enumerate() { + let s0 = std::mem::replace(&mut s[0], std::ptr::null_mut()); + let s1 = std::mem::replace(&mut s[1], std::ptr::null_mut()); + assert_eq!( + std::ffi::CStr::from_ptr(s0).to_str().unwrap(), + &i.to_string() + ); + assert_eq!( + std::ffi::CStr::from_ptr(s1).to_str().unwrap(), + &format!("abc {i} 😀") + ); + let _ = SqliteMallocString::from_raw(s0).unwrap(); + let _ = SqliteMallocString::from_raw(s1).unwrap(); + } + } + } +} diff --git a/third_party/rust/rusqlite/src/version.rs b/third_party/rust/rusqlite/src/version.rs new file mode 100644 index 0000000000..44053b741d --- /dev/null +++ b/third_party/rust/rusqlite/src/version.rs @@ -0,0 +1,27 @@ +use crate::ffi; +use std::ffi::CStr; + +/// Returns the SQLite version as an integer; e.g., `3016002` for version +/// 3.16.2. +/// +/// See [`sqlite3_libversion_number()`](https://www.sqlite.org/c3ref/libversion.html). +#[inline] +#[must_use] +pub fn version_number() -> i32 { + unsafe { ffi::sqlite3_libversion_number() } +} + +/// Returns the SQLite version as a string; e.g., `"3.16.2"` for version 3.16.2. +/// +/// See [`sqlite3_libversion()`](https://www.sqlite.org/c3ref/libversion.html). +/// +/// # Panics +/// +/// Panics when version is not valid UTF-8. +#[inline] +#[must_use] +pub fn version() -> &'static str { + let cstr = unsafe { CStr::from_ptr(ffi::sqlite3_libversion()) }; + cstr.to_str() + .expect("SQLite version string is not valid UTF8 ?!") +} diff --git a/third_party/rust/rusqlite/src/vtab/array.rs b/third_party/rust/rusqlite/src/vtab/array.rs new file mode 100644 index 0000000000..be19a3e127 --- /dev/null +++ b/third_party/rust/rusqlite/src/vtab/array.rs @@ -0,0 +1,223 @@ +//! Array Virtual Table. +//! +//! Note: `rarray`, not `carray` is the name of the table valued function we +//! define. +//! +//! Port of [carray](http://www.sqlite.org/cgi/src/finfo?name=ext/misc/carray.c) +//! C extension: `https://www.sqlite.org/carray.html` +//! +//! # Example +//! +//! ```rust,no_run +//! # use rusqlite::{types::Value, Connection, Result, params}; +//! # use std::rc::Rc; +//! fn example(db: &Connection) -> Result<()> { +//! // Note: This should be done once (usually when opening the DB). +//! rusqlite::vtab::array::load_module(&db)?; +//! let v = [1i64, 2, 3, 4]; +//! // Note: A `Rc<Vec<Value>>` must be used as the parameter. +//! let values = Rc::new(v.iter().copied().map(Value::from).collect::<Vec<Value>>()); +//! let mut stmt = db.prepare("SELECT value from rarray(?1);")?; +//! let rows = stmt.query_map([values], |row| row.get::<_, i64>(0))?; +//! for value in rows { +//! println!("{}", value?); +//! } +//! Ok(()) +//! } +//! ``` + +use std::default::Default; +use std::marker::PhantomData; +use std::os::raw::{c_char, c_int, c_void}; +use std::rc::Rc; + +use crate::ffi; +use crate::types::{ToSql, ToSqlOutput, Value}; +use crate::vtab::{ + eponymous_only_module, Context, IndexConstraintOp, IndexInfo, VTab, VTabConnection, VTabCursor, + Values, +}; +use crate::{Connection, Result}; + +// http://sqlite.org/bindptr.html + +pub(crate) const ARRAY_TYPE: *const c_char = (b"rarray\0" as *const u8).cast::<c_char>(); + +pub(crate) unsafe extern "C" fn free_array(p: *mut c_void) { + drop(Rc::from_raw(p as *const Vec<Value>)); +} + +/// Array parameter / pointer +pub type Array = Rc<Vec<Value>>; + +impl ToSql for Array { + #[inline] + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::Array(self.clone())) + } +} + +/// Register the "rarray" module. +pub fn load_module(conn: &Connection) -> Result<()> { + let aux: Option<()> = None; + conn.create_module("rarray", eponymous_only_module::<ArrayTab>(), aux) +} + +// Column numbers +// const CARRAY_COLUMN_VALUE : c_int = 0; +const CARRAY_COLUMN_POINTER: c_int = 1; + +/// An instance of the Array virtual table +#[repr(C)] +struct ArrayTab { + /// Base class. Must be first + base: ffi::sqlite3_vtab, +} + +unsafe impl<'vtab> VTab<'vtab> for ArrayTab { + type Aux = (); + type Cursor = ArrayTabCursor<'vtab>; + + fn connect( + _: &mut VTabConnection, + _aux: Option<&()>, + _args: &[&[u8]], + ) -> Result<(String, ArrayTab)> { + let vtab = ArrayTab { + base: ffi::sqlite3_vtab::default(), + }; + Ok(("CREATE TABLE x(value,pointer hidden)".to_owned(), vtab)) + } + + fn best_index(&self, info: &mut IndexInfo) -> Result<()> { + // Index of the pointer= constraint + let mut ptr_idx = false; + for (constraint, mut constraint_usage) in info.constraints_and_usages() { + if !constraint.is_usable() { + continue; + } + if constraint.operator() != IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_EQ { + continue; + } + if let CARRAY_COLUMN_POINTER = constraint.column() { + ptr_idx = true; + constraint_usage.set_argv_index(1); + constraint_usage.set_omit(true); + } + } + if ptr_idx { + info.set_estimated_cost(1_f64); + info.set_estimated_rows(100); + info.set_idx_num(1); + } else { + info.set_estimated_cost(2_147_483_647_f64); + info.set_estimated_rows(2_147_483_647); + info.set_idx_num(0); + } + Ok(()) + } + + fn open(&mut self) -> Result<ArrayTabCursor<'_>> { + Ok(ArrayTabCursor::new()) + } +} + +/// A cursor for the Array virtual table +#[repr(C)] +struct ArrayTabCursor<'vtab> { + /// Base class. Must be first + base: ffi::sqlite3_vtab_cursor, + /// The rowid + row_id: i64, + /// Pointer to the array of values ("pointer") + ptr: Option<Array>, + phantom: PhantomData<&'vtab ArrayTab>, +} + +impl ArrayTabCursor<'_> { + fn new<'vtab>() -> ArrayTabCursor<'vtab> { + ArrayTabCursor { + base: ffi::sqlite3_vtab_cursor::default(), + row_id: 0, + ptr: None, + phantom: PhantomData, + } + } + + fn len(&self) -> i64 { + match self.ptr { + Some(ref a) => a.len() as i64, + _ => 0, + } + } +} +unsafe impl VTabCursor for ArrayTabCursor<'_> { + fn filter(&mut self, idx_num: c_int, _idx_str: Option<&str>, args: &Values<'_>) -> Result<()> { + if idx_num > 0 { + self.ptr = args.get_array(0); + } else { + self.ptr = None; + } + self.row_id = 1; + Ok(()) + } + + fn next(&mut self) -> Result<()> { + self.row_id += 1; + Ok(()) + } + + fn eof(&self) -> bool { + self.row_id > self.len() + } + + fn column(&self, ctx: &mut Context, i: c_int) -> Result<()> { + match i { + CARRAY_COLUMN_POINTER => Ok(()), + _ => { + if let Some(ref array) = self.ptr { + let value = &array[(self.row_id - 1) as usize]; + ctx.set_result(&value) + } else { + Ok(()) + } + } + } + } + + fn rowid(&self) -> Result<i64> { + Ok(self.row_id) + } +} + +#[cfg(test)] +mod test { + use crate::types::Value; + use crate::vtab::array; + use crate::{Connection, Result}; + use std::rc::Rc; + + #[test] + fn test_array_module() -> Result<()> { + let db = Connection::open_in_memory()?; + array::load_module(&db)?; + + let v = vec![1i64, 2, 3, 4]; + let values: Vec<Value> = v.into_iter().map(Value::from).collect(); + let ptr = Rc::new(values); + { + let mut stmt = db.prepare("SELECT value from rarray(?1);")?; + + let rows = stmt.query_map([&ptr], |row| row.get::<_, i64>(0))?; + assert_eq!(2, Rc::strong_count(&ptr)); + let mut count = 0; + for (i, value) in rows.enumerate() { + assert_eq!(i as i64, value? - 1); + count += 1; + } + assert_eq!(4, count); + } + assert_eq!(1, Rc::strong_count(&ptr)); + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/vtab/csvtab.rs b/third_party/rust/rusqlite/src/vtab/csvtab.rs new file mode 100644 index 0000000000..acf3cd8180 --- /dev/null +++ b/third_party/rust/rusqlite/src/vtab/csvtab.rs @@ -0,0 +1,387 @@ +//! CSV Virtual Table. +//! +//! Port of [csv](http://www.sqlite.org/cgi/src/finfo?name=ext/misc/csv.c) C +//! extension: `https://www.sqlite.org/csv.html` +//! +//! # Example +//! +//! ```rust,no_run +//! # use rusqlite::{Connection, Result}; +//! fn example() -> Result<()> { +//! // Note: This should be done once (usually when opening the DB). +//! let db = Connection::open_in_memory()?; +//! rusqlite::vtab::csvtab::load_module(&db)?; +//! // Assume my_csv.csv +//! let schema = " +//! CREATE VIRTUAL TABLE my_csv_data +//! USING csv(filename = 'my_csv.csv') +//! "; +//! db.execute_batch(schema)?; +//! // Now the `my_csv_data` (virtual) table can be queried as normal... +//! Ok(()) +//! } +//! ``` +use std::fs::File; +use std::marker::PhantomData; +use std::os::raw::c_int; +use std::path::Path; +use std::str; + +use crate::ffi; +use crate::types::Null; +use crate::vtab::{ + escape_double_quote, parse_boolean, read_only_module, Context, CreateVTab, IndexInfo, VTab, + VTabConfig, VTabConnection, VTabCursor, VTabKind, Values, +}; +use crate::{Connection, Error, Result}; + +/// Register the "csv" module. +/// ```sql +/// CREATE VIRTUAL TABLE vtab USING csv( +/// filename=FILENAME -- Name of file containing CSV content +/// [, schema=SCHEMA] -- Alternative CSV schema. 'CREATE TABLE x(col1 TEXT NOT NULL, col2 INT, ...);' +/// [, header=YES|NO] -- First row of CSV defines the names of columns if "yes". Default "no". +/// [, columns=N] -- Assume the CSV file contains N columns. +/// [, delimiter=C] -- CSV delimiter. Default ','. +/// [, quote=C] -- CSV quote. Default '"'. 0 means no quote. +/// ); +/// ``` +pub fn load_module(conn: &Connection) -> Result<()> { + let aux: Option<()> = None; + conn.create_module("csv", read_only_module::<CsvTab>(), aux) +} + +/// An instance of the CSV virtual table +#[repr(C)] +struct CsvTab { + /// Base class. Must be first + base: ffi::sqlite3_vtab, + /// Name of the CSV file + filename: String, + has_headers: bool, + delimiter: u8, + quote: u8, + /// Offset to start of data + offset_first_row: csv::Position, +} + +impl CsvTab { + fn reader(&self) -> Result<csv::Reader<File>, csv::Error> { + csv::ReaderBuilder::new() + .has_headers(self.has_headers) + .delimiter(self.delimiter) + .quote(self.quote) + .from_path(&self.filename) + } + + fn parse_byte(arg: &str) -> Option<u8> { + if arg.len() == 1 { + arg.bytes().next() + } else { + None + } + } +} + +unsafe impl<'vtab> VTab<'vtab> for CsvTab { + type Aux = (); + type Cursor = CsvTabCursor<'vtab>; + + fn connect( + db: &mut VTabConnection, + _aux: Option<&()>, + args: &[&[u8]], + ) -> Result<(String, CsvTab)> { + if args.len() < 4 { + return Err(Error::ModuleError("no CSV file specified".to_owned())); + } + + let mut vtab = CsvTab { + base: ffi::sqlite3_vtab::default(), + filename: "".to_owned(), + has_headers: false, + delimiter: b',', + quote: b'"', + offset_first_row: csv::Position::new(), + }; + let mut schema = None; + let mut n_col = None; + + let args = &args[3..]; + for c_slice in args { + let (param, value) = super::parameter(c_slice)?; + match param { + "filename" => { + if !Path::new(value).exists() { + return Err(Error::ModuleError(format!("file '{value}' does not exist"))); + } + vtab.filename = value.to_owned(); + } + "schema" => { + schema = Some(value.to_owned()); + } + "columns" => { + if let Ok(n) = value.parse::<u16>() { + if n_col.is_some() { + return Err(Error::ModuleError( + "more than one 'columns' parameter".to_owned(), + )); + } else if n == 0 { + return Err(Error::ModuleError( + "must have at least one column".to_owned(), + )); + } + n_col = Some(n); + } else { + return Err(Error::ModuleError(format!( + "unrecognized argument to 'columns': {value}" + ))); + } + } + "header" => { + if let Some(b) = parse_boolean(value) { + vtab.has_headers = b; + } else { + return Err(Error::ModuleError(format!( + "unrecognized argument to 'header': {value}" + ))); + } + } + "delimiter" => { + if let Some(b) = CsvTab::parse_byte(value) { + vtab.delimiter = b; + } else { + return Err(Error::ModuleError(format!( + "unrecognized argument to 'delimiter': {value}" + ))); + } + } + "quote" => { + if let Some(b) = CsvTab::parse_byte(value) { + if b == b'0' { + vtab.quote = 0; + } else { + vtab.quote = b; + } + } else { + return Err(Error::ModuleError(format!( + "unrecognized argument to 'quote': {value}" + ))); + } + } + _ => { + return Err(Error::ModuleError(format!( + "unrecognized parameter '{param}'" + ))); + } + } + } + + if vtab.filename.is_empty() { + return Err(Error::ModuleError("no CSV file specified".to_owned())); + } + + let mut cols: Vec<String> = Vec::new(); + if vtab.has_headers || (n_col.is_none() && schema.is_none()) { + let mut reader = vtab.reader()?; + if vtab.has_headers { + { + let headers = reader.headers()?; + // headers ignored if cols is not empty + if n_col.is_none() && schema.is_none() { + cols = headers + .into_iter() + .map(|header| escape_double_quote(header).into_owned()) + .collect(); + } + } + vtab.offset_first_row = reader.position().clone(); + } else { + let mut record = csv::ByteRecord::new(); + if reader.read_byte_record(&mut record)? { + for (i, _) in record.iter().enumerate() { + cols.push(format!("c{i}")); + } + } + } + } else if let Some(n_col) = n_col { + for i in 0..n_col { + cols.push(format!("c{i}")); + } + } + + if cols.is_empty() && schema.is_none() { + return Err(Error::ModuleError("no column specified".to_owned())); + } + + if schema.is_none() { + let mut sql = String::from("CREATE TABLE x("); + for (i, col) in cols.iter().enumerate() { + sql.push('"'); + sql.push_str(col); + sql.push_str("\" TEXT"); + if i == cols.len() - 1 { + sql.push_str(");"); + } else { + sql.push_str(", "); + } + } + schema = Some(sql); + } + db.config(VTabConfig::DirectOnly)?; + Ok((schema.unwrap(), vtab)) + } + + // Only a forward full table scan is supported. + fn best_index(&self, info: &mut IndexInfo) -> Result<()> { + info.set_estimated_cost(1_000_000.); + Ok(()) + } + + fn open(&mut self) -> Result<CsvTabCursor<'_>> { + Ok(CsvTabCursor::new(self.reader()?)) + } +} + +impl CreateVTab<'_> for CsvTab { + const KIND: VTabKind = VTabKind::Default; +} + +/// A cursor for the CSV virtual table +#[repr(C)] +struct CsvTabCursor<'vtab> { + /// Base class. Must be first + base: ffi::sqlite3_vtab_cursor, + /// The CSV reader object + reader: csv::Reader<File>, + /// Current cursor position used as rowid + row_number: usize, + /// Values of the current row + cols: csv::StringRecord, + eof: bool, + phantom: PhantomData<&'vtab CsvTab>, +} + +impl CsvTabCursor<'_> { + fn new<'vtab>(reader: csv::Reader<File>) -> CsvTabCursor<'vtab> { + CsvTabCursor { + base: ffi::sqlite3_vtab_cursor::default(), + reader, + row_number: 0, + cols: csv::StringRecord::new(), + eof: false, + phantom: PhantomData, + } + } + + /// Accessor to the associated virtual table. + fn vtab(&self) -> &CsvTab { + unsafe { &*(self.base.pVtab as *const CsvTab) } + } +} + +unsafe impl VTabCursor for CsvTabCursor<'_> { + // Only a full table scan is supported. So `filter` simply rewinds to + // the beginning. + fn filter( + &mut self, + _idx_num: c_int, + _idx_str: Option<&str>, + _args: &Values<'_>, + ) -> Result<()> { + { + let offset_first_row = self.vtab().offset_first_row.clone(); + self.reader.seek(offset_first_row)?; + } + self.row_number = 0; + self.next() + } + + fn next(&mut self) -> Result<()> { + { + self.eof = self.reader.is_done(); + if self.eof { + return Ok(()); + } + + self.eof = !self.reader.read_record(&mut self.cols)?; + } + + self.row_number += 1; + Ok(()) + } + + fn eof(&self) -> bool { + self.eof + } + + fn column(&self, ctx: &mut Context, col: c_int) -> Result<()> { + if col < 0 || col as usize >= self.cols.len() { + return Err(Error::ModuleError(format!( + "column index out of bounds: {col}" + ))); + } + if self.cols.is_empty() { + return ctx.set_result(&Null); + } + // TODO Affinity + ctx.set_result(&self.cols[col as usize].to_owned()) + } + + fn rowid(&self) -> Result<i64> { + Ok(self.row_number as i64) + } +} + +impl From<csv::Error> for Error { + #[cold] + fn from(err: csv::Error) -> Error { + Error::ModuleError(err.to_string()) + } +} + +#[cfg(test)] +mod test { + use crate::vtab::csvtab; + use crate::{Connection, Result}; + use fallible_iterator::FallibleIterator; + + #[test] + fn test_csv_module() -> Result<()> { + let db = Connection::open_in_memory()?; + csvtab::load_module(&db)?; + db.execute_batch("CREATE VIRTUAL TABLE vtab USING csv(filename='test.csv', header=yes)")?; + + { + let mut s = db.prepare("SELECT rowid, * FROM vtab")?; + { + let headers = s.column_names(); + assert_eq!(vec!["rowid", "colA", "colB", "colC"], headers); + } + + let ids: Result<Vec<i32>> = s.query([])?.map(|row| row.get::<_, i32>(0)).collect(); + let sum = ids?.iter().sum::<i32>(); + assert_eq!(sum, 15); + } + db.execute_batch("DROP TABLE vtab") + } + + #[test] + fn test_csv_cursor() -> Result<()> { + let db = Connection::open_in_memory()?; + csvtab::load_module(&db)?; + db.execute_batch("CREATE VIRTUAL TABLE vtab USING csv(filename='test.csv', header=yes)")?; + + { + let mut s = db.prepare( + "SELECT v1.rowid, v1.* FROM vtab v1 NATURAL JOIN vtab v2 WHERE \ + v1.rowid < v2.rowid", + )?; + + let mut rows = s.query([])?; + let row = rows.next()?.unwrap(); + assert_eq!(row.get_unwrap::<_, i32>(0), 2); + } + db.execute_batch("DROP TABLE vtab") + } +} diff --git a/third_party/rust/rusqlite/src/vtab/mod.rs b/third_party/rust/rusqlite/src/vtab/mod.rs new file mode 100644 index 0000000000..7e2f5f54f2 --- /dev/null +++ b/third_party/rust/rusqlite/src/vtab/mod.rs @@ -0,0 +1,1333 @@ +//! Create virtual tables. +//! +//! Follow these steps to create your own virtual table: +//! 1. Write implementation of [`VTab`] and [`VTabCursor`] traits. +//! 2. Create an instance of the [`Module`] structure specialized for [`VTab`] +//! impl. from step 1. +//! 3. Register your [`Module`] structure using [`Connection::create_module`]. +//! 4. Run a `CREATE VIRTUAL TABLE` command that specifies the new module in the +//! `USING` clause. +//! +//! (See [SQLite doc](http://sqlite.org/vtab.html)) +use std::borrow::Cow::{self, Borrowed, Owned}; +use std::marker::PhantomData; +use std::marker::Sync; +use std::os::raw::{c_char, c_int, c_void}; +use std::ptr; +use std::slice; + +use crate::context::set_result; +use crate::error::{error_from_sqlite_code, to_sqlite_error}; +use crate::ffi; +pub use crate::ffi::{sqlite3_vtab, sqlite3_vtab_cursor}; +use crate::types::{FromSql, FromSqlError, ToSql, ValueRef}; +use crate::util::alloc; +use crate::{str_to_cstring, Connection, Error, InnerConnection, Result}; + +// let conn: Connection = ...; +// let mod: Module = ...; // VTab builder +// conn.create_module("module", mod); +// +// conn.execute("CREATE VIRTUAL TABLE foo USING module(...)"); +// \-> Module::xcreate +// |-> let vtab: VTab = ...; // on the heap +// \-> conn.declare_vtab("CREATE TABLE foo (...)"); +// conn = Connection::open(...); +// \-> Module::xconnect +// |-> let vtab: VTab = ...; // on the heap +// \-> conn.declare_vtab("CREATE TABLE foo (...)"); +// +// conn.close(); +// \-> vtab.xdisconnect +// conn.execute("DROP TABLE foo"); +// \-> vtab.xDestroy +// +// let stmt = conn.prepare("SELECT ... FROM foo WHERE ..."); +// \-> vtab.xbestindex +// stmt.query().next(); +// \-> vtab.xopen +// |-> let cursor: VTabCursor = ...; // on the heap +// |-> cursor.xfilter or xnext +// |-> cursor.xeof +// \-> if not eof { cursor.column or xrowid } else { cursor.xclose } +// + +// db: *mut ffi::sqlite3 => VTabConnection +// module: *const ffi::sqlite3_module => Module +// aux: *mut c_void => Module::Aux +// ffi::sqlite3_vtab => VTab +// ffi::sqlite3_vtab_cursor => VTabCursor + +/// Virtual table kind +pub enum VTabKind { + /// Non-eponymous + Default, + /// [`create`](CreateVTab::create) == [`connect`](VTab::connect) + /// + /// See [SQLite doc](https://sqlite.org/vtab.html#eponymous_virtual_tables) + Eponymous, + /// No [`create`](CreateVTab::create) / [`destroy`](CreateVTab::destroy) or + /// not used + /// + /// SQLite >= 3.9.0 + /// + /// See [SQLite doc](https://sqlite.org/vtab.html#eponymous_only_virtual_tables) + EponymousOnly, +} + +/// Virtual table module +/// +/// (See [SQLite doc](https://sqlite.org/c3ref/module.html)) +#[repr(transparent)] +pub struct Module<'vtab, T: VTab<'vtab>> { + base: ffi::sqlite3_module, + phantom: PhantomData<&'vtab T>, +} + +unsafe impl<'vtab, T: VTab<'vtab>> Send for Module<'vtab, T> {} +unsafe impl<'vtab, T: VTab<'vtab>> Sync for Module<'vtab, T> {} + +union ModuleZeroHack { + bytes: [u8; std::mem::size_of::<ffi::sqlite3_module>()], + module: ffi::sqlite3_module, +} + +// Used as a trailing initializer for sqlite3_module -- this way we avoid having +// the build fail if buildtime_bindgen is on. This is safe, as bindgen-generated +// structs are allowed to be zeroed. +const ZERO_MODULE: ffi::sqlite3_module = unsafe { + ModuleZeroHack { + bytes: [0_u8; std::mem::size_of::<ffi::sqlite3_module>()], + } + .module +}; + +macro_rules! module { + ($lt:lifetime, $vt:ty, $ct:ty, $xc:expr, $xd:expr, $xu:expr) => { + #[allow(clippy::needless_update)] + &Module { + base: ffi::sqlite3_module { + // We don't use V3 + iVersion: 2, + xCreate: $xc, + xConnect: Some(rust_connect::<$vt>), + xBestIndex: Some(rust_best_index::<$vt>), + xDisconnect: Some(rust_disconnect::<$vt>), + xDestroy: $xd, + xOpen: Some(rust_open::<$vt>), + xClose: Some(rust_close::<$ct>), + xFilter: Some(rust_filter::<$ct>), + xNext: Some(rust_next::<$ct>), + xEof: Some(rust_eof::<$ct>), + xColumn: Some(rust_column::<$ct>), + xRowid: Some(rust_rowid::<$ct>), // FIXME optional + xUpdate: $xu, + xBegin: None, + xSync: None, + xCommit: None, + xRollback: None, + xFindFunction: None, + xRename: None, + xSavepoint: None, + xRelease: None, + xRollbackTo: None, + ..ZERO_MODULE + }, + phantom: PhantomData::<&$lt $vt>, + } + }; +} + +/// Create an modifiable virtual table implementation. +/// +/// Step 2 of [Creating New Virtual Table Implementations](https://sqlite.org/vtab.html#creating_new_virtual_table_implementations). +#[must_use] +pub fn update_module<'vtab, T: UpdateVTab<'vtab>>() -> &'static Module<'vtab, T> { + match T::KIND { + VTabKind::EponymousOnly => { + module!('vtab, T, T::Cursor, None, None, Some(rust_update::<T>)) + } + VTabKind::Eponymous => { + module!('vtab, T, T::Cursor, Some(rust_connect::<T>), Some(rust_disconnect::<T>), Some(rust_update::<T>)) + } + _ => { + module!('vtab, T, T::Cursor, Some(rust_create::<T>), Some(rust_destroy::<T>), Some(rust_update::<T>)) + } + } +} + +/// Create a read-only virtual table implementation. +/// +/// Step 2 of [Creating New Virtual Table Implementations](https://sqlite.org/vtab.html#creating_new_virtual_table_implementations). +#[must_use] +pub fn read_only_module<'vtab, T: CreateVTab<'vtab>>() -> &'static Module<'vtab, T> { + match T::KIND { + VTabKind::EponymousOnly => eponymous_only_module(), + VTabKind::Eponymous => { + // A virtual table is eponymous if its xCreate method is the exact same function + // as the xConnect method + module!('vtab, T, T::Cursor, Some(rust_connect::<T>), Some(rust_disconnect::<T>), None) + } + _ => { + // The xConnect and xCreate methods may do the same thing, but they must be + // different so that the virtual table is not an eponymous virtual table. + module!('vtab, T, T::Cursor, Some(rust_create::<T>), Some(rust_destroy::<T>), None) + } + } +} + +/// Create an eponymous only virtual table implementation. +/// +/// Step 2 of [Creating New Virtual Table Implementations](https://sqlite.org/vtab.html#creating_new_virtual_table_implementations). +#[must_use] +pub fn eponymous_only_module<'vtab, T: VTab<'vtab>>() -> &'static Module<'vtab, T> { + // For eponymous-only virtual tables, the xCreate method is NULL + module!('vtab, T, T::Cursor, None, None, None) +} + +/// Virtual table configuration options +#[repr(i32)] +#[non_exhaustive] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum VTabConfig { + /// Equivalent to SQLITE_VTAB_CONSTRAINT_SUPPORT + ConstraintSupport = 1, + /// Equivalent to SQLITE_VTAB_INNOCUOUS + Innocuous = 2, + /// Equivalent to SQLITE_VTAB_DIRECTONLY + DirectOnly = 3, + /// Equivalent to SQLITE_VTAB_USES_ALL_SCHEMAS + UsesAllSchemas = 4, +} + +/// `feature = "vtab"` +pub struct VTabConnection(*mut ffi::sqlite3); + +impl VTabConnection { + /// Configure various facets of the virtual table interface + pub fn config(&mut self, config: VTabConfig) -> Result<()> { + crate::error::check(unsafe { ffi::sqlite3_vtab_config(self.0, config as c_int) }) + } + + // TODO sqlite3_vtab_on_conflict (http://sqlite.org/c3ref/vtab_on_conflict.html) & xUpdate + + /// Get access to the underlying SQLite database connection handle. + /// + /// # Warning + /// + /// You should not need to use this function. If you do need to, please + /// [open an issue on the rusqlite repository](https://github.com/rusqlite/rusqlite/issues) and describe + /// your use case. + /// + /// # Safety + /// + /// This function is unsafe because it gives you raw access + /// to the SQLite connection, and what you do with it could impact the + /// safety of this `Connection`. + pub unsafe fn handle(&mut self) -> *mut ffi::sqlite3 { + self.0 + } +} + +/// Eponymous-only virtual table instance trait. +/// +/// # Safety +/// +/// The first item in a struct implementing `VTab` must be +/// `rusqlite::sqlite3_vtab`, and the struct must be `#[repr(C)]`. +/// +/// ```rust,ignore +/// #[repr(C)] +/// struct MyTab { +/// /// Base class. Must be first +/// base: rusqlite::vtab::sqlite3_vtab, +/// /* Virtual table implementations will typically add additional fields */ +/// } +/// ``` +/// +/// (See [SQLite doc](https://sqlite.org/c3ref/vtab.html)) +pub unsafe trait VTab<'vtab>: Sized { + /// Client data passed to [`Connection::create_module`]. + type Aux; + /// Specific cursor implementation + type Cursor: VTabCursor; + + /// Establish a new connection to an existing virtual table. + /// + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xconnect_method)) + fn connect( + db: &mut VTabConnection, + aux: Option<&Self::Aux>, + args: &[&[u8]], + ) -> Result<(String, Self)>; + + /// Determine the best way to access the virtual table. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xbestindex_method)) + fn best_index(&self, info: &mut IndexInfo) -> Result<()>; + + /// Create a new cursor used for accessing a virtual table. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xopen_method)) + fn open(&'vtab mut self) -> Result<Self::Cursor>; +} + +/// Read-only virtual table instance trait. +/// +/// (See [SQLite doc](https://sqlite.org/c3ref/vtab.html)) +pub trait CreateVTab<'vtab>: VTab<'vtab> { + /// For [`EponymousOnly`](VTabKind::EponymousOnly), + /// [`create`](CreateVTab::create) and [`destroy`](CreateVTab::destroy) are + /// not called + const KIND: VTabKind; + /// Create a new instance of a virtual table in response to a CREATE VIRTUAL + /// TABLE statement. The `db` parameter is a pointer to the SQLite + /// database connection that is executing the CREATE VIRTUAL TABLE + /// statement. + /// + /// Call [`connect`](VTab::connect) by default. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xcreate_method)) + fn create( + db: &mut VTabConnection, + aux: Option<&Self::Aux>, + args: &[&[u8]], + ) -> Result<(String, Self)> { + Self::connect(db, aux, args) + } + + /// Destroy the underlying table implementation. This method undoes the work + /// of [`create`](CreateVTab::create). + /// + /// Do nothing by default. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xdestroy_method)) + fn destroy(&self) -> Result<()> { + Ok(()) + } +} + +/// Writable virtual table instance trait. +/// +/// (See [SQLite doc](https://sqlite.org/vtab.html#xupdate)) +pub trait UpdateVTab<'vtab>: CreateVTab<'vtab> { + /// Delete rowid or PK + fn delete(&mut self, arg: ValueRef<'_>) -> Result<()>; + /// Insert: `args[0] == NULL: old rowid or PK, args[1]: new rowid or PK, + /// args[2]: ...` + /// + /// Return the new rowid. + // TODO Make the distinction between argv[1] == NULL and argv[1] != NULL ? + fn insert(&mut self, args: &Values<'_>) -> Result<i64>; + /// Update: `args[0] != NULL: old rowid or PK, args[1]: new row id or PK, + /// args[2]: ...` + fn update(&mut self, args: &Values<'_>) -> Result<()>; +} + +/// Index constraint operator. +/// See [Virtual Table Constraint Operator Codes](https://sqlite.org/c3ref/c_index_constraint_eq.html) for details. +#[derive(Debug, Eq, PartialEq)] +#[allow(non_snake_case, non_camel_case_types, missing_docs)] +#[allow(clippy::upper_case_acronyms)] +pub enum IndexConstraintOp { + SQLITE_INDEX_CONSTRAINT_EQ, + SQLITE_INDEX_CONSTRAINT_GT, + SQLITE_INDEX_CONSTRAINT_LE, + SQLITE_INDEX_CONSTRAINT_LT, + SQLITE_INDEX_CONSTRAINT_GE, + SQLITE_INDEX_CONSTRAINT_MATCH, + SQLITE_INDEX_CONSTRAINT_LIKE, // 3.10.0 + SQLITE_INDEX_CONSTRAINT_GLOB, // 3.10.0 + SQLITE_INDEX_CONSTRAINT_REGEXP, // 3.10.0 + SQLITE_INDEX_CONSTRAINT_NE, // 3.21.0 + SQLITE_INDEX_CONSTRAINT_ISNOT, // 3.21.0 + SQLITE_INDEX_CONSTRAINT_ISNOTNULL, // 3.21.0 + SQLITE_INDEX_CONSTRAINT_ISNULL, // 3.21.0 + SQLITE_INDEX_CONSTRAINT_IS, // 3.21.0 + SQLITE_INDEX_CONSTRAINT_LIMIT, // 3.38.0 + SQLITE_INDEX_CONSTRAINT_OFFSET, // 3.38.0 + SQLITE_INDEX_CONSTRAINT_FUNCTION(u8), // 3.25.0 +} + +impl From<u8> for IndexConstraintOp { + fn from(code: u8) -> IndexConstraintOp { + match code { + 2 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_EQ, + 4 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_GT, + 8 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_LE, + 16 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_LT, + 32 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_GE, + 64 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_MATCH, + 65 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_LIKE, + 66 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_GLOB, + 67 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_REGEXP, + 68 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_NE, + 69 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_ISNOT, + 70 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_ISNOTNULL, + 71 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_ISNULL, + 72 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_IS, + 73 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_LIMIT, + 74 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_OFFSET, + v => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_FUNCTION(v), + } + } +} + +bitflags::bitflags! { + /// Virtual table scan flags + /// See [Function Flags](https://sqlite.org/c3ref/c_index_scan_unique.html) for details. + #[repr(C)] + pub struct IndexFlags: ::std::os::raw::c_int { + /// Default + const NONE = 0; + /// Scan visits at most 1 row. + const SQLITE_INDEX_SCAN_UNIQUE = ffi::SQLITE_INDEX_SCAN_UNIQUE; + } +} + +/// Pass information into and receive the reply from the +/// [`VTab::best_index`] method. +/// +/// (See [SQLite doc](http://sqlite.org/c3ref/index_info.html)) +#[derive(Debug)] +pub struct IndexInfo(*mut ffi::sqlite3_index_info); + +impl IndexInfo { + /// Iterate on index constraint and its associated usage. + #[inline] + pub fn constraints_and_usages(&mut self) -> IndexConstraintAndUsageIter<'_> { + let constraints = + unsafe { slice::from_raw_parts((*self.0).aConstraint, (*self.0).nConstraint as usize) }; + let constraint_usages = unsafe { + slice::from_raw_parts_mut((*self.0).aConstraintUsage, (*self.0).nConstraint as usize) + }; + IndexConstraintAndUsageIter { + iter: constraints.iter().zip(constraint_usages.iter_mut()), + } + } + + /// Record WHERE clause constraints. + #[inline] + #[must_use] + pub fn constraints(&self) -> IndexConstraintIter<'_> { + let constraints = + unsafe { slice::from_raw_parts((*self.0).aConstraint, (*self.0).nConstraint as usize) }; + IndexConstraintIter { + iter: constraints.iter(), + } + } + + /// Information about the ORDER BY clause. + #[inline] + #[must_use] + pub fn order_bys(&self) -> OrderByIter<'_> { + let order_bys = + unsafe { slice::from_raw_parts((*self.0).aOrderBy, (*self.0).nOrderBy as usize) }; + OrderByIter { + iter: order_bys.iter(), + } + } + + /// Number of terms in the ORDER BY clause + #[inline] + #[must_use] + pub fn num_of_order_by(&self) -> usize { + unsafe { (*self.0).nOrderBy as usize } + } + + /// Information about what parameters to pass to [`VTabCursor::filter`]. + #[inline] + pub fn constraint_usage(&mut self, constraint_idx: usize) -> IndexConstraintUsage<'_> { + let constraint_usages = unsafe { + slice::from_raw_parts_mut((*self.0).aConstraintUsage, (*self.0).nConstraint as usize) + }; + IndexConstraintUsage(&mut constraint_usages[constraint_idx]) + } + + /// Number used to identify the index + #[inline] + pub fn set_idx_num(&mut self, idx_num: c_int) { + unsafe { + (*self.0).idxNum = idx_num; + } + } + + /// String used to identify the index + pub fn set_idx_str(&mut self, idx_str: &str) { + unsafe { + (*self.0).idxStr = alloc(idx_str); + (*self.0).needToFreeIdxStr = 1; + } + } + + /// True if output is already ordered + #[inline] + pub fn set_order_by_consumed(&mut self, order_by_consumed: bool) { + unsafe { + (*self.0).orderByConsumed = order_by_consumed as c_int; + } + } + + /// Estimated cost of using this index + #[inline] + pub fn set_estimated_cost(&mut self, estimated_ost: f64) { + unsafe { + (*self.0).estimatedCost = estimated_ost; + } + } + + /// Estimated number of rows returned. + #[inline] + pub fn set_estimated_rows(&mut self, estimated_rows: i64) { + unsafe { + (*self.0).estimatedRows = estimated_rows; + } + } + + /// Mask of SQLITE_INDEX_SCAN_* flags. + #[inline] + pub fn set_idx_flags(&mut self, flags: IndexFlags) { + unsafe { (*self.0).idxFlags = flags.bits() }; + } + + /// Mask of columns used by statement + #[inline] + pub fn col_used(&self) -> u64 { + unsafe { (*self.0).colUsed } + } + + /// Determine the collation for a virtual table constraint + #[cfg(feature = "modern_sqlite")] // SQLite >= 3.22.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn collation(&self, constraint_idx: usize) -> Result<&str> { + use std::ffi::CStr; + let idx = constraint_idx as c_int; + let collation = unsafe { ffi::sqlite3_vtab_collation(self.0, idx) }; + if collation.is_null() { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("{constraint_idx} is out of range")), + )); + } + Ok(unsafe { CStr::from_ptr(collation) }.to_str()?) + } + + /*/// Determine if a virtual table query is DISTINCT + #[cfg(feature = "modern_sqlite")] // SQLite >= 3.38.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn distinct(&self) -> c_int { + unsafe { ffi::sqlite3_vtab_distinct(self.0) } + } + + /// Constraint values + #[cfg(feature = "modern_sqlite")] // SQLite >= 3.38.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn set_rhs_value(&mut self, constraint_idx: c_int, value: ValueRef) -> Result<()> { + // TODO ValueRef to sqlite3_value + crate::error::check(unsafe { ffi::sqlite3_vtab_rhs_value(self.O, constraint_idx, value) }) + } + + /// Identify and handle IN constraints + #[cfg(feature = "modern_sqlite")] // SQLite >= 3.38.0 + #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] + pub fn set_in_constraint(&mut self, constraint_idx: c_int, b_handle: c_int) -> bool { + unsafe { ffi::sqlite3_vtab_in(self.0, constraint_idx, b_handle) != 0 } + } // TODO sqlite3_vtab_in_first / sqlite3_vtab_in_next https://sqlite.org/c3ref/vtab_in_first.html + */ +} + +/// Iterate on index constraint and its associated usage. +pub struct IndexConstraintAndUsageIter<'a> { + iter: std::iter::Zip< + slice::Iter<'a, ffi::sqlite3_index_constraint>, + slice::IterMut<'a, ffi::sqlite3_index_constraint_usage>, + >, +} + +impl<'a> Iterator for IndexConstraintAndUsageIter<'a> { + type Item = (IndexConstraint<'a>, IndexConstraintUsage<'a>); + + #[inline] + fn next(&mut self) -> Option<(IndexConstraint<'a>, IndexConstraintUsage<'a>)> { + self.iter + .next() + .map(|raw| (IndexConstraint(raw.0), IndexConstraintUsage(raw.1))) + } + + #[inline] + fn size_hint(&self) -> (usize, Option<usize>) { + self.iter.size_hint() + } +} + +/// `feature = "vtab"` +pub struct IndexConstraintIter<'a> { + iter: slice::Iter<'a, ffi::sqlite3_index_constraint>, +} + +impl<'a> Iterator for IndexConstraintIter<'a> { + type Item = IndexConstraint<'a>; + + #[inline] + fn next(&mut self) -> Option<IndexConstraint<'a>> { + self.iter.next().map(IndexConstraint) + } + + #[inline] + fn size_hint(&self) -> (usize, Option<usize>) { + self.iter.size_hint() + } +} + +/// WHERE clause constraint. +pub struct IndexConstraint<'a>(&'a ffi::sqlite3_index_constraint); + +impl IndexConstraint<'_> { + /// Column constrained. -1 for ROWID + #[inline] + #[must_use] + pub fn column(&self) -> c_int { + self.0.iColumn + } + + /// Constraint operator + #[inline] + #[must_use] + pub fn operator(&self) -> IndexConstraintOp { + IndexConstraintOp::from(self.0.op) + } + + /// True if this constraint is usable + #[inline] + #[must_use] + pub fn is_usable(&self) -> bool { + self.0.usable != 0 + } +} + +/// Information about what parameters to pass to +/// [`VTabCursor::filter`]. +pub struct IndexConstraintUsage<'a>(&'a mut ffi::sqlite3_index_constraint_usage); + +impl IndexConstraintUsage<'_> { + /// if `argv_index` > 0, constraint is part of argv to + /// [`VTabCursor::filter`] + #[inline] + pub fn set_argv_index(&mut self, argv_index: c_int) { + self.0.argvIndex = argv_index; + } + + /// if `omit`, do not code a test for this constraint + #[inline] + pub fn set_omit(&mut self, omit: bool) { + self.0.omit = omit as std::os::raw::c_uchar; + } +} + +/// `feature = "vtab"` +pub struct OrderByIter<'a> { + iter: slice::Iter<'a, ffi::sqlite3_index_orderby>, +} + +impl<'a> Iterator for OrderByIter<'a> { + type Item = OrderBy<'a>; + + #[inline] + fn next(&mut self) -> Option<OrderBy<'a>> { + self.iter.next().map(OrderBy) + } + + #[inline] + fn size_hint(&self) -> (usize, Option<usize>) { + self.iter.size_hint() + } +} + +/// A column of the ORDER BY clause. +pub struct OrderBy<'a>(&'a ffi::sqlite3_index_orderby); + +impl OrderBy<'_> { + /// Column number + #[inline] + #[must_use] + pub fn column(&self) -> c_int { + self.0.iColumn + } + + /// True for DESC. False for ASC. + #[inline] + #[must_use] + pub fn is_order_by_desc(&self) -> bool { + self.0.desc != 0 + } +} + +/// Virtual table cursor trait. +/// +/// # Safety +/// +/// Implementations must be like: +/// ```rust,ignore +/// #[repr(C)] +/// struct MyTabCursor { +/// /// Base class. Must be first +/// base: rusqlite::vtab::sqlite3_vtab_cursor, +/// /* Virtual table implementations will typically add additional fields */ +/// } +/// ``` +/// +/// (See [SQLite doc](https://sqlite.org/c3ref/vtab_cursor.html)) +pub unsafe trait VTabCursor: Sized { + /// Begin a search of a virtual table. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xfilter_method)) + fn filter(&mut self, idx_num: c_int, idx_str: Option<&str>, args: &Values<'_>) -> Result<()>; + /// Advance cursor to the next row of a result set initiated by + /// [`filter`](VTabCursor::filter). (See [SQLite doc](https://sqlite.org/vtab.html#the_xnext_method)) + fn next(&mut self) -> Result<()>; + /// Must return `false` if the cursor currently points to a valid row of + /// data, or `true` otherwise. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xeof_method)) + fn eof(&self) -> bool; + /// Find the value for the `i`-th column of the current row. + /// `i` is zero-based so the first column is numbered 0. + /// May return its result back to SQLite using one of the specified `ctx`. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xcolumn_method)) + fn column(&self, ctx: &mut Context, i: c_int) -> Result<()>; + /// Return the rowid of row that the cursor is currently pointing at. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xrowid_method)) + fn rowid(&self) -> Result<i64>; +} + +/// Context is used by [`VTabCursor::column`] to specify the +/// cell value. +pub struct Context(*mut ffi::sqlite3_context); + +impl Context { + /// Set current cell value + #[inline] + pub fn set_result<T: ToSql>(&mut self, value: &T) -> Result<()> { + let t = value.to_sql()?; + unsafe { set_result(self.0, &t) }; + Ok(()) + } + + // TODO sqlite3_vtab_nochange (http://sqlite.org/c3ref/vtab_nochange.html) // 3.22.0 & xColumn +} + +/// Wrapper to [`VTabCursor::filter`] arguments, the values +/// requested by [`VTab::best_index`]. +pub struct Values<'a> { + args: &'a [*mut ffi::sqlite3_value], +} + +impl Values<'_> { + /// Returns the number of values. + #[inline] + #[must_use] + pub fn len(&self) -> usize { + self.args.len() + } + + /// Returns `true` if there is no value. + #[inline] + #[must_use] + pub fn is_empty(&self) -> bool { + self.args.is_empty() + } + + /// Returns value at `idx` + pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> { + let arg = self.args[idx]; + let value = unsafe { ValueRef::from_value(arg) }; + FromSql::column_result(value).map_err(|err| match err { + FromSqlError::InvalidType => Error::InvalidFilterParameterType(idx, value.data_type()), + FromSqlError::Other(err) => { + Error::FromSqlConversionFailure(idx, value.data_type(), err) + } + FromSqlError::InvalidBlobSize { .. } => { + Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err)) + } + FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i), + }) + } + + // `sqlite3_value_type` returns `SQLITE_NULL` for pointer. + // So it seems not possible to enhance `ValueRef::from_value`. + #[cfg(feature = "array")] + #[cfg_attr(docsrs, doc(cfg(feature = "array")))] + fn get_array(&self, idx: usize) -> Option<array::Array> { + use crate::types::Value; + let arg = self.args[idx]; + let ptr = unsafe { ffi::sqlite3_value_pointer(arg, array::ARRAY_TYPE) }; + if ptr.is_null() { + None + } else { + Some(unsafe { + let rc = array::Array::from_raw(ptr as *const Vec<Value>); + let array = rc.clone(); + array::Array::into_raw(rc); // don't consume it + array + }) + } + } + + /// Turns `Values` into an iterator. + #[inline] + #[must_use] + pub fn iter(&self) -> ValueIter<'_> { + ValueIter { + iter: self.args.iter(), + } + } + // TODO sqlite3_vtab_in_first / sqlite3_vtab_in_next https://sqlite.org/c3ref/vtab_in_first.html & 3.38.0 +} + +impl<'a> IntoIterator for &'a Values<'a> { + type IntoIter = ValueIter<'a>; + type Item = ValueRef<'a>; + + #[inline] + fn into_iter(self) -> ValueIter<'a> { + self.iter() + } +} + +/// [`Values`] iterator. +pub struct ValueIter<'a> { + iter: slice::Iter<'a, *mut ffi::sqlite3_value>, +} + +impl<'a> Iterator for ValueIter<'a> { + type Item = ValueRef<'a>; + + #[inline] + fn next(&mut self) -> Option<ValueRef<'a>> { + self.iter + .next() + .map(|&raw| unsafe { ValueRef::from_value(raw) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option<usize>) { + self.iter.size_hint() + } +} + +impl Connection { + /// Register a virtual table implementation. + /// + /// Step 3 of [Creating New Virtual Table + /// Implementations](https://sqlite.org/vtab.html#creating_new_virtual_table_implementations). + #[inline] + pub fn create_module<'vtab, T: VTab<'vtab>>( + &self, + module_name: &str, + module: &'static Module<'vtab, T>, + aux: Option<T::Aux>, + ) -> Result<()> { + self.db.borrow_mut().create_module(module_name, module, aux) + } +} + +impl InnerConnection { + fn create_module<'vtab, T: VTab<'vtab>>( + &mut self, + module_name: &str, + module: &'static Module<'vtab, T>, + aux: Option<T::Aux>, + ) -> Result<()> { + use crate::version; + if version::version_number() < 3_009_000 && module.base.xCreate.is_none() { + return Err(Error::ModuleError(format!( + "Eponymous-only virtual table not supported by SQLite version {}", + version::version() + ))); + } + let c_name = str_to_cstring(module_name)?; + let r = match aux { + Some(aux) => { + let boxed_aux: *mut T::Aux = Box::into_raw(Box::new(aux)); + unsafe { + ffi::sqlite3_create_module_v2( + self.db(), + c_name.as_ptr(), + &module.base, + boxed_aux.cast::<c_void>(), + Some(free_boxed_value::<T::Aux>), + ) + } + } + None => unsafe { + ffi::sqlite3_create_module_v2( + self.db(), + c_name.as_ptr(), + &module.base, + ptr::null_mut(), + None, + ) + }, + }; + self.decode_result(r) + } +} + +/// Escape double-quote (`"`) character occurrences by +/// doubling them (`""`). +#[must_use] +pub fn escape_double_quote(identifier: &str) -> Cow<'_, str> { + if identifier.contains('"') { + // escape quote by doubling them + Owned(identifier.replace('"', "\"\"")) + } else { + Borrowed(identifier) + } +} +/// Dequote string +#[must_use] +pub fn dequote(s: &str) -> &str { + if s.len() < 2 { + return s; + } + match s.bytes().next() { + Some(b) if b == b'"' || b == b'\'' => match s.bytes().next_back() { + Some(e) if e == b => &s[1..s.len() - 1], // FIXME handle inner escaped quote(s) + _ => s, + }, + _ => s, + } +} +/// The boolean can be one of: +/// ```text +/// 1 yes true on +/// 0 no false off +/// ``` +#[must_use] +pub fn parse_boolean(s: &str) -> Option<bool> { + if s.eq_ignore_ascii_case("yes") + || s.eq_ignore_ascii_case("on") + || s.eq_ignore_ascii_case("true") + || s.eq("1") + { + Some(true) + } else if s.eq_ignore_ascii_case("no") + || s.eq_ignore_ascii_case("off") + || s.eq_ignore_ascii_case("false") + || s.eq("0") + { + Some(false) + } else { + None + } +} + +/// `<param_name>=['"]?<param_value>['"]?` => `(<param_name>, <param_value>)` +pub fn parameter(c_slice: &[u8]) -> Result<(&str, &str)> { + let arg = std::str::from_utf8(c_slice)?.trim(); + let mut split = arg.split('='); + if let Some(key) = split.next() { + if let Some(value) = split.next() { + let param = key.trim(); + let value = dequote(value); + return Ok((param, value)); + } + } + Err(Error::ModuleError(format!("illegal argument: '{arg}'"))) +} + +// FIXME copy/paste from function.rs +unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) { + drop(Box::from_raw(p.cast::<T>())); +} + +unsafe extern "C" fn rust_create<'vtab, T>( + db: *mut ffi::sqlite3, + aux: *mut c_void, + argc: c_int, + argv: *const *const c_char, + pp_vtab: *mut *mut ffi::sqlite3_vtab, + err_msg: *mut *mut c_char, +) -> c_int +where + T: CreateVTab<'vtab>, +{ + use std::ffi::CStr; + + let mut conn = VTabConnection(db); + let aux = aux.cast::<T::Aux>(); + let args = slice::from_raw_parts(argv, argc as usize); + let vec = args + .iter() + .map(|&cs| CStr::from_ptr(cs).to_bytes()) // FIXME .to_str() -> Result<&str, Utf8Error> + .collect::<Vec<_>>(); + match T::create(&mut conn, aux.as_ref(), &vec[..]) { + Ok((sql, vtab)) => match std::ffi::CString::new(sql) { + Ok(c_sql) => { + let rc = ffi::sqlite3_declare_vtab(db, c_sql.as_ptr()); + if rc == ffi::SQLITE_OK { + let boxed_vtab: *mut T = Box::into_raw(Box::new(vtab)); + *pp_vtab = boxed_vtab.cast::<ffi::sqlite3_vtab>(); + ffi::SQLITE_OK + } else { + let err = error_from_sqlite_code(rc, None); + to_sqlite_error(&err, err_msg) + } + } + Err(err) => { + *err_msg = alloc(&err.to_string()); + ffi::SQLITE_ERROR + } + }, + Err(err) => to_sqlite_error(&err, err_msg), + } +} + +unsafe extern "C" fn rust_connect<'vtab, T>( + db: *mut ffi::sqlite3, + aux: *mut c_void, + argc: c_int, + argv: *const *const c_char, + pp_vtab: *mut *mut ffi::sqlite3_vtab, + err_msg: *mut *mut c_char, +) -> c_int +where + T: VTab<'vtab>, +{ + use std::ffi::CStr; + + let mut conn = VTabConnection(db); + let aux = aux.cast::<T::Aux>(); + let args = slice::from_raw_parts(argv, argc as usize); + let vec = args + .iter() + .map(|&cs| CStr::from_ptr(cs).to_bytes()) // FIXME .to_str() -> Result<&str, Utf8Error> + .collect::<Vec<_>>(); + match T::connect(&mut conn, aux.as_ref(), &vec[..]) { + Ok((sql, vtab)) => match std::ffi::CString::new(sql) { + Ok(c_sql) => { + let rc = ffi::sqlite3_declare_vtab(db, c_sql.as_ptr()); + if rc == ffi::SQLITE_OK { + let boxed_vtab: *mut T = Box::into_raw(Box::new(vtab)); + *pp_vtab = boxed_vtab.cast::<ffi::sqlite3_vtab>(); + ffi::SQLITE_OK + } else { + let err = error_from_sqlite_code(rc, None); + to_sqlite_error(&err, err_msg) + } + } + Err(err) => { + *err_msg = alloc(&err.to_string()); + ffi::SQLITE_ERROR + } + }, + Err(err) => to_sqlite_error(&err, err_msg), + } +} + +unsafe extern "C" fn rust_best_index<'vtab, T>( + vtab: *mut ffi::sqlite3_vtab, + info: *mut ffi::sqlite3_index_info, +) -> c_int +where + T: VTab<'vtab>, +{ + let vt = vtab.cast::<T>(); + let mut idx_info = IndexInfo(info); + match (*vt).best_index(&mut idx_info) { + Ok(_) => ffi::SQLITE_OK, + Err(Error::SqliteFailure(err, s)) => { + if let Some(err_msg) = s { + set_err_msg(vtab, &err_msg); + } + err.extended_code + } + Err(err) => { + set_err_msg(vtab, &err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +unsafe extern "C" fn rust_disconnect<'vtab, T>(vtab: *mut ffi::sqlite3_vtab) -> c_int +where + T: VTab<'vtab>, +{ + if vtab.is_null() { + return ffi::SQLITE_OK; + } + let vtab = vtab.cast::<T>(); + drop(Box::from_raw(vtab)); + ffi::SQLITE_OK +} + +unsafe extern "C" fn rust_destroy<'vtab, T>(vtab: *mut ffi::sqlite3_vtab) -> c_int +where + T: CreateVTab<'vtab>, +{ + if vtab.is_null() { + return ffi::SQLITE_OK; + } + let vt = vtab.cast::<T>(); + match (*vt).destroy() { + Ok(_) => { + drop(Box::from_raw(vt)); + ffi::SQLITE_OK + } + Err(Error::SqliteFailure(err, s)) => { + if let Some(err_msg) = s { + set_err_msg(vtab, &err_msg); + } + err.extended_code + } + Err(err) => { + set_err_msg(vtab, &err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +unsafe extern "C" fn rust_open<'vtab, T: 'vtab>( + vtab: *mut ffi::sqlite3_vtab, + pp_cursor: *mut *mut ffi::sqlite3_vtab_cursor, +) -> c_int +where + T: VTab<'vtab>, +{ + let vt = vtab.cast::<T>(); + match (*vt).open() { + Ok(cursor) => { + let boxed_cursor: *mut T::Cursor = Box::into_raw(Box::new(cursor)); + *pp_cursor = boxed_cursor.cast::<ffi::sqlite3_vtab_cursor>(); + ffi::SQLITE_OK + } + Err(Error::SqliteFailure(err, s)) => { + if let Some(err_msg) = s { + set_err_msg(vtab, &err_msg); + } + err.extended_code + } + Err(err) => { + set_err_msg(vtab, &err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +unsafe extern "C" fn rust_close<C>(cursor: *mut ffi::sqlite3_vtab_cursor) -> c_int +where + C: VTabCursor, +{ + let cr = cursor.cast::<C>(); + drop(Box::from_raw(cr)); + ffi::SQLITE_OK +} + +unsafe extern "C" fn rust_filter<C>( + cursor: *mut ffi::sqlite3_vtab_cursor, + idx_num: c_int, + idx_str: *const c_char, + argc: c_int, + argv: *mut *mut ffi::sqlite3_value, +) -> c_int +where + C: VTabCursor, +{ + use std::ffi::CStr; + use std::str; + let idx_name = if idx_str.is_null() { + None + } else { + let c_slice = CStr::from_ptr(idx_str).to_bytes(); + Some(str::from_utf8_unchecked(c_slice)) + }; + let args = slice::from_raw_parts_mut(argv, argc as usize); + let values = Values { args }; + let cr = cursor as *mut C; + cursor_error(cursor, (*cr).filter(idx_num, idx_name, &values)) +} + +unsafe extern "C" fn rust_next<C>(cursor: *mut ffi::sqlite3_vtab_cursor) -> c_int +where + C: VTabCursor, +{ + let cr = cursor as *mut C; + cursor_error(cursor, (*cr).next()) +} + +unsafe extern "C" fn rust_eof<C>(cursor: *mut ffi::sqlite3_vtab_cursor) -> c_int +where + C: VTabCursor, +{ + let cr = cursor.cast::<C>(); + (*cr).eof() as c_int +} + +unsafe extern "C" fn rust_column<C>( + cursor: *mut ffi::sqlite3_vtab_cursor, + ctx: *mut ffi::sqlite3_context, + i: c_int, +) -> c_int +where + C: VTabCursor, +{ + let cr = cursor.cast::<C>(); + let mut ctxt = Context(ctx); + result_error(ctx, (*cr).column(&mut ctxt, i)) +} + +unsafe extern "C" fn rust_rowid<C>( + cursor: *mut ffi::sqlite3_vtab_cursor, + p_rowid: *mut ffi::sqlite3_int64, +) -> c_int +where + C: VTabCursor, +{ + let cr = cursor.cast::<C>(); + match (*cr).rowid() { + Ok(rowid) => { + *p_rowid = rowid; + ffi::SQLITE_OK + } + err => cursor_error(cursor, err), + } +} + +unsafe extern "C" fn rust_update<'vtab, T: 'vtab>( + vtab: *mut ffi::sqlite3_vtab, + argc: c_int, + argv: *mut *mut ffi::sqlite3_value, + p_rowid: *mut ffi::sqlite3_int64, +) -> c_int +where + T: UpdateVTab<'vtab>, +{ + assert!(argc >= 1); + let args = slice::from_raw_parts_mut(argv, argc as usize); + let vt = vtab.cast::<T>(); + let r = if args.len() == 1 { + (*vt).delete(ValueRef::from_value(args[0])) + } else if ffi::sqlite3_value_type(args[0]) == ffi::SQLITE_NULL { + // TODO Make the distinction between argv[1] == NULL and argv[1] != NULL ? + let values = Values { args }; + match (*vt).insert(&values) { + Ok(rowid) => { + *p_rowid = rowid; + Ok(()) + } + Err(e) => Err(e), + } + } else { + let values = Values { args }; + (*vt).update(&values) + }; + match r { + Ok(_) => ffi::SQLITE_OK, + Err(Error::SqliteFailure(err, s)) => { + if let Some(err_msg) = s { + set_err_msg(vtab, &err_msg); + } + err.extended_code + } + Err(err) => { + set_err_msg(vtab, &err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +/// Virtual table cursors can set an error message by assigning a string to +/// `zErrMsg`. +#[cold] +unsafe fn cursor_error<T>(cursor: *mut ffi::sqlite3_vtab_cursor, result: Result<T>) -> c_int { + match result { + Ok(_) => ffi::SQLITE_OK, + Err(Error::SqliteFailure(err, s)) => { + if let Some(err_msg) = s { + set_err_msg((*cursor).pVtab, &err_msg); + } + err.extended_code + } + Err(err) => { + set_err_msg((*cursor).pVtab, &err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +/// Virtual tables methods can set an error message by assigning a string to +/// `zErrMsg`. +#[cold] +unsafe fn set_err_msg(vtab: *mut ffi::sqlite3_vtab, err_msg: &str) { + if !(*vtab).zErrMsg.is_null() { + ffi::sqlite3_free((*vtab).zErrMsg.cast::<c_void>()); + } + (*vtab).zErrMsg = alloc(err_msg); +} + +/// To raise an error, the `column` method should use this method to set the +/// error message and return the error code. +#[cold] +unsafe fn result_error<T>(ctx: *mut ffi::sqlite3_context, result: Result<T>) -> c_int { + match result { + Ok(_) => ffi::SQLITE_OK, + Err(Error::SqliteFailure(err, s)) => { + match err.extended_code { + ffi::SQLITE_TOOBIG => { + ffi::sqlite3_result_error_toobig(ctx); + } + ffi::SQLITE_NOMEM => { + ffi::sqlite3_result_error_nomem(ctx); + } + code => { + ffi::sqlite3_result_error_code(ctx, code); + if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } + }; + err.extended_code + } + Err(err) => { + ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_ERROR); + if let Ok(cstr) = str_to_cstring(&err.to_string()) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + ffi::SQLITE_ERROR + } + } +} + +#[cfg(feature = "array")] +#[cfg_attr(docsrs, doc(cfg(feature = "array")))] +pub mod array; +#[cfg(feature = "csvtab")] +#[cfg_attr(docsrs, doc(cfg(feature = "csvtab")))] +pub mod csvtab; +#[cfg(feature = "series")] +#[cfg_attr(docsrs, doc(cfg(feature = "series")))] +pub mod series; // SQLite >= 3.9.0 +#[cfg(all(test, feature = "modern_sqlite"))] +mod vtablog; + +#[cfg(test)] +mod test { + #[test] + fn test_dequote() { + assert_eq!("", super::dequote("")); + assert_eq!("'", super::dequote("'")); + assert_eq!("\"", super::dequote("\"")); + assert_eq!("'\"", super::dequote("'\"")); + assert_eq!("", super::dequote("''")); + assert_eq!("", super::dequote("\"\"")); + assert_eq!("x", super::dequote("'x'")); + assert_eq!("x", super::dequote("\"x\"")); + assert_eq!("x", super::dequote("x")); + } + #[test] + fn test_parse_boolean() { + assert_eq!(None, super::parse_boolean("")); + assert_eq!(Some(true), super::parse_boolean("1")); + assert_eq!(Some(true), super::parse_boolean("yes")); + assert_eq!(Some(true), super::parse_boolean("on")); + assert_eq!(Some(true), super::parse_boolean("true")); + assert_eq!(Some(false), super::parse_boolean("0")); + assert_eq!(Some(false), super::parse_boolean("no")); + assert_eq!(Some(false), super::parse_boolean("off")); + assert_eq!(Some(false), super::parse_boolean("false")); + } +} diff --git a/third_party/rust/rusqlite/src/vtab/series.rs b/third_party/rust/rusqlite/src/vtab/series.rs new file mode 100644 index 0000000000..5b67758149 --- /dev/null +++ b/third_party/rust/rusqlite/src/vtab/series.rs @@ -0,0 +1,341 @@ +//! Generate series virtual table. +//! +//! Port of C [generate series +//! "function"](http://www.sqlite.org/cgi/src/finfo?name=ext/misc/series.c): +//! `https://www.sqlite.org/series.html` +use std::default::Default; +use std::marker::PhantomData; +use std::os::raw::c_int; + +use crate::ffi; +use crate::types::Type; +use crate::vtab::{ + eponymous_only_module, Context, IndexConstraintOp, IndexInfo, VTab, VTabConfig, VTabConnection, + VTabCursor, Values, +}; +use crate::{Connection, Error, Result}; + +/// Register the "generate_series" module. +pub fn load_module(conn: &Connection) -> Result<()> { + let aux: Option<()> = None; + conn.create_module("generate_series", eponymous_only_module::<SeriesTab>(), aux) +} + +// Column numbers +// const SERIES_COLUMN_VALUE : c_int = 0; +const SERIES_COLUMN_START: c_int = 1; +const SERIES_COLUMN_STOP: c_int = 2; +const SERIES_COLUMN_STEP: c_int = 3; + +bitflags::bitflags! { + #[derive(Clone, Copy)] + #[repr(C)] + struct QueryPlanFlags: ::std::os::raw::c_int { + // start = $value -- constraint exists + const START = 1; + // stop = $value -- constraint exists + const STOP = 2; + // step = $value -- constraint exists + const STEP = 4; + // output in descending order + const DESC = 8; + // output in ascending order + const ASC = 16; + // Both start and stop + const BOTH = QueryPlanFlags::START.bits() | QueryPlanFlags::STOP.bits(); + } +} + +/// An instance of the Series virtual table +#[repr(C)] +struct SeriesTab { + /// Base class. Must be first + base: ffi::sqlite3_vtab, +} + +unsafe impl<'vtab> VTab<'vtab> for SeriesTab { + type Aux = (); + type Cursor = SeriesTabCursor<'vtab>; + + fn connect( + db: &mut VTabConnection, + _aux: Option<&()>, + _args: &[&[u8]], + ) -> Result<(String, SeriesTab)> { + let vtab = SeriesTab { + base: ffi::sqlite3_vtab::default(), + }; + db.config(VTabConfig::Innocuous)?; + Ok(( + "CREATE TABLE x(value,start hidden,stop hidden,step hidden)".to_owned(), + vtab, + )) + } + + fn best_index(&self, info: &mut IndexInfo) -> Result<()> { + // The query plan bitmask + let mut idx_num: QueryPlanFlags = QueryPlanFlags::empty(); + // Mask of unusable constraints + let mut unusable_mask: QueryPlanFlags = QueryPlanFlags::empty(); + // Constraints on start, stop, and step + let mut a_idx: [Option<usize>; 3] = [None, None, None]; + for (i, constraint) in info.constraints().enumerate() { + if constraint.column() < SERIES_COLUMN_START { + continue; + } + let (i_col, i_mask) = match constraint.column() { + SERIES_COLUMN_START => (0, QueryPlanFlags::START), + SERIES_COLUMN_STOP => (1, QueryPlanFlags::STOP), + SERIES_COLUMN_STEP => (2, QueryPlanFlags::STEP), + _ => { + unreachable!() + } + }; + if !constraint.is_usable() { + unusable_mask |= i_mask; + } else if constraint.operator() == IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_EQ { + idx_num |= i_mask; + a_idx[i_col] = Some(i); + } + } + // Number of arguments that SeriesTabCursor::filter expects + let mut n_arg = 0; + for j in a_idx.iter().flatten() { + n_arg += 1; + let mut constraint_usage = info.constraint_usage(*j); + constraint_usage.set_argv_index(n_arg); + constraint_usage.set_omit(true); + #[cfg(all(test, feature = "modern_sqlite"))] + debug_assert_eq!(Ok("BINARY"), info.collation(*j)); + } + if !(unusable_mask & !idx_num).is_empty() { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_CONSTRAINT), + None, + )); + } + if idx_num.contains(QueryPlanFlags::BOTH) { + // Both start= and stop= boundaries are available. + #[allow(clippy::bool_to_int_with_if)] + info.set_estimated_cost(f64::from( + 2 - if idx_num.contains(QueryPlanFlags::STEP) { + 1 + } else { + 0 + }, + )); + info.set_estimated_rows(1000); + let order_by_consumed = { + let mut order_bys = info.order_bys(); + if let Some(order_by) = order_bys.next() { + if order_by.column() == 0 { + if order_by.is_order_by_desc() { + idx_num |= QueryPlanFlags::DESC; + } else { + idx_num |= QueryPlanFlags::ASC; + } + true + } else { + false + } + } else { + false + } + }; + if order_by_consumed { + info.set_order_by_consumed(true); + } + } else { + // If either boundary is missing, we have to generate a huge span + // of numbers. Make this case very expensive so that the query + // planner will work hard to avoid it. + info.set_estimated_rows(2_147_483_647); + } + info.set_idx_num(idx_num.bits()); + Ok(()) + } + + fn open(&mut self) -> Result<SeriesTabCursor<'_>> { + Ok(SeriesTabCursor::new()) + } +} + +/// A cursor for the Series virtual table +#[repr(C)] +struct SeriesTabCursor<'vtab> { + /// Base class. Must be first + base: ffi::sqlite3_vtab_cursor, + /// True to count down rather than up + is_desc: bool, + /// The rowid + row_id: i64, + /// Current value ("value") + value: i64, + /// Minimum value ("start") + min_value: i64, + /// Maximum value ("stop") + max_value: i64, + /// Increment ("step") + step: i64, + phantom: PhantomData<&'vtab SeriesTab>, +} + +impl SeriesTabCursor<'_> { + fn new<'vtab>() -> SeriesTabCursor<'vtab> { + SeriesTabCursor { + base: ffi::sqlite3_vtab_cursor::default(), + is_desc: false, + row_id: 0, + value: 0, + min_value: 0, + max_value: 0, + step: 0, + phantom: PhantomData, + } + } +} +#[allow(clippy::comparison_chain)] +unsafe impl VTabCursor for SeriesTabCursor<'_> { + fn filter(&mut self, idx_num: c_int, _idx_str: Option<&str>, args: &Values<'_>) -> Result<()> { + let mut idx_num = QueryPlanFlags::from_bits_truncate(idx_num); + let mut i = 0; + if idx_num.contains(QueryPlanFlags::START) { + self.min_value = args.get::<Option<_>>(i)?.unwrap_or_default(); + i += 1; + } else { + self.min_value = 0; + } + if idx_num.contains(QueryPlanFlags::STOP) { + self.max_value = args.get::<Option<_>>(i)?.unwrap_or_default(); + i += 1; + } else { + self.max_value = 0xffff_ffff; + } + if idx_num.contains(QueryPlanFlags::STEP) { + self.step = args.get::<Option<_>>(i)?.unwrap_or_default(); + if self.step == 0 { + self.step = 1; + } else if self.step < 0 { + self.step = -self.step; + if !idx_num.contains(QueryPlanFlags::ASC) { + idx_num |= QueryPlanFlags::DESC; + } + } + } else { + self.step = 1; + }; + for arg in args.iter() { + if arg.data_type() == Type::Null { + // If any of the constraints have a NULL value, then return no rows. + self.min_value = 1; + self.max_value = 0; + break; + } + } + self.is_desc = idx_num.contains(QueryPlanFlags::DESC); + if self.is_desc { + self.value = self.max_value; + if self.step > 0 { + self.value -= (self.max_value - self.min_value) % self.step; + } + } else { + self.value = self.min_value; + } + self.row_id = 1; + Ok(()) + } + + fn next(&mut self) -> Result<()> { + if self.is_desc { + self.value -= self.step; + } else { + self.value += self.step; + } + self.row_id += 1; + Ok(()) + } + + fn eof(&self) -> bool { + if self.is_desc { + self.value < self.min_value + } else { + self.value > self.max_value + } + } + + fn column(&self, ctx: &mut Context, i: c_int) -> Result<()> { + let x = match i { + SERIES_COLUMN_START => self.min_value, + SERIES_COLUMN_STOP => self.max_value, + SERIES_COLUMN_STEP => self.step, + _ => self.value, + }; + ctx.set_result(&x) + } + + fn rowid(&self) -> Result<i64> { + Ok(self.row_id) + } +} + +#[cfg(test)] +mod test { + use crate::ffi; + use crate::vtab::series; + use crate::{Connection, Result}; + use fallible_iterator::FallibleIterator; + + #[test] + fn test_series_module() -> Result<()> { + let version = unsafe { ffi::sqlite3_libversion_number() }; + if version < 3_008_012 { + return Ok(()); + } + + let db = Connection::open_in_memory()?; + series::load_module(&db)?; + + let mut s = db.prepare("SELECT * FROM generate_series(0,20,5)")?; + + let series = s.query_map([], |row| row.get::<_, i32>(0))?; + + let mut expected = 0; + for value in series { + assert_eq!(expected, value?); + expected += 5; + } + + let mut s = + db.prepare("SELECT * FROM generate_series WHERE start=1 AND stop=9 AND step=2")?; + let series: Vec<i32> = s.query([])?.map(|r| r.get(0)).collect()?; + assert_eq!(vec![1, 3, 5, 7, 9], series); + let mut s = db.prepare("SELECT * FROM generate_series LIMIT 5")?; + let series: Vec<i32> = s.query([])?.map(|r| r.get(0)).collect()?; + assert_eq!(vec![0, 1, 2, 3, 4], series); + let mut s = db.prepare("SELECT * FROM generate_series(0,32,5) ORDER BY value DESC")?; + let series: Vec<i32> = s.query([])?.map(|r| r.get(0)).collect()?; + assert_eq!(vec![30, 25, 20, 15, 10, 5, 0], series); + + let mut s = db.prepare("SELECT * FROM generate_series(NULL)")?; + let series: Vec<i32> = s.query([])?.map(|r| r.get(0)).collect()?; + let empty = Vec::<i32>::new(); + assert_eq!(empty, series); + let mut s = db.prepare("SELECT * FROM generate_series(5,NULL)")?; + let series: Vec<i32> = s.query([])?.map(|r| r.get(0)).collect()?; + assert_eq!(empty, series); + let mut s = db.prepare("SELECT * FROM generate_series(5,10,NULL)")?; + let series: Vec<i32> = s.query([])?.map(|r| r.get(0)).collect()?; + assert_eq!(empty, series); + let mut s = db.prepare("SELECT * FROM generate_series(NULL,10,2)")?; + let series: Vec<i32> = s.query([])?.map(|r| r.get(0)).collect()?; + assert_eq!(empty, series); + let mut s = db.prepare("SELECT * FROM generate_series(5,NULL,2)")?; + let series: Vec<i32> = s.query([])?.map(|r| r.get(0)).collect()?; + assert_eq!(empty, series); + let mut s = db.prepare("SELECT * FROM generate_series(NULL) ORDER BY value DESC")?; + let series: Vec<i32> = s.query([])?.map(|r| r.get(0)).collect()?; + assert_eq!(empty, series); + + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/vtab/vtablog.rs b/third_party/rust/rusqlite/src/vtab/vtablog.rs new file mode 100644 index 0000000000..e289cfae71 --- /dev/null +++ b/third_party/rust/rusqlite/src/vtab/vtablog.rs @@ -0,0 +1,297 @@ +//! Port of C [vtablog](http://www.sqlite.org/cgi/src/finfo?name=ext/misc/vtablog.c) +use std::default::Default; +use std::marker::PhantomData; +use std::os::raw::c_int; +use std::str::FromStr; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use crate::vtab::{ + update_module, Context, CreateVTab, IndexInfo, UpdateVTab, VTab, VTabConnection, VTabCursor, + VTabKind, Values, +}; +use crate::{ffi, ValueRef}; +use crate::{Connection, Error, Result}; + +/// Register the "vtablog" module. +pub fn load_module(conn: &Connection) -> Result<()> { + let aux: Option<()> = None; + conn.create_module("vtablog", update_module::<VTabLog>(), aux) +} + +/// An instance of the vtablog virtual table +#[repr(C)] +struct VTabLog { + /// Base class. Must be first + base: ffi::sqlite3_vtab, + /// Number of rows in the table + n_row: i64, + /// Instance number for this vtablog table + i_inst: usize, + /// Number of cursors created + n_cursor: usize, +} + +impl VTabLog { + fn connect_create( + _: &mut VTabConnection, + _: Option<&()>, + args: &[&[u8]], + is_create: bool, + ) -> Result<(String, VTabLog)> { + static N_INST: AtomicUsize = AtomicUsize::new(1); + let i_inst = N_INST.fetch_add(1, Ordering::SeqCst); + println!( + "VTabLog::{}(tab={}, args={:?}):", + if is_create { "create" } else { "connect" }, + i_inst, + args, + ); + let mut schema = None; + let mut n_row = None; + + let args = &args[3..]; + for c_slice in args { + let (param, value) = super::parameter(c_slice)?; + match param { + "schema" => { + if schema.is_some() { + return Err(Error::ModuleError(format!( + "more than one '{param}' parameter" + ))); + } + schema = Some(value.to_owned()) + } + "rows" => { + if n_row.is_some() { + return Err(Error::ModuleError(format!( + "more than one '{param}' parameter" + ))); + } + if let Ok(n) = i64::from_str(value) { + n_row = Some(n) + } + } + _ => { + return Err(Error::ModuleError(format!( + "unrecognized parameter '{param}'" + ))); + } + } + } + if schema.is_none() { + return Err(Error::ModuleError("no schema defined".to_owned())); + } + let vtab = VTabLog { + base: ffi::sqlite3_vtab::default(), + n_row: n_row.unwrap_or(10), + i_inst, + n_cursor: 0, + }; + Ok((schema.unwrap(), vtab)) + } +} + +impl Drop for VTabLog { + fn drop(&mut self) { + println!("VTabLog::drop({})", self.i_inst); + } +} + +unsafe impl<'vtab> VTab<'vtab> for VTabLog { + type Aux = (); + type Cursor = VTabLogCursor<'vtab>; + + fn connect( + db: &mut VTabConnection, + aux: Option<&Self::Aux>, + args: &[&[u8]], + ) -> Result<(String, Self)> { + VTabLog::connect_create(db, aux, args, false) + } + + fn best_index(&self, info: &mut IndexInfo) -> Result<()> { + println!("VTabLog::best_index({})", self.i_inst); + info.set_estimated_cost(500.); + info.set_estimated_rows(500); + Ok(()) + } + + fn open(&'vtab mut self) -> Result<Self::Cursor> { + self.n_cursor += 1; + println!( + "VTabLog::open(tab={}, cursor={})", + self.i_inst, self.n_cursor + ); + Ok(VTabLogCursor { + base: ffi::sqlite3_vtab_cursor::default(), + i_cursor: self.n_cursor, + row_id: 0, + phantom: PhantomData, + }) + } +} + +impl<'vtab> CreateVTab<'vtab> for VTabLog { + const KIND: VTabKind = VTabKind::Default; + + fn create( + db: &mut VTabConnection, + aux: Option<&Self::Aux>, + args: &[&[u8]], + ) -> Result<(String, Self)> { + VTabLog::connect_create(db, aux, args, true) + } + + fn destroy(&self) -> Result<()> { + println!("VTabLog::destroy({})", self.i_inst); + Ok(()) + } +} + +impl<'vtab> UpdateVTab<'vtab> for VTabLog { + fn delete(&mut self, arg: ValueRef<'_>) -> Result<()> { + println!("VTabLog::delete({}, {arg:?})", self.i_inst); + Ok(()) + } + + fn insert(&mut self, args: &Values<'_>) -> Result<i64> { + println!( + "VTabLog::insert({}, {:?})", + self.i_inst, + args.iter().collect::<Vec<ValueRef<'_>>>() + ); + Ok(self.n_row) + } + + fn update(&mut self, args: &Values<'_>) -> Result<()> { + println!( + "VTabLog::update({}, {:?})", + self.i_inst, + args.iter().collect::<Vec<ValueRef<'_>>>() + ); + Ok(()) + } +} + +/// A cursor for the Series virtual table +#[repr(C)] +struct VTabLogCursor<'vtab> { + /// Base class. Must be first + base: ffi::sqlite3_vtab_cursor, + /// Cursor number + i_cursor: usize, + /// The rowid + row_id: i64, + phantom: PhantomData<&'vtab VTabLog>, +} + +impl VTabLogCursor<'_> { + fn vtab(&self) -> &VTabLog { + unsafe { &*(self.base.pVtab as *const VTabLog) } + } +} + +impl Drop for VTabLogCursor<'_> { + fn drop(&mut self) { + println!( + "VTabLogCursor::drop(tab={}, cursor={})", + self.vtab().i_inst, + self.i_cursor + ); + } +} + +unsafe impl VTabCursor for VTabLogCursor<'_> { + fn filter(&mut self, _: c_int, _: Option<&str>, _: &Values<'_>) -> Result<()> { + println!( + "VTabLogCursor::filter(tab={}, cursor={})", + self.vtab().i_inst, + self.i_cursor + ); + self.row_id = 0; + Ok(()) + } + + fn next(&mut self) -> Result<()> { + println!( + "VTabLogCursor::next(tab={}, cursor={}): rowid {} -> {}", + self.vtab().i_inst, + self.i_cursor, + self.row_id, + self.row_id + 1 + ); + self.row_id += 1; + Ok(()) + } + + fn eof(&self) -> bool { + let eof = self.row_id >= self.vtab().n_row; + println!( + "VTabLogCursor::eof(tab={}, cursor={}): {}", + self.vtab().i_inst, + self.i_cursor, + eof, + ); + eof + } + + fn column(&self, ctx: &mut Context, i: c_int) -> Result<()> { + let value = if i < 26 { + format!( + "{}{}", + "abcdefghijklmnopqrstuvwyz".chars().nth(i as usize).unwrap(), + self.row_id + ) + } else { + format!("{i}{}", self.row_id) + }; + println!( + "VTabLogCursor::column(tab={}, cursor={}, i={}): {}", + self.vtab().i_inst, + self.i_cursor, + i, + value, + ); + ctx.set_result(&value) + } + + fn rowid(&self) -> Result<i64> { + println!( + "VTabLogCursor::rowid(tab={}, cursor={}): {}", + self.vtab().i_inst, + self.i_cursor, + self.row_id, + ); + Ok(self.row_id) + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, Result}; + #[test] + fn test_module() -> Result<()> { + let db = Connection::open_in_memory()?; + super::load_module(&db)?; + + db.execute_batch( + "CREATE VIRTUAL TABLE temp.log USING vtablog( + schema='CREATE TABLE x(a,b,c)', + rows=25 + );", + )?; + let mut stmt = db.prepare("SELECT * FROM log;")?; + let mut rows = stmt.query([])?; + while rows.next()?.is_some() {} + db.execute("DELETE FROM log WHERE a = ?1", ["a1"])?; + db.execute( + "INSERT INTO log (a, b, c) VALUES (?1, ?2, ?3)", + ["a", "b", "c"], + )?; + db.execute( + "UPDATE log SET b = ?1, c = ?2 WHERE a = ?3", + ["bn", "cn", "a1"], + )?; + Ok(()) + } +} |