diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 14:29:10 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 14:29:10 +0000 |
commit | 2aa4a82499d4becd2284cdb482213d541b8804dd (patch) | |
tree | b80bf8bf13c3766139fbacc530efd0dd9d54394c /third_party/rust/rusqlite | |
parent | Initial commit. (diff) | |
download | firefox-upstream.tar.xz firefox-upstream.zip |
Adding upstream version 86.0.1.upstream/86.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
57 files changed, 16164 insertions, 0 deletions
diff --git a/third_party/rust/rusqlite/.cargo-checksum.json b/third_party/rust/rusqlite/.cargo-checksum.json new file mode 100644 index 0000000000..aaee43abb5 --- /dev/null +++ b/third_party/rust/rusqlite/.cargo-checksum.json @@ -0,0 +1 @@ +{"files":{"Cargo.toml":"a68dc0863956a747f6526f566355c2f176dc1f0e522742ee1b4fe83f56dce905","Changelog.md":"34c6aa9cde3eb6b2803ec1072368ae412eac2da7c8a770ab2bd643a3fa1f9c5d","LICENSE":"6e93da4b8f1e937ae6addcac29c06bac66358611488c17fd1f46c6fa485f519c","README.md":"2265be99f12cc6805897266a26fb27227e4bb226293f3952de6edc4e2e47b8ef","appveyor.yml":"0bcc97462ce19ea2d7018b5a0f322d4d0ffc085a31cd18abe540ede05c770ff2","benches/cache.rs":"30459d8d7f65dd684909a53590f9be32e2e545a0775cd93ee0ce85dd08410955","benches/exec.rs":"cb7feb1738d4565c8265538a95a869a99ca360ed20c68bb69e84bd31ecbce094","clippy.toml":"98f90bbe52849b3078daeccfbe87aec1c08d02b94a453a15149a1fcccb7cafd0","codecov.yml":"17e6a8fa616f9a9f46ce420761ad15ddf0d8b5fb64c75c04365e5f269ab60ee5","publish-ghp-docs.sh":"3e2c59e09ce377c46a62511aa17389e6bf824c17d2fc5a818409ede40ba64c23","src/backup.rs":"00ee47c6b3d96393e68a3c9467e61f4082613bdd7e06607ef8b9238b62122de2","src/blob/mod.rs":"a8ea454a6bb31becd8b83e10c32c2e19a5a2fbb91cdfbd9fb0ec8525a162b2a7","src/blob/pos_io.rs":"6a41d298789ba7b2efa339f7cfcefef4162d39daf0400ca4fd3849f2f9534c45","src/busy.rs":"1308fbb2a09940207e59e6730d84165e262d95409ee11d434e4e846bb327caac","src/cache.rs":"024aac5c2a955fe60e54871a5ffd0d4fe56410398d2f2202c53f13c425533820","src/collation.rs":"1871a5d1973aedeaf8db52ab635702d77cc7a958d604f307cb498b5060db0d8e","src/column.rs":"7e7c7d5532202eaed6d453cb4cc3f46e3901f0528c5f1c83bfa08a0fd3d87264","src/config.rs":"70620f037d79bbb59ce4fff406b4aff4333cc14901dab8a6deeb4715d30ebc63","src/context.rs":"7550b77ccc1362ebbefa85e51ad70bf3d4e7d4f7817fcaac8cf34300962b7c33","src/error.rs":"dd7ab850c864aabd7bed4988955e652bff00546c20c16e3fe9ed471a655fb684","src/functions.rs":"ed50aacc01f1087ba7af347c0d46ee443360932dbdc539583cbf8b59a221bfdb","src/hooks.rs":"3ca1bcc9ba80f41c003331d779f1116c5c83daaf2bc5740fd0feb56710babce1","src/inner_connection.rs":"844a609bde6131ca13dd1a496181b405ae4f6a79766721adec01c3cf75e82c2a","src/lib.rs":"cd4c916eb8e1fbcc8eff2435b136e7f3723e7d62a691758018ee6e2bbf4678c5","src/limits.rs":"75d339cf1f21c744f42f339dcfccf44e149bd2d1294ad0fe412b0792d9c809ab","src/load_extension_guard.rs":"cd349613067410fa4be811d7522fc50004fbc48feb050c19f80a097ab919f9fb","src/pragma.rs":"f6820bcccd50ef804be09f348069208a2bc227de6959c0e3d9ac42b9c4c3e454","src/raw_statement.rs":"40864c284d9b29842831e92ca92eecf6ac77e5f166b43fc55043e62405e9e190","src/row.rs":"5f0e20ec6bbe9e6c5bce80e34bb76ba06b3abf54ef42755394f144ed1ad2e0c9","src/session.rs":"a4730dd21a0dd880d6cff46236fd3c7ad5ad5de4f607203c13062011162f3029","src/statement.rs":"32e3df655ff71ae0415539c6d79197f8e21315a583bfcad39906e0f8da4abc38","src/trace.rs":"686f46f9b27c9d446aaeff2eae0e8346177cedc48859c1fdac7453386797fe6e","src/transaction.rs":"41c5b8b0355401cace43996d311065a8e05846cb40a15d043f51bdd9e0606ffb","src/types/chrono.rs":"1d6aebb7b8efc6038ec30c95d91cee7a16c826bf5a78aca1bc75c86349e17c83","src/types/from_sql.rs":"3508835cf9a7eba0628dbd0bbbe06783f6fe57e5e8469ca51b288f139cd73436","src/types/mod.rs":"77ef2a40101d7864a651fa601e84ecda0ad441723e6c16dedd5329f4cbe8de84","src/types/serde_json.rs":"a3f11fd63cf5f3ffc9e89134dd3d139ac275aff1948d315b1eb1673d81e8fe95","src/types/time.rs":"f4ce83e8368274cf67677a5b67133adfaf0adef5c600574dd34e2e68da102754","src/types/to_sql.rs":"8544f6c9c6a0676bd75375ddba234d051d9dcf28829f49e3dbb39534142b7981","src/types/url.rs":"b476ca2386a4cc4e8ac500f870993c0445abe91754d4e004bfd19a1b95b53816","src/types/value.rs":"1bc28c99b215bf74821feef60544d5bfbedfcde99590f229d8336e0fc75354f8","src/types/value_ref.rs":"6fb2f5bfb268a90794bfc4aeda41182c7cead0292c84ff56d5141a4825ff1bae","src/unlock_notify.rs":"cf2a91e6484454fc3174be546a7f80934fd40f5be51d174db6a5674e93c9ec4a","src/util/mod.rs":"330b1822393003003e4c890df183eb88e3cfe9d566c8325fb421426811a54193","src/util/param_cache.rs":"efb480749cd74c09aeca28be074050e3a9aed7c8ed1371ca562342cba9c408dd","src/util/small_cstr.rs":"5ec10de6708837c09a67d12a7c503fc2856e933b318f8776ee2b4a6b2ce6cd83","src/util/sqlite_string.rs":"b35a4fac5962d47e0efa1b488d3447997de138aca8e99be8f5073fcbf2081664","src/version.rs":"6df3d88ff62b1f4c46692b515a16d1f02ff27732a3e371788566e6a75d5c1d4d","src/vtab/array.rs":"485d2b441b01d26a8f6a59130c85970afd627e1b953b52434964835f8b6a2aef","src/vtab/csvtab.rs":"1a4b70bdb3cb1eaeee80caded133278280f760bda2eac7ca9185448e121f32ca","src/vtab/mod.rs":"24119ac8328b8aaaf8d75c59e60820055af87dc56db153dabcb3b14113a5e50a","src/vtab/series.rs":"6b6f1c13c3d453ae726721fd09bb51c90a4252d4e662c503d7b0398812e772b7","test.csv":"3f5649d7b9f80468999b80d4d0e747c6f4f1bba80ade792689eeb4359dc1834a","tests/config_log.rs":"e786bb1a0ca1b560788b24a2091e7d294753a7bee8a39f0e6037b435809267df","tests/deny_single_threaded_sqlite_config.rs":"586ee163d42d886d88ae319cb2184c021bdf9f6da9a8325d799ba10ddeadcbe0","tests/vtab.rs":"0bb466e5ad5f96443ed5ef7f7252656dc5d3c6c3cc44d179195f654190065c04"},"package":"d5f38ee71cbab2c827ec0ac24e76f82eca723cee92c509a65f67dee393c25112"}
\ No newline at end of file diff --git a/third_party/rust/rusqlite/Cargo.toml b/third_party/rust/rusqlite/Cargo.toml new file mode 100644 index 0000000000..8efebabd61 --- /dev/null +++ b/third_party/rust/rusqlite/Cargo.toml @@ -0,0 +1,169 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies +# +# If you believe there's an error in this file please file an +# issue against the rust-lang/cargo repository. If you're +# editing this file be aware that the upstream Cargo.toml +# will likely look very different (and much more reasonable) + +[package] +edition = "2018" +name = "rusqlite" +version = "0.24.2" +authors = ["The rusqlite developers"] +description = "Ergonomic wrapper for SQLite" +documentation = "http://docs.rs/rusqlite/" +readme = "README.md" +keywords = ["sqlite", "database", "ffi"] +categories = ["database"] +license = "MIT" +repository = "https://github.com/rusqlite/rusqlite" +[package.metadata.docs.rs] +all-features = false +default-target = "x86_64-unknown-linux-gnu" +features = ["array", "backup", "blob", "chrono", "collation", "functions", "limits", "load_extension", "serde_json", "time", "trace", "url", "vtab", "window", "modern_sqlite", "column_decltype"] +no-default-features = true + +[package.metadata.playground] +all-features = false +features = ["bundled-full"] + +[lib] +name = "rusqlite" + +[[test]] +name = "config_log" +harness = false + +[[test]] +name = "deny_single_threaded_sqlite_config" + +[[test]] +name = "vtab" + +[[bench]] +name = "cache" +harness = false + +[[bench]] +name = "exec" +harness = false +[dependencies.bitflags] +version = "1.2" + +[dependencies.byteorder] +version = "1.3" +features = ["i128"] +optional = true + +[dependencies.chrono] +version = "0.4" +optional = true + +[dependencies.csv] +version = "1.1" +optional = true + +[dependencies.fallible-iterator] +version = "0.2" + +[dependencies.fallible-streaming-iterator] +version = "0.1" + +[dependencies.hashlink] +version = "0.6" + +[dependencies.lazy_static] +version = "1.4" +optional = true + +[dependencies.libsqlite3-sys] +version = "0.20.1" + +[dependencies.memchr] +version = "2.3" + +[dependencies.serde_json] +version = "1.0" +optional = true + +[dependencies.smallvec] +version = "1.0" + +[dependencies.time] +version = "0.2" +optional = true + +[dependencies.url] +version = "2.1" +optional = true + +[dependencies.uuid] +version = "0.8" +optional = true +[dev-dependencies.bencher] +version = "0.1" + +[dev-dependencies.doc-comment] +version = "0.3" + +[dev-dependencies.lazy_static] +version = "1.4" + +[dev-dependencies.regex] +version = "1.3" + +[dev-dependencies.tempfile] +version = "3.1.0" + +[dev-dependencies.unicase] +version = "2.6.0" + +[dev-dependencies.uuid] +version = "0.8" +features = ["v4"] + +[features] +array = ["vtab"] +backup = ["libsqlite3-sys/min_sqlite_version_3_6_23"] +blob = ["libsqlite3-sys/min_sqlite_version_3_7_7"] +buildtime_bindgen = ["libsqlite3-sys/buildtime_bindgen"] +bundled = ["libsqlite3-sys/bundled", "modern_sqlite"] +bundled-full = ["array", "backup", "blob", "bundled", "chrono", "collation", "column_decltype", "csvtab", "extra_check", "functions", "hooks", "i128_blob", "limits", "load_extension", "serde_json", "series", "trace", "unlock_notify", "url", "uuid", "vtab", "window"] +bundled-windows = ["libsqlite3-sys/bundled-windows"] +collation = [] +column_decltype = [] +csvtab = ["csv", "vtab"] +extra_check = [] +functions = ["libsqlite3-sys/min_sqlite_version_3_7_7"] +hooks = [] +i128_blob = ["byteorder"] +in_gecko = ["modern_sqlite", "libsqlite3-sys/in_gecko"] +limits = [] +load_extension = [] +modern_sqlite = ["libsqlite3-sys/bundled_bindings"] +series = ["vtab"] +session = ["libsqlite3-sys/session", "hooks"] +sqlcipher = ["libsqlite3-sys/sqlcipher"] +trace = ["libsqlite3-sys/min_sqlite_version_3_6_23"] +unlock_notify = ["libsqlite3-sys/unlock_notify"] +vtab = ["libsqlite3-sys/min_sqlite_version_3_7_7", "lazy_static"] +wasm32-wasi-vfs = ["libsqlite3-sys/wasm32-wasi-vfs"] +window = ["functions"] +winsqlite3 = ["libsqlite3-sys/winsqlite3"] +with-asan = ["libsqlite3-sys/with-asan"] +[badges.appveyor] +repository = "rusqlite/rusqlite" + +[badges.codecov] +repository = "rusqlite/rusqlite" + +[badges.maintenance] +status = "actively-developed" + +[badges.travis-ci] +repository = "rusqlite/rusqlite" diff --git a/third_party/rust/rusqlite/Changelog.md b/third_party/rust/rusqlite/Changelog.md new file mode 100644 index 0000000000..6ba11d70e9 --- /dev/null +++ b/third_party/rust/rusqlite/Changelog.md @@ -0,0 +1,332 @@ +For version 0.15.0 and above, see [Releases](https://github.com/rusqlite/rusqlite/releases) page. + +# Version 0.14.0 (2018-08-17) + +* BREAKING CHANGE: `ToSql` implementation for `time::Timespec` uses RFC 3339 (%Y-%m-%dT%H:%M:%S.%fZ). + Previous format was %Y-%m-%d %H:%M:%S:%f %Z. +* BREAKING CHANGE: Remove potentially conflicting impl of ToSqlOutput (#313). +* BREAKING CHANGE: Replace column index/count type (i32) with usize. +* BREAKING CHANGE: Replace parameter index/count type (i32) with usize. +* BREAKING CHANGE: Replace row changes/count type (i32) with usize. +* BREAKING CHANGE: Scalar functions must be `Send`able and `'static`. +* Bugfix: Commit failure unhandled, database left in unusable state (#366). +* Bugfix: `free_boxed_hook` does not work for `fn`. +* Update the bundled SQLite version to 3.24.0 (#326). +* Add DropBehavior::Panic to enforce intentional commit or rollback. +* Implement `sqlite3_update_hook` (#260, #328), `sqlite3_commit_hook` and `sqlite3_rollback_hook`. +* Add support to unlock notification behind `unlock_notify` feature (#294, #331). +* Make `Statement::column_index` case insensitive (#330). +* Add comment to justify `&mut Connection` in `Transaction`. +* Fix `tyvar_behind_raw_pointer` warnings. +* Fix handful of clippy warnings. +* Fix `Connection::open` documentation (#332) +* Add binding to `sqlite3_get_autocommit` and `sqlite3_stmt_busy`. +* Add binding to `sqlite3_busy_timeout` and `sqlite3_busy_handler`. +* Add binding to `sqlite3_expanded_sql`. +* Use `rerun-if-env-changed` in libsqlite3-sys (#329). +* Return an `InvalidQuery` error when SQL is not read only. + +# Version 0.13.0 (2017-11-13) + +* Added ToSqlConversionFailure case to Error enum. +* Now depends on chrono 0.4, bitflats 1.0, and (optionally) cc 1.0 / bindgen 0.31. +* The ToSql/FromSql implementations for time::Timespec now include + and expect fractional seconds and timezone in the serialized string. +* The RowIndex type used in Row::get is now publicly exported. +* New `sqlcipher` feature allows linking against SQLCipher instead of SQLite. +* Doc link in README now point to docs.rs. + +# Version 0.12.0 (2017-05-29) + +* Defines HAVE\_USLEEP when building with a bundled SQLite (#263). +* Updates dependencies to their latest versions, particularly serde to 1.0. +* Adds support for vcpkg on Windows. +* Adds `ToSql` impls for `str` and `[u8]`. + +# Version 0.11.0 (2017-04-06) + +* Avoid publicly exporting SQLite constants multiple times from libsqlite3-sys. +* Adds `FromSql` and `ToSql` impls for `isize`. Documents why `usize` and `u64` are not included. + +# Version 0.10.1 (2017-03-03) + +* Updates the `bundled` SQLite version to 3.17.0. +* Changes the build process to no longer require `bindgen`. This should improve + build times and no longer require a new-ish Clang. See the README for more + details. + +# Version 0.10.0 (2017-02-28) + +* Re-export the `ErrorCode` enum from `libsqlite3-sys`. +* Adds `version()` and `version_number()` functions for querying the version of SQLite in use. +* Adds the `limits` feature, exposing `limit()` and `set_limit()` methods on `Connection`. +* Updates to `libsqlite3-sys` 0.7.0, which runs rust-bindgen at build-time instead of assuming the + precense of all expected SQLite constants and functions. +* Clarifies supported SQLite versions. Running with SQLite older than 3.6.8 now panics, and + some features will not compile unless a sufficiently-recent SQLite version is used. See + the README for requirements of particular features. +* When running with SQLite 3.6.x, rusqlite attempts to perform SQLite initialization. If it fails, + rusqlite will panic since it cannot ensure the threading mode for SQLite. This check can by + skipped by calling the unsafe function `rusqlite::bypass_sqlite_initialization()`. This is + technically a breaking change but is unlikely to affect anyone in practice, since prior to this + version the check that rusqlite was using would cause a segfault if linked against a SQLite + older than 3.7.0. +* rusqlite now performs a one-time check (prior to the first connection attempt) that the runtime + SQLite version is at least as new as the SQLite version found at buildtime. This check can by + skipped by calling the unsafe function `rusqlite::bypass_sqlite_version_check()`. +* Removes the `libc` dependency in favor of using `std::os::raw` + +# Version 0.9.5 (2017-01-26) + +* Add impls of `Clone`, `Debug`, and `PartialEq` to `ToSqlOutput`. + +# Version 0.9.4 (2017-01-25) + +* Update dependencies. + +# Version 0.9.3 (2017-01-23) + +* Make `ToSqlOutput` itself implement `ToSql`. + +# Version 0.9.2 (2017-01-22) + +* Bugfix: The `FromSql` impl for `i32` now returns an error instead of + truncating if the underlying SQLite value is out of `i32`'s range. +* Added `FromSql` and `ToSql` impls for `i8`, `i16`, `u8`, `u16`, and `u32`. + `i32` and `i64` already had impls. `u64` is omitted because their range + cannot be represented by `i64`, which is the type we use to communicate with + SQLite. + +# Version 0.9.1 (2017-01-20) + +* BREAKING CHANGE: `Connection::close()` now returns a `Result<(), (Connection, Error)>` instead + of a `Result<(), Error>` so callers get the still-open connection back on failure. + +# Version 0.8.0 (2016-12-31) + +* BREAKING CHANGE: The `FromSql` trait has been redesigned. It now requires a single, safe + method instead of the previous definition which required implementing one or two unsafe + methods. +* BREAKING CHANGE: The `ToSql` trait has been redesigned. It can now be implemented without + `unsafe`, and implementors can choose to return either borrowed or owned results. +* BREAKING CHANGE: The closure passed to `query_row`, `query_row_and_then`, `query_row_safe`, + and `query_row_named` now expects a `&Row` instead of a `Row`. The vast majority of calls + to these functions will probably not need to change; see + https://github.com/jgallagher/rusqlite/pull/184. +* BREAKING CHANGE: A few cases of the `Error` enum have sprouted additional information + (e.g., `FromSqlConversionFailure` now also includes the column index and the type returned + by SQLite). +* Added `#[deprecated(since = "...", note = "...")]` flags (new in Rust 1.9 for libraries) to + all deprecated APIs. +* Added `query_row` convenience function to `Statement`. +* Added `bundled` feature which will build SQLite from source instead of attempting to link + against a SQLite that already exists on the system. +* Fixed a bug where using cached prepared statements resulted in attempting to close a connection + failing with `DatabaseBusy`; see https://github.com/jgallagher/rusqlite/issues/186. + +# Version 0.7.3 (2016-06-01) + +* Fixes an incorrect failure from the `insert()` convenience function when back-to-back inserts to + different tables both returned the same row ID + ([#171](https://github.com/jgallagher/rusqlite/issues/171)). + +# Version 0.7.2 (2016-05-19) + +* BREAKING CHANGE: `Rows` no longer implements `Iterator`. It still has a `next()` method, but + the lifetime of the returned `Row` is now tied to the lifetime of the vending `Rows` object. + This behavior is more correct. Previously there were runtime checks to prevent misuse, but + other changes in this release to reset statements as soon as possible introduced yet another + hazard related to the lack of these lifetime connections. We were already recommending the + use of `query_map` and `query_and_then` over raw `query`; both of theose still return handles + that implement `Iterator`. +* BREAKING CHANGE: `Transaction::savepoint()` now returns a `Savepoint` instead of another + `Transaction`. Unlike `Transaction`, `Savepoint`s can be rolled back while keeping the current + savepoint active. +* BREAKING CHANGE: Creating transactions from a `Connection` or savepoints from a `Transaction` + now take `&mut self` instead of `&self` to correctly represent that transactions within a + connection are inherently nested. While a transaction is alive, the parent connection or + transaction is unusable, so `Transaction` now implements `Deref<Target=Connection>`, giving + access to `Connection`'s methods via the `Transaction` itself. +* BREAKING CHANGE: `Transaction::set_commit` and `Transaction::set_rollback` have been replaced + by `Transaction::set_drop_behavior`. +* Adds `Connection::prepare_cached`. `Connection` now keeps an internal cache of any statements + prepared via this method. The size of this cache defaults to 16 (`prepare_cached` will always + work but may re-prepare statements if more are prepared than the cache holds), and can be + controlled via `Connection::set_prepared_statement_cache_capacity`. +* Adds `query_map_named` and `query_and_then_named` to `Statement`. +* Adds `insert` convenience method to `Statement` which returns the row ID of an inserted row. +* Adds `exists` convenience method returning whether a query finds one or more rows. +* Adds support for serializing types from the `serde_json` crate. Requires the `serde_json` feature. +* Adds support for serializing types from the `chrono` crate. Requires the `chrono` feature. +* Removes `load_extension` feature from `libsqlite3-sys`. `load_extension` is still available + on rusqlite itself. +* Fixes crash on nightly Rust when using the `trace` feature. +* Adds optional `clippy` feature and addresses issues it found. +* Adds `column_count()` method to `Statement` and `Row`. +* Adds `types::Value` for dynamic column types. +* Adds support for user-defined aggregate functions (behind the existing `functions` Cargo feature). +* Introduces a `RowIndex` trait allowing columns to be fetched via index (as before) or name (new). +* Introduces `ZeroBlob` type under the `blob` module/feature exposing SQLite's zeroblob API. +* Adds CI testing for Windows via AppVeyor. +* Fixes a warning building libsqlite3-sys under Rust 1.6. +* Adds an unsafe `handle()` method to `Connection`. Please file an issue if you actually use it. + +# Version 0.6.0 (2015-12-17) + +* BREAKING CHANGE: `SqliteError` is now an enum instead of a struct. Previously, we were (ab)using + the error code and message to send back both underlying SQLite errors and errors that occurred + at the Rust level. Now those have been separated out; SQLite errors are returned as + `SqliteFailure` cases (which still include the error code but also include a Rust-friendlier + enum as well), and rusqlite-level errors are captured in other cases. Because of this change, + `SqliteError` no longer implements `PartialEq`. +* BREAKING CHANGE: When opening a new detection, rusqlite now detects if SQLite was compiled or + configured for single-threaded use only; if it was, connection attempts will fail. If this + affects you, please open an issue. +* BREAKING CHANGE: `SqliteTransactionDeferred`, `SqliteTransactionImmediate`, and + `SqliteTransactionExclusive` are no longer exported. Instead, use + `TransactionBehavior::Deferred`, `TransactionBehavior::Immediate`, and + `TransactionBehavior::Exclusive`. +* Removed `Sqlite` prefix on many types: + * `SqliteConnection` is now `Connection` + * `SqliteError` is now `Error` + * `SqliteResult` is now `Result` + * `SqliteStatement` is now `Statement` + * `SqliteRows` is now `Rows` + * `SqliteRow` is now `Row` + * `SqliteOpenFlags` is now `OpenFlags` + * `SqliteTransaction` is now `Transaction`. + * `SqliteTransactionBehavior` is now `TransactionBehavior`. + * `SqliteLoadExtensionGuard` is now `LoadExtensionGuard`. + The old, prefixed names are still exported but are deprecated. +* Adds a variety of `..._named` methods for executing queries using named placeholder parameters. +* Adds `backup` feature that exposes SQLite's online backup API. +* Adds `blob` feature that exposes SQLite's Incremental I/O for BLOB API. +* Adds `functions` feature that allows user-defined scalar functions to be added to + open `SqliteConnection`s. + +# Version 0.5.0 (2015-12-08) + +* Adds `trace` feature that allows the use of SQLite's logging, tracing, and profiling hooks. +* Slight change to the closure types passed to `query_map` and `query_and_then`: + * Remove the `'static` requirement on the closure's output type. + * Give the closure a `&SqliteRow` instead of a `SqliteRow`. +* When building, the environment variable `SQLITE3_LIB_DIR` now takes precedence over pkg-config. +* If `pkg-config` is not available, we will try to find `libsqlite3` in `/usr/lib`. +* Add more documentation for failure modes of functions that return `SqliteResult`s. +* Updates `libc` dependency to 0.2, fixing builds on ARM for Rust 1.6 or newer. + +# Version 0.4.0 (2015-11-03) + +* Adds `Sized` bound to `FromSql` trait as required by RFC 1214. + +# Version 0.3.1 (2015-09-22) + +* Reset underlying SQLite statements as soon as possible after executing, as recommended by + http://www.sqlite.org/cvstrac/wiki?p=ScrollingCursor. + +# Version 0.3.0 (2015-09-21) + +* Removes `get_opt`. Use `get_checked` instead. +* Add `query_row_and_then` and `query_and_then` convenience functions. These are analogous to + `query_row` and `query_map` but allow functions that can fail by returning `Result`s. +* Relax uses of `P: AsRef<...>` from `&P` to `P`. +* Add additional error check for calling `execute` when `query` was intended. +* Improve debug formatting of `SqliteStatement` and `SqliteConnection`. +* Changes documentation of `get_checked` to correctly indicate that it returns errors (not panics) + when given invalid types or column indices. + +# Version 0.2.0 (2015-07-26) + +* Add `column_names()` to `SqliteStatement`. +* By default, include `SQLITE_OPEN_NO_MUTEX` and `SQLITE_OPEN_URI` flags when opening a + new conneciton. +* Fix generated bindings (e.g., `sqlite3_exec` was wrong). +* Use now-generated `sqlite3_destructor_type` to define `SQLITE_STATIC` and `SQLITE_TRANSIENT`. + +# Version 0.1.0 (2015-05-11) + +* [breaking-change] Modify `query_row` to return a `Result` instead of unwrapping. +* Deprecate `query_row_safe` (use `query_row` instead). +* Add `query_map`. +* Add `get_checked`, which asks SQLite to do some basic type-checking of columns. + +# Version 0.0.17 (2015-04-03) + +* Publish version that builds on stable rust (beta). This version lives on the + `stable` branch. Development continues on `master` and still requires a nightly + version of Rust. + +# Version 0.0.16 + +* Updates to track rustc nightly. + +# Version 0.0.15 + +* Make SqliteConnection `Send`. + +# Version 0.0.14 + +* Remove unneeded features (also involves switching to `libc` crate). + +# Version 0.0.13 (2015-03-26) + +* Updates to track rustc nightly. + +# Version 0.0.12 (2015-03-24) + +* Updates to track rustc stabilization. + +# Version 0.0.11 (2015-03-12) + +* Reexport `sqlite3_stmt` from `libsqlite3-sys` for easier `impl`-ing of `ToSql` and `FromSql`. +* Updates to track latest rustc changes. +* Update dependency versions. + +# Version 0.0.10 (2015-02-23) + +* BREAKING CHANGE: `open` now expects a `Path` rather than a `str`. There is a separate + `open_in_memory` constructor for opening in-memory databases. +* Added the ability to load SQLite extensions. This is behind the `load_extension` Cargo feature, + because not all builds of sqlite3 include this ability. Notably the default libsqlite3 that + ships with OS X 10.10 does not support extensions. + +# Version 0.0.9 (2015-02-13) + +* Updates to track latest rustc changes. +* Implement standard `Error` trait for `SqliteError`. + +# Version 0.0.8 (2015-02-04) + +* Updates to track latest rustc changes. + +# Version 0.0.7 (2015-01-20) + +* Use external bitflags from crates.io. + +# Version 0.0.6 (2015-01-10) + +* Updates to track latest rustc changes (1.0.0-alpha). +* Add `query_row_safe`, a `SqliteResult`-returning variant of `query_row`. + +# Version 0.0.5 (2015-01-07) + +* Updates to track latest rustc changes (closure syntax). +* Updates to track latest rust stdlib changes (`std::c_str` -> `std::ffi`). + +# Version 0.0.4 (2015-01-05) + +* Updates to track latest rustc changes. + +# Version 0.0.3 (2014-12-23) + +* Updates to track latest rustc changes. +* Add call to `sqlite3_busy_timeout`. + +# Version 0.0.2 (2014-12-04) + +* Remove use of now-deprecated `std::vec::raw::from_buf`. +* Update to latest version of `time` crate. + +# Version 0.0.1 (2014-11-21) + +* Initial release diff --git a/third_party/rust/rusqlite/LICENSE b/third_party/rust/rusqlite/LICENSE new file mode 100644 index 0000000000..0245a8331f --- /dev/null +++ b/third_party/rust/rusqlite/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2014-2020 The rusqlite developers + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/third_party/rust/rusqlite/README.md b/third_party/rust/rusqlite/README.md new file mode 100644 index 0000000000..09f0c20bc7 --- /dev/null +++ b/third_party/rust/rusqlite/README.md @@ -0,0 +1,205 @@ +# Rusqlite + +[![Travis Build Status](https://api.travis-ci.org/rusqlite/rusqlite.svg?branch=master)](https://travis-ci.org/rusqlite/rusqlite) +[![AppVeyor Build Status](https://ci.appveyor.com/api/projects/status/github/rusqlite/rusqlite?branch=master&svg=true)](https://ci.appveyor.com/project/rusqlite/rusqlite) +[![Build Status](https://github.com/rusqlite/rusqlite/workflows/CI/badge.svg)](https://github.com/rusqlite/rusqlite/actions) +[![dependency status](https://deps.rs/repo/github/rusqlite/rusqlite/status.svg)](https://deps.rs/repo/github/rusqlite/rusqlite) +[![Latest Version](https://img.shields.io/crates/v/rusqlite.svg)](https://crates.io/crates/rusqlite) +[![Gitter](https://badges.gitter.im/rusqlite.svg)](https://gitter.im/rusqlite/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) +[![Docs](https://docs.rs/rusqlite/badge.svg)](https://docs.rs/rusqlite) +[![codecov](https://codecov.io/gh/rusqlite/rusqlite/branch/master/graph/badge.svg)](https://codecov.io/gh/rusqlite/rusqlite) + +Rusqlite is an ergonomic wrapper for using SQLite from Rust. It attempts to expose +an interface similar to [rust-postgres](https://github.com/sfackler/rust-postgres). + +```rust +use rusqlite::{params, Connection, Result}; + +#[derive(Debug)] +struct Person { + id: i32, + name: String, + data: Option<Vec<u8>>, +} + +fn main() -> Result<()> { + let conn = Connection::open_in_memory()?; + + conn.execute( + "CREATE TABLE person ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + data BLOB + )", + params![], + )?; + let me = Person { + id: 0, + name: "Steven".to_string(), + data: None, + }; + conn.execute( + "INSERT INTO person (name, data) VALUES (?1, ?2)", + params![me.name, me.data], + )?; + + let mut stmt = conn.prepare("SELECT id, name, data FROM person")?; + let person_iter = stmt.query_map(params![], |row| { + Ok(Person { + id: row.get(0)?, + name: row.get(1)?, + data: row.get(2)?, + }) + })?; + + for person in person_iter { + println!("Found person {:?}", person.unwrap()); + } + Ok(()) +} +``` + +### Supported SQLite Versions + +The base `rusqlite` package supports SQLite version 3.6.8 or newer. If you need +support for older versions, please file an issue. Some cargo features require a +newer SQLite version; see details below. + +### Optional Features + +Rusqlite provides several features that are behind [Cargo +features](https://doc.rust-lang.org/cargo/reference/manifest.html#the-features-section). They are: + +* [`load_extension`](https://docs.rs/rusqlite/~0/rusqlite/struct.LoadExtensionGuard.html) + allows loading dynamic library-based SQLite extensions. +* [`backup`](https://docs.rs/rusqlite/~0/rusqlite/backup/index.html) + allows use of SQLite's online backup API. Note: This feature requires SQLite 3.6.11 or later. +* [`functions`](https://docs.rs/rusqlite/~0/rusqlite/functions/index.html) + allows you to load Rust closures into SQLite connections for use in queries. + Note: This feature requires SQLite 3.7.3 or later. +* [`trace`](https://docs.rs/rusqlite/~0/rusqlite/trace/index.html) + allows hooks into SQLite's tracing and profiling APIs. Note: This feature + requires SQLite 3.6.23 or later. +* [`blob`](https://docs.rs/rusqlite/~0/rusqlite/blob/index.html) + gives `std::io::{Read, Write, Seek}` access to SQL BLOBs. Note: This feature + requires SQLite 3.7.4 or later. +* [`limits`](https://docs.rs/rusqlite/~0/rusqlite/struct.Connection.html#method.limit) + allows you to set and retrieve SQLite's per connection limits. +* `chrono` implements [`FromSql`](https://docs.rs/rusqlite/~0/rusqlite/types/trait.FromSql.html) + and [`ToSql`](https://docs.rs/rusqlite/~0/rusqlite/types/trait.ToSql.html) for various + types from the [`chrono` crate](https://crates.io/crates/chrono). +* `serde_json` implements [`FromSql`](https://docs.rs/rusqlite/~0/rusqlite/types/trait.FromSql.html) + and [`ToSql`](https://docs.rs/rusqlite/~0/rusqlite/types/trait.ToSql.html) for the + `Value` type from the [`serde_json` crate](https://crates.io/crates/serde_json). +* `time` implements [`FromSql`](https://docs.rs/rusqlite/~0/rusqlite/types/trait.FromSql.html) + and [`ToSql`](https://docs.rs/rusqlite/~0/rusqlite/types/trait.ToSql.html) for the + `time::OffsetDateTime` type from the [`time` crate](https://crates.io/crates/time). +* `url` implements [`FromSql`](https://docs.rs/rusqlite/~0/rusqlite/types/trait.FromSql.html) + and [`ToSql`](https://docs.rs/rusqlite/~0/rusqlite/types/trait.ToSql.html) for the + `Url` type from the [`url` crate](https://crates.io/crates/url). +* `bundled` uses a bundled version of sqlite3. This is a good option for cases where linking to sqlite3 is complicated, such as Windows. +* `sqlcipher` looks for the SQLCipher library to link against instead of SQLite. This feature is mutually exclusive with `bundled`. +* `hooks` for [Commit, Rollback](http://sqlite.org/c3ref/commit_hook.html) and [Data Change](http://sqlite.org/c3ref/update_hook.html) notification callbacks. +* `unlock_notify` for [Unlock](https://sqlite.org/unlock_notify.html) notification. +* `vtab` for [virtual table](https://sqlite.org/vtab.html) support (allows you to write virtual table implementations in Rust). Currently, only read-only virtual tables are supported. +* [`csvtab`](https://sqlite.org/csv.html), CSV virtual table written in Rust. +* [`array`](https://sqlite.org/carray.html), The `rarray()` Table-Valued Function. +* `i128_blob` allows storing values of type `i128` type in SQLite databases. Internally, the data is stored as a 16 byte big-endian blob, with the most significant bit flipped, which allows ordering and comparison between different blobs storing i128s to work as expected. +* `uuid` allows storing and retrieving `Uuid` values from the [`uuid`](https://docs.rs/uuid/) crate using blobs. +* [`session`](https://sqlite.org/sessionintro.html), Session module extension. Requires `buildtime_bindgen` feature. + +## Notes on building rusqlite and libsqlite3-sys + +`libsqlite3-sys` is a separate crate from `rusqlite` that provides the Rust +declarations for SQLite's C API. By default, `libsqlite3-sys` attempts to find a SQLite library that already exists on your system using pkg-config, or a +[Vcpkg](https://github.com/Microsoft/vcpkg) installation for MSVC ABI builds. + +You can adjust this behavior in a number of ways: + +* If you use the `bundled` feature, `libsqlite3-sys` will use the + [cc](https://crates.io/crates/cc) crate to compile SQLite from source and + link against that. This source is embedded in the `libsqlite3-sys` crate and + is currently SQLite 3.33.0 (as of `rusqlite` 0.24.1 / `libsqlite3-sys` + 0.20.0). This is probably the simplest solution to any build problems. You can enable this by adding the following in your `Cargo.toml` file: + ```toml + [dependencies.rusqlite] + version = "0.24.2" + features = ["bundled"] + ``` +* You can set the `SQLITE3_LIB_DIR` to point to directory containing the SQLite + library. +* Installing the sqlite3 development packages will usually be all that is required, but + the build helpers for [pkg-config](https://github.com/alexcrichton/pkg-config-rs) + and [vcpkg](https://github.com/mcgoo/vcpkg-rs) have some additional configuration + options. The default when using vcpkg is to dynamically link, + which must be enabled by setting `VCPKGRS_DYNAMIC=1` environment variable before build. + `vcpkg install sqlite3:x64-windows` will install the required library. + +### Binding generation + +We use [bindgen](https://crates.io/crates/bindgen) to generate the Rust +declarations from SQLite's C header file. `bindgen` +[recommends](https://github.com/servo/rust-bindgen#library-usage-with-buildrs) +running this as part of the build process of libraries that used this. We tried +this briefly (`rusqlite` 0.10.0, specifically), but it had some annoyances: + +* The build time for `libsqlite3-sys` (and therefore `rusqlite`) increased + dramatically. +* Running `bindgen` requires a relatively-recent version of Clang, which many + systems do not have installed by default. +* Running `bindgen` also requires the SQLite header file to be present. + +As of `rusqlite` 0.10.1, we avoid running `bindgen` at build-time by shipping +pregenerated bindings for several versions of SQLite. When compiling +`rusqlite`, we use your selected Cargo features to pick the bindings for the +minimum SQLite version that supports your chosen features. If you are using +`libsqlite3-sys` directly, you can use the same features to choose which +pregenerated bindings are chosen: + +* `min_sqlite_version_3_6_8` - SQLite 3.6.8 bindings (this is the default) +* `min_sqlite_version_3_6_23` - SQLite 3.6.23 bindings +* `min_sqlite_version_3_7_7` - SQLite 3.7.7 bindings + +If you use the `bundled` feature, you will get pregenerated bindings for the +bundled version of SQLite. If you need other specific pregenerated binding +versions, please file an issue. If you want to run `bindgen` at buildtime to +produce your own bindings, use the `buildtime_bindgen` Cargo feature. + +If you enable the `modern_sqlite` feature, we'll use the bindings we would have +included with the bundled build. You generally should have `buildtime_bindgen` +enabled if you turn this on, as otherwise you'll need to keep the version of +SQLite you link with in sync with what rusqlite would have bundled, (usually the +most recent release of sqlite). Failing to do this will cause a runtime error. + +## Contributing + +Rusqlite has many features, and many of them impact the build configuration in +incompatible ways. This is unfortunate, and makes testing changes hard. + +To help here: you generally should ensure that you run tests/lint for +`--features bundled`, and `--features bundled-full session buildtime_bindgen`. + +If running bindgen is problematic for you, `--features bundled-full` enables +bundled and all features which don't require binding generation, and can be used +instead. + +### Checklist + +- Run `cargo fmt` to ensure your Rust code is correctly formatted. +- Ensure `cargo clippy --all-targets --workspace --features bundled` passes without warnings. +- Ensure `cargo test --all-targets --workspace --features bundled-full session buildtime_bindgen` reports no failures. +- Ensure `cargo test --all-targets --workspace --features bundled` reports no failures. +- Ensure `cargo test --all-targets --workspace --features bundled-full session buildtime_bindgen` reports no failures. + +## Author + +Rusqlite is the product of hard work by a number of people. A list is available +here: https://github.com/rusqlite/rusqlite/graphs/contributors + +## Community + +Currently there's a gitter channel set up for rusqlite [here](https://gitter.im/rusqlite/community). + +## License + +Rusqlite is available under the MIT license. See the LICENSE file for more info. diff --git a/third_party/rust/rusqlite/appveyor.yml b/third_party/rust/rusqlite/appveyor.yml new file mode 100644 index 0000000000..2d88284f26 --- /dev/null +++ b/third_party/rust/rusqlite/appveyor.yml @@ -0,0 +1,42 @@ +environment: + matrix: + - TARGET: x86_64-pc-windows-gnu + MSYS2_BITS: 64 +# - TARGET: x86_64-pc-windows-msvc +# VCPKG_DEFAULT_TRIPLET: x64-windows +# VCPKGRS_DYNAMIC: 1 +# - TARGET: x86_64-pc-windows-msvc +# VCPKG_DEFAULT_TRIPLET: x64-windows-static +# RUSTFLAGS: -Ctarget-feature=+crt-static +install: + - appveyor-retry appveyor DownloadFile https://win.rustup.rs/ -FileName rustup-init.exe + - rustup-init.exe -y --default-host %TARGET% + - set PATH=%PATH%;C:\Users\appveyor\.cargo\bin + - if defined MSYS2_BITS set PATH=%PATH%;C:\msys64\mingw%MSYS2_BITS%\bin + - rustc -V + - cargo -V + # download SQLite dll (useful only when the `bundled` feature is not set) + - appveyor-retry appveyor DownloadFile https://sqlite.org/2018/sqlite-dll-win64-x64-3250200.zip -FileName sqlite-dll-win64-x64.zip + - if not defined VCPKG_DEFAULT_TRIPLET 7z e sqlite-dll-win64-x64.zip -y > nul + # download SQLite headers (useful only when the `bundled` feature is not set) + - appveyor-retry appveyor DownloadFile https://sqlite.org/2018/sqlite-amalgamation-3250200.zip -FileName sqlite-amalgamation.zip + - if not defined VCPKG_DEFAULT_TRIPLET 7z e sqlite-amalgamation.zip -y > nul + # specify where the SQLite dll has been downloaded (useful only when the `bundled` feature is not set) + - if not defined VCPKG_DEFAULT_TRIPLET SET SQLITE3_LIB_DIR=%APPVEYOR_BUILD_FOLDER% + # specify where the SQLite headers have been downloaded (useful only when the `bundled` feature is not set) + - if not defined VCPKG_DEFAULT_TRIPLET SET SQLITE3_INCLUDE_DIR=%APPVEYOR_BUILD_FOLDER% + # install sqlite3 package + - if defined VCPKG_DEFAULT_TRIPLET vcpkg install sqlite3 + +build: false + +test_script: + - cargo test --lib --verbose + - cargo test --lib --verbose --features bundled + - cargo test --lib --features "backup blob chrono collation functions hooks limits load_extension serde_json trace" + - cargo test --lib --features "backup blob chrono functions hooks limits load_extension serde_json trace buildtime_bindgen" + - cargo test --lib --features "backup blob chrono csvtab functions hooks limits load_extension serde_json trace vtab bundled" + - cargo test --lib --features "backup blob chrono csvtab functions hooks limits load_extension serde_json trace vtab bundled buildtime_bindgen" + +cache: + - C:\Users\appveyor\.cargo diff --git a/third_party/rust/rusqlite/benches/cache.rs b/third_party/rust/rusqlite/benches/cache.rs new file mode 100644 index 0000000000..dd3683ec72 --- /dev/null +++ b/third_party/rust/rusqlite/benches/cache.rs @@ -0,0 +1,18 @@ +use bencher::{benchmark_group, benchmark_main, Bencher}; +use rusqlite::Connection; + +fn bench_no_cache(b: &mut Bencher) { + let db = Connection::open_in_memory().unwrap(); + db.set_prepared_statement_cache_capacity(0); + let sql = "SELECT 1, 'test', 3.14 UNION SELECT 2, 'exp', 2.71"; + b.iter(|| db.prepare(sql).unwrap()); +} + +fn bench_cache(b: &mut Bencher) { + let db = Connection::open_in_memory().unwrap(); + let sql = "SELECT 1, 'test', 3.14 UNION SELECT 2, 'exp', 2.71"; + b.iter(|| db.prepare_cached(sql).unwrap()); +} + +benchmark_group!(cache_benches, bench_no_cache, bench_cache); +benchmark_main!(cache_benches); diff --git a/third_party/rust/rusqlite/benches/exec.rs b/third_party/rust/rusqlite/benches/exec.rs new file mode 100644 index 0000000000..360a98b86f --- /dev/null +++ b/third_party/rust/rusqlite/benches/exec.rs @@ -0,0 +1,17 @@ +use bencher::{benchmark_group, benchmark_main, Bencher}; +use rusqlite::{Connection, NO_PARAMS}; + +fn bench_execute(b: &mut Bencher) { + let db = Connection::open_in_memory().unwrap(); + let sql = "PRAGMA user_version=1"; + b.iter(|| db.execute(sql, NO_PARAMS).unwrap()); +} + +fn bench_execute_batch(b: &mut Bencher) { + let db = Connection::open_in_memory().unwrap(); + let sql = "PRAGMA user_version=1"; + b.iter(|| db.execute_batch(sql).unwrap()); +} + +benchmark_group!(exec_benches, bench_execute, bench_execute_batch); +benchmark_main!(exec_benches); diff --git a/third_party/rust/rusqlite/clippy.toml b/third_party/rust/rusqlite/clippy.toml new file mode 100644 index 0000000000..82447d9290 --- /dev/null +++ b/third_party/rust/rusqlite/clippy.toml @@ -0,0 +1 @@ +doc-valid-idents = ["SQLite", "lang_transaction"] diff --git a/third_party/rust/rusqlite/codecov.yml b/third_party/rust/rusqlite/codecov.yml new file mode 100644 index 0000000000..7a4789ee2b --- /dev/null +++ b/third_party/rust/rusqlite/codecov.yml @@ -0,0 +1,11 @@ +ignore: + - "libsqlite3-sys/bindgen-bindings" + - "libsqlite3-sys/sqlite3" +coverage: + status: + project: + default: + informational: true + patch: + default: + informational: true diff --git a/third_party/rust/rusqlite/publish-ghp-docs.sh b/third_party/rust/rusqlite/publish-ghp-docs.sh new file mode 100755 index 0000000000..14f358a32f --- /dev/null +++ b/third_party/rust/rusqlite/publish-ghp-docs.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +git describe --exact-match --tags $(git log -n1 --pretty='%h') >/dev/null 2>&1 +if [[ $? != 0 ]]; then + echo "Should not publish tags from an untagged commit!" + exit 1 +fi + +cd $(git rev-parse --show-toplevel) +rm -rf target/doc/ +rustup run nightly cargo doc --no-deps --features "backup blob chrono functions limits load_extension serde_json trace" +echo '<meta http-equiv=refresh content=0;url=rusqlite/index.html>' > target/doc/index.html +ghp-import target/doc +git push origin gh-pages:gh-pages diff --git a/third_party/rust/rusqlite/src/backup.rs b/third_party/rust/rusqlite/src/backup.rs new file mode 100644 index 0000000000..4654907204 --- /dev/null +++ b/third_party/rust/rusqlite/src/backup.rs @@ -0,0 +1,436 @@ +//! `feature = "backup"` Online SQLite backup API. +//! +//! To create a `Backup`, you must have two distinct `Connection`s - one +//! for the source (which can be used while the backup is running) and one for +//! the destination (which cannot). A `Backup` handle exposes three methods: +//! `step` will attempt to back up a specified number of pages, `progress` gets +//! the current progress of the backup as of the last call to `step`, and +//! `run_to_completion` will attempt to back up the entire source database, +//! allowing you to specify how many pages are backed up at a time and how long +//! the thread should sleep between chunks of pages. +//! +//! The following example is equivalent to "Example 2: Online Backup of a +//! Running Database" from [SQLite's Online Backup API +//! documentation](https://www.sqlite.org/backup.html). +//! +//! ```rust,no_run +//! # use rusqlite::{backup, Connection, Result}; +//! # use std::path::Path; +//! # use std::time; +//! +//! fn backup_db<P: AsRef<Path>>( +//! src: &Connection, +//! dst: P, +//! progress: fn(backup::Progress), +//! ) -> Result<()> { +//! let mut dst = Connection::open(dst)?; +//! let backup = backup::Backup::new(src, &mut dst)?; +//! backup.run_to_completion(5, time::Duration::from_millis(250), Some(progress)) +//! } +//! ``` + +use std::marker::PhantomData; +use std::path::Path; +use std::ptr; + +use std::os::raw::c_int; +use std::thread; +use std::time::Duration; + +use crate::ffi; + +use crate::error::{error_from_handle, error_from_sqlite_code}; +use crate::{Connection, DatabaseName, Result}; + +impl Connection { + /// `feature = "backup"` Back up the `name` database to the given + /// destination path. + /// + /// If `progress` is not `None`, it will be called periodically + /// until the backup completes. + /// + /// For more fine-grained control over the backup process (e.g., + /// to sleep periodically during the backup or to back up to an + /// already-open database connection), see the `backup` module. + /// + /// # Failure + /// + /// Will return `Err` if the destination path cannot be opened + /// or if the backup fails. + pub fn backup<P: AsRef<Path>>( + &self, + name: DatabaseName<'_>, + dst_path: P, + progress: Option<fn(Progress)>, + ) -> Result<()> { + use self::StepResult::{Busy, Done, Locked, More}; + let mut dst = Connection::open(dst_path)?; + let backup = Backup::new_with_names(self, name, &mut dst, DatabaseName::Main)?; + + let mut r = More; + while r == More { + r = backup.step(100)?; + if let Some(f) = progress { + f(backup.progress()); + } + } + + match r { + Done => Ok(()), + Busy => Err(unsafe { error_from_handle(ptr::null_mut(), ffi::SQLITE_BUSY) }), + Locked => Err(unsafe { error_from_handle(ptr::null_mut(), ffi::SQLITE_LOCKED) }), + More => unreachable!(), + } + } + + /// `feature = "backup"` Restore the given source path into the + /// `name` database. If `progress` is not `None`, it will be + /// called periodically until the restore completes. + /// + /// For more fine-grained control over the restore process (e.g., + /// to sleep periodically during the restore or to restore from an + /// already-open database connection), see the `backup` module. + /// + /// # Failure + /// + /// Will return `Err` if the destination path cannot be opened + /// or if the restore fails. + pub fn restore<P: AsRef<Path>, F: Fn(Progress)>( + &mut self, + name: DatabaseName<'_>, + src_path: P, + progress: Option<F>, + ) -> Result<()> { + use self::StepResult::{Busy, Done, Locked, More}; + let src = Connection::open(src_path)?; + let restore = Backup::new_with_names(&src, DatabaseName::Main, self, name)?; + + let mut r = More; + let mut busy_count = 0i32; + 'restore_loop: while r == More || r == Busy { + r = restore.step(100)?; + if let Some(ref f) = progress { + f(restore.progress()); + } + if r == Busy { + busy_count += 1; + if busy_count >= 3 { + break 'restore_loop; + } + thread::sleep(Duration::from_millis(100)); + } + } + + match r { + Done => Ok(()), + Busy => Err(unsafe { error_from_handle(ptr::null_mut(), ffi::SQLITE_BUSY) }), + Locked => Err(unsafe { error_from_handle(ptr::null_mut(), ffi::SQLITE_LOCKED) }), + More => unreachable!(), + } + } +} + +/// `feature = "backup"` Possible successful results of calling `Backup::step`. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum StepResult { + /// The backup is complete. + Done, + + /// The step was successful but there are still more pages that need to be + /// backed up. + More, + + /// The step failed because appropriate locks could not be aquired. This is + /// not a fatal error - the step can be retried. + Busy, + + /// The step failed because the source connection was writing to the + /// database. This is not a fatal error - the step can be retried. + Locked, +} + +/// `feature = "backup"` Struct specifying the progress of a backup. The +/// percentage completion can be calculated as `(pagecount - remaining) / +/// pagecount`. The progress of a backup is as of the last call to `step` - if +/// the source database is modified after a call to `step`, the progress value +/// will become outdated and potentially incorrect. +#[derive(Copy, Clone, Debug)] +pub struct Progress { + /// Number of pages in the source database that still need to be backed up. + pub remaining: c_int, + /// Total number of pages in the source database. + pub pagecount: c_int, +} + +/// `feature = "backup"` A handle to an online backup. +pub struct Backup<'a, 'b> { + phantom_from: PhantomData<&'a Connection>, + phantom_to: PhantomData<&'b Connection>, + b: *mut ffi::sqlite3_backup, +} + +impl Backup<'_, '_> { + /// Attempt to create a new handle that will allow backups from `from` to + /// `to`. Note that `to` is a `&mut` - this is because SQLite forbids any + /// API calls on the destination of a backup while the backup is taking + /// place. + /// + /// # Failure + /// + /// Will return `Err` if the underlying `sqlite3_backup_init` call returns + /// `NULL`. + pub fn new<'a, 'b>(from: &'a Connection, to: &'b mut Connection) -> Result<Backup<'a, 'b>> { + Backup::new_with_names(from, DatabaseName::Main, to, DatabaseName::Main) + } + + /// Attempt to create a new handle that will allow backups from the + /// `from_name` database of `from` to the `to_name` database of `to`. Note + /// that `to` is a `&mut` - this is because SQLite forbids any API calls on + /// the destination of a backup while the backup is taking place. + /// + /// # Failure + /// + /// Will return `Err` if the underlying `sqlite3_backup_init` call returns + /// `NULL`. + pub fn new_with_names<'a, 'b>( + from: &'a Connection, + from_name: DatabaseName<'_>, + to: &'b mut Connection, + to_name: DatabaseName<'_>, + ) -> Result<Backup<'a, 'b>> { + let to_name = to_name.to_cstring()?; + let from_name = from_name.to_cstring()?; + + let to_db = to.db.borrow_mut().db; + + let b = unsafe { + let b = ffi::sqlite3_backup_init( + to_db, + to_name.as_ptr(), + from.db.borrow_mut().db, + from_name.as_ptr(), + ); + if b.is_null() { + return Err(error_from_handle(to_db, ffi::sqlite3_errcode(to_db))); + } + b + }; + + Ok(Backup { + phantom_from: PhantomData, + phantom_to: PhantomData, + b, + }) + } + + /// Gets the progress of the backup as of the last call to `step`. + pub fn progress(&self) -> Progress { + unsafe { + Progress { + remaining: ffi::sqlite3_backup_remaining(self.b), + pagecount: ffi::sqlite3_backup_pagecount(self.b), + } + } + } + + /// Attempts to back up the given number of pages. If `num_pages` is + /// negative, will attempt to back up all remaining pages. This will hold a + /// lock on the source database for the duration, so it is probably not + /// what you want for databases that are currently active (see + /// `run_to_completion` for a better alternative). + /// + /// # Failure + /// + /// Will return `Err` if the underlying `sqlite3_backup_step` call returns + /// an error code other than `DONE`, `OK`, `BUSY`, or `LOCKED`. `BUSY` and + /// `LOCKED` are transient errors and are therefore returned as possible + /// `Ok` values. + pub fn step(&self, num_pages: c_int) -> Result<StepResult> { + use self::StepResult::{Busy, Done, Locked, More}; + + let rc = unsafe { ffi::sqlite3_backup_step(self.b, num_pages) }; + match rc { + ffi::SQLITE_DONE => Ok(Done), + ffi::SQLITE_OK => Ok(More), + ffi::SQLITE_BUSY => Ok(Busy), + ffi::SQLITE_LOCKED => Ok(Locked), + _ => Err(error_from_sqlite_code(rc, None)), + } + } + + /// Attempts to run the entire backup. Will call `step(pages_per_step)` as + /// many times as necessary, sleeping for `pause_between_pages` between + /// each call to give the source database time to process any pending + /// queries. This is a direct implementation of "Example 2: Online Backup + /// of a Running Database" from [SQLite's Online Backup API + /// documentation](https://www.sqlite.org/backup.html). + /// + /// If `progress` is not `None`, it will be called after each step with the + /// current progress of the backup. Note that is possible the progress may + /// not change if the step returns `Busy` or `Locked` even though the + /// backup is still running. + /// + /// # Failure + /// + /// Will return `Err` if any of the calls to `step` return `Err`. + pub fn run_to_completion( + &self, + pages_per_step: c_int, + pause_between_pages: Duration, + progress: Option<fn(Progress)>, + ) -> Result<()> { + use self::StepResult::{Busy, Done, Locked, More}; + + assert!(pages_per_step > 0, "pages_per_step must be positive"); + + loop { + let r = self.step(pages_per_step)?; + if let Some(progress) = progress { + progress(self.progress()) + } + match r { + More | Busy | Locked => thread::sleep(pause_between_pages), + Done => return Ok(()), + } + } + } +} + +impl Drop for Backup<'_, '_> { + fn drop(&mut self) { + unsafe { ffi::sqlite3_backup_finish(self.b) }; + } +} + +#[cfg(test)] +mod test { + use super::Backup; + use crate::{Connection, DatabaseName, NO_PARAMS}; + use std::time::Duration; + + #[test] + fn test_backup() { + let src = Connection::open_in_memory().unwrap(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(42); + END;"; + src.execute_batch(sql).unwrap(); + + let mut dst = Connection::open_in_memory().unwrap(); + + { + let backup = Backup::new(&src, &mut dst).unwrap(); + backup.step(-1).unwrap(); + } + + let the_answer: i64 = dst + .query_row("SELECT x FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(42, the_answer); + + src.execute_batch("INSERT INTO foo VALUES(43)").unwrap(); + + { + let backup = Backup::new(&src, &mut dst).unwrap(); + backup + .run_to_completion(5, Duration::from_millis(250), None) + .unwrap(); + } + + let the_answer: i64 = dst + .query_row("SELECT SUM(x) FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(42 + 43, the_answer); + } + + #[test] + fn test_backup_temp() { + let src = Connection::open_in_memory().unwrap(); + let sql = "BEGIN; + CREATE TEMPORARY TABLE foo(x INTEGER); + INSERT INTO foo VALUES(42); + END;"; + src.execute_batch(sql).unwrap(); + + let mut dst = Connection::open_in_memory().unwrap(); + + { + let backup = + Backup::new_with_names(&src, DatabaseName::Temp, &mut dst, DatabaseName::Main) + .unwrap(); + backup.step(-1).unwrap(); + } + + let the_answer: i64 = dst + .query_row("SELECT x FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(42, the_answer); + + src.execute_batch("INSERT INTO foo VALUES(43)").unwrap(); + + { + let backup = + Backup::new_with_names(&src, DatabaseName::Temp, &mut dst, DatabaseName::Main) + .unwrap(); + backup + .run_to_completion(5, Duration::from_millis(250), None) + .unwrap(); + } + + let the_answer: i64 = dst + .query_row("SELECT SUM(x) FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(42 + 43, the_answer); + } + + #[test] + fn test_backup_attached() { + let src = Connection::open_in_memory().unwrap(); + let sql = "ATTACH DATABASE ':memory:' AS my_attached; + BEGIN; + CREATE TABLE my_attached.foo(x INTEGER); + INSERT INTO my_attached.foo VALUES(42); + END;"; + src.execute_batch(sql).unwrap(); + + let mut dst = Connection::open_in_memory().unwrap(); + + { + let backup = Backup::new_with_names( + &src, + DatabaseName::Attached("my_attached"), + &mut dst, + DatabaseName::Main, + ) + .unwrap(); + backup.step(-1).unwrap(); + } + + let the_answer: i64 = dst + .query_row("SELECT x FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(42, the_answer); + + src.execute_batch("INSERT INTO foo VALUES(43)").unwrap(); + + { + let backup = Backup::new_with_names( + &src, + DatabaseName::Attached("my_attached"), + &mut dst, + DatabaseName::Main, + ) + .unwrap(); + backup + .run_to_completion(5, Duration::from_millis(250), None) + .unwrap(); + } + + let the_answer: i64 = dst + .query_row("SELECT SUM(x) FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(42 + 43, the_answer); + } +} diff --git a/third_party/rust/rusqlite/src/blob/mod.rs b/third_party/rust/rusqlite/src/blob/mod.rs new file mode 100644 index 0000000000..7d7ec3d03a --- /dev/null +++ b/third_party/rust/rusqlite/src/blob/mod.rs @@ -0,0 +1,547 @@ +//! `feature = "blob"` Incremental BLOB I/O. +//! +//! Note that SQLite does not provide API-level access to change the size of a +//! BLOB; that must be performed through SQL statements. +//! +//! There are two choices for how to perform IO on a [`Blob`]. +//! +//! 1. The implementations it provides of the `std::io::Read`, `std::io::Write`, +//! and `std::io::Seek` traits. +//! +//! 2. A positional IO API, e.g. [`Blob::read_at`], [`Blob::write_at`] and +//! similar. +//! +//! Documenting these in order: +//! +//! ## 1. `std::io` trait implementations. +//! +//! `Blob` conforms to `std::io::Read`, `std::io::Write`, and `std::io::Seek`, +//! so it plays nicely with other types that build on these (such as +//! `std::io::BufReader` and `std::io::BufWriter`). However, you must be careful +//! with the size of the blob. For example, when using a `BufWriter`, the +//! `BufWriter` will accept more data than the `Blob` will allow, so make sure +//! to call `flush` and check for errors. (See the unit tests in this module for +//! an example.) +//! +//! ## 2. Positional IO +//! +//! `Blob`s also offer a `pread` / `pwrite`-style positional IO api in the form +//! of [`Blob::read_at`], [`Blob::write_at`], [`Blob::raw_read_at`], +//! [`Blob::read_at_exact`], and [`Blob::raw_read_at_exact`]. +//! +//! These APIs all take the position to read from or write to from as a +//! parameter, instead of using an internal `pos` value. +//! +//! ### Positional IO Read Variants +//! +//! For the `read` functions, there are several functions provided: +//! +//! - [`Blob::read_at`] +//! - [`Blob::raw_read_at`] +//! - [`Blob::read_at_exact`] +//! - [`Blob::raw_read_at_exact`] +//! +//! These can be divided along two axes: raw/not raw, and exact/inexact: +//! +//! 1. Raw/not raw refers to the type of the destination buffer. The raw +//! functions take a `&mut [MaybeUninit<u8>]` as the destination buffer, +//! where the "normal" functions take a `&mut [u8]`. +//! +//! Using `MaybeUninit` here can be more efficient in some cases, but is +//! often inconvenient, so both are provided. +//! +//! 2. Exact/inexact refers to to whether or not the entire buffer must be +//! filled in order for the call to be considered a success. +//! +//! The "exact" functions require the provided buffer be entirely filled, or +//! they return an error, wheras the "inexact" functions read as much out of +//! the blob as is available, and return how much they were able to read. +//! +//! The inexact functions are preferrable if you do not know the size of the +//! blob already, and the exact functions are preferrable if you do. +//! +//! ### Comparison to using the `std::io` traits: +//! +//! In general, the positional methods offer the following Pro/Cons compared to +//! using the implementation `std::io::{Read, Write, Seek}` we provide for +//! `Blob`: +//! +//! 1. (Pro) There is no need to first seek to a position in order to perform IO +//! on it as the position is a parameter. +//! +//! 2. (Pro) `Blob`'s positional read functions don't mutate the blob in any +//! way, and take `&self`. No `&mut` access required. +//! +//! 3. (Pro) Positional IO functions return `Err(rusqlite::Error)` on failure, +//! rather than `Err(std::io::Error)`. Returning `rusqlite::Error` is more +//! accurate and convenient. +//! +//! Note that for the `std::io` API, no data is lost however, and it can be +//! recovered with `io_err.downcast::<rusqlite::Error>()` (this can be easy +//! to forget, though). +//! +//! 4. (Pro, for now). A `raw` version of the read API exists which can allow +//! reading into a `&mut [MaybeUninit<u8>]` buffer, which avoids a potential +//! costly initialization step. (However, `std::io` traits will certainly +//! gain this someday, which is why this is only a "Pro, for now"). +//! +//! 5. (Con) The set of functions is more bare-bones than what is offered in +//! `std::io`, which has a number of adapters, handy algorithms, further +//! traits. +//! +//! 6. (Con) No meaningful interoperability with other crates, so if you need +//! that you must use `std::io`. +//! +//! To generalize: the `std::io` traits are useful because they conform to a +//! standard interface that a lot of code knows how to handle, however that +//! interface is not a perfect fit for [`Blob`], so another small set of +//! functions is provided as well. +//! +//! # Example (`std::io`) +//! +//! ```rust +//! # use rusqlite::blob::ZeroBlob; +//! # use rusqlite::{Connection, DatabaseName, NO_PARAMS}; +//! # use std::error::Error; +//! # use std::io::{Read, Seek, SeekFrom, Write}; +//! # fn main() -> Result<(), Box<dyn Error>> { +//! let db = Connection::open_in_memory()?; +//! db.execute_batch("CREATE TABLE test_table (content BLOB);")?; +//! +//! // Insert a BLOB into the `content` column of `test_table`. Note that the Blob +//! // I/O API provides no way of inserting or resizing BLOBs in the DB -- this +//! // must be done via SQL. +//! db.execute( +//! "INSERT INTO test_table (content) VALUES (ZEROBLOB(10))", +//! NO_PARAMS, +//! )?; +//! +//! // Get the row id off the BLOB we just inserted. +//! let rowid = db.last_insert_rowid(); +//! // Open the BLOB we just inserted for IO. +//! let mut blob = db.blob_open(DatabaseName::Main, "test_table", "content", rowid, false)?; +//! +//! // Write some data into the blob. Make sure to test that the number of bytes +//! // written matches what you expect; if you try to write too much, the data +//! // will be truncated to the size of the BLOB. +//! let bytes_written = blob.write(b"01234567")?; +//! assert_eq!(bytes_written, 8); +//! +//! // Move back to the start and read into a local buffer. +//! // Same guidance - make sure you check the number of bytes read! +//! blob.seek(SeekFrom::Start(0))?; +//! let mut buf = [0u8; 20]; +//! let bytes_read = blob.read(&mut buf[..])?; +//! assert_eq!(bytes_read, 10); // note we read 10 bytes because the blob has size 10 +//! +//! // Insert another BLOB, this time using a parameter passed in from +//! // rust (potentially with a dynamic size). +//! db.execute("INSERT INTO test_table (content) VALUES (?)", &[ZeroBlob(64)])?; +//! +//! // given a new row ID, we can reopen the blob on that row +//! let rowid = db.last_insert_rowid(); +//! blob.reopen(rowid)?; +//! // Just check that the size is right. +//! assert_eq!(blob.len(), 64); +//! # Ok(()) +//! # } +//! ``` +//! +//! # Example (Positional) +//! +//! ```rust +//! # use rusqlite::blob::ZeroBlob; +//! # use rusqlite::{Connection, DatabaseName, NO_PARAMS}; +//! # use std::error::Error; +//! # fn main() -> Result<(), Box<dyn Error>> { +//! let db = Connection::open_in_memory()?; +//! db.execute_batch("CREATE TABLE test_table (content BLOB);")?; +//! // Insert a blob into the `content` column of `test_table`. Note that the Blob +//! // I/O API provides no way of inserting or resizing blobs in the DB -- this +//! // must be done via SQL. +//! db.execute( +//! "INSERT INTO test_table (content) VALUES (ZEROBLOB(10))", +//! NO_PARAMS, +//! )?; +//! // Get the row id off the blob we just inserted. +//! let rowid = db.last_insert_rowid(); +//! // Open the blob we just inserted for IO. +//! let mut blob = db.blob_open(DatabaseName::Main, "test_table", "content", rowid, false)?; +//! // Write some data into the blob. +//! blob.write_at(b"ABCDEF", 2)?; +//! +//! // Read the whole blob into a local buffer. +//! let mut buf = [0u8; 10]; +//! blob.read_at_exact(&mut buf, 0)?; +//! assert_eq!(&buf, b"\0\0ABCDEF\0\0"); +//! +//! // Insert another blob, this time using a parameter passed in from +//! // rust (potentially with a dynamic size). +//! db.execute("INSERT INTO test_table (content) VALUES (?)", &[ZeroBlob(64)])?; +//! +//! // given a new row ID, we can reopen the blob on that row +//! let rowid = db.last_insert_rowid(); +//! blob.reopen(rowid)?; +//! assert_eq!(blob.len(), 64); +//! # Ok(()) +//! # } +//! ``` +use std::cmp::min; +use std::io; +use std::ptr; + +use super::ffi; +use super::types::{ToSql, ToSqlOutput}; +use crate::{Connection, DatabaseName, Result}; + +mod pos_io; + +/// `feature = "blob"` Handle to an open BLOB. See [`rusqlite::blob`](crate::blob) documentation for +/// in-depth discussion. +pub struct Blob<'conn> { + conn: &'conn Connection, + blob: *mut ffi::sqlite3_blob, + // used by std::io implementations, + pos: i32, +} + +impl Connection { + /// `feature = "blob"` Open a handle to the BLOB located in `row_id`, + /// `column`, `table` in database `db`. + /// + /// # Failure + /// + /// Will return `Err` if `db`/`table`/`column` cannot be converted to a + /// C-compatible string or if the underlying SQLite BLOB open call + /// fails. + pub fn blob_open<'a>( + &'a self, + db: DatabaseName<'_>, + table: &str, + column: &str, + row_id: i64, + read_only: bool, + ) -> Result<Blob<'a>> { + let mut c = self.db.borrow_mut(); + let mut blob = ptr::null_mut(); + let db = db.to_cstring()?; + let table = super::str_to_cstring(table)?; + let column = super::str_to_cstring(column)?; + let rc = unsafe { + ffi::sqlite3_blob_open( + c.db(), + db.as_ptr(), + table.as_ptr(), + column.as_ptr(), + row_id, + if read_only { 0 } else { 1 }, + &mut blob, + ) + }; + c.decode_result(rc).map(|_| Blob { + conn: self, + blob, + pos: 0, + }) + } +} + +impl Blob<'_> { + /// Move a BLOB handle to a new row. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite BLOB reopen call fails. + pub fn reopen(&mut self, row: i64) -> Result<()> { + let rc = unsafe { ffi::sqlite3_blob_reopen(self.blob, row) }; + if rc != ffi::SQLITE_OK { + return self.conn.decode_result(rc); + } + self.pos = 0; + Ok(()) + } + + /// Return the size in bytes of the BLOB. + pub fn size(&self) -> i32 { + unsafe { ffi::sqlite3_blob_bytes(self.blob) } + } + + /// Return the current size in bytes of the BLOB. + pub fn len(&self) -> usize { + use std::convert::TryInto; + self.size().try_into().unwrap() + } + + /// Return true if the BLOB is empty. + pub fn is_empty(&self) -> bool { + self.size() == 0 + } + + /// Close a BLOB handle. + /// + /// Calling `close` explicitly is not required (the BLOB will be closed + /// when the `Blob` is dropped), but it is available so you can get any + /// errors that occur. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite close call fails. + pub fn close(mut self) -> Result<()> { + self.close_() + } + + fn close_(&mut self) -> Result<()> { + let rc = unsafe { ffi::sqlite3_blob_close(self.blob) }; + self.blob = ptr::null_mut(); + self.conn.decode_result(rc) + } +} + +impl io::Read for Blob<'_> { + /// Read data from a BLOB incrementally. Will return Ok(0) if the end of + /// the blob has been reached. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite read call fails. + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + let max_allowed_len = (self.size() - self.pos) as usize; + let n = min(buf.len(), max_allowed_len) as i32; + if n <= 0 { + return Ok(0); + } + let rc = + unsafe { ffi::sqlite3_blob_read(self.blob, buf.as_mut_ptr() as *mut _, n, self.pos) }; + self.conn + .decode_result(rc) + .map(|_| { + self.pos += n; + n as usize + }) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err)) + } +} + +impl io::Write for Blob<'_> { + /// Write data into a BLOB incrementally. Will return `Ok(0)` if the end of + /// the blob has been reached; consider using `Write::write_all(buf)` + /// if you want to get an error if the entirety of the buffer cannot be + /// written. + /// + /// This function may only modify the contents of the BLOB; it is not + /// possible to increase the size of a BLOB using this API. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite write call fails. + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + let max_allowed_len = (self.size() - self.pos) as usize; + let n = min(buf.len(), max_allowed_len) as i32; + if n <= 0 { + return Ok(0); + } + let rc = unsafe { ffi::sqlite3_blob_write(self.blob, buf.as_ptr() as *mut _, n, self.pos) }; + self.conn + .decode_result(rc) + .map(|_| { + self.pos += n; + n as usize + }) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err)) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl io::Seek for Blob<'_> { + /// Seek to an offset, in bytes, in BLOB. + fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> { + let pos = match pos { + io::SeekFrom::Start(offset) => offset as i64, + io::SeekFrom::Current(offset) => i64::from(self.pos) + offset, + io::SeekFrom::End(offset) => i64::from(self.size()) + offset, + }; + + if pos < 0 { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid seek to negative position", + )) + } else if pos > i64::from(self.size()) { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid seek to position past end of blob", + )) + } else { + self.pos = pos as i32; + Ok(pos as u64) + } + } +} + +#[allow(unused_must_use)] +impl Drop for Blob<'_> { + fn drop(&mut self) { + self.close_(); + } +} + +/// `feature = "blob"` BLOB of length N that is filled with zeroes. +/// +/// Zeroblobs are intended to serve as placeholders for BLOBs whose content is +/// later written using incremental BLOB I/O routines. +/// +/// A negative value for the zeroblob results in a zero-length BLOB. +#[derive(Copy, Clone)] +pub struct ZeroBlob(pub i32); + +impl ToSql for ZeroBlob { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + let ZeroBlob(length) = *self; + Ok(ToSqlOutput::ZeroBlob(length)) + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, DatabaseName, Result}; + use std::io::{BufRead, BufReader, BufWriter, Read, Seek, SeekFrom, Write}; + + fn db_with_test_blob() -> Result<(Connection, i64)> { + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE test (content BLOB); + INSERT INTO test VALUES (ZEROBLOB(10)); + END;"; + db.execute_batch(sql)?; + let rowid = db.last_insert_rowid(); + Ok((db, rowid)) + } + + #[test] + fn test_blob() { + let (db, rowid) = db_with_test_blob().unwrap(); + + let mut blob = db + .blob_open(DatabaseName::Main, "test", "content", rowid, false) + .unwrap(); + assert_eq!(4, blob.write(b"Clob").unwrap()); + assert_eq!(6, blob.write(b"567890xxxxxx").unwrap()); // cannot write past 10 + assert_eq!(0, blob.write(b"5678").unwrap()); // still cannot write past 10 + + blob.reopen(rowid).unwrap(); + blob.close().unwrap(); + + blob = db + .blob_open(DatabaseName::Main, "test", "content", rowid, true) + .unwrap(); + let mut bytes = [0u8; 5]; + assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"Clob5"); + assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"67890"); + assert_eq!(0, blob.read(&mut bytes[..]).unwrap()); + + blob.seek(SeekFrom::Start(2)).unwrap(); + assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"ob567"); + + // only first 4 bytes of `bytes` should be read into + blob.seek(SeekFrom::Current(-1)).unwrap(); + assert_eq!(4, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"78907"); + + blob.seek(SeekFrom::End(-6)).unwrap(); + assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"56789"); + + blob.reopen(rowid).unwrap(); + assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"Clob5"); + + // should not be able to seek negative or past end + assert!(blob.seek(SeekFrom::Current(-20)).is_err()); + assert!(blob.seek(SeekFrom::End(0)).is_ok()); + assert!(blob.seek(SeekFrom::Current(1)).is_err()); + + // write_all should detect when we return Ok(0) because there is no space left, + // and return a write error + blob.reopen(rowid).unwrap(); + assert!(blob.write_all(b"0123456789x").is_err()); + } + + #[test] + fn test_blob_in_bufreader() { + let (db, rowid) = db_with_test_blob().unwrap(); + + let mut blob = db + .blob_open(DatabaseName::Main, "test", "content", rowid, false) + .unwrap(); + assert_eq!(8, blob.write(b"one\ntwo\n").unwrap()); + + blob.reopen(rowid).unwrap(); + let mut reader = BufReader::new(blob); + + let mut line = String::new(); + assert_eq!(4, reader.read_line(&mut line).unwrap()); + assert_eq!("one\n", line); + + line.truncate(0); + assert_eq!(4, reader.read_line(&mut line).unwrap()); + assert_eq!("two\n", line); + + line.truncate(0); + assert_eq!(2, reader.read_line(&mut line).unwrap()); + assert_eq!("\0\0", line); + } + + #[test] + fn test_blob_in_bufwriter() { + let (db, rowid) = db_with_test_blob().unwrap(); + + { + let blob = db + .blob_open(DatabaseName::Main, "test", "content", rowid, false) + .unwrap(); + let mut writer = BufWriter::new(blob); + + // trying to write too much and then flush should fail + assert_eq!(8, writer.write(b"01234567").unwrap()); + assert_eq!(8, writer.write(b"01234567").unwrap()); + assert!(writer.flush().is_err()); + } + + { + // ... but it should've written the first 10 bytes + let mut blob = db + .blob_open(DatabaseName::Main, "test", "content", rowid, false) + .unwrap(); + let mut bytes = [0u8; 10]; + assert_eq!(10, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(b"0123456701", &bytes); + } + + { + let blob = db + .blob_open(DatabaseName::Main, "test", "content", rowid, false) + .unwrap(); + let mut writer = BufWriter::new(blob); + + // trying to write_all too much should fail + writer.write_all(b"aaaaaaaaaabbbbb").unwrap(); + assert!(writer.flush().is_err()); + } + + { + // ... but it should've written the first 10 bytes + let mut blob = db + .blob_open(DatabaseName::Main, "test", "content", rowid, false) + .unwrap(); + let mut bytes = [0u8; 10]; + assert_eq!(10, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(b"aaaaaaaaaa", &bytes); + } + } +} diff --git a/third_party/rust/rusqlite/src/blob/pos_io.rs b/third_party/rust/rusqlite/src/blob/pos_io.rs new file mode 100644 index 0000000000..9f1f994a99 --- /dev/null +++ b/third_party/rust/rusqlite/src/blob/pos_io.rs @@ -0,0 +1,281 @@ +use super::Blob; + +use std::convert::TryFrom; +use std::mem::MaybeUninit; +use std::slice::from_raw_parts_mut; + +use crate::ffi; +use crate::{Error, Result}; + +impl<'conn> Blob<'conn> { + /// Write `buf` to `self` starting at `write_start`, returning an error if + /// `write_start + buf.len()` is past the end of the blob. + /// + /// If an error is returned, no data is written. + /// + /// Note: the blob cannot be resized using this function -- that must be + /// done using SQL (for example, an `UPDATE` statement). + /// + /// Note: This is part of the positional I/O API, and thus takes an absolute + /// position write to, instead of using the internal position that can be + /// manipulated by the `std::io` traits. + /// + /// Unlike the similarly named [`FileExt::write_at`][fext_write_at] function + /// (from `std::os::unix`), it's always an error to perform a "short write". + /// + /// [fext_write_at]: https://doc.rust-lang.org/std/os/unix/fs/trait.FileExt.html#tymethod.write_at + #[inline] + pub fn write_at(&mut self, buf: &[u8], write_start: usize) -> Result<()> { + let len = self.len(); + + if buf.len().saturating_add(write_start) > len { + return Err(Error::BlobSizeError); + } + // We know `len` fits in an `i32`, so either: + // + // 1. `buf.len() + write_start` overflows, in which case we'd hit the + // return above (courtesy of `saturating_add`). + // + // 2. `buf.len() + write_start` doesn't overflow but is larger than len, + // in which case ditto. + // + // 3. `buf.len() + write_start` doesn't overflow but is less than len. + // This means that both `buf.len()` and `write_start` can also be + // losslessly converted to i32, since `len` came from an i32. + // Sanity check the above. + debug_assert!(i32::try_from(write_start).is_ok() && i32::try_from(buf.len()).is_ok()); + unsafe { + check!(ffi::sqlite3_blob_write( + self.blob, + buf.as_ptr() as *const _, + buf.len() as i32, + write_start as i32, + )); + } + Ok(()) + } + + /// An alias for `write_at` provided for compatibility with the conceptually + /// equivalent [`std::os::unix::FileExt::write_all_at`][write_all_at] + /// function from libstd: + /// + /// [write_all_at]: https://doc.rust-lang.org/std/os/unix/fs/trait.FileExt.html#method.write_all_at + #[inline] + pub fn write_all_at(&mut self, buf: &[u8], write_start: usize) -> Result<()> { + self.write_at(buf, write_start) + } + + /// Read as much as possible from `offset` to `offset + buf.len()` out of + /// `self`, writing into `buf`. On success, returns the number of bytes + /// written. + /// + /// If there's insufficient data in `self`, then the returned value will be + /// less than `buf.len()`. + /// + /// See also [`Blob::raw_read_at`], which can take an uninitialized buffer, + /// or [`Blob::read_at_exact`] which returns an error if the entire `buf` is + /// not read. + /// + /// Note: This is part of the positional I/O API, and thus takes an absolute + /// position to read from, instead of using the internal position that can + /// be manipulated by the `std::io` traits. Consequently, it does not change + /// that value either. + #[inline] + pub fn read_at(&self, buf: &mut [u8], read_start: usize) -> Result<usize> { + // Safety: this is safe because `raw_read_at` never stores uninitialized + // data into `as_uninit`. + let as_uninit: &mut [MaybeUninit<u8>] = + unsafe { from_raw_parts_mut(buf.as_mut_ptr() as *mut _, buf.len()) }; + self.raw_read_at(as_uninit, read_start).map(|s| s.len()) + } + + /// Read as much as possible from `offset` to `offset + buf.len()` out of + /// `self`, writing into `buf`. On success, returns the portion of `buf` + /// which was initialized by this call. + /// + /// If there's insufficient data in `self`, then the returned value will be + /// shorter than `buf`. + /// + /// See also [`Blob::read_at`], which takes a `&mut [u8]` buffer instead of + /// a slice of `MaybeUninit<u8>`. + /// + /// Note: This is part of the positional I/O API, and thus takes an absolute + /// position to read from, instead of using the internal position that can + /// be manipulated by the `std::io` traits. Consequently, it does not change + /// that value either. + #[inline] + pub fn raw_read_at<'a>( + &self, + buf: &'a mut [MaybeUninit<u8>], + read_start: usize, + ) -> Result<&'a mut [u8]> { + let len = self.len(); + + let read_len = match len.checked_sub(read_start) { + None | Some(0) => 0, + Some(v) => v.min(buf.len()), + }; + + if read_len == 0 { + // We could return `Ok(&mut [])`, but it seems confusing that the + // pointers don't match, so fabricate a empty slice of u8 with the + // same base pointer as `buf`. + let empty = unsafe { from_raw_parts_mut(buf.as_mut_ptr() as *mut u8, 0) }; + return Ok(empty); + } + + // At this point we believe `read_start as i32` is lossless because: + // + // 1. `len as i32` is known to be lossless, since it comes from a SQLite + // api returning an i32. + // + // 2. If we got here, `len.checked_sub(read_start)` was Some (or else + // we'd have hit the `if read_len == 0` early return), so `len` must + // be larger than `read_start`, and so it must fit in i32 as well. + debug_assert!(i32::try_from(read_start).is_ok()); + + // We also believe that `read_start + read_len <= len` because: + // + // 1. This is equivalent to `read_len <= len - read_start` via algebra. + // 2. We know that `read_len` is `min(len - read_start, buf.len())` + // 3. Expanding, this is `min(len - read_start, buf.len()) <= len - read_start`, + // or `min(A, B) <= A` which is clearly true. + // + // Note that this stuff is in debug_assert so no need to use checked_add + // and such -- we'll always panic on overflow in debug builds. + debug_assert!(read_start + read_len <= len); + + // These follow naturally. + debug_assert!(buf.len() >= read_len); + debug_assert!(i32::try_from(buf.len()).is_ok()); + debug_assert!(i32::try_from(read_len).is_ok()); + + unsafe { + check!(ffi::sqlite3_blob_read( + self.blob, + buf.as_mut_ptr() as *mut _, + read_len as i32, + read_start as i32, + )); + + Ok(from_raw_parts_mut(buf.as_mut_ptr() as *mut u8, read_len)) + } + } + + /// Equivalent to [`Blob::read_at`], but returns a `BlobSizeError` if `buf` + /// is not fully initialized. + #[inline] + pub fn read_at_exact(&self, buf: &mut [u8], read_start: usize) -> Result<()> { + let n = self.read_at(buf, read_start)?; + if n != buf.len() { + Err(Error::BlobSizeError) + } else { + Ok(()) + } + } + + /// Equivalent to [`Blob::raw_read_at`], but returns a `BlobSizeError` if + /// `buf` is not fully initialized. + #[inline] + pub fn raw_read_at_exact<'a>( + &self, + buf: &'a mut [MaybeUninit<u8>], + read_start: usize, + ) -> Result<&'a mut [u8]> { + let buflen = buf.len(); + let initted = self.raw_read_at(buf, read_start)?; + if initted.len() != buflen { + Err(Error::BlobSizeError) + } else { + Ok(initted) + } + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, DatabaseName, NO_PARAMS}; + // to ensure we don't modify seek pos + use std::io::Seek as _; + + #[test] + fn test_pos_io() { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE test_table(content BLOB);") + .unwrap(); + db.execute( + "INSERT INTO test_table(content) VALUES (ZEROBLOB(10))", + NO_PARAMS, + ) + .unwrap(); + + let rowid = db.last_insert_rowid(); + let mut blob = db + .blob_open(DatabaseName::Main, "test_table", "content", rowid, false) + .unwrap(); + // modify the seek pos to ensure we aren't using it or modifying it. + blob.seek(std::io::SeekFrom::Start(1)).unwrap(); + + let one2ten: [u8; 10] = [1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + blob.write_at(&one2ten, 0).unwrap(); + + let mut s = [0u8; 10]; + blob.read_at_exact(&mut s, 0).unwrap(); + assert_eq!(&s, &one2ten, "write should go through"); + assert!(blob.read_at_exact(&mut s, 1).is_err()); + + blob.read_at_exact(&mut s, 0).unwrap(); + assert_eq!(&s, &one2ten, "should be unchanged"); + + let mut fives = [0u8; 5]; + blob.read_at_exact(&mut fives, 0).unwrap(); + assert_eq!(&fives, &[1u8, 2, 3, 4, 5]); + + blob.read_at_exact(&mut fives, 5).unwrap(); + assert_eq!(&fives, &[6u8, 7, 8, 9, 10]); + assert!(blob.read_at_exact(&mut fives, 7).is_err()); + assert!(blob.read_at_exact(&mut fives, 12).is_err()); + assert!(blob.read_at_exact(&mut fives, 10).is_err()); + assert!(blob.read_at_exact(&mut fives, i32::MAX as usize).is_err()); + assert!(blob + .read_at_exact(&mut fives, i32::MAX as usize + 1) + .is_err()); + + // zero length writes are fine if in bounds + blob.read_at_exact(&mut [], 10).unwrap(); + blob.read_at_exact(&mut [], 0).unwrap(); + blob.read_at_exact(&mut [], 5).unwrap(); + + blob.write_all_at(&[16, 17, 18, 19, 20], 5).unwrap(); + blob.read_at_exact(&mut s, 0).unwrap(); + assert_eq!(&s, &[1u8, 2, 3, 4, 5, 16, 17, 18, 19, 20]); + + assert!(blob.write_at(&[100, 99, 98, 97, 96], 6).is_err()); + assert!(blob + .write_at(&[100, 99, 98, 97, 96], i32::MAX as usize) + .is_err()); + assert!(blob + .write_at(&[100, 99, 98, 97, 96], i32::MAX as usize + 1) + .is_err()); + + blob.read_at_exact(&mut s, 0).unwrap(); + assert_eq!(&s, &[1u8, 2, 3, 4, 5, 16, 17, 18, 19, 20]); + + let mut s2: [std::mem::MaybeUninit<u8>; 10] = [std::mem::MaybeUninit::uninit(); 10]; + { + let read = blob.raw_read_at_exact(&mut s2, 0).unwrap(); + assert_eq!(read, &s); + assert!(std::ptr::eq(read.as_ptr(), s2.as_ptr().cast())); + } + + let mut empty = []; + assert!(std::ptr::eq( + blob.raw_read_at_exact(&mut empty, 0).unwrap().as_ptr(), + empty.as_ptr().cast(), + )); + assert!(blob.raw_read_at_exact(&mut s2, 5).is_err()); + + let end_pos = blob.seek(std::io::SeekFrom::Current(0)).unwrap(); + assert_eq!(end_pos, 1); + } +} diff --git a/third_party/rust/rusqlite/src/busy.rs b/third_party/rust/rusqlite/src/busy.rs new file mode 100644 index 0000000000..b87504a974 --- /dev/null +++ b/third_party/rust/rusqlite/src/busy.rs @@ -0,0 +1,176 @@ +///! Busy handler (when the database is locked) +use std::convert::TryInto; +use std::mem; +use std::os::raw::{c_int, c_void}; +use std::panic::catch_unwind; +use std::ptr; +use std::time::Duration; + +use crate::ffi; +use crate::{Connection, InnerConnection, Result}; + +impl Connection { + /// Set a busy handler that sleeps for a specified amount of time when a + /// table is locked. The handler will sleep multiple times until at + /// least "ms" milliseconds of sleeping have accumulated. + /// + /// Calling this routine with an argument equal to zero turns off all busy + /// handlers. + /// + /// There can only be a single busy handler for a particular database + /// connection at any given moment. If another busy handler was defined + /// (using `busy_handler`) prior to calling this routine, that other + /// busy handler is cleared. + pub fn busy_timeout(&self, timeout: Duration) -> Result<()> { + let ms: i32 = timeout + .as_secs() + .checked_mul(1000) + .and_then(|t| t.checked_add(timeout.subsec_millis().into())) + .and_then(|t| t.try_into().ok()) + .expect("too big"); + self.db.borrow_mut().busy_timeout(ms) + } + + /// Register a callback to handle `SQLITE_BUSY` errors. + /// + /// If the busy callback is `None`, then `SQLITE_BUSY is returned + /// immediately upon encountering the lock.` The argument to the busy + /// handler callback is the number of times that the + /// busy handler has been invoked previously for the + /// same locking event. If the busy callback returns `false`, then no + /// additional attempts are made to access the + /// database and `SQLITE_BUSY` is returned to the + /// application. If the callback returns `true`, then another attempt + /// is made to access the database and the cycle repeats. + /// + /// There can only be a single busy handler defined for each database + /// connection. Setting a new busy handler clears any previously set + /// handler. Note that calling `busy_timeout()` or evaluating `PRAGMA + /// busy_timeout=N` will change the busy handler and thus + /// clear any previously set busy handler. + pub fn busy_handler(&self, callback: Option<fn(i32) -> bool>) -> Result<()> { + unsafe extern "C" fn busy_handler_callback(p_arg: *mut c_void, count: c_int) -> c_int { + let handler_fn: fn(i32) -> bool = mem::transmute(p_arg); + if let Ok(true) = catch_unwind(|| handler_fn(count)) { + 1 + } else { + 0 + } + } + let mut c = self.db.borrow_mut(); + let r = match callback { + Some(f) => unsafe { + ffi::sqlite3_busy_handler(c.db(), Some(busy_handler_callback), f as *mut c_void) + }, + None => unsafe { ffi::sqlite3_busy_handler(c.db(), None, ptr::null_mut()) }, + }; + c.decode_result(r) + } +} + +impl InnerConnection { + fn busy_timeout(&mut self, timeout: c_int) -> Result<()> { + let r = unsafe { ffi::sqlite3_busy_timeout(self.db, timeout) }; + self.decode_result(r) + } +} + +#[cfg(test)] +mod test { + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::mpsc::sync_channel; + use std::thread; + use std::time::Duration; + + use crate::{Connection, Error, ErrorCode, Result, TransactionBehavior, NO_PARAMS}; + + #[test] + fn test_default_busy() { + let temp_dir = tempfile::tempdir().unwrap(); + let path = temp_dir.path().join("test.db3"); + + let mut db1 = Connection::open(&path).unwrap(); + let tx1 = db1 + .transaction_with_behavior(TransactionBehavior::Exclusive) + .unwrap(); + let db2 = Connection::open(&path).unwrap(); + let r: Result<()> = db2.query_row("PRAGMA schema_version", NO_PARAMS, |_| unreachable!()); + match r.unwrap_err() { + Error::SqliteFailure(err, _) => { + assert_eq!(err.code, ErrorCode::DatabaseBusy); + } + err => panic!("Unexpected error {}", err), + } + tx1.rollback().unwrap(); + } + + #[test] + #[ignore] // FIXME: unstable + fn test_busy_timeout() { + let temp_dir = tempfile::tempdir().unwrap(); + let path = temp_dir.path().join("test.db3"); + + let db2 = Connection::open(&path).unwrap(); + db2.busy_timeout(Duration::from_secs(1)).unwrap(); + + let (rx, tx) = sync_channel(0); + let child = thread::spawn(move || { + let mut db1 = Connection::open(&path).unwrap(); + let tx1 = db1 + .transaction_with_behavior(TransactionBehavior::Exclusive) + .unwrap(); + rx.send(1).unwrap(); + thread::sleep(Duration::from_millis(100)); + tx1.rollback().unwrap(); + }); + + assert_eq!(tx.recv().unwrap(), 1); + let _ = db2 + .query_row("PRAGMA schema_version", NO_PARAMS, |row| { + row.get::<_, i32>(0) + }) + .expect("unexpected error"); + + child.join().unwrap(); + } + + #[test] + #[ignore] // FIXME: unstable + fn test_busy_handler() { + lazy_static::lazy_static! { + static ref CALLED: AtomicBool = AtomicBool::new(false); + } + fn busy_handler(_: i32) -> bool { + CALLED.store(true, Ordering::Relaxed); + thread::sleep(Duration::from_millis(100)); + true + } + + let temp_dir = tempfile::tempdir().unwrap(); + let path = temp_dir.path().join("test.db3"); + + let db2 = Connection::open(&path).unwrap(); + db2.busy_handler(Some(busy_handler)).unwrap(); + + let (rx, tx) = sync_channel(0); + let child = thread::spawn(move || { + let mut db1 = Connection::open(&path).unwrap(); + let tx1 = db1 + .transaction_with_behavior(TransactionBehavior::Exclusive) + .unwrap(); + rx.send(1).unwrap(); + thread::sleep(Duration::from_millis(100)); + tx1.rollback().unwrap(); + }); + + assert_eq!(tx.recv().unwrap(), 1); + let _ = db2 + .query_row("PRAGMA schema_version", NO_PARAMS, |row| { + row.get::<_, i32>(0) + }) + .expect("unexpected error"); + assert_eq!(CALLED.load(Ordering::Relaxed), true); + + child.join().unwrap(); + } +} diff --git a/third_party/rust/rusqlite/src/cache.rs b/third_party/rust/rusqlite/src/cache.rs new file mode 100644 index 0000000000..7dc9d235c6 --- /dev/null +++ b/third_party/rust/rusqlite/src/cache.rs @@ -0,0 +1,360 @@ +//! Prepared statements cache for faster execution. + +use crate::raw_statement::RawStatement; +use crate::{Connection, Result, Statement}; +use hashlink::LruCache; +use std::cell::RefCell; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; + +impl Connection { + /// Prepare a SQL statement for execution, returning a previously prepared + /// (but not currently in-use) statement if one is available. The + /// returned statement will be cached for reuse by future calls to + /// `prepare_cached` once it is dropped. + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn insert_new_people(conn: &Connection) -> Result<()> { + /// { + /// let mut stmt = conn.prepare_cached("INSERT INTO People (name) VALUES (?)")?; + /// stmt.execute(&["Joe Smith"])?; + /// } + /// { + /// // This will return the same underlying SQLite statement handle without + /// // having to prepare it again. + /// let mut stmt = conn.prepare_cached("INSERT INTO People (name) VALUES (?)")?; + /// stmt.execute(&["Bob Jones"])?; + /// } + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + pub fn prepare_cached(&self, sql: &str) -> Result<CachedStatement<'_>> { + self.cache.get(self, sql) + } + + /// Set the maximum number of cached prepared statements this connection + /// will hold. By default, a connection will hold a relatively small + /// number of cached statements. If you need more, or know that you + /// will not use cached statements, you + /// can set the capacity manually using this method. + pub fn set_prepared_statement_cache_capacity(&self, capacity: usize) { + self.cache.set_capacity(capacity) + } + + /// Remove/finalize all prepared statements currently in the cache. + pub fn flush_prepared_statement_cache(&self) { + self.cache.flush() + } +} + +/// Prepared statements LRU cache. +// #[derive(Debug)] // FIXME: https://github.com/kyren/hashlink/pull/4 +pub struct StatementCache(RefCell<LruCache<Arc<str>, RawStatement>>); + +/// Cacheable statement. +/// +/// Statement will return automatically to the cache by default. +/// If you want the statement to be discarded, call `discard()` on it. +pub struct CachedStatement<'conn> { + stmt: Option<Statement<'conn>>, + cache: &'conn StatementCache, +} + +impl<'conn> Deref for CachedStatement<'conn> { + type Target = Statement<'conn>; + + fn deref(&self) -> &Statement<'conn> { + self.stmt.as_ref().unwrap() + } +} + +impl<'conn> DerefMut for CachedStatement<'conn> { + fn deref_mut(&mut self) -> &mut Statement<'conn> { + self.stmt.as_mut().unwrap() + } +} + +impl Drop for CachedStatement<'_> { + #[allow(unused_must_use)] + fn drop(&mut self) { + if let Some(stmt) = self.stmt.take() { + self.cache.cache_stmt(unsafe { stmt.into_raw() }); + } + } +} + +impl CachedStatement<'_> { + fn new<'conn>(stmt: Statement<'conn>, cache: &'conn StatementCache) -> CachedStatement<'conn> { + CachedStatement { + stmt: Some(stmt), + cache, + } + } + + /// Discard the statement, preventing it from being returned to its + /// `Connection`'s collection of cached statements. + pub fn discard(mut self) { + self.stmt = None; + } +} + +impl StatementCache { + /// Create a statement cache. + pub fn with_capacity(capacity: usize) -> StatementCache { + StatementCache(RefCell::new(LruCache::new(capacity))) + } + + fn set_capacity(&self, capacity: usize) { + self.0.borrow_mut().set_capacity(capacity) + } + + // Search the cache for a prepared-statement object that implements `sql`. + // If no such prepared-statement can be found, allocate and prepare a new one. + // + // # Failure + // + // Will return `Err` if no cached statement can be found and the underlying + // SQLite prepare call fails. + fn get<'conn>( + &'conn self, + conn: &'conn Connection, + sql: &str, + ) -> Result<CachedStatement<'conn>> { + let trimmed = sql.trim(); + let mut cache = self.0.borrow_mut(); + let stmt = match cache.remove(trimmed) { + Some(raw_stmt) => Ok(Statement::new(conn, raw_stmt)), + None => conn.prepare(trimmed), + }; + stmt.map(|mut stmt| { + stmt.stmt.set_statement_cache_key(trimmed); + CachedStatement::new(stmt, self) + }) + } + + // Return a statement to the cache. + fn cache_stmt(&self, stmt: RawStatement) { + if stmt.is_null() { + return; + } + let mut cache = self.0.borrow_mut(); + stmt.clear_bindings(); + if let Some(sql) = stmt.statement_cache_key() { + cache.insert(sql, stmt); + } else { + debug_assert!( + false, + "bug in statement cache code, statement returned to cache that without key" + ); + } + } + + fn flush(&self) { + let mut cache = self.0.borrow_mut(); + cache.clear() + } +} + +#[cfg(test)] +mod test { + use super::StatementCache; + use crate::{Connection, NO_PARAMS}; + use fallible_iterator::FallibleIterator; + + impl StatementCache { + fn clear(&self) { + self.0.borrow_mut().clear(); + } + + fn len(&self) -> usize { + self.0.borrow().len() + } + + fn capacity(&self) -> usize { + self.0.borrow().capacity() + } + } + + #[test] + fn test_cache() { + let db = Connection::open_in_memory().unwrap(); + let cache = &db.cache; + let initial_capacity = cache.capacity(); + assert_eq!(0, cache.len()); + assert!(initial_capacity > 0); + + let sql = "PRAGMA schema_version"; + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!(0, cache.len()); + assert_eq!( + 0, + stmt.query_row(NO_PARAMS, |r| r.get::<_, i64>(0)).unwrap() + ); + } + assert_eq!(1, cache.len()); + + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!(0, cache.len()); + assert_eq!( + 0, + stmt.query_row(NO_PARAMS, |r| r.get::<_, i64>(0)).unwrap() + ); + } + assert_eq!(1, cache.len()); + + cache.clear(); + assert_eq!(0, cache.len()); + assert_eq!(initial_capacity, cache.capacity()); + } + + #[test] + fn test_set_capacity() { + let db = Connection::open_in_memory().unwrap(); + let cache = &db.cache; + + let sql = "PRAGMA schema_version"; + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!(0, cache.len()); + assert_eq!( + 0, + stmt.query_row(NO_PARAMS, |r| r.get::<_, i64>(0)).unwrap() + ); + } + assert_eq!(1, cache.len()); + + db.set_prepared_statement_cache_capacity(0); + assert_eq!(0, cache.len()); + + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!(0, cache.len()); + assert_eq!( + 0, + stmt.query_row(NO_PARAMS, |r| r.get::<_, i64>(0)).unwrap() + ); + } + assert_eq!(0, cache.len()); + + db.set_prepared_statement_cache_capacity(8); + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!(0, cache.len()); + assert_eq!( + 0, + stmt.query_row(NO_PARAMS, |r| r.get::<_, i64>(0)).unwrap() + ); + } + assert_eq!(1, cache.len()); + } + + #[test] + fn test_discard() { + let db = Connection::open_in_memory().unwrap(); + let cache = &db.cache; + + let sql = "PRAGMA schema_version"; + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!(0, cache.len()); + assert_eq!( + 0, + stmt.query_row(NO_PARAMS, |r| r.get::<_, i64>(0)).unwrap() + ); + stmt.discard(); + } + assert_eq!(0, cache.len()); + } + + #[test] + fn test_ddl() { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch( + r#" + CREATE TABLE foo (x INT); + INSERT INTO foo VALUES (1); + "#, + ) + .unwrap(); + + let sql = "SELECT * FROM foo"; + + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!( + Ok(Some(1i32)), + stmt.query(NO_PARAMS).unwrap().map(|r| r.get(0)).next() + ); + } + + db.execute_batch( + r#" + ALTER TABLE foo ADD COLUMN y INT; + UPDATE foo SET y = 2; + "#, + ) + .unwrap(); + + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!( + Ok(Some((1i32, 2i32))), + stmt.query(NO_PARAMS) + .unwrap() + .map(|r| Ok((r.get(0)?, r.get(1)?))) + .next() + ); + } + } + + #[test] + fn test_connection_close() { + let conn = Connection::open_in_memory().unwrap(); + conn.prepare_cached("SELECT * FROM sqlite_master;").unwrap(); + + conn.close().expect("connection not closed"); + } + + #[test] + fn test_cache_key() { + let db = Connection::open_in_memory().unwrap(); + let cache = &db.cache; + assert_eq!(0, cache.len()); + + //let sql = " PRAGMA schema_version; -- comment"; + let sql = "PRAGMA schema_version; "; + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!(0, cache.len()); + assert_eq!( + 0, + stmt.query_row(NO_PARAMS, |r| r.get::<_, i64>(0)).unwrap() + ); + } + assert_eq!(1, cache.len()); + + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!(0, cache.len()); + assert_eq!( + 0, + stmt.query_row(NO_PARAMS, |r| r.get::<_, i64>(0)).unwrap() + ); + } + assert_eq!(1, cache.len()); + } + + #[test] + fn test_empty_stmt() { + let conn = Connection::open_in_memory().unwrap(); + conn.prepare_cached("").unwrap(); + } +} diff --git a/third_party/rust/rusqlite/src/collation.rs b/third_party/rust/rusqlite/src/collation.rs new file mode 100644 index 0000000000..1168b75c56 --- /dev/null +++ b/third_party/rust/rusqlite/src/collation.rs @@ -0,0 +1,206 @@ +//! `feature = "collation"` Add, remove, or modify a collation +use std::cmp::Ordering; +use std::os::raw::{c_char, c_int, c_void}; +use std::panic::{catch_unwind, UnwindSafe}; +use std::ptr; +use std::slice; + +use crate::ffi; +use crate::{str_to_cstring, Connection, InnerConnection, Result}; + +// FIXME copy/paste from function.rs +unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) { + drop(Box::from_raw(p as *mut T)); +} + +impl Connection { + /// `feature = "collation"` Add or modify a collation. + pub fn create_collation<C>(&self, collation_name: &str, x_compare: C) -> Result<()> + where + C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static, + { + self.db + .borrow_mut() + .create_collation(collation_name, x_compare) + } + + /// `feature = "collation"` Collation needed callback + pub fn collation_needed( + &self, + x_coll_needed: fn(&Connection, &str) -> Result<()>, + ) -> Result<()> { + self.db.borrow_mut().collation_needed(x_coll_needed) + } + + /// `feature = "collation"` Remove collation. + pub fn remove_collation(&self, collation_name: &str) -> Result<()> { + self.db.borrow_mut().remove_collation(collation_name) + } +} + +impl InnerConnection { + fn create_collation<C>(&mut self, collation_name: &str, x_compare: C) -> Result<()> + where + C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static, + { + unsafe extern "C" fn call_boxed_closure<C>( + arg1: *mut c_void, + arg2: c_int, + arg3: *const c_void, + arg4: c_int, + arg5: *const c_void, + ) -> c_int + where + C: Fn(&str, &str) -> Ordering, + { + let r = catch_unwind(|| { + let boxed_f: *mut C = arg1 as *mut C; + assert!(!boxed_f.is_null(), "Internal error - null function pointer"); + let s1 = { + let c_slice = slice::from_raw_parts(arg3 as *const u8, arg2 as usize); + String::from_utf8_lossy(c_slice) + }; + let s2 = { + let c_slice = slice::from_raw_parts(arg5 as *const u8, arg4 as usize); + String::from_utf8_lossy(c_slice) + }; + (*boxed_f)(s1.as_ref(), s2.as_ref()) + }); + let t = match r { + Err(_) => { + return -1; // FIXME How ? + } + Ok(r) => r, + }; + + match t { + Ordering::Less => -1, + Ordering::Equal => 0, + Ordering::Greater => 1, + } + } + + let boxed_f: *mut C = Box::into_raw(Box::new(x_compare)); + let c_name = str_to_cstring(collation_name)?; + let flags = ffi::SQLITE_UTF8; + let r = unsafe { + ffi::sqlite3_create_collation_v2( + self.db(), + c_name.as_ptr(), + flags, + boxed_f as *mut c_void, + Some(call_boxed_closure::<C>), + Some(free_boxed_value::<C>), + ) + }; + self.decode_result(r) + } + + fn collation_needed( + &mut self, + x_coll_needed: fn(&Connection, &str) -> Result<()>, + ) -> Result<()> { + use std::mem; + unsafe extern "C" fn collation_needed_callback( + arg1: *mut c_void, + arg2: *mut ffi::sqlite3, + e_text_rep: c_int, + arg3: *const c_char, + ) { + use std::ffi::CStr; + use std::str; + + if e_text_rep != ffi::SQLITE_UTF8 { + // TODO: validate + return; + } + + let callback: fn(&Connection, &str) -> Result<()> = mem::transmute(arg1); + let res = catch_unwind(|| { + let conn = Connection::from_handle(arg2).unwrap(); + let collation_name = { + let c_slice = CStr::from_ptr(arg3).to_bytes(); + str::from_utf8(c_slice).expect("illegal coallation sequence name") + }; + callback(&conn, collation_name) + }); + if res.is_err() { + return; // FIXME How ? + } + } + + let r = unsafe { + ffi::sqlite3_collation_needed( + self.db(), + x_coll_needed as *mut c_void, + Some(collation_needed_callback), + ) + }; + self.decode_result(r) + } + + fn remove_collation(&mut self, collation_name: &str) -> Result<()> { + let c_name = str_to_cstring(collation_name)?; + let r = unsafe { + ffi::sqlite3_create_collation_v2( + self.db(), + c_name.as_ptr(), + ffi::SQLITE_UTF8, + ptr::null_mut(), + None, + None, + ) + }; + self.decode_result(r) + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, Result, NO_PARAMS}; + use fallible_streaming_iterator::FallibleStreamingIterator; + use std::cmp::Ordering; + use unicase::UniCase; + + fn unicase_compare(s1: &str, s2: &str) -> Ordering { + UniCase::new(s1).cmp(&UniCase::new(s2)) + } + + #[test] + fn test_unicase() { + let db = Connection::open_in_memory().unwrap(); + + db.create_collation("unicase", unicase_compare).unwrap(); + + collate(db); + } + + fn collate(db: Connection) { + db.execute_batch( + "CREATE TABLE foo (bar); + INSERT INTO foo (bar) VALUES ('Maße'); + INSERT INTO foo (bar) VALUES ('MASSE');", + ) + .unwrap(); + let mut stmt = db + .prepare("SELECT DISTINCT bar COLLATE unicase FROM foo ORDER BY 1") + .unwrap(); + let rows = stmt.query(NO_PARAMS).unwrap(); + assert_eq!(rows.count().unwrap(), 1); + } + + fn collation_needed(db: &Connection, collation_name: &str) -> Result<()> { + if "unicase" == collation_name { + db.create_collation(collation_name, unicase_compare) + } else { + Ok(()) + } + } + + #[test] + fn test_collation_needed() { + let db = Connection::open_in_memory().unwrap(); + db.collation_needed(collation_needed).unwrap(); + collate(db); + } +} diff --git a/third_party/rust/rusqlite/src/column.rs b/third_party/rust/rusqlite/src/column.rs new file mode 100644 index 0000000000..4f6daac91c --- /dev/null +++ b/third_party/rust/rusqlite/src/column.rs @@ -0,0 +1,220 @@ +use std::str; + +use crate::{Error, Result, Row, Rows, Statement}; + +/// Information about a column of a SQLite query. +#[derive(Debug)] +pub struct Column<'stmt> { + name: &'stmt str, + decl_type: Option<&'stmt str>, +} + +impl Column<'_> { + /// Returns the name of the column. + pub fn name(&self) -> &str { + self.name + } + + /// Returns the type of the column (`None` for expression). + pub fn decl_type(&self) -> Option<&str> { + self.decl_type + } +} + +impl Statement<'_> { + /// Get all the column names in the result set of the prepared statement. + pub fn column_names(&self) -> Vec<&str> { + let n = self.column_count(); + let mut cols = Vec::with_capacity(n as usize); + for i in 0..n { + let s = self.column_name_unwrap(i); + cols.push(s); + } + cols + } + + /// Return the number of columns in the result set returned by the prepared + /// statement. + pub fn column_count(&self) -> usize { + self.stmt.column_count() + } + + pub(super) fn column_name_unwrap(&self, col: usize) -> &str { + // Just panic if the bounds are wrong for now, we never call this + // without checking first. + self.column_name(col).expect("Column out of bounds") + } + + /// Returns the name assigned to a particular column in the result set + /// returned by the prepared statement. + /// + /// ## Failure + /// + /// Returns an `Error::InvalidColumnIndex` if `idx` is outside the valid + /// column range for this row. + /// + /// Panics when column name is not valid UTF-8. + pub fn column_name(&self, col: usize) -> Result<&str> { + self.stmt + .column_name(col) + .ok_or(Error::InvalidColumnIndex(col)) + .map(|slice| { + str::from_utf8(slice.to_bytes()).expect("Invalid UTF-8 sequence in column name") + }) + } + + /// Returns the column index in the result set for a given column name. + /// + /// If there is no AS clause then the name of the column is unspecified and + /// may change from one release of SQLite to the next. + /// + /// # Failure + /// + /// Will return an `Error::InvalidColumnName` when there is no column with + /// the specified `name`. + pub fn column_index(&self, name: &str) -> Result<usize> { + let bytes = name.as_bytes(); + let n = self.column_count(); + for i in 0..n { + // Note: `column_name` is only fallible if `i` is out of bounds, + // which we've already checked. + if bytes.eq_ignore_ascii_case(self.stmt.column_name(i).unwrap().to_bytes()) { + return Ok(i); + } + } + Err(Error::InvalidColumnName(String::from(name))) + } + + /// Returns a slice describing the columns of the result of the query. + #[cfg(feature = "column_decltype")] + pub fn columns(&self) -> Vec<Column> { + let n = self.column_count(); + let mut cols = Vec::with_capacity(n as usize); + for i in 0..n { + let name = self.column_name_unwrap(i); + let slice = self.stmt.column_decltype(i); + let decl_type = slice.map(|s| { + str::from_utf8(s.to_bytes()).expect("Invalid UTF-8 sequence in column declaration") + }); + cols.push(Column { name, decl_type }); + } + cols + } +} + +impl<'stmt> Rows<'stmt> { + /// Get all the column names. + pub fn column_names(&self) -> Option<Vec<&str>> { + self.stmt.map(Statement::column_names) + } + + /// Return the number of columns. + pub fn column_count(&self) -> Option<usize> { + self.stmt.map(Statement::column_count) + } + + /// Return the name of the column. + pub fn column_name(&self, col: usize) -> Option<Result<&str>> { + self.stmt.map(|stmt| stmt.column_name(col)) + } + + /// Return the index of the column. + pub fn column_index(&self, name: &str) -> Option<Result<usize>> { + self.stmt.map(|stmt| stmt.column_index(name)) + } + + /// Returns a slice describing the columns of the Rows. + #[cfg(feature = "column_decltype")] + pub fn columns(&self) -> Option<Vec<Column>> { + self.stmt.map(Statement::columns) + } +} + +impl<'stmt> Row<'stmt> { + /// Get all the column names of the Row. + pub fn column_names(&self) -> Vec<&str> { + self.stmt.column_names() + } + + /// Return the number of columns in the current row. + pub fn column_count(&self) -> usize { + self.stmt.column_count() + } + + /// Return the name of the column. + pub fn column_name(&self, col: usize) -> Result<&str> { + self.stmt.column_name(col) + } + + /// Return the index of the column. + pub fn column_index(&self, name: &str) -> Result<usize> { + self.stmt.column_index(name) + } + + /// Returns a slice describing the columns of the Row. + #[cfg(feature = "column_decltype")] + pub fn columns(&self) -> Vec<Column> { + self.stmt.columns() + } +} + +#[cfg(test)] +mod test { + use crate::Connection; + + #[test] + #[cfg(feature = "column_decltype")] + fn test_columns() { + use super::Column; + + let db = Connection::open_in_memory().unwrap(); + let query = db.prepare("SELECT * FROM sqlite_master").unwrap(); + let columns = query.columns(); + let column_names: Vec<&str> = columns.iter().map(Column::name).collect(); + assert_eq!( + column_names.as_slice(), + &["type", "name", "tbl_name", "rootpage", "sql"] + ); + let column_types: Vec<Option<&str>> = columns.iter().map(Column::decl_type).collect(); + assert_eq!( + &column_types[..3], + &[Some("text"), Some("text"), Some("text"),] + ); + } + + #[test] + fn test_column_name_in_error() { + use crate::{types::Type, Error}; + let db = Connection::open_in_memory().unwrap(); + db.execute_batch( + "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, NULL); + END;", + ) + .unwrap(); + let mut stmt = db.prepare("SELECT x as renamed, y FROM foo").unwrap(); + let mut rows = stmt.query(crate::NO_PARAMS).unwrap(); + let row = rows.next().unwrap().unwrap(); + match row.get::<_, String>(0).unwrap_err() { + Error::InvalidColumnType(idx, name, ty) => { + assert_eq!(idx, 0); + assert_eq!(name, "renamed"); + assert_eq!(ty, Type::Integer); + } + e => { + panic!("Unexpected error type: {:?}", e); + } + } + match row.get::<_, String>("y").unwrap_err() { + Error::InvalidColumnType(idx, name, ty) => { + assert_eq!(idx, 1); + assert_eq!(name, "y"); + assert_eq!(ty, Type::Null); + } + e => { + panic!("Unexpected error type: {:?}", e); + } + } + } +} diff --git a/third_party/rust/rusqlite/src/config.rs b/third_party/rust/rusqlite/src/config.rs new file mode 100644 index 0000000000..797069e2e3 --- /dev/null +++ b/third_party/rust/rusqlite/src/config.rs @@ -0,0 +1,151 @@ +//! Configure database connections + +use std::os::raw::c_int; + +use crate::ffi; +use crate::{Connection, Result}; + +/// Database Connection Configuration Options +/// See [Database Connection Configuration Options](https://sqlite.org/c3ref/c_dbconfig_enable_fkey.html) for details. +#[repr(i32)] +#[allow(non_snake_case, non_camel_case_types)] +#[non_exhaustive] +pub enum DbConfig { + //SQLITE_DBCONFIG_MAINDBNAME = 1000, /* const char* */ + //SQLITE_DBCONFIG_LOOKASIDE = 1001, /* void* int int */ + /// Enable or disable the enforcement of foreign key constraints. + SQLITE_DBCONFIG_ENABLE_FKEY = 1002, + /// Enable or disable triggers. + SQLITE_DBCONFIG_ENABLE_TRIGGER = 1003, + /// Enable or disable the fts3_tokenizer() function which is part of the + /// FTS3 full-text search engine extension. + SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER = 1004, // 3.12.0 + //SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION = 1005, + /// In WAL mode, enable or disable the checkpoint operation before closing + /// the connection. + SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE = 1006, // 3.16.2 + /// Activates or deactivates the query planner stability guarantee (QPSG). + SQLITE_DBCONFIG_ENABLE_QPSG = 1007, // 3.20.0 + /// Includes or excludes output for any operations performed by trigger + /// programs from the output of EXPLAIN QUERY PLAN commands. + SQLITE_DBCONFIG_TRIGGER_EQP = 1008, // 3.22.0 + //SQLITE_DBCONFIG_RESET_DATABASE = 1009, + /// Activates or deactivates the "defensive" flag for a database connection. + SQLITE_DBCONFIG_DEFENSIVE = 1010, // 3.26.0 + /// Activates or deactivates the "writable_schema" flag. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_WRITABLE_SCHEMA = 1011, // 3.28.0 + /// Activates or deactivates the legacy behavior of the ALTER TABLE RENAME + /// command. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_LEGACY_ALTER_TABLE = 1012, // 3.29 + /// Activates or deactivates the legacy double-quoted string literal + /// misfeature for DML statements only. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_DQS_DML = 1013, // 3.29.0 + /// Activates or deactivates the legacy double-quoted string literal + /// misfeature for DDL statements. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_DQS_DDL = 1014, // 3.29.0 + /// Enable or disable views. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_ENABLE_VIEW = 1015, // 3.30.0 + /// Activates or deactivates the legacy file format flag. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_LEGACY_FILE_FORMAT = 1016, // 3.31.0 + /// Tells SQLite to assume that database schemas (the contents of the + /// sqlite_master tables) are untainted by malicious content. + #[cfg(feature = "modern_sqlite")] + SQLITE_DBCONFIG_TRUSTED_SCHEMA = 1017, // 3.31.0 +} + +impl Connection { + /// Returns the current value of a `config`. + /// + /// - SQLITE_DBCONFIG_ENABLE_FKEY: return `false` or `true` to indicate + /// whether FK enforcement is off or on + /// - SQLITE_DBCONFIG_ENABLE_TRIGGER: return `false` or `true` to indicate + /// whether triggers are disabled or enabled + /// - SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER: return `false` or `true` to + /// indicate whether fts3_tokenizer are disabled or enabled + /// - SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE: return `false` to indicate + /// checkpoints-on-close are not disabled or `true` if they are + /// - SQLITE_DBCONFIG_ENABLE_QPSG: return `false` or `true` to indicate + /// whether the QPSG is disabled or enabled + /// - SQLITE_DBCONFIG_TRIGGER_EQP: return `false` to indicate + /// output-for-trigger are not disabled or `true` if it is + pub fn db_config(&self, config: DbConfig) -> Result<bool> { + let c = self.db.borrow(); + unsafe { + let mut val = 0; + check!(ffi::sqlite3_db_config( + c.db(), + config as c_int, + -1, + &mut val + )); + Ok(val != 0) + } + } + + /// Make configuration changes to a database connection + /// + /// - SQLITE_DBCONFIG_ENABLE_FKEY: `false` to disable FK enforcement, `true` + /// to enable FK enforcement + /// - SQLITE_DBCONFIG_ENABLE_TRIGGER: `false` to disable triggers, `true` to + /// enable triggers + /// - SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER: `false` to disable + /// fts3_tokenizer(), `true` to enable fts3_tokenizer() + /// - SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE: `false` (the default) to enable + /// checkpoints-on-close, `true` to disable them + /// - SQLITE_DBCONFIG_ENABLE_QPSG: `false` to disable the QPSG, `true` to + /// enable QPSG + /// - SQLITE_DBCONFIG_TRIGGER_EQP: `false` to disable output for trigger + /// programs, `true` to enable it + pub fn set_db_config(&self, config: DbConfig, new_val: bool) -> Result<bool> { + let c = self.db.borrow_mut(); + unsafe { + let mut val = 0; + check!(ffi::sqlite3_db_config( + c.db(), + config as c_int, + if new_val { 1 } else { 0 }, + &mut val + )); + Ok(val != 0) + } + } +} + +#[cfg(test)] +mod test { + use super::DbConfig; + use crate::Connection; + + #[test] + fn test_db_config() { + let db = Connection::open_in_memory().unwrap(); + + let opposite = !db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_FKEY).unwrap(); + assert_eq!( + db.set_db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_FKEY, opposite), + Ok(opposite) + ); + assert_eq!( + db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_FKEY), + Ok(opposite) + ); + + let opposite = !db + .db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_TRIGGER) + .unwrap(); + assert_eq!( + db.set_db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_TRIGGER, opposite), + Ok(opposite) + ); + assert_eq!( + db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_TRIGGER), + Ok(opposite) + ); + } +} diff --git a/third_party/rust/rusqlite/src/context.rs b/third_party/rust/rusqlite/src/context.rs new file mode 100644 index 0000000000..b7e8bc8893 --- /dev/null +++ b/third_party/rust/rusqlite/src/context.rs @@ -0,0 +1,68 @@ +//! Code related to `sqlite3_context` common to `functions` and `vtab` modules. + +use std::os::raw::{c_int, c_void}; +#[cfg(feature = "array")] +use std::rc::Rc; + +use crate::ffi; +use crate::ffi::sqlite3_context; + +use crate::str_for_sqlite; +use crate::types::{ToSqlOutput, ValueRef}; +#[cfg(feature = "array")] +use crate::vtab::array::{free_array, ARRAY_TYPE}; + +pub(super) unsafe fn set_result(ctx: *mut sqlite3_context, result: &ToSqlOutput<'_>) { + let value = match *result { + ToSqlOutput::Borrowed(v) => v, + ToSqlOutput::Owned(ref v) => ValueRef::from(v), + + #[cfg(feature = "blob")] + ToSqlOutput::ZeroBlob(len) => { + return ffi::sqlite3_result_zeroblob(ctx, len); + } + #[cfg(feature = "array")] + ToSqlOutput::Array(ref a) => { + return ffi::sqlite3_result_pointer( + ctx, + Rc::into_raw(a.clone()) as *mut c_void, + ARRAY_TYPE, + Some(free_array), + ); + } + }; + + match value { + ValueRef::Null => ffi::sqlite3_result_null(ctx), + ValueRef::Integer(i) => ffi::sqlite3_result_int64(ctx, i), + ValueRef::Real(r) => ffi::sqlite3_result_double(ctx, r), + ValueRef::Text(s) => { + let length = s.len(); + if length > c_int::max_value() as usize { + ffi::sqlite3_result_error_toobig(ctx); + } else { + let (c_str, len, destructor) = match str_for_sqlite(s) { + Ok(c_str) => c_str, + // TODO sqlite3_result_error + Err(_) => return ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_MISUSE), + }; + ffi::sqlite3_result_text(ctx, c_str, len, destructor); + } + } + ValueRef::Blob(b) => { + let length = b.len(); + if length > c_int::max_value() as usize { + ffi::sqlite3_result_error_toobig(ctx); + } else if length == 0 { + ffi::sqlite3_result_zeroblob(ctx, 0) + } else { + ffi::sqlite3_result_blob( + ctx, + b.as_ptr() as *const c_void, + length as c_int, + ffi::SQLITE_TRANSIENT(), + ); + } + } + } +} diff --git a/third_party/rust/rusqlite/src/error.rs b/third_party/rust/rusqlite/src/error.rs new file mode 100644 index 0000000000..98583cb8a2 --- /dev/null +++ b/third_party/rust/rusqlite/src/error.rs @@ -0,0 +1,349 @@ +use crate::types::FromSqlError; +use crate::types::Type; +use crate::{errmsg_to_string, ffi}; +use std::error; +use std::fmt; +use std::os::raw::c_int; +use std::path::PathBuf; +use std::str; + +/// Enum listing possible errors from rusqlite. +#[derive(Debug)] +#[allow(clippy::enum_variant_names)] +#[non_exhaustive] +pub enum Error { + /// An error from an underlying SQLite call. + SqliteFailure(ffi::Error, Option<String>), + + /// Error reported when attempting to open a connection when SQLite was + /// configured to allow single-threaded use only. + SqliteSingleThreadedMode, + + /// Error when the value of a particular column is requested, but it cannot + /// be converted to the requested Rust type. + FromSqlConversionFailure(usize, Type, Box<dyn error::Error + Send + Sync + 'static>), + + /// Error when SQLite gives us an integral value outside the range of the + /// requested type (e.g., trying to get the value 1000 into a `u8`). + /// The associated `usize` is the column index, + /// and the associated `i64` is the value returned by SQLite. + IntegralValueOutOfRange(usize, i64), + + /// Error converting a string to UTF-8. + Utf8Error(str::Utf8Error), + + /// Error converting a string to a C-compatible string because it contained + /// an embedded nul. + NulError(::std::ffi::NulError), + + /// Error when using SQL named parameters and passing a parameter name not + /// present in the SQL. + InvalidParameterName(String), + + /// Error converting a file path to a string. + InvalidPath(PathBuf), + + /// Error returned when an `execute` call returns rows. + ExecuteReturnedResults, + + /// Error when a query that was expected to return at least one row (e.g., + /// for `query_row`) did not return any. + QueryReturnedNoRows, + + /// Error when the value of a particular column is requested, but the index + /// is out of range for the statement. + InvalidColumnIndex(usize), + + /// Error when the value of a named column is requested, but no column + /// matches the name for the statement. + InvalidColumnName(String), + + /// Error when the value of a particular column is requested, but the type + /// of the result in that column cannot be converted to the requested + /// Rust type. + InvalidColumnType(usize, String, Type), + + /// Error when a query that was expected to insert one row did not insert + /// any or insert many. + StatementChangedRows(usize), + + /// Error returned by `functions::Context::get` when the function argument + /// cannot be converted to the requested type. + #[cfg(feature = "functions")] + InvalidFunctionParameterType(usize, Type), + /// Error returned by `vtab::Values::get` when the filter argument cannot + /// be converted to the requested type. + #[cfg(feature = "vtab")] + InvalidFilterParameterType(usize, Type), + + /// An error case available for implementors of custom user functions (e.g., + /// `create_scalar_function`). + #[cfg(feature = "functions")] + #[allow(dead_code)] + UserFunctionError(Box<dyn error::Error + Send + Sync + 'static>), + + /// Error available for the implementors of the `ToSql` trait. + ToSqlConversionFailure(Box<dyn error::Error + Send + Sync + 'static>), + + /// Error when the SQL is not a `SELECT`, is not read-only. + InvalidQuery, + + /// An error case available for implementors of custom modules (e.g., + /// `create_module`). + #[cfg(feature = "vtab")] + #[allow(dead_code)] + ModuleError(String), + + /// An unwinding panic occurs in an UDF (user-defined function). + #[cfg(feature = "functions")] + UnwindingPanic, + + /// An error returned when `Context::get_aux` attempts to retrieve data + /// of a different type than what had been stored using `Context::set_aux`. + #[cfg(feature = "functions")] + GetAuxWrongType, + + /// Error when the SQL contains multiple statements. + MultipleStatement, + /// Error when the number of bound parameters does not match the number of + /// parameters in the query. The first `usize` is how many parameters were + /// given, the 2nd is how many were expected. + InvalidParameterCount(usize, usize), + + /// Returned from various functions in the Blob IO positional API. For + /// example, [`Blob::raw_read_at_exact`](crate::blob::Blob::raw_read_at_exact) + /// will return it if the blob has insufficient data. + #[cfg(feature = "blob")] + BlobSizeError, +} + +impl PartialEq for Error { + fn eq(&self, other: &Error) -> bool { + match (self, other) { + (Error::SqliteFailure(e1, s1), Error::SqliteFailure(e2, s2)) => e1 == e2 && s1 == s2, + (Error::SqliteSingleThreadedMode, Error::SqliteSingleThreadedMode) => true, + (Error::IntegralValueOutOfRange(i1, n1), Error::IntegralValueOutOfRange(i2, n2)) => { + i1 == i2 && n1 == n2 + } + (Error::Utf8Error(e1), Error::Utf8Error(e2)) => e1 == e2, + (Error::NulError(e1), Error::NulError(e2)) => e1 == e2, + (Error::InvalidParameterName(n1), Error::InvalidParameterName(n2)) => n1 == n2, + (Error::InvalidPath(p1), Error::InvalidPath(p2)) => p1 == p2, + (Error::ExecuteReturnedResults, Error::ExecuteReturnedResults) => true, + (Error::QueryReturnedNoRows, Error::QueryReturnedNoRows) => true, + (Error::InvalidColumnIndex(i1), Error::InvalidColumnIndex(i2)) => i1 == i2, + (Error::InvalidColumnName(n1), Error::InvalidColumnName(n2)) => n1 == n2, + (Error::InvalidColumnType(i1, n1, t1), Error::InvalidColumnType(i2, n2, t2)) => { + i1 == i2 && t1 == t2 && n1 == n2 + } + (Error::StatementChangedRows(n1), Error::StatementChangedRows(n2)) => n1 == n2, + #[cfg(feature = "functions")] + ( + Error::InvalidFunctionParameterType(i1, t1), + Error::InvalidFunctionParameterType(i2, t2), + ) => i1 == i2 && t1 == t2, + #[cfg(feature = "vtab")] + ( + Error::InvalidFilterParameterType(i1, t1), + Error::InvalidFilterParameterType(i2, t2), + ) => i1 == i2 && t1 == t2, + (Error::InvalidQuery, Error::InvalidQuery) => true, + #[cfg(feature = "vtab")] + (Error::ModuleError(s1), Error::ModuleError(s2)) => s1 == s2, + #[cfg(feature = "functions")] + (Error::UnwindingPanic, Error::UnwindingPanic) => true, + #[cfg(feature = "functions")] + (Error::GetAuxWrongType, Error::GetAuxWrongType) => true, + (Error::InvalidParameterCount(i1, n1), Error::InvalidParameterCount(i2, n2)) => { + i1 == i2 && n1 == n2 + } + #[cfg(feature = "blob")] + (Error::BlobSizeError, Error::BlobSizeError) => true, + (..) => false, + } + } +} + +impl From<str::Utf8Error> for Error { + fn from(err: str::Utf8Error) -> Error { + Error::Utf8Error(err) + } +} + +impl From<::std::ffi::NulError> for Error { + fn from(err: ::std::ffi::NulError) -> Error { + Error::NulError(err) + } +} + +const UNKNOWN_COLUMN: usize = std::usize::MAX; + +/// The conversion isn't precise, but it's convenient to have it +/// to allow use of `get_raw(…).as_…()?` in callbacks that take `Error`. +impl From<FromSqlError> for Error { + fn from(err: FromSqlError) -> Error { + // The error type requires index and type fields, but they aren't known in this + // context. + match err { + FromSqlError::OutOfRange(val) => Error::IntegralValueOutOfRange(UNKNOWN_COLUMN, val), + #[cfg(feature = "i128_blob")] + FromSqlError::InvalidI128Size(_) => { + Error::FromSqlConversionFailure(UNKNOWN_COLUMN, Type::Blob, Box::new(err)) + } + #[cfg(feature = "uuid")] + FromSqlError::InvalidUuidSize(_) => { + Error::FromSqlConversionFailure(UNKNOWN_COLUMN, Type::Blob, Box::new(err)) + } + FromSqlError::Other(source) => { + Error::FromSqlConversionFailure(UNKNOWN_COLUMN, Type::Null, source) + } + _ => Error::FromSqlConversionFailure(UNKNOWN_COLUMN, Type::Null, Box::new(err)), + } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Error::SqliteFailure(ref err, None) => err.fmt(f), + Error::SqliteFailure(_, Some(ref s)) => write!(f, "{}", s), + Error::SqliteSingleThreadedMode => write!( + f, + "SQLite was compiled or configured for single-threaded use only" + ), + Error::FromSqlConversionFailure(i, ref t, ref err) => { + if i != UNKNOWN_COLUMN { + write!( + f, + "Conversion error from type {} at index: {}, {}", + t, i, err + ) + } else { + err.fmt(f) + } + } + Error::IntegralValueOutOfRange(col, val) => { + if col != UNKNOWN_COLUMN { + write!(f, "Integer {} out of range at index {}", val, col) + } else { + write!(f, "Integer {} out of range", val) + } + } + Error::Utf8Error(ref err) => err.fmt(f), + Error::NulError(ref err) => err.fmt(f), + Error::InvalidParameterName(ref name) => write!(f, "Invalid parameter name: {}", name), + Error::InvalidPath(ref p) => write!(f, "Invalid path: {}", p.to_string_lossy()), + Error::ExecuteReturnedResults => { + write!(f, "Execute returned results - did you mean to call query?") + } + Error::QueryReturnedNoRows => write!(f, "Query returned no rows"), + Error::InvalidColumnIndex(i) => write!(f, "Invalid column index: {}", i), + Error::InvalidColumnName(ref name) => write!(f, "Invalid column name: {}", name), + Error::InvalidColumnType(i, ref name, ref t) => write!( + f, + "Invalid column type {} at index: {}, name: {}", + t, i, name + ), + Error::InvalidParameterCount(i1, n1) => write!( + f, + "Wrong number of parameters passed to query. Got {}, needed {}", + i1, n1 + ), + Error::StatementChangedRows(i) => write!(f, "Query changed {} rows", i), + + #[cfg(feature = "functions")] + Error::InvalidFunctionParameterType(i, ref t) => { + write!(f, "Invalid function parameter type {} at index {}", t, i) + } + #[cfg(feature = "vtab")] + Error::InvalidFilterParameterType(i, ref t) => { + write!(f, "Invalid filter parameter type {} at index {}", t, i) + } + #[cfg(feature = "functions")] + Error::UserFunctionError(ref err) => err.fmt(f), + Error::ToSqlConversionFailure(ref err) => err.fmt(f), + Error::InvalidQuery => write!(f, "Query is not read-only"), + #[cfg(feature = "vtab")] + Error::ModuleError(ref desc) => write!(f, "{}", desc), + #[cfg(feature = "functions")] + Error::UnwindingPanic => write!(f, "unwinding panic"), + #[cfg(feature = "functions")] + Error::GetAuxWrongType => write!(f, "get_aux called with wrong type"), + Error::MultipleStatement => write!(f, "Multiple statements provided"), + + #[cfg(feature = "blob")] + Error::BlobSizeError => "Blob size is insufficient".fmt(f), + } + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match *self { + Error::SqliteFailure(ref err, _) => Some(err), + Error::Utf8Error(ref err) => Some(err), + Error::NulError(ref err) => Some(err), + + Error::IntegralValueOutOfRange(..) + | Error::SqliteSingleThreadedMode + | Error::InvalidParameterName(_) + | Error::ExecuteReturnedResults + | Error::QueryReturnedNoRows + | Error::InvalidColumnIndex(_) + | Error::InvalidColumnName(_) + | Error::InvalidColumnType(..) + | Error::InvalidPath(_) + | Error::InvalidParameterCount(..) + | Error::StatementChangedRows(_) + | Error::InvalidQuery + | Error::MultipleStatement => None, + + #[cfg(feature = "functions")] + Error::InvalidFunctionParameterType(..) => None, + #[cfg(feature = "vtab")] + Error::InvalidFilterParameterType(..) => None, + + #[cfg(feature = "functions")] + Error::UserFunctionError(ref err) => Some(&**err), + + Error::FromSqlConversionFailure(_, _, ref err) + | Error::ToSqlConversionFailure(ref err) => Some(&**err), + + #[cfg(feature = "vtab")] + Error::ModuleError(_) => None, + + #[cfg(feature = "functions")] + Error::UnwindingPanic => None, + + #[cfg(feature = "functions")] + Error::GetAuxWrongType => None, + + #[cfg(feature = "blob")] + Error::BlobSizeError => None, + } + } +} + +// These are public but not re-exported by lib.rs, so only visible within crate. + +pub fn error_from_sqlite_code(code: c_int, message: Option<String>) -> Error { + Error::SqliteFailure(ffi::Error::new(code), message) +} + +pub unsafe fn error_from_handle(db: *mut ffi::sqlite3, code: c_int) -> Error { + let message = if db.is_null() { + None + } else { + Some(errmsg_to_string(ffi::sqlite3_errmsg(db))) + }; + error_from_sqlite_code(code, message) +} + +macro_rules! check { + ($funcall:expr) => {{ + let rc = $funcall; + if rc != crate::ffi::SQLITE_OK { + return Err(crate::error::error_from_sqlite_code(rc, None).into()); + } + }}; +} diff --git a/third_party/rust/rusqlite/src/functions.rs b/third_party/rust/rusqlite/src/functions.rs new file mode 100644 index 0000000000..3531391882 --- /dev/null +++ b/third_party/rust/rusqlite/src/functions.rs @@ -0,0 +1,1045 @@ +//! `feature = "functions"` Create or redefine SQL functions. +//! +//! # Example +//! +//! Adding a `regexp` function to a connection in which compiled regular +//! expressions are cached in a `HashMap`. For an alternative implementation +//! that uses SQLite's [Function Auxilliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface +//! to avoid recompiling regular expressions, see the unit tests for this +//! module. +//! +//! ```rust +//! use regex::Regex; +//! use rusqlite::functions::FunctionFlags; +//! use rusqlite::{Connection, Error, Result, NO_PARAMS}; +//! use std::sync::Arc; +//! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>; +//! +//! fn add_regexp_function(db: &Connection) -> Result<()> { +//! db.create_scalar_function( +//! "regexp", +//! 2, +//! FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, +//! move |ctx| { +//! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); +//! let regexp: Arc<Regex> = ctx +//! .get_or_create_aux(0, |vr| -> Result<_, BoxError> { +//! Ok(Regex::new(vr.as_str()?)?) +//! })?; +//! let is_match = { +//! let text = ctx +//! .get_raw(1) +//! .as_str() +//! .map_err(|e| Error::UserFunctionError(e.into()))?; +//! +//! regexp.is_match(text) +//! }; +//! +//! Ok(is_match) +//! }, +//! ) +//! } +//! +//! fn main() -> Result<()> { +//! let db = Connection::open_in_memory()?; +//! add_regexp_function(&db)?; +//! +//! let is_match: bool = db.query_row( +//! "SELECT regexp('[aeiou]*', 'aaaaeeeiii')", +//! NO_PARAMS, +//! |row| row.get(0), +//! )?; +//! +//! assert!(is_match); +//! Ok(()) +//! } +//! ``` +use std::any::Any; +use std::os::raw::{c_int, c_void}; +use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe}; +use std::ptr; +use std::slice; +use std::sync::Arc; + +use crate::ffi; +use crate::ffi::sqlite3_context; +use crate::ffi::sqlite3_value; + +use crate::context::set_result; +use crate::types::{FromSql, FromSqlError, ToSql, ValueRef}; + +use crate::{str_to_cstring, Connection, Error, InnerConnection, Result}; + +unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) { + // Extended constraint error codes were added in SQLite 3.7.16. We don't have + // an explicit feature check for that, and this doesn't really warrant one. + // We'll use the extended code if we're on the bundled version (since it's + // at least 3.17.0) and the normal constraint error code if not. + #[cfg(feature = "modern_sqlite")] + fn constraint_error_code() -> i32 { + ffi::SQLITE_CONSTRAINT_FUNCTION + } + #[cfg(not(feature = "modern_sqlite"))] + fn constraint_error_code() -> i32 { + ffi::SQLITE_CONSTRAINT + } + + match *err { + Error::SqliteFailure(ref err, ref s) => { + ffi::sqlite3_result_error_code(ctx, err.extended_code); + if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } + _ => { + ffi::sqlite3_result_error_code(ctx, constraint_error_code()); + if let Ok(cstr) = str_to_cstring(&err.to_string()) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } + } +} + +unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) { + drop(Box::from_raw(p as *mut T)); +} + +/// `feature = "functions"` Context is a wrapper for the SQLite function +/// evaluation context. +pub struct Context<'a> { + ctx: *mut sqlite3_context, + args: &'a [*mut sqlite3_value], +} + +impl Context<'_> { + /// Returns the number of arguments to the function. + pub fn len(&self) -> usize { + self.args.len() + } + + /// Returns `true` when there is no argument. + pub fn is_empty(&self) -> bool { + self.args.is_empty() + } + + /// Returns the `idx`th argument as a `T`. + /// + /// # Failure + /// + /// Will panic if `idx` is greater than or equal to `self.len()`. + /// + /// Will return Err if the underlying SQLite type cannot be converted to a + /// `T`. + pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> { + let arg = self.args[idx]; + let value = unsafe { ValueRef::from_value(arg) }; + FromSql::column_result(value).map_err(|err| match err { + FromSqlError::InvalidType => { + Error::InvalidFunctionParameterType(idx, value.data_type()) + } + FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i), + FromSqlError::Other(err) => { + Error::FromSqlConversionFailure(idx, value.data_type(), err) + } + #[cfg(feature = "i128_blob")] + FromSqlError::InvalidI128Size(_) => { + Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err)) + } + #[cfg(feature = "uuid")] + FromSqlError::InvalidUuidSize(_) => { + Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err)) + } + }) + } + + /// Returns the `idx`th argument as a `ValueRef`. + /// + /// # Failure + /// + /// Will panic if `idx` is greater than or equal to `self.len()`. + pub fn get_raw(&self, idx: usize) -> ValueRef<'_> { + let arg = self.args[idx]; + unsafe { ValueRef::from_value(arg) } + } + + /// Fetch or insert the the auxilliary data associated with a particular + /// parameter. This is intended to be an easier-to-use way of fetching it + /// compared to calling `get_aux` and `set_aux` separately. + /// + /// See https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of + /// this feature, or the unit tests of this module for an example. + pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>> + where + T: Send + Sync + 'static, + E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>, + F: FnOnce(ValueRef<'_>) -> Result<T, E>, + { + if let Some(v) = self.get_aux(arg)? { + Ok(v) + } else { + let vr = self.get_raw(arg as usize); + self.set_aux( + arg, + func(vr).map_err(|e| Error::UserFunctionError(e.into()))?, + ) + } + } + + /// Sets the auxilliary data associated with a particular parameter. See + /// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of + /// this feature, or the unit tests of this module for an example. + pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> { + let orig: Arc<T> = Arc::new(value); + let inner: AuxInner = orig.clone(); + let outer = Box::new(inner); + let raw: *mut AuxInner = Box::into_raw(outer); + unsafe { + ffi::sqlite3_set_auxdata( + self.ctx, + arg, + raw as *mut _, + Some(free_boxed_value::<AuxInner>), + ) + }; + Ok(orig) + } + + /// Gets the auxilliary data that was associated with a given parameter via + /// `set_aux`. Returns `Ok(None)` if no data has been associated, and + /// Ok(Some(v)) if it has. Returns an error if the requested type does not + /// match. + pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> { + let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner }; + if p.is_null() { + Ok(None) + } else { + let v: AuxInner = AuxInner::clone(unsafe { &*p }); + v.downcast::<T>() + .map(Some) + .map_err(|_| Error::GetAuxWrongType) + } + } +} + +type AuxInner = Arc<dyn Any + Send + Sync + 'static>; + +/// `feature = "functions"` Aggregate is the callback interface for user-defined +/// aggregate function. +/// +/// `A` is the type of the aggregation context and `T` is the type of the final +/// result. Implementations should be stateless. +pub trait Aggregate<A, T> +where + A: RefUnwindSafe + UnwindSafe, + T: ToSql, +{ + /// Initializes the aggregation context. Will be called prior to the first + /// call to `step()` to set up the context for an invocation of the + /// function. (Note: `init()` will not be called if there are no rows.) + fn init(&self) -> A; + + /// "step" function called once for each row in an aggregate group. May be + /// called 0 times if there are no rows. + fn step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>; + + /// Computes and returns the final result. Will be called exactly once for + /// each invocation of the function. If `step()` was called at least + /// once, will be given `Some(A)` (the same `A` as was created by + /// `init` and given to `step`); if `step()` was not called (because + /// the function is running against 0 rows), will be given `None`. + fn finalize(&self, _: Option<A>) -> Result<T>; +} + +/// `feature = "window"` WindowAggregate is the callback interface for +/// user-defined aggregate window function. +#[cfg(feature = "window")] +pub trait WindowAggregate<A, T>: Aggregate<A, T> +where + A: RefUnwindSafe + UnwindSafe, + T: ToSql, +{ + /// Returns the current value of the aggregate. Unlike xFinal, the + /// implementation should not delete any context. + fn value(&self, _: Option<&A>) -> Result<T>; + + /// Removes a row from the current window. + fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>; +} + +bitflags::bitflags! { + /// Function Flags. + /// See [sqlite3_create_function](https://sqlite.org/c3ref/create_function.html) + /// and [Function Flags](https://sqlite.org/c3ref/c_deterministic.html) for details. + #[repr(C)] + pub struct FunctionFlags: ::std::os::raw::c_int { + /// Specifies UTF-8 as the text encoding this SQL function prefers for its parameters. + const SQLITE_UTF8 = ffi::SQLITE_UTF8; + /// Specifies UTF-16 using little-endian byte order as the text encoding this SQL function prefers for its parameters. + const SQLITE_UTF16LE = ffi::SQLITE_UTF16LE; + /// Specifies UTF-16 using big-endian byte order as the text encoding this SQL function prefers for its parameters. + const SQLITE_UTF16BE = ffi::SQLITE_UTF16BE; + /// Specifies UTF-16 using native byte order as the text encoding this SQL function prefers for its parameters. + const SQLITE_UTF16 = ffi::SQLITE_UTF16; + /// Means that the function always gives the same output when the input parameters are the same. + const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC; + /// Means that the function may only be invoked from top-level SQL. + const SQLITE_DIRECTONLY = 0x0000_0008_0000; // 3.30.0 + /// Indicates to SQLite that a function may call `sqlite3_value_subtype()` to inspect the sub-types of its arguments. + const SQLITE_SUBTYPE = 0x0000_0010_0000; // 3.30.0 + /// Means that the function is unlikely to cause problems even if misused. + const SQLITE_INNOCUOUS = 0x0000_0020_0000; // 3.31.0 + } +} + +impl Default for FunctionFlags { + fn default() -> FunctionFlags { + FunctionFlags::SQLITE_UTF8 + } +} + +impl Connection { + /// `feature = "functions"` Attach a user-defined scalar function to + /// this database connection. + /// + /// `fn_name` is the name the function will be accessible from SQL. + /// `n_arg` is the number of arguments to the function. Use `-1` for a + /// variable number. If the function always returns the same value + /// given the same input, `deterministic` should be `true`. + /// + /// The function will remain available until the connection is closed or + /// until it is explicitly removed via `remove_function`. + /// + /// # Example + /// + /// ```rust + /// # use rusqlite::{Connection, Result, NO_PARAMS}; + /// # use rusqlite::functions::FunctionFlags; + /// fn scalar_function_example(db: Connection) -> Result<()> { + /// db.create_scalar_function( + /// "halve", + /// 1, + /// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + /// |ctx| { + /// let value = ctx.get::<f64>(0)?; + /// Ok(value / 2f64) + /// }, + /// )?; + /// + /// let six_halved: f64 = db.query_row("SELECT halve(6)", NO_PARAMS, |r| r.get(0))?; + /// assert_eq!(six_halved, 3f64); + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return Err if the function could not be attached to the connection. + pub fn create_scalar_function<F, T>( + &self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + x_func: F, + ) -> Result<()> + where + F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, + T: ToSql, + { + self.db + .borrow_mut() + .create_scalar_function(fn_name, n_arg, flags, x_func) + } + + /// `feature = "functions"` Attach a user-defined aggregate function to this + /// database connection. + /// + /// # Failure + /// + /// Will return Err if the function could not be attached to the connection. + pub fn create_aggregate_function<A, D, T>( + &self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + aggr: D, + ) -> Result<()> + where + A: RefUnwindSafe + UnwindSafe, + D: Aggregate<A, T>, + T: ToSql, + { + self.db + .borrow_mut() + .create_aggregate_function(fn_name, n_arg, flags, aggr) + } + + /// `feature = "window"` Attach a user-defined aggregate window function to + /// this database connection. + /// + /// See https://sqlite.org/windowfunctions.html#udfwinfunc for more + /// information. + #[cfg(feature = "window")] + pub fn create_window_function<A, W, T>( + &self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + aggr: W, + ) -> Result<()> + where + A: RefUnwindSafe + UnwindSafe, + W: WindowAggregate<A, T>, + T: ToSql, + { + self.db + .borrow_mut() + .create_window_function(fn_name, n_arg, flags, aggr) + } + + /// `feature = "functions"` Removes a user-defined function from this + /// database connection. + /// + /// `fn_name` and `n_arg` should match the name and number of arguments + /// given to `create_scalar_function` or `create_aggregate_function`. + /// + /// # Failure + /// + /// Will return Err if the function could not be removed. + pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()> { + self.db.borrow_mut().remove_function(fn_name, n_arg) + } +} + +impl InnerConnection { + fn create_scalar_function<F, T>( + &mut self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + x_func: F, + ) -> Result<()> + where + F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, + T: ToSql, + { + unsafe extern "C" fn call_boxed_closure<F, T>( + ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, + ) where + F: FnMut(&Context<'_>) -> Result<T>, + T: ToSql, + { + let r = catch_unwind(|| { + let boxed_f: *mut F = ffi::sqlite3_user_data(ctx) as *mut F; + assert!(!boxed_f.is_null(), "Internal error - null function pointer"); + let ctx = Context { + ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + (*boxed_f)(&ctx) + }); + let t = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + let t = t.as_ref().map(|t| ToSql::to_sql(t)); + + match t { + Ok(Ok(ref value)) => set_result(ctx, value), + Ok(Err(err)) => report_error(ctx, &err), + Err(err) => report_error(ctx, err), + } + } + + let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); + let c_name = str_to_cstring(fn_name)?; + let r = unsafe { + ffi::sqlite3_create_function_v2( + self.db(), + c_name.as_ptr(), + n_arg, + flags.bits(), + boxed_f as *mut c_void, + Some(call_boxed_closure::<F, T>), + None, + None, + Some(free_boxed_value::<F>), + ) + }; + self.decode_result(r) + } + + fn create_aggregate_function<A, D, T>( + &mut self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + aggr: D, + ) -> Result<()> + where + A: RefUnwindSafe + UnwindSafe, + D: Aggregate<A, T>, + T: ToSql, + { + let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); + let c_name = str_to_cstring(fn_name)?; + let r = unsafe { + ffi::sqlite3_create_function_v2( + self.db(), + c_name.as_ptr(), + n_arg, + flags.bits(), + boxed_aggr as *mut c_void, + None, + Some(call_boxed_step::<A, D, T>), + Some(call_boxed_final::<A, D, T>), + Some(free_boxed_value::<D>), + ) + }; + self.decode_result(r) + } + + #[cfg(feature = "window")] + fn create_window_function<A, W, T>( + &mut self, + fn_name: &str, + n_arg: c_int, + flags: FunctionFlags, + aggr: W, + ) -> Result<()> + where + A: RefUnwindSafe + UnwindSafe, + W: WindowAggregate<A, T>, + T: ToSql, + { + let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr)); + let c_name = str_to_cstring(fn_name)?; + let r = unsafe { + ffi::sqlite3_create_window_function( + self.db(), + c_name.as_ptr(), + n_arg, + flags.bits(), + boxed_aggr as *mut c_void, + Some(call_boxed_step::<A, W, T>), + Some(call_boxed_final::<A, W, T>), + Some(call_boxed_value::<A, W, T>), + Some(call_boxed_inverse::<A, W, T>), + Some(free_boxed_value::<W>), + ) + }; + self.decode_result(r) + } + + fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> { + let c_name = str_to_cstring(fn_name)?; + let r = unsafe { + ffi::sqlite3_create_function_v2( + self.db(), + c_name.as_ptr(), + n_arg, + ffi::SQLITE_UTF8, + ptr::null_mut(), + None, + None, + None, + None, + ) + }; + self.decode_result(r) + } +} + +unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> { + let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A; + if pac.is_null() { + return None; + } + Some(pac) +} + +unsafe extern "C" fn call_boxed_step<A, D, T>( + ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, +) where + A: RefUnwindSafe + UnwindSafe, + D: Aggregate<A, T>, + T: ToSql, +{ + let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) { + Some(pac) => pac, + None => { + ffi::sqlite3_result_error_nomem(ctx); + return; + } + }; + + let r = catch_unwind(|| { + let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D; + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + if (*pac as *mut A).is_null() { + *pac = Box::into_raw(Box::new((*boxed_aggr).init())); + } + let mut ctx = Context { + ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + (*boxed_aggr).step(&mut ctx, &mut **pac) + }); + let r = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + match r { + Ok(_) => {} + Err(err) => report_error(ctx, &err), + }; +} + +#[cfg(feature = "window")] +unsafe extern "C" fn call_boxed_inverse<A, W, T>( + ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, +) where + A: RefUnwindSafe + UnwindSafe, + W: WindowAggregate<A, T>, + T: ToSql, +{ + let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) { + Some(pac) => pac, + None => { + ffi::sqlite3_result_error_nomem(ctx); + return; + } + }; + + let r = catch_unwind(|| { + let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W; + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + let mut ctx = Context { + ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + (*boxed_aggr).inverse(&mut ctx, &mut **pac) + }); + let r = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + match r { + Ok(_) => {} + Err(err) => report_error(ctx, &err), + }; +} + +unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context) +where + A: RefUnwindSafe + UnwindSafe, + D: Aggregate<A, T>, + T: ToSql, +{ + // Within the xFinal callback, it is customary to set N=0 in calls to + // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. + let a: Option<A> = match aggregate_context(ctx, 0) { + Some(pac) => { + if (*pac as *mut A).is_null() { + None + } else { + let a = Box::from_raw(*pac); + Some(*a) + } + } + None => None, + }; + + let r = catch_unwind(|| { + let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D; + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + (*boxed_aggr).finalize(a) + }); + let t = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + let t = t.as_ref().map(|t| ToSql::to_sql(t)); + match t { + Ok(Ok(ref value)) => set_result(ctx, value), + Ok(Err(err)) => report_error(ctx, &err), + Err(err) => report_error(ctx, err), + } +} + +#[cfg(feature = "window")] +unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context) +where + A: RefUnwindSafe + UnwindSafe, + W: WindowAggregate<A, T>, + T: ToSql, +{ + // Within the xValue callback, it is customary to set N=0 in calls to + // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. + let a: Option<&A> = match aggregate_context(ctx, 0) { + Some(pac) => { + if (*pac as *mut A).is_null() { + None + } else { + let a = &**pac; + Some(a) + } + } + None => None, + }; + + let r = catch_unwind(|| { + let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W; + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + (*boxed_aggr).value(a) + }); + let t = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + let t = t.as_ref().map(|t| ToSql::to_sql(t)); + match t { + Ok(Ok(ref value)) => set_result(ctx, value), + Ok(Err(err)) => report_error(ctx, &err), + Err(err) => report_error(ctx, err), + } +} + +#[cfg(test)] +mod test { + use regex::Regex; + use std::f64::EPSILON; + use std::os::raw::c_double; + + #[cfg(feature = "window")] + use crate::functions::WindowAggregate; + use crate::functions::{Aggregate, Context, FunctionFlags}; + use crate::{Connection, Error, Result, NO_PARAMS}; + + fn half(ctx: &Context<'_>) -> Result<c_double> { + assert_eq!(ctx.len(), 1, "called with unexpected number of arguments"); + let value = ctx.get::<c_double>(0)?; + Ok(value / 2f64) + } + + #[test] + fn test_function_half() { + let db = Connection::open_in_memory().unwrap(); + db.create_scalar_function( + "half", + 1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + half, + ) + .unwrap(); + let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0)); + + assert!((3f64 - result.unwrap()).abs() < EPSILON); + } + + #[test] + fn test_remove_function() { + let db = Connection::open_in_memory().unwrap(); + db.create_scalar_function( + "half", + 1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + half, + ) + .unwrap(); + let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0)); + assert!((3f64 - result.unwrap()).abs() < EPSILON); + + db.remove_function("half", 1).unwrap(); + let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0)); + assert!(result.is_err()); + } + + // This implementation of a regexp scalar function uses SQLite's auxilliary data + // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular + // expression multiple times within one query. + fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> { + assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); + type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>; + let regexp: std::sync::Arc<Regex> = ctx + .get_or_create_aux(0, |vr| -> Result<_, BoxError> { + Ok(Regex::new(vr.as_str()?)?) + })?; + + let is_match = { + let text = ctx + .get_raw(1) + .as_str() + .map_err(|e| Error::UserFunctionError(e.into()))?; + + regexp.is_match(text) + }; + + Ok(is_match) + } + + #[test] + fn test_function_regexp_with_auxilliary() { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch( + "BEGIN; + CREATE TABLE foo (x string); + INSERT INTO foo VALUES ('lisa'); + INSERT INTO foo VALUES ('lXsi'); + INSERT INTO foo VALUES ('lisX'); + END;", + ) + .unwrap(); + db.create_scalar_function( + "regexp", + 2, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + regexp_with_auxilliary, + ) + .unwrap(); + + let result: Result<bool> = + db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", NO_PARAMS, |r| { + r.get(0) + }); + + assert_eq!(true, result.unwrap()); + + let result: Result<i64> = db.query_row( + "SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1", + NO_PARAMS, + |r| r.get(0), + ); + + assert_eq!(2, result.unwrap()); + } + + #[test] + fn test_varargs_function() { + let db = Connection::open_in_memory().unwrap(); + db.create_scalar_function( + "my_concat", + -1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + |ctx| { + let mut ret = String::new(); + + for idx in 0..ctx.len() { + let s = ctx.get::<String>(idx)?; + ret.push_str(&s); + } + + Ok(ret) + }, + ) + .unwrap(); + + for &(expected, query) in &[ + ("", "SELECT my_concat()"), + ("onetwo", "SELECT my_concat('one', 'two')"), + ("abc", "SELECT my_concat('a', 'b', 'c')"), + ] { + let result: String = db.query_row(query, NO_PARAMS, |r| r.get(0)).unwrap(); + assert_eq!(expected, result); + } + } + + #[test] + fn test_get_aux_type_checking() { + let db = Connection::open_in_memory().unwrap(); + db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| { + if !ctx.get::<bool>(1)? { + ctx.set_aux::<i64>(0, 100)?; + } else { + assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType)); + assert_eq!(*ctx.get_aux::<i64>(0).unwrap().unwrap(), 100); + } + Ok(true) + }) + .unwrap(); + + let res: bool = db + .query_row( + "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)", + NO_PARAMS, + |r| r.get(0), + ) + .unwrap(); + // Doesn't actually matter, we'll assert in the function if there's a problem. + assert!(res); + } + + struct Sum; + struct Count; + + impl Aggregate<i64, Option<i64>> for Sum { + fn init(&self) -> i64 { + 0 + } + + fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> { + *sum += ctx.get::<i64>(0)?; + Ok(()) + } + + fn finalize(&self, sum: Option<i64>) -> Result<Option<i64>> { + Ok(sum) + } + } + + impl Aggregate<i64, i64> for Count { + fn init(&self) -> i64 { + 0 + } + + fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> { + *sum += 1; + Ok(()) + } + + fn finalize(&self, sum: Option<i64>) -> Result<i64> { + Ok(sum.unwrap_or(0)) + } + } + + #[test] + fn test_sum() { + let db = Connection::open_in_memory().unwrap(); + db.create_aggregate_function( + "my_sum", + 1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + Sum, + ) + .unwrap(); + + // sum should return NULL when given no columns (contrast with count below) + let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; + let result: Option<i64> = db.query_row(no_result, NO_PARAMS, |r| r.get(0)).unwrap(); + assert!(result.is_none()); + + let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; + let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap(); + assert_eq!(4, result); + + let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \ + 2, 1)"; + let result: (i64, i64) = db + .query_row(dual_sum, NO_PARAMS, |r| Ok((r.get(0)?, r.get(1)?))) + .unwrap(); + assert_eq!((4, 2), result); + } + + #[test] + fn test_count() { + let db = Connection::open_in_memory().unwrap(); + db.create_aggregate_function( + "my_count", + -1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + Count, + ) + .unwrap(); + + // count should return 0 when given no columns (contrast with sum above) + let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; + let result: i64 = db.query_row(no_result, NO_PARAMS, |r| r.get(0)).unwrap(); + assert_eq!(result, 0); + + let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; + let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap(); + assert_eq!(2, result); + } + + #[cfg(feature = "window")] + impl WindowAggregate<i64, Option<i64>> for Sum { + fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> { + *sum -= ctx.get::<i64>(0)?; + Ok(()) + } + + fn value(&self, sum: Option<&i64>) -> Result<Option<i64>> { + Ok(sum.copied()) + } + } + + #[test] + #[cfg(feature = "window")] + fn test_window() { + use fallible_iterator::FallibleIterator; + + let db = Connection::open_in_memory().unwrap(); + db.create_window_function( + "sumint", + 1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + Sum, + ) + .unwrap(); + db.execute_batch( + "CREATE TABLE t3(x, y); + INSERT INTO t3 VALUES('a', 4), + ('b', 5), + ('c', 3), + ('d', 8), + ('e', 1);", + ) + .unwrap(); + + let mut stmt = db + .prepare( + "SELECT x, sumint(y) OVER ( + ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING + ) AS sum_y + FROM t3 ORDER BY x;", + ) + .unwrap(); + + let results: Vec<(String, i64)> = stmt + .query(NO_PARAMS) + .unwrap() + .map(|row| Ok((row.get("x")?, row.get("sum_y")?))) + .collect() + .unwrap(); + let expected = vec![ + ("a".to_owned(), 9), + ("b".to_owned(), 12), + ("c".to_owned(), 16), + ("d".to_owned(), 12), + ("e".to_owned(), 9), + ]; + assert_eq!(expected, results); + } +} diff --git a/third_party/rust/rusqlite/src/hooks.rs b/third_party/rust/rusqlite/src/hooks.rs new file mode 100644 index 0000000000..53dc041185 --- /dev/null +++ b/third_party/rust/rusqlite/src/hooks.rs @@ -0,0 +1,314 @@ +//! `feature = "hooks"` Commit, Data Change and Rollback Notification Callbacks +#![allow(non_camel_case_types)] + +use std::os::raw::{c_char, c_int, c_void}; +use std::panic::catch_unwind; +use std::ptr; + +use crate::ffi; + +use crate::{Connection, InnerConnection}; + +/// `feature = "hooks"` Action Codes +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(i32)] +#[non_exhaustive] +pub enum Action { + /// Unsupported / unexpected action + UNKNOWN = -1, + /// DELETE command + SQLITE_DELETE = ffi::SQLITE_DELETE, + /// INSERT command + SQLITE_INSERT = ffi::SQLITE_INSERT, + /// UPDATE command + SQLITE_UPDATE = ffi::SQLITE_UPDATE, +} + +impl From<i32> for Action { + fn from(code: i32) -> Action { + match code { + ffi::SQLITE_DELETE => Action::SQLITE_DELETE, + ffi::SQLITE_INSERT => Action::SQLITE_INSERT, + ffi::SQLITE_UPDATE => Action::SQLITE_UPDATE, + _ => Action::UNKNOWN, + } + } +} + +impl Connection { + /// `feature = "hooks"` Register a callback function to be invoked whenever + /// a transaction is committed. + /// + /// The callback returns `true` to rollback. + pub fn commit_hook<F>(&self, hook: Option<F>) + where + F: FnMut() -> bool + Send + 'static, + { + self.db.borrow_mut().commit_hook(hook); + } + + /// `feature = "hooks"` Register a callback function to be invoked whenever + /// a transaction is committed. + /// + /// The callback returns `true` to rollback. + pub fn rollback_hook<F>(&self, hook: Option<F>) + where + F: FnMut() + Send + 'static, + { + self.db.borrow_mut().rollback_hook(hook); + } + + /// `feature = "hooks"` Register a callback function to be invoked whenever + /// a row is updated, inserted or deleted in a rowid table. + /// + /// The callback parameters are: + /// + /// - the type of database update (SQLITE_INSERT, SQLITE_UPDATE or + /// SQLITE_DELETE), + /// - the name of the database ("main", "temp", ...), + /// - the name of the table that is updated, + /// - the ROWID of the row that is updated. + pub fn update_hook<F>(&self, hook: Option<F>) + where + F: FnMut(Action, &str, &str, i64) + Send + 'static, + { + self.db.borrow_mut().update_hook(hook); + } +} + +impl InnerConnection { + pub fn remove_hooks(&mut self) { + self.update_hook(None::<fn(Action, &str, &str, i64)>); + self.commit_hook(None::<fn() -> bool>); + self.rollback_hook(None::<fn()>); + } + + fn commit_hook<F>(&mut self, hook: Option<F>) + where + F: FnMut() -> bool + Send + 'static, + { + unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int + where + F: FnMut() -> bool, + { + let r = catch_unwind(|| { + let boxed_hook: *mut F = p_arg as *mut F; + (*boxed_hook)() + }); + if let Ok(true) = r { + 1 + } else { + 0 + } + } + + // unlike `sqlite3_create_function_v2`, we cannot specify a `xDestroy` with + // `sqlite3_commit_hook`. so we keep the `xDestroy` function in + // `InnerConnection.free_boxed_hook`. + let free_commit_hook = if hook.is_some() { + Some(free_boxed_hook::<F> as unsafe fn(*mut c_void)) + } else { + None + }; + + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_commit_hook( + self.db(), + Some(call_boxed_closure::<F>), + boxed_hook as *mut _, + ) + } + } + _ => unsafe { ffi::sqlite3_commit_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_commit_hook { + unsafe { free_boxed_hook(previous_hook) }; + } + } + self.free_commit_hook = free_commit_hook; + } + + fn rollback_hook<F>(&mut self, hook: Option<F>) + where + F: FnMut() + Send + 'static, + { + unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) + where + F: FnMut(), + { + let _ = catch_unwind(|| { + let boxed_hook: *mut F = p_arg as *mut F; + (*boxed_hook)(); + }); + } + + let free_rollback_hook = if hook.is_some() { + Some(free_boxed_hook::<F> as unsafe fn(*mut c_void)) + } else { + None + }; + + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_rollback_hook( + self.db(), + Some(call_boxed_closure::<F>), + boxed_hook as *mut _, + ) + } + } + _ => unsafe { ffi::sqlite3_rollback_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_rollback_hook { + unsafe { free_boxed_hook(previous_hook) }; + } + } + self.free_rollback_hook = free_rollback_hook; + } + + fn update_hook<F>(&mut self, hook: Option<F>) + where + F: FnMut(Action, &str, &str, i64) + Send + 'static, + { + unsafe extern "C" fn call_boxed_closure<F>( + p_arg: *mut c_void, + action_code: c_int, + db_str: *const c_char, + tbl_str: *const c_char, + row_id: i64, + ) where + F: FnMut(Action, &str, &str, i64), + { + use std::ffi::CStr; + use std::str; + + let action = Action::from(action_code); + let db_name = { + let c_slice = CStr::from_ptr(db_str).to_bytes(); + str::from_utf8(c_slice) + }; + let tbl_name = { + let c_slice = CStr::from_ptr(tbl_str).to_bytes(); + str::from_utf8(c_slice) + }; + + let _ = catch_unwind(|| { + let boxed_hook: *mut F = p_arg as *mut F; + (*boxed_hook)( + action, + db_name.expect("illegal db name"), + tbl_name.expect("illegal table name"), + row_id, + ); + }); + } + + let free_update_hook = if hook.is_some() { + Some(free_boxed_hook::<F> as unsafe fn(*mut c_void)) + } else { + None + }; + + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_update_hook( + self.db(), + Some(call_boxed_closure::<F>), + boxed_hook as *mut _, + ) + } + } + _ => unsafe { ffi::sqlite3_update_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_update_hook { + unsafe { free_boxed_hook(previous_hook) }; + } + } + self.free_update_hook = free_update_hook; + } +} + +unsafe fn free_boxed_hook<F>(p: *mut c_void) { + drop(Box::from_raw(p as *mut F)); +} + +#[cfg(test)] +mod test { + use super::Action; + use crate::Connection; + use lazy_static::lazy_static; + use std::sync::atomic::{AtomicBool, Ordering}; + + #[test] + fn test_commit_hook() { + let db = Connection::open_in_memory().unwrap(); + + lazy_static! { + static ref CALLED: AtomicBool = AtomicBool::new(false); + } + db.commit_hook(Some(|| { + CALLED.store(true, Ordering::Relaxed); + false + })); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;") + .unwrap(); + assert!(CALLED.load(Ordering::Relaxed)); + } + + #[test] + fn test_fn_commit_hook() { + let db = Connection::open_in_memory().unwrap(); + + fn hook() -> bool { + true + } + + db.commit_hook(Some(hook)); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;") + .unwrap_err(); + } + + #[test] + fn test_rollback_hook() { + let db = Connection::open_in_memory().unwrap(); + + lazy_static! { + static ref CALLED: AtomicBool = AtomicBool::new(false); + } + db.rollback_hook(Some(|| { + CALLED.store(true, Ordering::Relaxed); + })); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); ROLLBACK;") + .unwrap(); + assert!(CALLED.load(Ordering::Relaxed)); + } + + #[test] + fn test_update_hook() { + let db = Connection::open_in_memory().unwrap(); + + lazy_static! { + static ref CALLED: AtomicBool = AtomicBool::new(false); + } + db.update_hook(Some(|action, db: &str, tbl: &str, row_id| { + assert_eq!(Action::SQLITE_INSERT, action); + assert_eq!("main", db); + assert_eq!("foo", tbl); + assert_eq!(1, row_id); + CALLED.store(true, Ordering::Relaxed); + })); + db.execute_batch("CREATE TABLE foo (t TEXT)").unwrap(); + db.execute_batch("INSERT INTO foo VALUES ('lisa')").unwrap(); + assert!(CALLED.load(Ordering::Relaxed)); + } +} diff --git a/third_party/rust/rusqlite/src/inner_connection.rs b/third_party/rust/rusqlite/src/inner_connection.rs new file mode 100644 index 0000000000..dd786fed0f --- /dev/null +++ b/third_party/rust/rusqlite/src/inner_connection.rs @@ -0,0 +1,430 @@ +use std::ffi::CStr; +use std::os::raw::{c_char, c_int}; +#[cfg(feature = "load_extension")] +use std::path::Path; +use std::ptr; +use std::str; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; + +use super::ffi; +use super::str_for_sqlite; +use super::{Connection, InterruptHandle, OpenFlags, Result}; +use crate::error::{error_from_handle, error_from_sqlite_code, Error}; +use crate::raw_statement::RawStatement; +use crate::statement::Statement; +use crate::unlock_notify; +use crate::version::version_number; + +pub struct InnerConnection { + pub db: *mut ffi::sqlite3, + // It's unsafe to call `sqlite3_close` while another thread is performing + // a `sqlite3_interrupt`, and vice versa, so we take this mutex during + // those functions. This protects a copy of the `db` pointer (which is + // cleared on closing), however the main copy, `db`, is unprotected. + // Otherwise, a long running query would prevent calling interrupt, as + // interrupt would only acquire the lock after the query's completion. + interrupt_lock: Arc<Mutex<*mut ffi::sqlite3>>, + #[cfg(feature = "hooks")] + pub free_commit_hook: Option<unsafe fn(*mut ::std::os::raw::c_void)>, + #[cfg(feature = "hooks")] + pub free_rollback_hook: Option<unsafe fn(*mut ::std::os::raw::c_void)>, + #[cfg(feature = "hooks")] + pub free_update_hook: Option<unsafe fn(*mut ::std::os::raw::c_void)>, + owned: bool, +} + +impl InnerConnection { + #[allow(clippy::mutex_atomic)] + pub unsafe fn new(db: *mut ffi::sqlite3, owned: bool) -> InnerConnection { + InnerConnection { + db, + interrupt_lock: Arc::new(Mutex::new(db)), + #[cfg(feature = "hooks")] + free_commit_hook: None, + #[cfg(feature = "hooks")] + free_rollback_hook: None, + #[cfg(feature = "hooks")] + free_update_hook: None, + owned, + } + } + + pub fn open_with_flags( + c_path: &CStr, + flags: OpenFlags, + vfs: Option<&CStr>, + ) -> Result<InnerConnection> { + #[cfg(not(feature = "bundled"))] + ensure_valid_sqlite_version(); + ensure_safe_sqlite_threading_mode()?; + + // Replicate the check for sane open flags from SQLite, because the check in + // SQLite itself wasn't added until version 3.7.3. + debug_assert_eq!(1 << OpenFlags::SQLITE_OPEN_READ_ONLY.bits, 0x02); + debug_assert_eq!(1 << OpenFlags::SQLITE_OPEN_READ_WRITE.bits, 0x04); + debug_assert_eq!( + 1 << (OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE).bits, + 0x40 + ); + if (1 << (flags.bits & 0x7)) & 0x46 == 0 { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + None, + )); + } + + let z_vfs = match vfs { + Some(c_vfs) => c_vfs.as_ptr(), + None => ptr::null(), + }; + + unsafe { + let mut db: *mut ffi::sqlite3 = ptr::null_mut(); + let r = ffi::sqlite3_open_v2(c_path.as_ptr(), &mut db, flags.bits(), z_vfs); + if r != ffi::SQLITE_OK { + let e = if db.is_null() { + error_from_sqlite_code(r, Some(c_path.to_string_lossy().to_string())) + } else { + let mut e = error_from_handle(db, r); + if let Error::SqliteFailure( + ffi::Error { + code: ffi::ErrorCode::CannotOpen, + .. + }, + Some(msg), + ) = e + { + e = Error::SqliteFailure( + ffi::Error::new(r), + Some(format!("{}: {}", msg, c_path.to_string_lossy())), + ); + } + ffi::sqlite3_close(db); + e + }; + + return Err(e); + } + + // attempt to turn on extended results code; don't fail if we can't. + ffi::sqlite3_extended_result_codes(db, 1); + + let r = ffi::sqlite3_busy_timeout(db, 5000); + if r != ffi::SQLITE_OK { + let e = error_from_handle(db, r); + ffi::sqlite3_close(db); + return Err(e); + } + + Ok(InnerConnection::new(db, true)) + } + } + + pub fn db(&self) -> *mut ffi::sqlite3 { + self.db + } + + pub fn decode_result(&mut self, code: c_int) -> Result<()> { + unsafe { InnerConnection::decode_result_raw(self.db(), code) } + } + + unsafe fn decode_result_raw(db: *mut ffi::sqlite3, code: c_int) -> Result<()> { + if code == ffi::SQLITE_OK { + Ok(()) + } else { + Err(error_from_handle(db, code)) + } + } + + #[allow(clippy::mutex_atomic)] + pub fn close(&mut self) -> Result<()> { + if self.db.is_null() { + return Ok(()); + } + self.remove_hooks(); + let mut shared_handle = self.interrupt_lock.lock().unwrap(); + assert!( + !shared_handle.is_null(), + "Bug: Somehow interrupt_lock was cleared before the DB was closed" + ); + if !self.owned { + self.db = ptr::null_mut(); + return Ok(()); + } + unsafe { + let r = ffi::sqlite3_close(self.db); + // Need to use _raw because _guard has a reference out, and + // decode_result takes &mut self. + let r = InnerConnection::decode_result_raw(self.db, r); + if r.is_ok() { + *shared_handle = ptr::null_mut(); + self.db = ptr::null_mut(); + } + r + } + } + + pub fn get_interrupt_handle(&self) -> InterruptHandle { + InterruptHandle { + db_lock: Arc::clone(&self.interrupt_lock), + } + } + + #[cfg(feature = "load_extension")] + pub fn enable_load_extension(&mut self, onoff: c_int) -> Result<()> { + let r = unsafe { ffi::sqlite3_enable_load_extension(self.db, onoff) }; + self.decode_result(r) + } + + #[cfg(feature = "load_extension")] + pub fn load_extension(&self, dylib_path: &Path, entry_point: Option<&str>) -> Result<()> { + let dylib_str = super::path_to_cstring(dylib_path)?; + unsafe { + let mut errmsg: *mut c_char = ptr::null_mut(); + let r = if let Some(entry_point) = entry_point { + let c_entry = crate::str_to_cstring(entry_point)?; + ffi::sqlite3_load_extension( + self.db, + dylib_str.as_ptr(), + c_entry.as_ptr(), + &mut errmsg, + ) + } else { + ffi::sqlite3_load_extension(self.db, dylib_str.as_ptr(), ptr::null(), &mut errmsg) + }; + if r == ffi::SQLITE_OK { + Ok(()) + } else { + let message = super::errmsg_to_string(errmsg); + ffi::sqlite3_free(errmsg as *mut ::std::os::raw::c_void); + Err(error_from_sqlite_code(r, Some(message))) + } + } + } + + pub fn last_insert_rowid(&self) -> i64 { + unsafe { ffi::sqlite3_last_insert_rowid(self.db()) } + } + + pub fn prepare<'a>(&mut self, conn: &'a Connection, sql: &str) -> Result<Statement<'a>> { + let mut c_stmt = ptr::null_mut(); + let (c_sql, len, _) = str_for_sqlite(sql.as_bytes())?; + let mut c_tail = ptr::null(); + let r = unsafe { + if cfg!(feature = "unlock_notify") { + let mut rc; + loop { + rc = ffi::sqlite3_prepare_v2( + self.db(), + c_sql, + len, + &mut c_stmt as *mut *mut ffi::sqlite3_stmt, + &mut c_tail as *mut *const c_char, + ); + if !unlock_notify::is_locked(self.db, rc) { + break; + } + rc = unlock_notify::wait_for_unlock_notify(self.db); + if rc != ffi::SQLITE_OK { + break; + } + } + rc + } else { + ffi::sqlite3_prepare_v2( + self.db(), + c_sql, + len, + &mut c_stmt as *mut *mut ffi::sqlite3_stmt, + &mut c_tail as *mut *const c_char, + ) + } + }; + // If there is an error, *ppStmt is set to NULL. + self.decode_result(r)?; + // If the input text contains no SQL (if the input is an empty string or a + // comment) then *ppStmt is set to NULL. + let c_stmt: *mut ffi::sqlite3_stmt = c_stmt; + let c_tail: *const c_char = c_tail; + let tail = if c_tail.is_null() { + 0 + } else { + // TODO nightly feature ptr_offset_from #41079 + let n = (c_tail as isize) - (c_sql as isize); + if n <= 0 || n >= len as isize { + 0 + } else { + n as usize + } + }; + Ok(Statement::new(conn, unsafe { + RawStatement::new(c_stmt, tail) + })) + } + + pub fn changes(&mut self) -> usize { + unsafe { ffi::sqlite3_changes(self.db()) as usize } + } + + pub fn is_autocommit(&self) -> bool { + unsafe { ffi::sqlite3_get_autocommit(self.db()) != 0 } + } + + #[cfg(feature = "modern_sqlite")] // 3.8.6 + pub fn is_busy(&self) -> bool { + let db = self.db(); + unsafe { + let mut stmt = ffi::sqlite3_next_stmt(db, ptr::null_mut()); + while !stmt.is_null() { + if ffi::sqlite3_stmt_busy(stmt) != 0 { + return true; + } + stmt = ffi::sqlite3_next_stmt(db, stmt); + } + } + false + } + + #[cfg(not(feature = "hooks"))] + fn remove_hooks(&mut self) {} +} + +impl Drop for InnerConnection { + #[allow(unused_must_use)] + fn drop(&mut self) { + use std::thread::panicking; + + if let Err(e) = self.close() { + if panicking() { + eprintln!("Error while closing SQLite connection: {:?}", e); + } else { + panic!("Error while closing SQLite connection: {:?}", e); + } + } + } +} + +#[cfg(not(feature = "bundled"))] +static SQLITE_VERSION_CHECK: std::sync::Once = std::sync::Once::new(); +#[cfg(not(feature = "bundled"))] +pub static BYPASS_VERSION_CHECK: AtomicBool = AtomicBool::new(false); + +#[cfg(not(feature = "bundled"))] +fn ensure_valid_sqlite_version() { + use crate::version::version; + + SQLITE_VERSION_CHECK.call_once(|| { + let version_number = version_number(); + + // Check our hard floor. + if version_number < 3_006_008 { + panic!("rusqlite requires SQLite 3.6.8 or newer"); + } + + // Check that the major version number for runtime and buildtime match. + let buildtime_major = ffi::SQLITE_VERSION_NUMBER / 1_000_000; + let runtime_major = version_number / 1_000_000; + if buildtime_major != runtime_major { + panic!( + "rusqlite was built against SQLite {} but is running with SQLite {}", + str::from_utf8(ffi::SQLITE_VERSION).unwrap(), + version() + ); + } + + if BYPASS_VERSION_CHECK.load(Ordering::Relaxed) { + return; + } + + // Check that the runtime version number is compatible with the version number + // we found at build-time. + if version_number < ffi::SQLITE_VERSION_NUMBER { + panic!( + "\ +rusqlite was built against SQLite {} but the runtime SQLite version is {}. To fix this, either: +* Recompile rusqlite and link against the SQLite version you are using at runtime, or +* Call rusqlite::bypass_sqlite_version_check() prior to your first connection attempt. Doing this + means you're sure everything will work correctly even though the runtime version is older than + the version we found at build time.", + str::from_utf8(ffi::SQLITE_VERSION).unwrap(), + version() + ); + } + }); +} + +#[cfg(not(any(target_arch = "wasm32")))] +static SQLITE_INIT: std::sync::Once = std::sync::Once::new(); + +pub static BYPASS_SQLITE_INIT: AtomicBool = AtomicBool::new(false); + +// threading mode checks are not necessary (and do not work) on target +// platforms that do not have threading (such as webassembly) +#[cfg(any(target_arch = "wasm32"))] +fn ensure_safe_sqlite_threading_mode() -> Result<()> { + Ok(()) +} + +#[cfg(not(any(target_arch = "wasm32")))] +fn ensure_safe_sqlite_threading_mode() -> Result<()> { + // Ensure SQLite was compiled in thredsafe mode. + if unsafe { ffi::sqlite3_threadsafe() == 0 } { + return Err(Error::SqliteSingleThreadedMode); + } + + // Now we know SQLite is _capable_ of being in Multi-thread of Serialized mode, + // but it's possible someone configured it to be in Single-thread mode + // before calling into us. That would mean we're exposing an unsafe API via + // a safe one (in Rust terminology), which is no good. We have two options + // to protect against this, depending on the version of SQLite we're linked + // with: + // + // 1. If we're on 3.7.0 or later, we can ask SQLite for a mutex and check for + // the magic value 8. This isn't documented, but it's what SQLite + // returns for its mutex allocation function in Single-thread mode. + // 2. If we're prior to SQLite 3.7.0, AFAIK there's no way to check the + // threading mode. The check we perform for >= 3.7.0 will segfault. + // Instead, we insist on being able to call sqlite3_config and + // sqlite3_initialize ourself, ensuring we know the threading + // mode. This will fail if someone else has already initialized SQLite + // even if they initialized it safely. That's not ideal either, which is + // why we expose bypass_sqlite_initialization above. + if version_number() >= 3_007_000 { + const SQLITE_SINGLETHREADED_MUTEX_MAGIC: usize = 8; + let is_singlethreaded = unsafe { + let mutex_ptr = ffi::sqlite3_mutex_alloc(0); + let is_singlethreaded = mutex_ptr as usize == SQLITE_SINGLETHREADED_MUTEX_MAGIC; + ffi::sqlite3_mutex_free(mutex_ptr); + is_singlethreaded + }; + if is_singlethreaded { + Err(Error::SqliteSingleThreadedMode) + } else { + Ok(()) + } + } else { + SQLITE_INIT.call_once(|| { + if BYPASS_SQLITE_INIT.load(Ordering::Relaxed) { + return; + } + + unsafe { + let msg = "\ +Could not ensure safe initialization of SQLite. +To fix this, either: +* Upgrade SQLite to at least version 3.7.0 +* Ensure that SQLite has been initialized in Multi-thread or Serialized mode and call + rusqlite::bypass_sqlite_initialization() prior to your first connection attempt."; + + if ffi::sqlite3_config(ffi::SQLITE_CONFIG_MULTITHREAD) != ffi::SQLITE_OK { + panic!(msg); + } + if ffi::sqlite3_initialize() != ffi::SQLITE_OK { + panic!(msg); + } + } + }); + Ok(()) + } +} diff --git a/third_party/rust/rusqlite/src/lib.rs b/third_party/rust/rusqlite/src/lib.rs new file mode 100644 index 0000000000..53f1773fd3 --- /dev/null +++ b/third_party/rust/rusqlite/src/lib.rs @@ -0,0 +1,1811 @@ +//! Rusqlite is an ergonomic wrapper for using SQLite from Rust. It attempts to +//! expose an interface similar to [rust-postgres](https://github.com/sfackler/rust-postgres). +//! +//! ```rust +//! use rusqlite::{params, Connection, Result}; +//! +//! #[derive(Debug)] +//! struct Person { +//! id: i32, +//! name: String, +//! data: Option<Vec<u8>>, +//! } +//! +//! fn main() -> Result<()> { +//! let conn = Connection::open_in_memory()?; +//! +//! conn.execute( +//! "CREATE TABLE person ( +//! id INTEGER PRIMARY KEY, +//! name TEXT NOT NULL, +//! data BLOB +//! )", +//! params![], +//! )?; +//! let me = Person { +//! id: 0, +//! name: "Steven".to_string(), +//! data: None, +//! }; +//! conn.execute( +//! "INSERT INTO person (name, data) VALUES (?1, ?2)", +//! params![me.name, me.data], +//! )?; +//! +//! let mut stmt = conn.prepare("SELECT id, name, data FROM person")?; +//! let person_iter = stmt.query_map(params![], |row| { +//! Ok(Person { +//! id: row.get(0)?, +//! name: row.get(1)?, +//! data: row.get(2)?, +//! }) +//! })?; +//! +//! for person in person_iter { +//! println!("Found person {:?}", person.unwrap()); +//! } +//! Ok(()) +//! } +//! ``` +#![warn(missing_docs)] + +pub use libsqlite3_sys as ffi; + +use std::cell::RefCell; +use std::convert; +use std::default::Default; +use std::ffi::{CStr, CString}; +use std::fmt; +use std::os::raw::{c_char, c_int}; + +use std::path::{Path, PathBuf}; +use std::result; +use std::str; +use std::sync::atomic::Ordering; +use std::sync::{Arc, Mutex}; + +use crate::cache::StatementCache; +use crate::inner_connection::{InnerConnection, BYPASS_SQLITE_INIT}; +use crate::raw_statement::RawStatement; +use crate::types::ValueRef; + +pub use crate::cache::CachedStatement; +pub use crate::column::Column; +pub use crate::error::Error; +pub use crate::ffi::ErrorCode; +#[cfg(feature = "hooks")] +pub use crate::hooks::Action; +#[cfg(feature = "load_extension")] +pub use crate::load_extension_guard::LoadExtensionGuard; +pub use crate::row::{AndThenRows, Map, MappedRows, Row, RowIndex, Rows}; +pub use crate::statement::{Statement, StatementStatus}; +pub use crate::transaction::{DropBehavior, Savepoint, Transaction, TransactionBehavior}; +pub use crate::types::ToSql; +pub use crate::version::*; + +#[macro_use] +mod error; + +#[cfg(feature = "backup")] +pub mod backup; +#[cfg(feature = "blob")] +pub mod blob; +mod busy; +mod cache; +#[cfg(feature = "collation")] +mod collation; +mod column; +pub mod config; +#[cfg(any(feature = "functions", feature = "vtab"))] +mod context; +#[cfg(feature = "functions")] +pub mod functions; +#[cfg(feature = "hooks")] +mod hooks; +mod inner_connection; +#[cfg(feature = "limits")] +pub mod limits; +#[cfg(feature = "load_extension")] +mod load_extension_guard; +mod pragma; +mod raw_statement; +mod row; +#[cfg(feature = "session")] +pub mod session; +mod statement; +#[cfg(feature = "trace")] +pub mod trace; +mod transaction; +pub mod types; +mod unlock_notify; +mod version; +#[cfg(feature = "vtab")] +pub mod vtab; + +pub(crate) mod util; +pub(crate) use util::SmallCString; + +// Number of cached prepared statements we'll hold on to. +const STATEMENT_CACHE_DEFAULT_CAPACITY: usize = 16; +/// To be used when your statement has no [parameter](https://sqlite.org/lang_expr.html#varparam). +pub const NO_PARAMS: &[&dyn ToSql] = &[]; + +/// A macro making it more convenient to pass heterogeneous lists +/// of parameters as a `&[&dyn ToSql]`. +/// +/// # Example +/// +/// ```rust,no_run +/// # use rusqlite::{Result, Connection, params}; +/// +/// struct Person { +/// name: String, +/// age_in_years: u8, +/// data: Option<Vec<u8>>, +/// } +/// +/// fn add_person(conn: &Connection, person: &Person) -> Result<()> { +/// conn.execute("INSERT INTO person (name, age_in_years, data) +/// VALUES (?1, ?2, ?3)", +/// params![person.name, person.age_in_years, person.data])?; +/// Ok(()) +/// } +/// ``` +#[macro_export] +macro_rules! params { + () => { + $crate::NO_PARAMS + }; + ($($param:expr),+ $(,)?) => { + &[$(&$param as &dyn $crate::ToSql),+] as &[&dyn $crate::ToSql] + }; +} + +/// A macro making it more convenient to pass lists of named parameters +/// as a `&[(&str, &dyn ToSql)]`. +/// +/// # Example +/// +/// ```rust,no_run +/// # use rusqlite::{Result, Connection, named_params}; +/// +/// struct Person { +/// name: String, +/// age_in_years: u8, +/// data: Option<Vec<u8>>, +/// } +/// +/// fn add_person(conn: &Connection, person: &Person) -> Result<()> { +/// conn.execute_named( +/// "INSERT INTO person (name, age_in_years, data) +/// VALUES (:name, :age, :data)", +/// named_params!{ +/// ":name": person.name, +/// ":age": person.age_in_years, +/// ":data": person.data, +/// } +/// )?; +/// Ok(()) +/// } +/// ``` +#[macro_export] +macro_rules! named_params { + () => { + &[] + }; + // Note: It's a lot more work to support this as part of the same macro as + // `params!`, unfortunately. + ($($param_name:literal: $param_val:expr),+ $(,)?) => { + &[$(($param_name, &$param_val as &dyn $crate::ToSql)),+] + }; +} + +/// A typedef of the result returned by many methods. +pub type Result<T, E = Error> = result::Result<T, E>; + +/// See the [method documentation](#tymethod.optional). +pub trait OptionalExtension<T> { + /// Converts a `Result<T>` into a `Result<Option<T>>`. + /// + /// By default, Rusqlite treats 0 rows being returned from a query that is + /// expected to return 1 row as an error. This method will + /// handle that error, and give you back an `Option<T>` instead. + fn optional(self) -> Result<Option<T>>; +} + +impl<T> OptionalExtension<T> for Result<T> { + fn optional(self) -> Result<Option<T>> { + match self { + Ok(value) => Ok(Some(value)), + Err(Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(e), + } + } +} + +unsafe fn errmsg_to_string(errmsg: *const c_char) -> String { + let c_slice = CStr::from_ptr(errmsg).to_bytes(); + String::from_utf8_lossy(c_slice).into_owned() +} + +fn str_to_cstring(s: &str) -> Result<SmallCString> { + Ok(SmallCString::new(s)?) +} + +/// Returns `Ok((string ptr, len as c_int, SQLITE_STATIC | SQLITE_TRANSIENT))` +/// normally. +/// Returns error if the string is too large for sqlite. +/// The `sqlite3_destructor_type` item is always `SQLITE_TRANSIENT` unless +/// the string was empty (in which case it's `SQLITE_STATIC`, and the ptr is +/// static). +fn str_for_sqlite(s: &[u8]) -> Result<(*const c_char, c_int, ffi::sqlite3_destructor_type)> { + let len = len_as_c_int(s.len())?; + let (ptr, dtor_info) = if len != 0 { + (s.as_ptr() as *const c_char, ffi::SQLITE_TRANSIENT()) + } else { + // Return a pointer guaranteed to live forever + ("".as_ptr() as *const c_char, ffi::SQLITE_STATIC()) + }; + Ok((ptr, len, dtor_info)) +} + +// Helper to cast to c_int safely, returning the correct error type if the cast +// failed. +fn len_as_c_int(len: usize) -> Result<c_int> { + if len >= (c_int::max_value() as usize) { + Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_TOOBIG), + None, + )) + } else { + Ok(len as c_int) + } +} + +#[cfg(unix)] +fn path_to_cstring(p: &Path) -> Result<CString> { + use std::os::unix::ffi::OsStrExt; + Ok(CString::new(p.as_os_str().as_bytes())?) +} + +#[cfg(not(unix))] +fn path_to_cstring(p: &Path) -> Result<CString> { + let s = p.to_str().ok_or_else(|| Error::InvalidPath(p.to_owned()))?; + Ok(CString::new(s)?) +} + +/// Name for a database within a SQLite connection. +#[derive(Copy, Clone)] +pub enum DatabaseName<'a> { + /// The main database. + Main, + + /// The temporary database (e.g., any "CREATE TEMPORARY TABLE" tables). + Temp, + + /// A database that has been attached via "ATTACH DATABASE ...". + Attached(&'a str), +} + +// Currently DatabaseName is only used by the backup and blob mods, so hide +// this (private) impl to avoid dead code warnings. +#[cfg(any( + feature = "backup", + feature = "blob", + feature = "session", + feature = "modern_sqlite" +))] +impl DatabaseName<'_> { + fn to_cstring(&self) -> Result<util::SmallCString> { + use self::DatabaseName::{Attached, Main, Temp}; + match *self { + Main => str_to_cstring("main"), + Temp => str_to_cstring("temp"), + Attached(s) => str_to_cstring(s), + } + } +} + +/// A connection to a SQLite database. +pub struct Connection { + db: RefCell<InnerConnection>, + cache: StatementCache, + path: Option<PathBuf>, +} + +unsafe impl Send for Connection {} + +impl Drop for Connection { + fn drop(&mut self) { + self.flush_prepared_statement_cache(); + } +} + +impl Connection { + /// Open a new connection to a SQLite database. + /// + /// `Connection::open(path)` is equivalent to + /// `Connection::open_with_flags(path, + /// OpenFlags::SQLITE_OPEN_READ_WRITE | + /// OpenFlags::SQLITE_OPEN_CREATE)`. + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn open_my_db() -> Result<()> { + /// let path = "./my_db.db3"; + /// let db = Connection::open(&path)?; + /// println!("{}", db.is_autocommit()); + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if `path` cannot be converted to a C-compatible + /// string or if the underlying SQLite open call fails. + pub fn open<P: AsRef<Path>>(path: P) -> Result<Connection> { + let flags = OpenFlags::default(); + Connection::open_with_flags(path, flags) + } + + /// Open a new connection to an in-memory SQLite database. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite open call fails. + pub fn open_in_memory() -> Result<Connection> { + let flags = OpenFlags::default(); + Connection::open_in_memory_with_flags(flags) + } + + /// Open a new connection to a SQLite database. + /// + /// [Database Connection](http://www.sqlite.org/c3ref/open.html) for a description of valid + /// flag combinations. + /// + /// # Failure + /// + /// Will return `Err` if `path` cannot be converted to a C-compatible + /// string or if the underlying SQLite open call fails. + pub fn open_with_flags<P: AsRef<Path>>(path: P, flags: OpenFlags) -> Result<Connection> { + let c_path = path_to_cstring(path.as_ref())?; + InnerConnection::open_with_flags(&c_path, flags, None).map(|db| Connection { + db: RefCell::new(db), + cache: StatementCache::with_capacity(STATEMENT_CACHE_DEFAULT_CAPACITY), + path: Some(path.as_ref().to_path_buf()), + }) + } + + /// Open a new connection to a SQLite database using the specific flags and + /// vfs name. + /// + /// [Database Connection](http://www.sqlite.org/c3ref/open.html) for a description of valid + /// flag combinations. + /// + /// # Failure + /// + /// Will return `Err` if either `path` or `vfs` cannot be converted to a + /// C-compatible string or if the underlying SQLite open call fails. + pub fn open_with_flags_and_vfs<P: AsRef<Path>>( + path: P, + flags: OpenFlags, + vfs: &str, + ) -> Result<Connection> { + let c_path = path_to_cstring(path.as_ref())?; + let c_vfs = str_to_cstring(vfs)?; + InnerConnection::open_with_flags(&c_path, flags, Some(&c_vfs)).map(|db| Connection { + db: RefCell::new(db), + cache: StatementCache::with_capacity(STATEMENT_CACHE_DEFAULT_CAPACITY), + path: Some(path.as_ref().to_path_buf()), + }) + } + + /// Open a new connection to an in-memory SQLite database. + /// + /// [Database Connection](http://www.sqlite.org/c3ref/open.html) for a description of valid + /// flag combinations. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite open call fails. + pub fn open_in_memory_with_flags(flags: OpenFlags) -> Result<Connection> { + Connection::open_with_flags(":memory:", flags) + } + + /// Open a new connection to an in-memory SQLite database using the specific + /// flags and vfs name. + /// + /// [Database Connection](http://www.sqlite.org/c3ref/open.html) for a description of valid + /// flag combinations. + /// + /// # Failure + /// + /// Will return `Err` if vfs` cannot be converted to a C-compatible + /// string or if the underlying SQLite open call fails. + pub fn open_in_memory_with_flags_and_vfs(flags: OpenFlags, vfs: &str) -> Result<Connection> { + Connection::open_with_flags_and_vfs(":memory:", flags, vfs) + } + + /// Convenience method to run multiple SQL statements (that cannot take any + /// parameters). + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn create_tables(conn: &Connection) -> Result<()> { + /// conn.execute_batch( + /// "BEGIN; + /// CREATE TABLE foo(x INTEGER); + /// CREATE TABLE bar(y TEXT); + /// COMMIT;", + /// ) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + pub fn execute_batch(&self, sql: &str) -> Result<()> { + let mut sql = sql; + while !sql.is_empty() { + let stmt = self.prepare(sql)?; + if !stmt.stmt.is_null() && stmt.step()? && cfg!(feature = "extra_check") { + // Some PRAGMA may return rows + return Err(Error::ExecuteReturnedResults); + } + let tail = stmt.stmt.tail(); + if tail == 0 || tail >= sql.len() { + break; + } + sql = &sql[tail..]; + } + Ok(()) + } + + /// Convenience method to prepare and execute a single SQL statement. + /// + /// On success, returns the number of rows that were changed or inserted or + /// deleted (via `sqlite3_changes`). + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection}; + /// fn update_rows(conn: &Connection) { + /// match conn.execute("UPDATE foo SET bar = 'baz' WHERE qux = ?", &[1i32]) { + /// Ok(updated) => println!("{} rows were updated", updated), + /// Err(err) => println!("update failed: {}", err), + /// } + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + pub fn execute<P>(&self, sql: &str, params: P) -> Result<usize> + where + P: IntoIterator, + P::Item: ToSql, + { + self.prepare(sql) + .and_then(|mut stmt| stmt.check_no_tail().and_then(|_| stmt.execute(params))) + } + + /// Convenience method to prepare and execute a single SQL statement with + /// named parameter(s). + /// + /// On success, returns the number of rows that were changed or inserted or + /// deleted (via `sqlite3_changes`). + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn insert(conn: &Connection) -> Result<usize> { + /// conn.execute_named( + /// "INSERT INTO test (name) VALUES (:name)", + /// &[(":name", &"one")], + /// ) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + pub fn execute_named(&self, sql: &str, params: &[(&str, &dyn ToSql)]) -> Result<usize> { + self.prepare(sql).and_then(|mut stmt| { + stmt.check_no_tail() + .and_then(|_| stmt.execute_named(params)) + }) + } + + /// Get the SQLite rowid of the most recent successful INSERT. + /// + /// Uses [sqlite3_last_insert_rowid](https://www.sqlite.org/c3ref/last_insert_rowid.html) under + /// the hood. + pub fn last_insert_rowid(&self) -> i64 { + self.db.borrow_mut().last_insert_rowid() + } + + /// Convenience method to execute a query that is expected to return a + /// single row. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Result,Connection, NO_PARAMS}; + /// fn preferred_locale(conn: &Connection) -> Result<String> { + /// conn.query_row( + /// "SELECT value FROM preferences WHERE name='locale'", + /// NO_PARAMS, + /// |row| row.get(0), + /// ) + /// } + /// ``` + /// + /// If the query returns more than one row, all rows except the first are + /// ignored. + /// + /// Returns `Err(QueryReturnedNoRows)` if no results are returned. If the + /// query truly is optional, you can call `.optional()` on the result of + /// this to get a `Result<Option<T>>`. + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + pub fn query_row<T, P, F>(&self, sql: &str, params: P, f: F) -> Result<T> + where + P: IntoIterator, + P::Item: ToSql, + F: FnOnce(&Row<'_>) -> Result<T>, + { + let mut stmt = self.prepare(sql)?; + stmt.check_no_tail()?; + stmt.query_row(params, f) + } + + /// Convenience method to execute a query with named parameter(s) that is + /// expected to return a single row. + /// + /// If the query returns more than one row, all rows except the first are + /// ignored. + /// + /// Returns `Err(QueryReturnedNoRows)` if no results are returned. If the + /// query truly is optional, you can call `.optional()` on the result of + /// this to get a `Result<Option<T>>`. + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + pub fn query_row_named<T, F>(&self, sql: &str, params: &[(&str, &dyn ToSql)], f: F) -> Result<T> + where + F: FnOnce(&Row<'_>) -> Result<T>, + { + let mut stmt = self.prepare(sql)?; + stmt.check_no_tail()?; + stmt.query_row_named(params, f) + } + + /// Convenience method to execute a query that is expected to return a + /// single row, and execute a mapping via `f` on that returned row with + /// the possibility of failure. The `Result` type of `f` must implement + /// `std::convert::From<Error>`. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Result,Connection, NO_PARAMS}; + /// fn preferred_locale(conn: &Connection) -> Result<String> { + /// conn.query_row_and_then( + /// "SELECT value FROM preferences WHERE name='locale'", + /// NO_PARAMS, + /// |row| row.get(0), + /// ) + /// } + /// ``` + /// + /// If the query returns more than one row, all rows except the first are + /// ignored. + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + pub fn query_row_and_then<T, E, P, F>(&self, sql: &str, params: P, f: F) -> Result<T, E> + where + P: IntoIterator, + P::Item: ToSql, + F: FnOnce(&Row<'_>) -> Result<T, E>, + E: convert::From<Error>, + { + let mut stmt = self.prepare(sql)?; + stmt.check_no_tail()?; + let mut rows = stmt.query(params)?; + + rows.get_expected_row().map_err(E::from).and_then(|r| f(&r)) + } + + /// Prepare a SQL statement for execution. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn insert_new_people(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("INSERT INTO People (name) VALUES (?)")?; + /// stmt.execute(&["Joe Smith"])?; + /// stmt.execute(&["Bob Jones"])?; + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + pub fn prepare(&self, sql: &str) -> Result<Statement<'_>> { + self.db.borrow_mut().prepare(self, sql) + } + + /// Close the SQLite connection. + /// + /// This is functionally equivalent to the `Drop` implementation for + /// `Connection` except that on failure, it returns an error and the + /// connection itself (presumably so closing can be attempted again). + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + pub fn close(self) -> Result<(), (Connection, Error)> { + self.flush_prepared_statement_cache(); + let r = self.db.borrow_mut().close(); + r.map_err(move |err| (self, err)) + } + + /// `feature = "load_extension"` Enable loading of SQLite extensions. + /// Strongly consider using `LoadExtensionGuard` instead of this function. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// # use std::path::{Path}; + /// fn load_my_extension(conn: &Connection) -> Result<()> { + /// conn.load_extension_enable()?; + /// conn.load_extension(Path::new("my_sqlite_extension"), None)?; + /// conn.load_extension_disable() + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[cfg(feature = "load_extension")] + pub fn load_extension_enable(&self) -> Result<()> { + self.db.borrow_mut().enable_load_extension(1) + } + + /// `feature = "load_extension"` Disable loading of SQLite extensions. + /// + /// See `load_extension_enable` for an example. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[cfg(feature = "load_extension")] + pub fn load_extension_disable(&self) -> Result<()> { + self.db.borrow_mut().enable_load_extension(0) + } + + /// `feature = "load_extension"` Load the SQLite extension at `dylib_path`. + /// `dylib_path` is passed through to `sqlite3_load_extension`, which may + /// attempt OS-specific modifications if the file cannot be loaded directly. + /// + /// If `entry_point` is `None`, SQLite will attempt to find the entry + /// point. If it is not `None`, the entry point will be passed through + /// to `sqlite3_load_extension`. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result, LoadExtensionGuard}; + /// # use std::path::{Path}; + /// fn load_my_extension(conn: &Connection) -> Result<()> { + /// let _guard = LoadExtensionGuard::new(conn)?; + /// + /// conn.load_extension("my_sqlite_extension", None) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + #[cfg(feature = "load_extension")] + pub fn load_extension<P: AsRef<Path>>( + &self, + dylib_path: P, + entry_point: Option<&str>, + ) -> Result<()> { + self.db + .borrow_mut() + .load_extension(dylib_path.as_ref(), entry_point) + } + + /// Get access to the underlying SQLite database connection handle. + /// + /// # Warning + /// + /// You should not need to use this function. If you do need to, please + /// [open an issue on the rusqlite repository](https://github.com/rusqlite/rusqlite/issues) and describe + /// your use case. + /// + /// # Safety + /// + /// This function is unsafe because it gives you raw access + /// to the SQLite connection, and what you do with it could impact the + /// safety of this `Connection`. + pub unsafe fn handle(&self) -> *mut ffi::sqlite3 { + self.db.borrow().db() + } + + /// Create a `Connection` from a raw handle. + /// + /// The underlying SQLite database connection handle will not be closed when + /// the returned connection is dropped/closed. + /// + /// # Safety + /// + /// This function is unsafe because improper use may impact the Connection. + pub unsafe fn from_handle(db: *mut ffi::sqlite3) -> Result<Connection> { + let db_path = db_filename(db); + let db = InnerConnection::new(db, false); + Ok(Connection { + db: RefCell::new(db), + cache: StatementCache::with_capacity(STATEMENT_CACHE_DEFAULT_CAPACITY), + path: db_path, + }) + } + + /// Get access to a handle that can be used to interrupt long running + /// queries from another thread. + pub fn get_interrupt_handle(&self) -> InterruptHandle { + self.db.borrow().get_interrupt_handle() + } + + fn decode_result(&self, code: c_int) -> Result<()> { + self.db.borrow_mut().decode_result(code) + } + + /// Return the number of rows modified, inserted or deleted by the most + /// recently completed INSERT, UPDATE or DELETE statement on the database + /// connection. + fn changes(&self) -> usize { + self.db.borrow_mut().changes() + } + + /// Test for auto-commit mode. + /// Autocommit mode is on by default. + pub fn is_autocommit(&self) -> bool { + self.db.borrow().is_autocommit() + } + + /// Determine if all associated prepared statements have been reset. + #[cfg(feature = "modern_sqlite")] // 3.8.6 + pub fn is_busy(&self) -> bool { + self.db.borrow().is_busy() + } +} + +impl fmt::Debug for Connection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Connection") + .field("path", &self.path) + .finish() + } +} + +bitflags::bitflags! { + /// Flags for opening SQLite database connections. + /// See [sqlite3_open_v2](http://www.sqlite.org/c3ref/open.html) for details. + #[repr(C)] + pub struct OpenFlags: ::std::os::raw::c_int { + /// The database is opened in read-only mode. + /// If the database does not already exist, an error is returned. + const SQLITE_OPEN_READ_ONLY = ffi::SQLITE_OPEN_READONLY; + /// The database is opened for reading and writing if possible, + /// or reading only if the file is write protected by the operating system. + /// In either case the database must already exist, otherwise an error is returned. + const SQLITE_OPEN_READ_WRITE = ffi::SQLITE_OPEN_READWRITE; + /// The database is created if it does not already exist + const SQLITE_OPEN_CREATE = ffi::SQLITE_OPEN_CREATE; + /// The filename can be interpreted as a URI if this flag is set. + const SQLITE_OPEN_URI = 0x0000_0040; + /// The database will be opened as an in-memory database. + const SQLITE_OPEN_MEMORY = 0x0000_0080; + /// The new database connection will use the "multi-thread" threading mode. + const SQLITE_OPEN_NO_MUTEX = ffi::SQLITE_OPEN_NOMUTEX; + /// The new database connection will use the "serialized" threading mode. + const SQLITE_OPEN_FULL_MUTEX = ffi::SQLITE_OPEN_FULLMUTEX; + /// The database is opened shared cache enabled. + const SQLITE_OPEN_SHARED_CACHE = 0x0002_0000; + /// The database is opened shared cache disabled. + const SQLITE_OPEN_PRIVATE_CACHE = 0x0004_0000; + /// The database filename is not allowed to be a symbolic link. + const SQLITE_OPEN_NOFOLLOW = 0x0100_0000; + } +} + +impl Default for OpenFlags { + fn default() -> OpenFlags { + OpenFlags::SQLITE_OPEN_READ_WRITE + | OpenFlags::SQLITE_OPEN_CREATE + | OpenFlags::SQLITE_OPEN_NO_MUTEX + | OpenFlags::SQLITE_OPEN_URI + } +} + +/// rusqlite's check for a safe SQLite threading mode requires SQLite 3.7.0 or +/// later. If you are running against a SQLite older than that, rusqlite +/// attempts to ensure safety by performing configuration and initialization of +/// SQLite itself the first time you +/// attempt to open a connection. By default, rusqlite panics if that +/// initialization fails, since that could mean SQLite has been initialized in +/// single-thread mode. +/// +/// If you are encountering that panic _and_ can ensure that SQLite has been +/// initialized in either multi-thread or serialized mode, call this function +/// prior to attempting to open a connection and rusqlite's initialization +/// process will by skipped. +/// +/// # Safety +/// +/// This function is unsafe because if you call it and SQLite has actually been +/// configured to run in single-thread mode, +/// you may enounter memory errors or data corruption or any number of terrible +/// things that should not be possible when you're using Rust. +pub unsafe fn bypass_sqlite_initialization() { + BYPASS_SQLITE_INIT.store(true, Ordering::Relaxed); +} + +/// rusqlite performs a one-time check that the runtime SQLite version is at +/// least as new as the version of SQLite found when rusqlite was built. +/// Bypassing this check may be dangerous; e.g., if you use features of SQLite +/// that are not present in the runtime version. +/// +/// # Safety +/// +/// If you are sure the runtime version is compatible with the +/// build-time version for your usage, you can bypass the version check by +/// calling this function before your first connection attempt. +pub unsafe fn bypass_sqlite_version_check() { + #[cfg(not(feature = "bundled"))] + inner_connection::BYPASS_VERSION_CHECK.store(true, Ordering::Relaxed); +} + +/// Allows interrupting a long-running computation. +pub struct InterruptHandle { + db_lock: Arc<Mutex<*mut ffi::sqlite3>>, +} + +unsafe impl Send for InterruptHandle {} +unsafe impl Sync for InterruptHandle {} + +impl InterruptHandle { + /// Interrupt the query currently executing on another thread. This will + /// cause that query to fail with a `SQLITE3_INTERRUPT` error. + pub fn interrupt(&self) { + let db_handle = self.db_lock.lock().unwrap(); + if !db_handle.is_null() { + unsafe { ffi::sqlite3_interrupt(*db_handle) } + } + } +} + +#[cfg(feature = "modern_sqlite")] // 3.7.10 +unsafe fn db_filename(db: *mut ffi::sqlite3) -> Option<PathBuf> { + let db_name = DatabaseName::Main.to_cstring().unwrap(); + let db_filename = ffi::sqlite3_db_filename(db, db_name.as_ptr()); + if db_filename.is_null() { + None + } else { + CStr::from_ptr(db_filename).to_str().ok().map(PathBuf::from) + } +} +#[cfg(not(feature = "modern_sqlite"))] +unsafe fn db_filename(_: *mut ffi::sqlite3) -> Option<PathBuf> { + None +} + +#[cfg(doctest)] +doc_comment::doctest!("../README.md"); + +#[cfg(test)] +mod test { + use super::*; + use crate::ffi; + use fallible_iterator::FallibleIterator; + use std::error::Error as StdError; + use std::fmt; + + // this function is never called, but is still type checked; in + // particular, calls with specific instantiations will require + // that those types are `Send`. + #[allow(dead_code, unconditional_recursion)] + fn ensure_send<T: Send>() { + ensure_send::<Connection>(); + ensure_send::<InterruptHandle>(); + } + + #[allow(dead_code, unconditional_recursion)] + fn ensure_sync<T: Sync>() { + ensure_sync::<InterruptHandle>(); + } + + pub fn checked_memory_handle() -> Connection { + Connection::open_in_memory().unwrap() + } + + #[test] + fn test_concurrent_transactions_busy_commit() { + use std::time::Duration; + let tmp = tempfile::tempdir().unwrap(); + let path = tmp.path().join("transactions.db3"); + + Connection::open(&path) + .expect("create temp db") + .execute_batch( + " + BEGIN; CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(42); END;", + ) + .expect("create temp db"); + + let mut db1 = + Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_WRITE).unwrap(); + let mut db2 = Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_ONLY).unwrap(); + + db1.busy_timeout(Duration::from_millis(0)).unwrap(); + db2.busy_timeout(Duration::from_millis(0)).unwrap(); + + { + let tx1 = db1.transaction().unwrap(); + let tx2 = db2.transaction().unwrap(); + + // SELECT first makes sqlite lock with a shared lock + tx1.query_row("SELECT x FROM foo LIMIT 1", NO_PARAMS, |_| Ok(())) + .unwrap(); + tx2.query_row("SELECT x FROM foo LIMIT 1", NO_PARAMS, |_| Ok(())) + .unwrap(); + + tx1.execute("INSERT INTO foo VALUES(?1)", &[1]).unwrap(); + let _ = tx2.execute("INSERT INTO foo VALUES(?1)", &[2]); + + let _ = tx1.commit(); + let _ = tx2.commit(); + } + + let _ = db1 + .transaction() + .expect("commit should have closed transaction"); + let _ = db2 + .transaction() + .expect("commit should have closed transaction"); + } + + #[test] + fn test_persistence() { + let temp_dir = tempfile::tempdir().unwrap(); + let path = temp_dir.path().join("test.db3"); + + { + let db = Connection::open(&path).unwrap(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(42); + END;"; + db.execute_batch(sql).unwrap(); + } + + let path_string = path.to_str().unwrap(); + let db = Connection::open(&path_string).unwrap(); + let the_answer: Result<i64> = db.query_row("SELECT x FROM foo", NO_PARAMS, |r| r.get(0)); + + assert_eq!(42i64, the_answer.unwrap()); + } + + #[test] + fn test_open() { + assert!(Connection::open_in_memory().is_ok()); + + let db = checked_memory_handle(); + assert!(db.close().is_ok()); + } + + #[test] + fn test_open_failure() { + let filename = "no_such_file.db"; + let result = Connection::open_with_flags(filename, OpenFlags::SQLITE_OPEN_READ_ONLY); + assert!(!result.is_ok()); + let err = result.err().unwrap(); + if let Error::SqliteFailure(e, Some(msg)) = err { + assert_eq!(ErrorCode::CannotOpen, e.code); + assert_eq!(ffi::SQLITE_CANTOPEN, e.extended_code); + assert!( + msg.contains(filename), + "error message '{}' does not contain '{}'", + msg, + filename + ); + } else { + panic!("SqliteFailure expected"); + } + } + + #[cfg(unix)] + #[test] + fn test_invalid_unicode_file_names() { + use std::ffi::OsStr; + use std::fs::File; + use std::os::unix::ffi::OsStrExt; + let temp_dir = tempfile::tempdir().unwrap(); + + let path = temp_dir.path(); + if File::create(path.join(OsStr::from_bytes(&[0xFE]))).is_err() { + // Skip test, filesystem doesn't support invalid Unicode + return; + } + let db_path = path.join(OsStr::from_bytes(&[0xFF])); + { + let db = Connection::open(&db_path).unwrap(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(42); + END;"; + db.execute_batch(sql).unwrap(); + } + + let db = Connection::open(&db_path).unwrap(); + let the_answer: Result<i64> = db.query_row("SELECT x FROM foo", NO_PARAMS, |r| r.get(0)); + + assert_eq!(42i64, the_answer.unwrap()); + } + + #[test] + fn test_close_retry() { + let db = checked_memory_handle(); + + // force the DB to be busy by preparing a statement; this must be done at the + // FFI level to allow us to call .close() without dropping the prepared + // statement first. + let raw_stmt = { + use super::str_to_cstring; + use std::os::raw::c_int; + use std::ptr; + + let raw_db = db.db.borrow_mut().db; + let sql = "SELECT 1"; + let mut raw_stmt: *mut ffi::sqlite3_stmt = ptr::null_mut(); + let cstring = str_to_cstring(sql).unwrap(); + let rc = unsafe { + ffi::sqlite3_prepare_v2( + raw_db, + cstring.as_ptr(), + (sql.len() + 1) as c_int, + &mut raw_stmt, + ptr::null_mut(), + ) + }; + assert_eq!(rc, ffi::SQLITE_OK); + raw_stmt + }; + + // now that we have an open statement, trying (and retrying) to close should + // fail. + let (db, _) = db.close().unwrap_err(); + let (db, _) = db.close().unwrap_err(); + let (db, _) = db.close().unwrap_err(); + + // finalize the open statement so a final close will succeed + assert_eq!(ffi::SQLITE_OK, unsafe { ffi::sqlite3_finalize(raw_stmt) }); + + db.close().unwrap(); + } + + #[test] + fn test_open_with_flags() { + for bad_flags in &[ + OpenFlags::empty(), + OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_READ_WRITE, + OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_CREATE, + ] { + assert!(Connection::open_in_memory_with_flags(*bad_flags).is_err()); + } + } + + #[test] + fn test_execute_batch() { + let db = checked_memory_handle(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(1); + INSERT INTO foo VALUES(2); + INSERT INTO foo VALUES(3); + INSERT INTO foo VALUES(4); + END;"; + db.execute_batch(sql).unwrap(); + + db.execute_batch("UPDATE foo SET x = 3 WHERE x < 3") + .unwrap(); + + assert!(db.execute_batch("INVALID SQL").is_err()); + } + + #[test] + fn test_execute() { + let db = checked_memory_handle(); + db.execute_batch("CREATE TABLE foo(x INTEGER)").unwrap(); + + assert_eq!( + 1, + db.execute("INSERT INTO foo(x) VALUES (?)", &[1i32]) + .unwrap() + ); + assert_eq!( + 1, + db.execute("INSERT INTO foo(x) VALUES (?)", &[2i32]) + .unwrap() + ); + + assert_eq!( + 3i32, + db.query_row::<i32, _, _>("SELECT SUM(x) FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap() + ); + } + + #[test] + #[cfg(feature = "extra_check")] + fn test_execute_select() { + let db = checked_memory_handle(); + let err = db.execute("SELECT 1 WHERE 1 < ?", &[1i32]).unwrap_err(); + if err != Error::ExecuteReturnedResults { + panic!("Unexpected error: {}", err); + } + } + + #[test] + #[cfg(feature = "extra_check")] + fn test_execute_multiple() { + let db = checked_memory_handle(); + let err = db + .execute( + "CREATE TABLE foo(x INTEGER); CREATE TABLE foo(x INTEGER)", + NO_PARAMS, + ) + .unwrap_err(); + match err { + Error::MultipleStatement => (), + _ => panic!("Unexpected error: {}", err), + } + } + + #[test] + fn test_prepare_column_names() { + let db = checked_memory_handle(); + db.execute_batch("CREATE TABLE foo(x INTEGER);").unwrap(); + + let stmt = db.prepare("SELECT * FROM foo").unwrap(); + assert_eq!(stmt.column_count(), 1); + assert_eq!(stmt.column_names(), vec!["x"]); + + let stmt = db.prepare("SELECT x AS a, x AS b FROM foo").unwrap(); + assert_eq!(stmt.column_count(), 2); + assert_eq!(stmt.column_names(), vec!["a", "b"]); + } + + #[test] + fn test_prepare_execute() { + let db = checked_memory_handle(); + db.execute_batch("CREATE TABLE foo(x INTEGER);").unwrap(); + + let mut insert_stmt = db.prepare("INSERT INTO foo(x) VALUES(?)").unwrap(); + assert_eq!(insert_stmt.execute(&[1i32]).unwrap(), 1); + assert_eq!(insert_stmt.execute(&[2i32]).unwrap(), 1); + assert_eq!(insert_stmt.execute(&[3i32]).unwrap(), 1); + + assert_eq!(insert_stmt.execute(&["hello".to_string()]).unwrap(), 1); + assert_eq!(insert_stmt.execute(&["goodbye".to_string()]).unwrap(), 1); + assert_eq!(insert_stmt.execute(&[types::Null]).unwrap(), 1); + + let mut update_stmt = db.prepare("UPDATE foo SET x=? WHERE x<?").unwrap(); + assert_eq!(update_stmt.execute(&[3i32, 3i32]).unwrap(), 2); + assert_eq!(update_stmt.execute(&[3i32, 3i32]).unwrap(), 0); + assert_eq!(update_stmt.execute(&[8i32, 8i32]).unwrap(), 3); + } + + #[test] + fn test_prepare_query() { + let db = checked_memory_handle(); + db.execute_batch("CREATE TABLE foo(x INTEGER);").unwrap(); + + let mut insert_stmt = db.prepare("INSERT INTO foo(x) VALUES(?)").unwrap(); + assert_eq!(insert_stmt.execute(&[1i32]).unwrap(), 1); + assert_eq!(insert_stmt.execute(&[2i32]).unwrap(), 1); + assert_eq!(insert_stmt.execute(&[3i32]).unwrap(), 1); + + let mut query = db + .prepare("SELECT x FROM foo WHERE x < ? ORDER BY x DESC") + .unwrap(); + { + let mut rows = query.query(&[4i32]).unwrap(); + let mut v = Vec::<i32>::new(); + + while let Some(row) = rows.next().unwrap() { + v.push(row.get(0).unwrap()); + } + + assert_eq!(v, [3i32, 2, 1]); + } + + { + let mut rows = query.query(&[3i32]).unwrap(); + let mut v = Vec::<i32>::new(); + + while let Some(row) = rows.next().unwrap() { + v.push(row.get(0).unwrap()); + } + + assert_eq!(v, [2i32, 1]); + } + } + + #[test] + fn test_query_map() { + let db = checked_memory_handle(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + INSERT INTO foo VALUES(3, \", \"); + INSERT INTO foo VALUES(2, \"world\"); + INSERT INTO foo VALUES(1, \"!\"); + END;"; + db.execute_batch(sql).unwrap(); + + let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC").unwrap(); + let results: Result<Vec<String>> = query + .query(NO_PARAMS) + .unwrap() + .map(|row| row.get(1)) + .collect(); + + assert_eq!(results.unwrap().concat(), "hello, world!"); + } + + #[test] + fn test_query_row() { + let db = checked_memory_handle(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(1); + INSERT INTO foo VALUES(2); + INSERT INTO foo VALUES(3); + INSERT INTO foo VALUES(4); + END;"; + db.execute_batch(sql).unwrap(); + + assert_eq!( + 10i64, + db.query_row::<i64, _, _>("SELECT SUM(x) FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap() + ); + + let result: Result<i64> = + db.query_row("SELECT x FROM foo WHERE x > 5", NO_PARAMS, |r| r.get(0)); + match result.unwrap_err() { + Error::QueryReturnedNoRows => (), + err => panic!("Unexpected error {}", err), + } + + let bad_query_result = db.query_row("NOT A PROPER QUERY; test123", NO_PARAMS, |_| Ok(())); + + assert!(bad_query_result.is_err()); + } + + #[test] + fn test_optional() { + let db = checked_memory_handle(); + + let result: Result<i64> = db.query_row("SELECT 1 WHERE 0 <> 0", NO_PARAMS, |r| r.get(0)); + let result = result.optional(); + match result.unwrap() { + None => (), + _ => panic!("Unexpected result"), + } + + let result: Result<i64> = db.query_row("SELECT 1 WHERE 0 == 0", NO_PARAMS, |r| r.get(0)); + let result = result.optional(); + match result.unwrap() { + Some(1) => (), + _ => panic!("Unexpected result"), + } + + let bad_query_result: Result<i64> = + db.query_row("NOT A PROPER QUERY", NO_PARAMS, |r| r.get(0)); + let bad_query_result = bad_query_result.optional(); + assert!(bad_query_result.is_err()); + } + + #[test] + fn test_pragma_query_row() { + let db = checked_memory_handle(); + + assert_eq!( + "memory", + db.query_row::<String, _, _>("PRAGMA journal_mode", NO_PARAMS, |r| r.get(0)) + .unwrap() + ); + assert_eq!( + "off", + db.query_row::<String, _, _>("PRAGMA journal_mode=off", NO_PARAMS, |r| r.get(0)) + .unwrap() + ); + } + + #[test] + fn test_prepare_failures() { + let db = checked_memory_handle(); + db.execute_batch("CREATE TABLE foo(x INTEGER);").unwrap(); + + let err = db.prepare("SELECT * FROM does_not_exist").unwrap_err(); + assert!(format!("{}", err).contains("does_not_exist")); + } + + #[test] + fn test_last_insert_rowid() { + let db = checked_memory_handle(); + db.execute_batch("CREATE TABLE foo(x INTEGER PRIMARY KEY)") + .unwrap(); + db.execute_batch("INSERT INTO foo DEFAULT VALUES").unwrap(); + + assert_eq!(db.last_insert_rowid(), 1); + + let mut stmt = db.prepare("INSERT INTO foo DEFAULT VALUES").unwrap(); + for _ in 0i32..9 { + stmt.execute(NO_PARAMS).unwrap(); + } + assert_eq!(db.last_insert_rowid(), 10); + } + + #[test] + fn test_is_autocommit() { + let db = checked_memory_handle(); + assert!( + db.is_autocommit(), + "autocommit expected to be active by default" + ); + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn test_is_busy() { + let db = checked_memory_handle(); + assert!(!db.is_busy()); + let mut stmt = db.prepare("PRAGMA schema_version").unwrap(); + assert!(!db.is_busy()); + { + let mut rows = stmt.query(NO_PARAMS).unwrap(); + assert!(!db.is_busy()); + let row = rows.next().unwrap(); + assert!(db.is_busy()); + assert!(row.is_some()); + } + assert!(!db.is_busy()); + } + + #[test] + fn test_statement_debugging() { + let db = checked_memory_handle(); + let query = "SELECT 12345"; + let stmt = db.prepare(query).unwrap(); + + assert!(format!("{:?}", stmt).contains(query)); + } + + #[test] + fn test_notnull_constraint_error() { + // extended error codes for constraints were added in SQLite 3.7.16; if we're + // running on our bundled version, we know the extended error code exists. + #[cfg(feature = "modern_sqlite")] + fn check_extended_code(extended_code: c_int) { + assert_eq!(extended_code, ffi::SQLITE_CONSTRAINT_NOTNULL); + } + #[cfg(not(feature = "modern_sqlite"))] + fn check_extended_code(_extended_code: c_int) {} + + let db = checked_memory_handle(); + db.execute_batch("CREATE TABLE foo(x NOT NULL)").unwrap(); + + let result = db.execute("INSERT INTO foo (x) VALUES (NULL)", NO_PARAMS); + assert!(result.is_err()); + + match result.unwrap_err() { + Error::SqliteFailure(err, _) => { + assert_eq!(err.code, ErrorCode::ConstraintViolation); + check_extended_code(err.extended_code); + } + err => panic!("Unexpected error {}", err), + } + } + + #[test] + fn test_version_string() { + let n = version_number(); + let major = n / 1_000_000; + let minor = (n % 1_000_000) / 1_000; + let patch = n % 1_000; + + assert!(version().contains(&format!("{}.{}.{}", major, minor, patch))); + } + + #[test] + #[cfg(feature = "functions")] + fn test_interrupt() { + let db = checked_memory_handle(); + + let interrupt_handle = db.get_interrupt_handle(); + + db.create_scalar_function( + "interrupt", + 0, + crate::functions::FunctionFlags::default(), + move |_| { + interrupt_handle.interrupt(); + Ok(0) + }, + ) + .unwrap(); + + let mut stmt = db + .prepare("SELECT interrupt() FROM (SELECT 1 UNION SELECT 2 UNION SELECT 3)") + .unwrap(); + + let result: Result<Vec<i32>> = stmt.query(NO_PARAMS).unwrap().map(|r| r.get(0)).collect(); + + match result.unwrap_err() { + Error::SqliteFailure(err, _) => { + assert_eq!(err.code, ErrorCode::OperationInterrupted); + } + err => { + panic!("Unexpected error {}", err); + } + } + } + + #[test] + fn test_interrupt_close() { + let db = checked_memory_handle(); + let handle = db.get_interrupt_handle(); + handle.interrupt(); + db.close().unwrap(); + handle.interrupt(); + + // Look at it's internals to see if we cleared it out properly. + let db_guard = handle.db_lock.lock().unwrap(); + assert!(db_guard.is_null()); + // It would be nice to test that we properly handle close/interrupt + // running at the same time, but it seems impossible to do with any + // degree of reliability. + } + + #[test] + fn test_get_raw() { + let db = checked_memory_handle(); + db.execute_batch("CREATE TABLE foo(i, x);").unwrap(); + let vals = ["foobar", "1234", "qwerty"]; + let mut insert_stmt = db.prepare("INSERT INTO foo(i, x) VALUES(?, ?)").unwrap(); + for (i, v) in vals.iter().enumerate() { + let i_to_insert = i as i64; + assert_eq!(insert_stmt.execute(params![i_to_insert, v]).unwrap(), 1); + } + + let mut query = db.prepare("SELECT i, x FROM foo").unwrap(); + let mut rows = query.query(NO_PARAMS).unwrap(); + + while let Some(row) = rows.next().unwrap() { + let i = row.get_raw(0).as_i64().unwrap(); + let expect = vals[i as usize]; + let x = row.get_raw("x").as_str().unwrap(); + assert_eq!(x, expect); + } + } + + #[test] + fn test_from_handle() { + let db = checked_memory_handle(); + let handle = unsafe { db.handle() }; + { + let db = unsafe { Connection::from_handle(handle) }.unwrap(); + db.execute_batch("PRAGMA VACUUM").unwrap(); + } + db.close().unwrap(); + } + + mod query_and_then_tests { + + use super::*; + + #[derive(Debug)] + enum CustomError { + SomeError, + Sqlite(Error), + } + + impl fmt::Display for CustomError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match *self { + CustomError::SomeError => write!(f, "my custom error"), + CustomError::Sqlite(ref se) => write!(f, "my custom error: {}", se), + } + } + } + + impl StdError for CustomError { + fn description(&self) -> &str { + "my custom error" + } + + fn cause(&self) -> Option<&dyn StdError> { + match *self { + CustomError::SomeError => None, + CustomError::Sqlite(ref se) => Some(se), + } + } + } + + impl From<Error> for CustomError { + fn from(se: Error) -> CustomError { + CustomError::Sqlite(se) + } + } + + type CustomResult<T> = Result<T, CustomError>; + + #[test] + fn test_query_and_then() { + let db = checked_memory_handle(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + INSERT INTO foo VALUES(3, \", \"); + INSERT INTO foo VALUES(2, \"world\"); + INSERT INTO foo VALUES(1, \"!\"); + END;"; + db.execute_batch(sql).unwrap(); + + let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC").unwrap(); + let results: Result<Vec<String>> = query + .query_and_then(NO_PARAMS, |row| row.get(1)) + .unwrap() + .collect(); + + assert_eq!(results.unwrap().concat(), "hello, world!"); + } + + #[test] + fn test_query_and_then_fails() { + let db = checked_memory_handle(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + INSERT INTO foo VALUES(3, \", \"); + INSERT INTO foo VALUES(2, \"world\"); + INSERT INTO foo VALUES(1, \"!\"); + END;"; + db.execute_batch(sql).unwrap(); + + let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC").unwrap(); + let bad_type: Result<Vec<f64>> = query + .query_and_then(NO_PARAMS, |row| row.get(1)) + .unwrap() + .collect(); + + match bad_type.unwrap_err() { + Error::InvalidColumnType(..) => (), + err => panic!("Unexpected error {}", err), + } + + let bad_idx: Result<Vec<String>> = query + .query_and_then(NO_PARAMS, |row| row.get(3)) + .unwrap() + .collect(); + + match bad_idx.unwrap_err() { + Error::InvalidColumnIndex(_) => (), + err => panic!("Unexpected error {}", err), + } + } + + #[test] + fn test_query_and_then_custom_error() { + let db = checked_memory_handle(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + INSERT INTO foo VALUES(3, \", \"); + INSERT INTO foo VALUES(2, \"world\"); + INSERT INTO foo VALUES(1, \"!\"); + END;"; + db.execute_batch(sql).unwrap(); + + let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC").unwrap(); + let results: CustomResult<Vec<String>> = query + .query_and_then(NO_PARAMS, |row| row.get(1).map_err(CustomError::Sqlite)) + .unwrap() + .collect(); + + assert_eq!(results.unwrap().concat(), "hello, world!"); + } + + #[test] + fn test_query_and_then_custom_error_fails() { + let db = checked_memory_handle(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + INSERT INTO foo VALUES(3, \", \"); + INSERT INTO foo VALUES(2, \"world\"); + INSERT INTO foo VALUES(1, \"!\"); + END;"; + db.execute_batch(sql).unwrap(); + + let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC").unwrap(); + let bad_type: CustomResult<Vec<f64>> = query + .query_and_then(NO_PARAMS, |row| row.get(1).map_err(CustomError::Sqlite)) + .unwrap() + .collect(); + + match bad_type.unwrap_err() { + CustomError::Sqlite(Error::InvalidColumnType(..)) => (), + err => panic!("Unexpected error {}", err), + } + + let bad_idx: CustomResult<Vec<String>> = query + .query_and_then(NO_PARAMS, |row| row.get(3).map_err(CustomError::Sqlite)) + .unwrap() + .collect(); + + match bad_idx.unwrap_err() { + CustomError::Sqlite(Error::InvalidColumnIndex(_)) => (), + err => panic!("Unexpected error {}", err), + } + + let non_sqlite_err: CustomResult<Vec<String>> = query + .query_and_then(NO_PARAMS, |_| Err(CustomError::SomeError)) + .unwrap() + .collect(); + + match non_sqlite_err.unwrap_err() { + CustomError::SomeError => (), + err => panic!("Unexpected error {}", err), + } + } + + #[test] + fn test_query_row_and_then_custom_error() { + let db = checked_memory_handle(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + END;"; + db.execute_batch(sql).unwrap(); + + let query = "SELECT x, y FROM foo ORDER BY x DESC"; + let results: CustomResult<String> = db.query_row_and_then(query, NO_PARAMS, |row| { + row.get(1).map_err(CustomError::Sqlite) + }); + + assert_eq!(results.unwrap(), "hello"); + } + + #[test] + fn test_query_row_and_then_custom_error_fails() { + let db = checked_memory_handle(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + END;"; + db.execute_batch(sql).unwrap(); + + let query = "SELECT x, y FROM foo ORDER BY x DESC"; + let bad_type: CustomResult<f64> = db.query_row_and_then(query, NO_PARAMS, |row| { + row.get(1).map_err(CustomError::Sqlite) + }); + + match bad_type.unwrap_err() { + CustomError::Sqlite(Error::InvalidColumnType(..)) => (), + err => panic!("Unexpected error {}", err), + } + + let bad_idx: CustomResult<String> = db.query_row_and_then(query, NO_PARAMS, |row| { + row.get(3).map_err(CustomError::Sqlite) + }); + + match bad_idx.unwrap_err() { + CustomError::Sqlite(Error::InvalidColumnIndex(_)) => (), + err => panic!("Unexpected error {}", err), + } + + let non_sqlite_err: CustomResult<String> = + db.query_row_and_then(query, NO_PARAMS, |_| Err(CustomError::SomeError)); + + match non_sqlite_err.unwrap_err() { + CustomError::SomeError => (), + err => panic!("Unexpected error {}", err), + } + } + + #[test] + fn test_dynamic() { + let db = checked_memory_handle(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + END;"; + db.execute_batch(sql).unwrap(); + + db.query_row("SELECT * FROM foo", params![], |r| { + assert_eq!(2, r.column_count()); + Ok(()) + }) + .unwrap(); + } + #[test] + fn test_dyn_box() { + let db = checked_memory_handle(); + db.execute_batch("CREATE TABLE foo(x INTEGER);").unwrap(); + let b: Box<dyn ToSql> = Box::new(5); + db.execute("INSERT INTO foo VALUES(?)", &[b]).unwrap(); + db.query_row("SELECT x FROM foo", params![], |r| { + assert_eq!(5, r.get_unwrap::<_, i32>(0)); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_params() { + let db = checked_memory_handle(); + db.query_row( + "SELECT + ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, + ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, + ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, + ?, ?, ?, ?;", + params![ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + ], + |r| { + assert_eq!(1, r.get_unwrap::<_, i32>(0)); + Ok(()) + }, + ) + .unwrap(); + } + + #[test] + #[cfg(not(feature = "extra_check"))] + fn test_alter_table() { + let db = checked_memory_handle(); + db.execute_batch("CREATE TABLE x(t);").unwrap(); + // `execute_batch` should be used but `execute` should also work + db.execute("ALTER TABLE x RENAME TO y;", params![]).unwrap(); + } + } +} diff --git a/third_party/rust/rusqlite/src/limits.rs b/third_party/rust/rusqlite/src/limits.rs new file mode 100644 index 0000000000..238ce563fc --- /dev/null +++ b/third_party/rust/rusqlite/src/limits.rs @@ -0,0 +1,72 @@ +//! `feature = "limits"` Run-Time Limits + +use std::os::raw::c_int; + +use crate::ffi; +pub use crate::ffi::Limit; + +use crate::Connection; + +impl Connection { + /// `feature = "limits"` Returns the current value of a limit. + pub fn limit(&self, limit: Limit) -> i32 { + let c = self.db.borrow(); + unsafe { ffi::sqlite3_limit(c.db(), limit as c_int, -1) } + } + + /// `feature = "limits"` Changes the limit to `new_val`, returning the prior + /// value of the limit. + pub fn set_limit(&self, limit: Limit, new_val: i32) -> i32 { + let c = self.db.borrow_mut(); + unsafe { ffi::sqlite3_limit(c.db(), limit as c_int, new_val) } + } +} + +#[cfg(test)] +mod test { + use crate::ffi::Limit; + use crate::Connection; + + #[test] + fn test_limit() { + let db = Connection::open_in_memory().unwrap(); + db.set_limit(Limit::SQLITE_LIMIT_LENGTH, 1024); + assert_eq!(1024, db.limit(Limit::SQLITE_LIMIT_LENGTH)); + + db.set_limit(Limit::SQLITE_LIMIT_SQL_LENGTH, 1024); + assert_eq!(1024, db.limit(Limit::SQLITE_LIMIT_SQL_LENGTH)); + + db.set_limit(Limit::SQLITE_LIMIT_COLUMN, 64); + assert_eq!(64, db.limit(Limit::SQLITE_LIMIT_COLUMN)); + + db.set_limit(Limit::SQLITE_LIMIT_EXPR_DEPTH, 256); + assert_eq!(256, db.limit(Limit::SQLITE_LIMIT_EXPR_DEPTH)); + + db.set_limit(Limit::SQLITE_LIMIT_COMPOUND_SELECT, 32); + assert_eq!(32, db.limit(Limit::SQLITE_LIMIT_COMPOUND_SELECT)); + + db.set_limit(Limit::SQLITE_LIMIT_FUNCTION_ARG, 32); + assert_eq!(32, db.limit(Limit::SQLITE_LIMIT_FUNCTION_ARG)); + + db.set_limit(Limit::SQLITE_LIMIT_ATTACHED, 2); + assert_eq!(2, db.limit(Limit::SQLITE_LIMIT_ATTACHED)); + + db.set_limit(Limit::SQLITE_LIMIT_LIKE_PATTERN_LENGTH, 128); + assert_eq!(128, db.limit(Limit::SQLITE_LIMIT_LIKE_PATTERN_LENGTH)); + + db.set_limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER, 99); + assert_eq!(99, db.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER)); + + // SQLITE_LIMIT_TRIGGER_DEPTH was added in SQLite 3.6.18. + if crate::version_number() >= 3_006_018 { + db.set_limit(Limit::SQLITE_LIMIT_TRIGGER_DEPTH, 32); + assert_eq!(32, db.limit(Limit::SQLITE_LIMIT_TRIGGER_DEPTH)); + } + + // SQLITE_LIMIT_WORKER_THREADS was added in SQLite 3.8.7. + if crate::version_number() >= 3_008_007 { + db.set_limit(Limit::SQLITE_LIMIT_WORKER_THREADS, 2); + assert_eq!(2, db.limit(Limit::SQLITE_LIMIT_WORKER_THREADS)); + } + } +} diff --git a/third_party/rust/rusqlite/src/load_extension_guard.rs b/third_party/rust/rusqlite/src/load_extension_guard.rs new file mode 100644 index 0000000000..f4f67d1edd --- /dev/null +++ b/third_party/rust/rusqlite/src/load_extension_guard.rs @@ -0,0 +1,36 @@ +use crate::{Connection, Result}; + +/// `feature = "load_extension"` RAII guard temporarily enabling SQLite +/// extensions to be loaded. +/// +/// ## Example +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result, LoadExtensionGuard}; +/// # use std::path::{Path}; +/// fn load_my_extension(conn: &Connection) -> Result<()> { +/// let _guard = LoadExtensionGuard::new(conn)?; +/// +/// conn.load_extension(Path::new("my_sqlite_extension"), None) +/// } +/// ``` +pub struct LoadExtensionGuard<'conn> { + conn: &'conn Connection, +} + +impl LoadExtensionGuard<'_> { + /// Attempt to enable loading extensions. Loading extensions will be + /// disabled when this guard goes out of scope. Cannot be meaningfully + /// nested. + pub fn new(conn: &Connection) -> Result<LoadExtensionGuard<'_>> { + conn.load_extension_enable() + .map(|_| LoadExtensionGuard { conn }) + } +} + +#[allow(unused_must_use)] +impl Drop for LoadExtensionGuard<'_> { + fn drop(&mut self) { + self.conn.load_extension_disable(); + } +} diff --git a/third_party/rust/rusqlite/src/pragma.rs b/third_party/rust/rusqlite/src/pragma.rs new file mode 100644 index 0000000000..4855154dba --- /dev/null +++ b/third_party/rust/rusqlite/src/pragma.rs @@ -0,0 +1,444 @@ +//! Pragma helpers + +use std::ops::Deref; + +use crate::error::Error; +use crate::ffi; +use crate::types::{ToSql, ToSqlOutput, ValueRef}; +use crate::{Connection, DatabaseName, Result, Row, NO_PARAMS}; + +pub struct Sql { + buf: String, +} + +impl Sql { + pub fn new() -> Sql { + Sql { buf: String::new() } + } + + pub fn push_pragma( + &mut self, + schema_name: Option<DatabaseName<'_>>, + pragma_name: &str, + ) -> Result<()> { + self.push_keyword("PRAGMA")?; + self.push_space(); + if let Some(schema_name) = schema_name { + self.push_schema_name(schema_name); + self.push_dot(); + } + self.push_keyword(pragma_name) + } + + pub fn push_keyword(&mut self, keyword: &str) -> Result<()> { + if !keyword.is_empty() && is_identifier(keyword) { + self.buf.push_str(keyword); + Ok(()) + } else { + Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Invalid keyword \"{}\"", keyword)), + )) + } + } + + pub fn push_schema_name(&mut self, schema_name: DatabaseName<'_>) { + match schema_name { + DatabaseName::Main => self.buf.push_str("main"), + DatabaseName::Temp => self.buf.push_str("temp"), + DatabaseName::Attached(s) => self.push_identifier(s), + }; + } + + pub fn push_identifier(&mut self, s: &str) { + if is_identifier(s) { + self.buf.push_str(s); + } else { + self.wrap_and_escape(s, '"'); + } + } + + pub fn push_value(&mut self, value: &dyn ToSql) -> Result<()> { + let value = value.to_sql()?; + let value = match value { + ToSqlOutput::Borrowed(v) => v, + ToSqlOutput::Owned(ref v) => ValueRef::from(v), + #[cfg(feature = "blob")] + ToSqlOutput::ZeroBlob(_) => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{:?}\"", value)), + )); + } + #[cfg(feature = "array")] + ToSqlOutput::Array(_) => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{:?}\"", value)), + )); + } + }; + match value { + ValueRef::Integer(i) => { + self.push_int(i); + } + ValueRef::Real(r) => { + self.push_real(r); + } + ValueRef::Text(s) => { + let s = std::str::from_utf8(s)?; + self.push_string_literal(s); + } + _ => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{:?}\"", value)), + )); + } + }; + Ok(()) + } + + pub fn push_string_literal(&mut self, s: &str) { + self.wrap_and_escape(s, '\''); + } + + pub fn push_int(&mut self, i: i64) { + self.buf.push_str(&i.to_string()); + } + + pub fn push_real(&mut self, f: f64) { + self.buf.push_str(&f.to_string()); + } + + pub fn push_space(&mut self) { + self.buf.push(' '); + } + + pub fn push_dot(&mut self) { + self.buf.push('.'); + } + + pub fn push_equal_sign(&mut self) { + self.buf.push('='); + } + + pub fn open_brace(&mut self) { + self.buf.push('('); + } + + pub fn close_brace(&mut self) { + self.buf.push(')'); + } + + pub fn as_str(&self) -> &str { + &self.buf + } + + fn wrap_and_escape(&mut self, s: &str, quote: char) { + self.buf.push(quote); + let chars = s.chars(); + for ch in chars { + // escape `quote` by doubling it + if ch == quote { + self.buf.push(ch); + } + self.buf.push(ch) + } + self.buf.push(quote); + } +} + +impl Deref for Sql { + type Target = str; + + fn deref(&self) -> &str { + self.as_str() + } +} + +impl Connection { + /// Query the current value of `pragma_name`. + /// + /// Some pragmas will return multiple rows/values which cannot be retrieved + /// with this method. + /// + /// Prefer [PRAGMA function](https://sqlite.org/pragma.html#pragfunc) introduced in SQLite 3.20: + /// `SELECT user_version FROM pragma_user_version;` + pub fn pragma_query_value<T, F>( + &self, + schema_name: Option<DatabaseName<'_>>, + pragma_name: &str, + f: F, + ) -> Result<T> + where + F: FnOnce(&Row<'_>) -> Result<T>, + { + let mut query = Sql::new(); + query.push_pragma(schema_name, pragma_name)?; + self.query_row(&query, NO_PARAMS, f) + } + + /// Query the current rows/values of `pragma_name`. + /// + /// Prefer [PRAGMA function](https://sqlite.org/pragma.html#pragfunc) introduced in SQLite 3.20: + /// `SELECT * FROM pragma_collation_list;` + pub fn pragma_query<F>( + &self, + schema_name: Option<DatabaseName<'_>>, + pragma_name: &str, + mut f: F, + ) -> Result<()> + where + F: FnMut(&Row<'_>) -> Result<()>, + { + let mut query = Sql::new(); + query.push_pragma(schema_name, pragma_name)?; + let mut stmt = self.prepare(&query)?; + let mut rows = stmt.query(NO_PARAMS)?; + while let Some(result_row) = rows.next()? { + let row = result_row; + f(&row)?; + } + Ok(()) + } + + /// Query the current value(s) of `pragma_name` associated to + /// `pragma_value`. + /// + /// This method can be used with query-only pragmas which need an argument + /// (e.g. `table_info('one_tbl')`) or pragmas which returns value(s) + /// (e.g. `integrity_check`). + /// + /// Prefer [PRAGMA function](https://sqlite.org/pragma.html#pragfunc) introduced in SQLite 3.20: + /// `SELECT * FROM pragma_table_info(?);` + pub fn pragma<F>( + &self, + schema_name: Option<DatabaseName<'_>>, + pragma_name: &str, + pragma_value: &dyn ToSql, + mut f: F, + ) -> Result<()> + where + F: FnMut(&Row<'_>) -> Result<()>, + { + let mut sql = Sql::new(); + sql.push_pragma(schema_name, pragma_name)?; + // The argument may be either in parentheses + // or it may be separated from the pragma name by an equal sign. + // The two syntaxes yield identical results. + sql.open_brace(); + sql.push_value(pragma_value)?; + sql.close_brace(); + let mut stmt = self.prepare(&sql)?; + let mut rows = stmt.query(NO_PARAMS)?; + while let Some(result_row) = rows.next()? { + let row = result_row; + f(&row)?; + } + Ok(()) + } + + /// Set a new value to `pragma_name`. + /// + /// Some pragmas will return the updated value which cannot be retrieved + /// with this method. + pub fn pragma_update( + &self, + schema_name: Option<DatabaseName<'_>>, + pragma_name: &str, + pragma_value: &dyn ToSql, + ) -> Result<()> { + let mut sql = Sql::new(); + sql.push_pragma(schema_name, pragma_name)?; + // The argument may be either in parentheses + // or it may be separated from the pragma name by an equal sign. + // The two syntaxes yield identical results. + sql.push_equal_sign(); + sql.push_value(pragma_value)?; + self.execute_batch(&sql) + } + + /// Set a new value to `pragma_name` and return the updated value. + /// + /// Only few pragmas automatically return the updated value. + pub fn pragma_update_and_check<F, T>( + &self, + schema_name: Option<DatabaseName<'_>>, + pragma_name: &str, + pragma_value: &dyn ToSql, + f: F, + ) -> Result<T> + where + F: FnOnce(&Row<'_>) -> Result<T>, + { + let mut sql = Sql::new(); + sql.push_pragma(schema_name, pragma_name)?; + // The argument may be either in parentheses + // or it may be separated from the pragma name by an equal sign. + // The two syntaxes yield identical results. + sql.push_equal_sign(); + sql.push_value(pragma_value)?; + self.query_row(&sql, NO_PARAMS, f) + } +} + +fn is_identifier(s: &str) -> bool { + let chars = s.char_indices(); + for (i, ch) in chars { + if i == 0 { + if !is_identifier_start(ch) { + return false; + } + } else if !is_identifier_continue(ch) { + return false; + } + } + true +} + +fn is_identifier_start(c: char) -> bool { + (c >= 'A' && c <= 'Z') || c == '_' || (c >= 'a' && c <= 'z') || c > '\x7F' +} + +fn is_identifier_continue(c: char) -> bool { + c == '$' + || (c >= '0' && c <= '9') + || (c >= 'A' && c <= 'Z') + || c == '_' + || (c >= 'a' && c <= 'z') + || c > '\x7F' +} + +#[cfg(test)] +mod test { + use super::Sql; + use crate::pragma; + use crate::{Connection, DatabaseName}; + + #[test] + fn pragma_query_value() { + let db = Connection::open_in_memory().unwrap(); + let user_version: i32 = db + .pragma_query_value(None, "user_version", |row| row.get(0)) + .unwrap(); + assert_eq!(0, user_version); + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn pragma_func_query_value() { + use crate::NO_PARAMS; + + let db = Connection::open_in_memory().unwrap(); + let user_version: i32 = db + .query_row( + "SELECT user_version FROM pragma_user_version", + NO_PARAMS, + |row| row.get(0), + ) + .unwrap(); + assert_eq!(0, user_version); + } + + #[test] + fn pragma_query_no_schema() { + let db = Connection::open_in_memory().unwrap(); + let mut user_version = -1; + db.pragma_query(None, "user_version", |row| { + user_version = row.get(0)?; + Ok(()) + }) + .unwrap(); + assert_eq!(0, user_version); + } + + #[test] + fn pragma_query_with_schema() { + let db = Connection::open_in_memory().unwrap(); + let mut user_version = -1; + db.pragma_query(Some(DatabaseName::Main), "user_version", |row| { + user_version = row.get(0)?; + Ok(()) + }) + .unwrap(); + assert_eq!(0, user_version); + } + + #[test] + fn pragma() { + let db = Connection::open_in_memory().unwrap(); + let mut columns = Vec::new(); + db.pragma(None, "table_info", &"sqlite_master", |row| { + let column: String = row.get(1)?; + columns.push(column); + Ok(()) + }) + .unwrap(); + assert_eq!(5, columns.len()); + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn pragma_func() { + let db = Connection::open_in_memory().unwrap(); + let mut table_info = db.prepare("SELECT * FROM pragma_table_info(?)").unwrap(); + let mut columns = Vec::new(); + let mut rows = table_info.query(&["sqlite_master"]).unwrap(); + + while let Some(row) = rows.next().unwrap() { + let row = row; + let column: String = row.get(1).unwrap(); + columns.push(column); + } + assert_eq!(5, columns.len()); + } + + #[test] + fn pragma_update() { + let db = Connection::open_in_memory().unwrap(); + db.pragma_update(None, "user_version", &1).unwrap(); + } + + #[test] + fn pragma_update_and_check() { + let db = Connection::open_in_memory().unwrap(); + let journal_mode: String = db + .pragma_update_and_check(None, "journal_mode", &"OFF", |row| row.get(0)) + .unwrap(); + assert_eq!("off", &journal_mode); + } + + #[test] + fn is_identifier() { + assert!(pragma::is_identifier("full")); + assert!(pragma::is_identifier("r2d2")); + assert!(!pragma::is_identifier("sp ce")); + assert!(!pragma::is_identifier("semi;colon")); + } + + #[test] + fn double_quote() { + let mut sql = Sql::new(); + sql.push_schema_name(DatabaseName::Attached(r#"schema";--"#)); + assert_eq!(r#""schema"";--""#, sql.as_str()); + } + + #[test] + fn wrap_and_escape() { + let mut sql = Sql::new(); + sql.push_string_literal("value'; --"); + assert_eq!("'value''; --'", sql.as_str()); + } + + #[test] + fn locking_mode() { + let db = Connection::open_in_memory().unwrap(); + let r = db.pragma_update(None, "locking_mode", &"exclusive"); + if cfg!(feature = "extra_check") { + r.unwrap_err(); + } else { + r.unwrap(); + } + } +} diff --git a/third_party/rust/rusqlite/src/raw_statement.rs b/third_party/rust/rusqlite/src/raw_statement.rs new file mode 100644 index 0000000000..c02dcd9528 --- /dev/null +++ b/third_party/rust/rusqlite/src/raw_statement.rs @@ -0,0 +1,185 @@ +use super::ffi; +use super::unlock_notify; +use super::StatementStatus; +#[cfg(feature = "modern_sqlite")] +use crate::util::SqliteMallocString; +use std::ffi::CStr; +use std::os::raw::c_int; +use std::ptr; +use std::sync::Arc; + +// Private newtype for raw sqlite3_stmts that finalize themselves when dropped. +#[derive(Debug)] +pub struct RawStatement { + ptr: *mut ffi::sqlite3_stmt, + tail: usize, + // Cached indices of named parameters, computed on the fly. + cache: crate::util::ParamIndexCache, + // Cached SQL (trimmed) that we use as the key when we're in the statement + // cache. This is None for statements which didn't come from the statement + // cache. + // + // This is probably the same as `self.sql()` in most cases, but we don't + // care either way -- It's a better cache key as it is anyway since it's the + // actual source we got from rust. + // + // One example of a case where the result of `sqlite_sql` and the value in + // `statement_cache_key` might differ is if the statement has a `tail`. + statement_cache_key: Option<Arc<str>>, +} + +impl RawStatement { + pub unsafe fn new(stmt: *mut ffi::sqlite3_stmt, tail: usize) -> RawStatement { + RawStatement { + ptr: stmt, + tail, + cache: Default::default(), + statement_cache_key: None, + } + } + + pub fn is_null(&self) -> bool { + self.ptr.is_null() + } + + pub(crate) fn set_statement_cache_key(&mut self, p: impl Into<Arc<str>>) { + self.statement_cache_key = Some(p.into()); + } + + pub(crate) fn statement_cache_key(&self) -> Option<Arc<str>> { + self.statement_cache_key.clone() + } + + pub unsafe fn ptr(&self) -> *mut ffi::sqlite3_stmt { + self.ptr + } + + pub fn column_count(&self) -> usize { + // Note: Can't cache this as it changes if the schema is altered. + unsafe { ffi::sqlite3_column_count(self.ptr) as usize } + } + + pub fn column_type(&self, idx: usize) -> c_int { + unsafe { ffi::sqlite3_column_type(self.ptr, idx as c_int) } + } + + #[cfg(feature = "column_decltype")] + pub fn column_decltype(&self, idx: usize) -> Option<&CStr> { + unsafe { + let decltype = ffi::sqlite3_column_decltype(self.ptr, idx as c_int); + if decltype.is_null() { + None + } else { + Some(CStr::from_ptr(decltype)) + } + } + } + + pub fn column_name(&self, idx: usize) -> Option<&CStr> { + let idx = idx as c_int; + if idx < 0 || idx >= self.column_count() as c_int { + return None; + } + unsafe { + let ptr = ffi::sqlite3_column_name(self.ptr, idx); + // If ptr is null here, it's an OOM, so there's probably nothing + // meaningful we can do. Just assert instead of returning None. + assert!( + !ptr.is_null(), + "Null pointer from sqlite3_column_name: Out of memory?" + ); + Some(CStr::from_ptr(ptr)) + } + } + + pub fn step(&self) -> c_int { + if cfg!(feature = "unlock_notify") { + let db = unsafe { ffi::sqlite3_db_handle(self.ptr) }; + let mut rc; + loop { + rc = unsafe { ffi::sqlite3_step(self.ptr) }; + if unsafe { !unlock_notify::is_locked(db, rc) } { + break; + } + rc = unsafe { unlock_notify::wait_for_unlock_notify(db) }; + if rc != ffi::SQLITE_OK { + break; + } + self.reset(); + } + rc + } else { + unsafe { ffi::sqlite3_step(self.ptr) } + } + } + + pub fn reset(&self) -> c_int { + unsafe { ffi::sqlite3_reset(self.ptr) } + } + + pub fn bind_parameter_count(&self) -> usize { + unsafe { ffi::sqlite3_bind_parameter_count(self.ptr) as usize } + } + + pub fn bind_parameter_index(&self, name: &str) -> Option<usize> { + self.cache.get_or_insert_with(name, |param_cstr| { + let r = unsafe { ffi::sqlite3_bind_parameter_index(self.ptr, param_cstr.as_ptr()) }; + match r { + 0 => None, + i => Some(i as usize), + } + }) + } + + pub fn clear_bindings(&self) -> c_int { + unsafe { ffi::sqlite3_clear_bindings(self.ptr) } + } + + pub fn sql(&self) -> Option<&CStr> { + if self.ptr.is_null() { + None + } else { + Some(unsafe { CStr::from_ptr(ffi::sqlite3_sql(self.ptr)) }) + } + } + + pub fn finalize(mut self) -> c_int { + self.finalize_() + } + + fn finalize_(&mut self) -> c_int { + let r = unsafe { ffi::sqlite3_finalize(self.ptr) }; + self.ptr = ptr::null_mut(); + r + } + + #[cfg(all(feature = "extra_check", feature = "modern_sqlite"))] // 3.7.4 + pub fn readonly(&self) -> bool { + unsafe { ffi::sqlite3_stmt_readonly(self.ptr) != 0 } + } + + #[cfg(feature = "modern_sqlite")] // 3.14.0 + pub(crate) fn expanded_sql(&self) -> Option<SqliteMallocString> { + unsafe { SqliteMallocString::from_raw(ffi::sqlite3_expanded_sql(self.ptr)) } + } + + pub fn get_status(&self, status: StatementStatus, reset: bool) -> i32 { + assert!(!self.ptr.is_null()); + unsafe { ffi::sqlite3_stmt_status(self.ptr, status as i32, reset as i32) } + } + + #[cfg(feature = "extra_check")] + pub fn has_tail(&self) -> bool { + self.tail != 0 + } + + pub fn tail(&self) -> usize { + self.tail + } +} + +impl Drop for RawStatement { + fn drop(&mut self) { + self.finalize_(); + } +} diff --git a/third_party/rust/rusqlite/src/row.rs b/third_party/rust/rusqlite/src/row.rs new file mode 100644 index 0000000000..36aa1a6d11 --- /dev/null +++ b/third_party/rust/rusqlite/src/row.rs @@ -0,0 +1,570 @@ +use fallible_iterator::FallibleIterator; +use fallible_streaming_iterator::FallibleStreamingIterator; +use std::convert; + +use super::{Error, Result, Statement}; +use crate::types::{FromSql, FromSqlError, ValueRef}; + +/// An handle for the resulting rows of a query. +#[must_use = "Rows is lazy and will do nothing unless consumed"] +pub struct Rows<'stmt> { + pub(crate) stmt: Option<&'stmt Statement<'stmt>>, + row: Option<Row<'stmt>>, +} + +impl<'stmt> Rows<'stmt> { + fn reset(&mut self) { + if let Some(stmt) = self.stmt.take() { + stmt.reset(); + } + } + + /// Attempt to get the next row from the query. Returns `Ok(Some(Row))` if + /// there is another row, `Err(...)` if there was an error + /// getting the next row, and `Ok(None)` if all rows have been retrieved. + /// + /// ## Note + /// + /// This interface is not compatible with Rust's `Iterator` trait, because + /// the lifetime of the returned row is tied to the lifetime of `self`. + /// This is a fallible "streaming iterator". For a more natural interface, + /// consider using `query_map` or `query_and_then` instead, which + /// return types that implement `Iterator`. + #[allow(clippy::should_implement_trait)] // cannot implement Iterator + pub fn next(&mut self) -> Result<Option<&Row<'stmt>>> { + self.advance()?; + Ok((*self).get()) + } + + /// Map over this `Rows`, converting it to a [`Map`], which + /// implements `FallibleIterator`. + /// ```rust,no_run + /// use fallible_iterator::FallibleIterator; + /// # use rusqlite::{Result, Statement, NO_PARAMS}; + /// fn query(stmt: &mut Statement) -> Result<Vec<i64>> { + /// let rows = stmt.query(NO_PARAMS)?; + /// rows.map(|r| r.get(0)).collect() + /// } + /// ``` + // FIXME Hide FallibleStreamingIterator::map + pub fn map<F, B>(self, f: F) -> Map<'stmt, F> + where + F: FnMut(&Row<'_>) -> Result<B>, + { + Map { rows: self, f } + } + + /// Map over this `Rows`, converting it to a [`MappedRows`], which + /// implements `Iterator`. + pub fn mapped<F, B>(self, f: F) -> MappedRows<'stmt, F> + where + F: FnMut(&Row<'_>) -> Result<B>, + { + MappedRows { rows: self, map: f } + } + + /// Map over this `Rows` with a fallible function, converting it to a + /// [`AndThenRows`], which implements `Iterator` (instead of + /// `FallibleStreamingIterator`). + pub fn and_then<F, T, E>(self, f: F) -> AndThenRows<'stmt, F> + where + F: FnMut(&Row<'_>) -> Result<T, E>, + { + AndThenRows { rows: self, map: f } + } +} + +impl<'stmt> Rows<'stmt> { + pub(crate) fn new(stmt: &'stmt Statement<'stmt>) -> Rows<'stmt> { + Rows { + stmt: Some(stmt), + row: None, + } + } + + pub(crate) fn get_expected_row(&mut self) -> Result<&Row<'stmt>> { + match self.next()? { + Some(row) => Ok(row), + None => Err(Error::QueryReturnedNoRows), + } + } +} + +impl Drop for Rows<'_> { + fn drop(&mut self) { + self.reset(); + } +} + +/// `F` is used to tranform the _streaming_ iterator into a _fallible_ iterator. +#[must_use = "iterators are lazy and do nothing unless consumed"] +pub struct Map<'stmt, F> { + rows: Rows<'stmt>, + f: F, +} + +impl<F, B> FallibleIterator for Map<'_, F> +where + F: FnMut(&Row<'_>) -> Result<B>, +{ + type Error = Error; + type Item = B; + + fn next(&mut self) -> Result<Option<B>> { + match self.rows.next()? { + Some(v) => Ok(Some((self.f)(v)?)), + None => Ok(None), + } + } +} + +/// An iterator over the mapped resulting rows of a query. +/// +/// `F` is used to tranform the _streaming_ iterator into a _standard_ iterator. +#[must_use = "iterators are lazy and do nothing unless consumed"] +pub struct MappedRows<'stmt, F> { + rows: Rows<'stmt>, + map: F, +} + +impl<'stmt, T, F> MappedRows<'stmt, F> +where + F: FnMut(&Row<'_>) -> Result<T>, +{ + pub(crate) fn new(rows: Rows<'stmt>, f: F) -> MappedRows<'stmt, F> { + MappedRows { rows, map: f } + } +} + +impl<T, F> Iterator for MappedRows<'_, F> +where + F: FnMut(&Row<'_>) -> Result<T>, +{ + type Item = Result<T>; + + fn next(&mut self) -> Option<Result<T>> { + let map = &mut self.map; + self.rows + .next() + .transpose() + .map(|row_result| row_result.and_then(|row| (map)(&row))) + } +} + +/// An iterator over the mapped resulting rows of a query, with an Error type +/// unifying with Error. +#[must_use = "iterators are lazy and do nothing unless consumed"] +pub struct AndThenRows<'stmt, F> { + rows: Rows<'stmt>, + map: F, +} + +impl<'stmt, T, E, F> AndThenRows<'stmt, F> +where + F: FnMut(&Row<'_>) -> Result<T, E>, +{ + pub(crate) fn new(rows: Rows<'stmt>, f: F) -> AndThenRows<'stmt, F> { + AndThenRows { rows, map: f } + } +} + +impl<T, E, F> Iterator for AndThenRows<'_, F> +where + E: convert::From<Error>, + F: FnMut(&Row<'_>) -> Result<T, E>, +{ + type Item = Result<T, E>; + + fn next(&mut self) -> Option<Self::Item> { + let map = &mut self.map; + self.rows + .next() + .transpose() + .map(|row_result| row_result.map_err(E::from).and_then(|row| (map)(&row))) + } +} + +/// `FallibleStreamingIterator` differs from the standard library's `Iterator` +/// in two ways: +/// * each call to `next` (sqlite3_step) can fail. +/// * returned `Row` is valid until `next` is called again or `Statement` is +/// reset or finalized. +/// +/// While these iterators cannot be used with Rust `for` loops, `while let` +/// loops offer a similar level of ergonomics: +/// ```rust,no_run +/// # use rusqlite::{Result, Statement, NO_PARAMS}; +/// fn query(stmt: &mut Statement) -> Result<()> { +/// let mut rows = stmt.query(NO_PARAMS)?; +/// while let Some(row) = rows.next()? { +/// // scan columns value +/// } +/// Ok(()) +/// } +/// ``` +impl<'stmt> FallibleStreamingIterator for Rows<'stmt> { + type Error = Error; + type Item = Row<'stmt>; + + fn advance(&mut self) -> Result<()> { + match self.stmt { + Some(ref stmt) => match stmt.step() { + Ok(true) => { + self.row = Some(Row { stmt }); + Ok(()) + } + Ok(false) => { + self.reset(); + self.row = None; + Ok(()) + } + Err(e) => { + self.reset(); + self.row = None; + Err(e) + } + }, + None => { + self.row = None; + Ok(()) + } + } + } + + fn get(&self) -> Option<&Row<'stmt>> { + self.row.as_ref() + } +} + +/// A single result row of a query. +pub struct Row<'stmt> { + pub(crate) stmt: &'stmt Statement<'stmt>, +} + +impl<'stmt> Row<'stmt> { + /// Get the value of a particular column of the result row. + /// + /// ## Failure + /// + /// Panics if calling `row.get(idx)` would return an error, + /// including: + /// + /// * If the underlying SQLite column type is not a valid type as a source + /// for `T` + /// * If the underlying SQLite integral value is outside the range + /// representable by `T` + /// * If `idx` is outside the range of columns in the returned query + pub fn get_unwrap<I: RowIndex, T: FromSql>(&self, idx: I) -> T { + self.get(idx).unwrap() + } + + /// Get the value of a particular column of the result row. + /// + /// ## Failure + /// + /// Returns an `Error::InvalidColumnType` if the underlying SQLite column + /// type is not a valid type as a source for `T`. + /// + /// Returns an `Error::InvalidColumnIndex` if `idx` is outside the valid + /// column range for this row. + /// + /// Returns an `Error::InvalidColumnName` if `idx` is not a valid column + /// name for this row. + /// + /// If the result type is i128 (which requires the `i128_blob` feature to be + /// enabled), and the underlying SQLite column is a blob whose size is not + /// 16 bytes, `Error::InvalidColumnType` will also be returned. + pub fn get<I: RowIndex, T: FromSql>(&self, idx: I) -> Result<T> { + let idx = idx.idx(self.stmt)?; + let value = self.stmt.value_ref(idx); + FromSql::column_result(value).map_err(|err| match err { + FromSqlError::InvalidType => Error::InvalidColumnType( + idx, + self.stmt.column_name_unwrap(idx).into(), + value.data_type(), + ), + FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i), + FromSqlError::Other(err) => { + Error::FromSqlConversionFailure(idx as usize, value.data_type(), err) + } + #[cfg(feature = "i128_blob")] + FromSqlError::InvalidI128Size(_) => Error::InvalidColumnType( + idx, + self.stmt.column_name_unwrap(idx).into(), + value.data_type(), + ), + #[cfg(feature = "uuid")] + FromSqlError::InvalidUuidSize(_) => Error::InvalidColumnType( + idx, + self.stmt.column_name_unwrap(idx).into(), + value.data_type(), + ), + }) + } + + /// Get the value of a particular column of the result row as a `ValueRef`, + /// allowing data to be read out of a row without copying. + /// + /// This `ValueRef` is valid only as long as this Row, which is enforced by + /// it's lifetime. This means that while this method is completely safe, + /// it can be somewhat difficult to use, and most callers will be better + /// served by `get` or `get`. + /// + /// ## Failure + /// + /// Returns an `Error::InvalidColumnIndex` if `idx` is outside the valid + /// column range for this row. + /// + /// Returns an `Error::InvalidColumnName` if `idx` is not a valid column + /// name for this row. + pub fn get_raw_checked<I: RowIndex>(&self, idx: I) -> Result<ValueRef<'_>> { + let idx = idx.idx(self.stmt)?; + // Narrowing from `ValueRef<'stmt>` (which `self.stmt.value_ref(idx)` + // returns) to `ValueRef<'a>` is needed because it's only valid until + // the next call to sqlite3_step. + let val_ref = self.stmt.value_ref(idx); + Ok(val_ref) + } + + /// Get the value of a particular column of the result row as a `ValueRef`, + /// allowing data to be read out of a row without copying. + /// + /// This `ValueRef` is valid only as long as this Row, which is enforced by + /// it's lifetime. This means that while this method is completely safe, + /// it can be difficult to use, and most callers will be better served by + /// `get` or `get`. + /// + /// ## Failure + /// + /// Panics if calling `row.get_raw_checked(idx)` would return an error, + /// including: + /// + /// * If `idx` is outside the range of columns in the returned query. + /// * If `idx` is not a valid column name for this row. + pub fn get_raw<I: RowIndex>(&self, idx: I) -> ValueRef<'_> { + self.get_raw_checked(idx).unwrap() + } +} + +/// A trait implemented by types that can index into columns of a row. +pub trait RowIndex { + /// Returns the index of the appropriate column, or `None` if no such + /// column exists. + fn idx(&self, stmt: &Statement<'_>) -> Result<usize>; +} + +impl RowIndex for usize { + #[inline] + fn idx(&self, stmt: &Statement<'_>) -> Result<usize> { + if *self >= stmt.column_count() { + Err(Error::InvalidColumnIndex(*self)) + } else { + Ok(*self) + } + } +} + +impl RowIndex for &'_ str { + #[inline] + fn idx(&self, stmt: &Statement<'_>) -> Result<usize> { + stmt.column_index(*self) + } +} + +macro_rules! tuple_try_from_row { + ($($field:ident),*) => { + impl<'a, $($field,)*> convert::TryFrom<&'a Row<'a>> for ($($field,)*) where $($field: FromSql,)* { + type Error = crate::Error; + + // we end with index += 1, which rustc warns about + // unused_variables and unused_mut are allowed for () + #[allow(unused_assignments, unused_variables, unused_mut)] + fn try_from(row: &'a Row<'a>) -> Result<Self> { + let mut index = 0; + $( + #[allow(non_snake_case)] + let $field = row.get::<_, $field>(index)?; + index += 1; + )* + Ok(($($field,)*)) + } + } + } +} + +macro_rules! tuples_try_from_row { + () => { + // not very useful, but maybe some other macro users will find this helpful + tuple_try_from_row!(); + }; + ($first:ident $(, $remaining:ident)*) => { + tuple_try_from_row!($first $(, $remaining)*); + tuples_try_from_row!($($remaining),*); + }; +} + +tuples_try_from_row!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P); + +#[cfg(test)] +mod tests { + #![allow(clippy::redundant_closure)] // false positives due to lifetime issues; clippy issue #5594 + + #[test] + fn test_try_from_row_for_tuple_1() { + use crate::{Connection, ToSql}; + use std::convert::TryFrom; + + let conn = Connection::open_in_memory().expect("failed to create in-memoory database"); + conn.execute( + "CREATE TABLE test (a INTEGER)", + std::iter::empty::<&dyn ToSql>(), + ) + .expect("failed to create table"); + conn.execute( + "INSERT INTO test VALUES (42)", + std::iter::empty::<&dyn ToSql>(), + ) + .expect("failed to insert value"); + let val = conn + .query_row( + "SELECT a FROM test", + std::iter::empty::<&dyn ToSql>(), + |row| <(u32,)>::try_from(row), + ) + .expect("failed to query row"); + assert_eq!(val, (42,)); + let fail = conn.query_row( + "SELECT a FROM test", + std::iter::empty::<&dyn ToSql>(), + |row| <(u32, u32)>::try_from(row), + ); + assert!(fail.is_err()); + } + + #[test] + fn test_try_from_row_for_tuple_2() { + use crate::{Connection, ToSql}; + use std::convert::TryFrom; + + let conn = Connection::open_in_memory().expect("failed to create in-memoory database"); + conn.execute( + "CREATE TABLE test (a INTEGER, b INTEGER)", + std::iter::empty::<&dyn ToSql>(), + ) + .expect("failed to create table"); + conn.execute( + "INSERT INTO test VALUES (42, 47)", + std::iter::empty::<&dyn ToSql>(), + ) + .expect("failed to insert value"); + let val = conn + .query_row( + "SELECT a, b FROM test", + std::iter::empty::<&dyn ToSql>(), + |row| <(u32, u32)>::try_from(row), + ) + .expect("failed to query row"); + assert_eq!(val, (42, 47)); + let fail = conn.query_row( + "SELECT a, b FROM test", + std::iter::empty::<&dyn ToSql>(), + |row| <(u32, u32, u32)>::try_from(row), + ); + assert!(fail.is_err()); + } + + #[test] + fn test_try_from_row_for_tuple_16() { + use crate::{Connection, ToSql}; + use std::convert::TryFrom; + + let create_table = "CREATE TABLE test ( + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER, + e INTEGER, + f INTEGER, + g INTEGER, + h INTEGER, + i INTEGER, + j INTEGER, + k INTEGER, + l INTEGER, + m INTEGER, + n INTEGER, + o INTEGER, + p INTEGER + )"; + + let insert_values = "INSERT INTO test VALUES ( + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15 + )"; + + type BigTuple = ( + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + u32, + ); + + let conn = Connection::open_in_memory().expect("failed to create in-memoory database"); + conn.execute(create_table, std::iter::empty::<&dyn ToSql>()) + .expect("failed to create table"); + conn.execute(insert_values, std::iter::empty::<&dyn ToSql>()) + .expect("failed to insert value"); + let val = conn + .query_row( + "SELECT * FROM test", + std::iter::empty::<&dyn ToSql>(), + |row| BigTuple::try_from(row), + ) + .expect("failed to query row"); + // Debug is not implemented for tuples of 16 + assert_eq!(val.0, 0); + assert_eq!(val.1, 1); + assert_eq!(val.2, 2); + assert_eq!(val.3, 3); + assert_eq!(val.4, 4); + assert_eq!(val.5, 5); + assert_eq!(val.6, 6); + assert_eq!(val.7, 7); + assert_eq!(val.8, 8); + assert_eq!(val.9, 9); + assert_eq!(val.10, 10); + assert_eq!(val.11, 11); + assert_eq!(val.12, 12); + assert_eq!(val.13, 13); + assert_eq!(val.14, 14); + assert_eq!(val.15, 15); + + // We don't test one bigger because it's unimplemented + } +} diff --git a/third_party/rust/rusqlite/src/session.rs b/third_party/rust/rusqlite/src/session.rs new file mode 100644 index 0000000000..97ae3a5ba2 --- /dev/null +++ b/third_party/rust/rusqlite/src/session.rs @@ -0,0 +1,918 @@ +//! `feature = "session"` [Session Extension](https://sqlite.org/sessionintro.html) +#![allow(non_camel_case_types)] + +use std::ffi::CStr; +use std::io::{Read, Write}; +use std::marker::PhantomData; +use std::os::raw::{c_char, c_int, c_uchar, c_void}; +use std::panic::{catch_unwind, RefUnwindSafe}; +use std::ptr; +use std::slice::{from_raw_parts, from_raw_parts_mut}; + +use fallible_streaming_iterator::FallibleStreamingIterator; + +use crate::error::error_from_sqlite_code; +use crate::ffi; +use crate::hooks::Action; +use crate::types::ValueRef; +use crate::{errmsg_to_string, str_to_cstring, Connection, DatabaseName, Result}; + +// https://sqlite.org/session.html + +/// `feature = "session"` An instance of this object is a session that can be +/// used to record changes to a database. +pub struct Session<'conn> { + phantom: PhantomData<&'conn Connection>, + s: *mut ffi::sqlite3_session, + filter: Option<Box<dyn Fn(&str) -> bool>>, +} + +impl Session<'_> { + /// Create a new session object + pub fn new<'conn>(db: &'conn Connection) -> Result<Session<'conn>> { + Session::new_with_name(db, DatabaseName::Main) + } + + /// Create a new session object + pub fn new_with_name<'conn>( + db: &'conn Connection, + name: DatabaseName<'_>, + ) -> Result<Session<'conn>> { + let name = name.to_cstring()?; + + let db = db.db.borrow_mut().db; + + let mut s: *mut ffi::sqlite3_session = ptr::null_mut(); + check!(unsafe { ffi::sqlite3session_create(db, name.as_ptr(), &mut s) }); + + Ok(Session { + phantom: PhantomData, + s, + filter: None, + }) + } + + /// Set a table filter + pub fn table_filter<F>(&mut self, filter: Option<F>) + where + F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, + { + unsafe extern "C" fn call_boxed_closure<F>( + p_arg: *mut c_void, + tbl_str: *const c_char, + ) -> c_int + where + F: Fn(&str) -> bool + RefUnwindSafe, + { + use std::str; + + let boxed_filter: *mut F = p_arg as *mut F; + let tbl_name = { + let c_slice = CStr::from_ptr(tbl_str).to_bytes(); + str::from_utf8(c_slice) + }; + if let Ok(true) = + catch_unwind(|| (*boxed_filter)(tbl_name.expect("non-utf8 table name"))) + { + 1 + } else { + 0 + } + } + + match filter { + Some(filter) => { + let boxed_filter = Box::new(filter); + unsafe { + ffi::sqlite3session_table_filter( + self.s, + Some(call_boxed_closure::<F>), + &*boxed_filter as *const F as *mut _, + ); + } + self.filter = Some(boxed_filter); + } + _ => { + unsafe { ffi::sqlite3session_table_filter(self.s, None, ptr::null_mut()) } + self.filter = None; + } + }; + } + + /// Attach a table. `None` means all tables. + pub fn attach(&mut self, table: Option<&str>) -> Result<()> { + let table = if let Some(table) = table { + Some(str_to_cstring(table)?) + } else { + None + }; + let table = table.as_ref().map(|s| s.as_ptr()).unwrap_or(ptr::null()); + unsafe { check!(ffi::sqlite3session_attach(self.s, table)) }; + Ok(()) + } + + /// Generate a Changeset + pub fn changeset(&mut self) -> Result<Changeset> { + let mut n = 0; + let mut cs: *mut c_void = ptr::null_mut(); + check!(unsafe { ffi::sqlite3session_changeset(self.s, &mut n, &mut cs) }); + Ok(Changeset { cs, n }) + } + + /// Write the set of changes represented by this session to `output`. + pub fn changeset_strm(&mut self, output: &mut dyn Write) -> Result<()> { + let output_ref = &output; + check!(unsafe { + ffi::sqlite3session_changeset_strm( + self.s, + Some(x_output), + output_ref as *const &mut dyn Write as *mut c_void, + ) + }); + Ok(()) + } + + /// Generate a Patchset + pub fn patchset(&mut self) -> Result<Changeset> { + let mut n = 0; + let mut ps: *mut c_void = ptr::null_mut(); + check!(unsafe { ffi::sqlite3session_patchset(self.s, &mut n, &mut ps) }); + // TODO Validate: same struct + Ok(Changeset { cs: ps, n }) + } + + /// Write the set of patches represented by this session to `output`. + pub fn patchset_strm(&mut self, output: &mut dyn Write) -> Result<()> { + let output_ref = &output; + check!(unsafe { + ffi::sqlite3session_patchset_strm( + self.s, + Some(x_output), + output_ref as *const &mut dyn Write as *mut c_void, + ) + }); + Ok(()) + } + + /// Load the difference between tables. + pub fn diff(&mut self, from: DatabaseName<'_>, table: &str) -> Result<()> { + let from = from.to_cstring()?; + let table = str_to_cstring(table)?; + let table = table.as_ptr(); + unsafe { + let mut errmsg = ptr::null_mut(); + let r = + ffi::sqlite3session_diff(self.s, from.as_ptr(), table, &mut errmsg as *mut *mut _); + if r != ffi::SQLITE_OK { + let errmsg: *mut c_char = errmsg; + let message = errmsg_to_string(&*errmsg); + ffi::sqlite3_free(errmsg as *mut ::std::os::raw::c_void); + return Err(error_from_sqlite_code(r, Some(message))); + } + } + Ok(()) + } + + /// Test if a changeset has recorded any changes + pub fn is_empty(&self) -> bool { + unsafe { ffi::sqlite3session_isempty(self.s) != 0 } + } + + /// Query the current state of the session + pub fn is_enabled(&self) -> bool { + unsafe { ffi::sqlite3session_enable(self.s, -1) != 0 } + } + + /// Enable or disable the recording of changes + pub fn set_enabled(&mut self, enabled: bool) { + unsafe { + ffi::sqlite3session_enable(self.s, if enabled { 1 } else { 0 }); + } + } + + /// Query the current state of the indirect flag + pub fn is_indirect(&self) -> bool { + unsafe { ffi::sqlite3session_indirect(self.s, -1) != 0 } + } + + /// Set or clear the indirect change flag + pub fn set_indirect(&mut self, indirect: bool) { + unsafe { + ffi::sqlite3session_indirect(self.s, if indirect { 1 } else { 0 }); + } + } +} + +impl Drop for Session<'_> { + fn drop(&mut self) { + if self.filter.is_some() { + self.table_filter(None::<fn(&str) -> bool>); + } + unsafe { ffi::sqlite3session_delete(self.s) }; + } +} + +/// `feature = "session"` Invert a changeset +pub fn invert_strm(input: &mut dyn Read, output: &mut dyn Write) -> Result<()> { + let input_ref = &input; + let output_ref = &output; + check!(unsafe { + ffi::sqlite3changeset_invert_strm( + Some(x_input), + input_ref as *const &mut dyn Read as *mut c_void, + Some(x_output), + output_ref as *const &mut dyn Write as *mut c_void, + ) + }); + Ok(()) +} + +/// `feature = "session"` Combine two changesets +pub fn concat_strm( + input_a: &mut dyn Read, + input_b: &mut dyn Read, + output: &mut dyn Write, +) -> Result<()> { + let input_a_ref = &input_a; + let input_b_ref = &input_b; + let output_ref = &output; + check!(unsafe { + ffi::sqlite3changeset_concat_strm( + Some(x_input), + input_a_ref as *const &mut dyn Read as *mut c_void, + Some(x_input), + input_b_ref as *const &mut dyn Read as *mut c_void, + Some(x_output), + output_ref as *const &mut dyn Write as *mut c_void, + ) + }); + Ok(()) +} + +/// `feature = "session"` Changeset or Patchset +pub struct Changeset { + cs: *mut c_void, + n: c_int, +} + +impl Changeset { + /// Invert a changeset + pub fn invert(&self) -> Result<Changeset> { + let mut n = 0; + let mut cs = ptr::null_mut(); + check!(unsafe { + ffi::sqlite3changeset_invert(self.n, self.cs, &mut n, &mut cs as *mut *mut _) + }); + Ok(Changeset { cs, n }) + } + + /// Create an iterator to traverse a changeset + pub fn iter(&self) -> Result<ChangesetIter<'_>> { + let mut it = ptr::null_mut(); + check!(unsafe { ffi::sqlite3changeset_start(&mut it as *mut *mut _, self.n, self.cs) }); + Ok(ChangesetIter { + phantom: PhantomData, + it, + item: None, + }) + } + + /// Concatenate two changeset objects + pub fn concat(a: &Changeset, b: &Changeset) -> Result<Changeset> { + let mut n = 0; + let mut cs = ptr::null_mut(); + check!(unsafe { + ffi::sqlite3changeset_concat(a.n, a.cs, b.n, b.cs, &mut n, &mut cs as *mut *mut _) + }); + Ok(Changeset { cs, n }) + } +} + +impl Drop for Changeset { + fn drop(&mut self) { + unsafe { + ffi::sqlite3_free(self.cs); + } + } +} + +/// `feature = "session"` Cursor for iterating over the elements of a changeset +/// or patchset. +pub struct ChangesetIter<'changeset> { + phantom: PhantomData<&'changeset Changeset>, + it: *mut ffi::sqlite3_changeset_iter, + item: Option<ChangesetItem>, +} + +impl ChangesetIter<'_> { + /// Create an iterator on `input` + pub fn start_strm<'input>(input: &&'input mut dyn Read) -> Result<ChangesetIter<'input>> { + let mut it = ptr::null_mut(); + check!(unsafe { + ffi::sqlite3changeset_start_strm( + &mut it as *mut *mut _, + Some(x_input), + input as *const &mut dyn Read as *mut c_void, + ) + }); + Ok(ChangesetIter { + phantom: PhantomData, + it, + item: None, + }) + } +} + +impl FallibleStreamingIterator for ChangesetIter<'_> { + type Error = crate::error::Error; + type Item = ChangesetItem; + + fn advance(&mut self) -> Result<()> { + let rc = unsafe { ffi::sqlite3changeset_next(self.it) }; + match rc { + ffi::SQLITE_ROW => { + self.item = Some(ChangesetItem { it: self.it }); + Ok(()) + } + ffi::SQLITE_DONE => { + self.item = None; + Ok(()) + } + code => Err(error_from_sqlite_code(code, None)), + } + } + + fn get(&self) -> Option<&ChangesetItem> { + self.item.as_ref() + } +} + +/// `feature = "session"` +pub struct Operation<'item> { + table_name: &'item str, + number_of_columns: i32, + code: Action, + indirect: bool, +} + +impl Operation<'_> { + /// Returns the table name. + pub fn table_name(&self) -> &str { + self.table_name + } + + /// Returns the number of columns in table + pub fn number_of_columns(&self) -> i32 { + self.number_of_columns + } + + /// Returns the action code. + pub fn code(&self) -> Action { + self.code + } + + /// Returns `true` for an 'indirect' change. + pub fn indirect(&self) -> bool { + self.indirect + } +} + +impl Drop for ChangesetIter<'_> { + fn drop(&mut self) { + unsafe { + ffi::sqlite3changeset_finalize(self.it); + } + } +} + +/// `feature = "session"` An item passed to a conflict-handler by +/// `Connection::apply`, or an item generated by `ChangesetIter::next`. +// TODO enum ? Delete, Insert, Update, ... +pub struct ChangesetItem { + it: *mut ffi::sqlite3_changeset_iter, +} + +impl ChangesetItem { + /// Obtain conflicting row values + /// + /// May only be called with an `SQLITE_CHANGESET_DATA` or + /// `SQLITE_CHANGESET_CONFLICT` conflict handler callback. + pub fn conflict(&self, col: usize) -> Result<ValueRef<'_>> { + unsafe { + let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut(); + check!(ffi::sqlite3changeset_conflict( + self.it, + col as i32, + &mut p_value, + )); + Ok(ValueRef::from_value(p_value)) + } + } + + /// Determine the number of foreign key constraint violations + /// + /// May only be called with an `SQLITE_CHANGESET_FOREIGN_KEY` conflict + /// handler callback. + pub fn fk_conflicts(&self) -> Result<i32> { + unsafe { + let mut p_out = 0; + check!(ffi::sqlite3changeset_fk_conflicts(self.it, &mut p_out)); + Ok(p_out) + } + } + + /// Obtain new.* Values + /// + /// May only be called if the type of change is either `SQLITE_UPDATE` or + /// `SQLITE_INSERT`. + pub fn new_value(&self, col: usize) -> Result<ValueRef<'_>> { + unsafe { + let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut(); + check!(ffi::sqlite3changeset_new(self.it, col as i32, &mut p_value,)); + Ok(ValueRef::from_value(p_value)) + } + } + + /// Obtain old.* Values + /// + /// May only be called if the type of change is either `SQLITE_DELETE` or + /// `SQLITE_UPDATE`. + pub fn old_value(&self, col: usize) -> Result<ValueRef<'_>> { + unsafe { + let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut(); + check!(ffi::sqlite3changeset_old(self.it, col as i32, &mut p_value,)); + Ok(ValueRef::from_value(p_value)) + } + } + + /// Obtain the current operation + pub fn op(&self) -> Result<Operation<'_>> { + let mut number_of_columns = 0; + let mut code = 0; + let mut indirect = 0; + let tab = unsafe { + let mut pz_tab: *const c_char = ptr::null(); + check!(ffi::sqlite3changeset_op( + self.it, + &mut pz_tab, + &mut number_of_columns, + &mut code, + &mut indirect + )); + CStr::from_ptr(pz_tab) + }; + let table_name = tab.to_str()?; + Ok(Operation { + table_name, + number_of_columns, + code: Action::from(code), + indirect: indirect != 0, + }) + } + + /// Obtain the primary key definition of a table + pub fn pk(&self) -> Result<&[u8]> { + let mut number_of_columns = 0; + unsafe { + let mut pks: *mut c_uchar = ptr::null_mut(); + check!(ffi::sqlite3changeset_pk( + self.it, + &mut pks, + &mut number_of_columns + )); + Ok(from_raw_parts(pks, number_of_columns as usize)) + } + } +} + +/// `feature = "session"` Used to combine two or more changesets or +/// patchsets +pub struct Changegroup { + cg: *mut ffi::sqlite3_changegroup, +} + +impl Changegroup { + /// Create a new change group. + pub fn new() -> Result<Self> { + let mut cg = ptr::null_mut(); + check!(unsafe { ffi::sqlite3changegroup_new(&mut cg) }); + Ok(Changegroup { cg }) + } + + /// Add a changeset + pub fn add(&mut self, cs: &Changeset) -> Result<()> { + check!(unsafe { ffi::sqlite3changegroup_add(self.cg, cs.n, cs.cs) }); + Ok(()) + } + + /// Add a changeset read from `input` to this change group. + pub fn add_stream(&mut self, input: &mut dyn Read) -> Result<()> { + let input_ref = &input; + check!(unsafe { + ffi::sqlite3changegroup_add_strm( + self.cg, + Some(x_input), + input_ref as *const &mut dyn Read as *mut c_void, + ) + }); + Ok(()) + } + + /// Obtain a composite Changeset + pub fn output(&mut self) -> Result<Changeset> { + let mut n = 0; + let mut output: *mut c_void = ptr::null_mut(); + check!(unsafe { ffi::sqlite3changegroup_output(self.cg, &mut n, &mut output) }); + Ok(Changeset { cs: output, n }) + } + + /// Write the combined set of changes to `output`. + pub fn output_strm(&mut self, output: &mut dyn Write) -> Result<()> { + let output_ref = &output; + check!(unsafe { + ffi::sqlite3changegroup_output_strm( + self.cg, + Some(x_output), + output_ref as *const &mut dyn Write as *mut c_void, + ) + }); + Ok(()) + } +} + +impl Drop for Changegroup { + fn drop(&mut self) { + unsafe { + ffi::sqlite3changegroup_delete(self.cg); + } + } +} + +impl Connection { + /// `feature = "session"` Apply a changeset to a database + pub fn apply<F, C>(&self, cs: &Changeset, filter: Option<F>, conflict: C) -> Result<()> + where + F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, + C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, + { + let db = self.db.borrow_mut().db; + + let filtered = filter.is_some(); + let tuple = &mut (filter, conflict); + check!(unsafe { + if filtered { + ffi::sqlite3changeset_apply( + db, + cs.n, + cs.cs, + Some(call_filter::<F, C>), + Some(call_conflict::<F, C>), + tuple as *mut (Option<F>, C) as *mut c_void, + ) + } else { + ffi::sqlite3changeset_apply( + db, + cs.n, + cs.cs, + None, + Some(call_conflict::<F, C>), + tuple as *mut (Option<F>, C) as *mut c_void, + ) + } + }); + Ok(()) + } + + /// `feature = "session"` Apply a changeset to a database + pub fn apply_strm<F, C>( + &self, + input: &mut dyn Read, + filter: Option<F>, + conflict: C, + ) -> Result<()> + where + F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, + C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, + { + let input_ref = &input; + let db = self.db.borrow_mut().db; + + let filtered = filter.is_some(); + let tuple = &mut (filter, conflict); + check!(unsafe { + if filtered { + ffi::sqlite3changeset_apply_strm( + db, + Some(x_input), + input_ref as *const &mut dyn Read as *mut c_void, + Some(call_filter::<F, C>), + Some(call_conflict::<F, C>), + tuple as *mut (Option<F>, C) as *mut c_void, + ) + } else { + ffi::sqlite3changeset_apply_strm( + db, + Some(x_input), + input_ref as *const &mut dyn Read as *mut c_void, + None, + Some(call_conflict::<F, C>), + tuple as *mut (Option<F>, C) as *mut c_void, + ) + } + }); + Ok(()) + } +} + +/// `feature = "session"` Constants passed to the conflict handler +/// See [here](https://sqlite.org/session.html#SQLITE_CHANGESET_CONFLICT) for details. +#[allow(missing_docs)] +#[repr(i32)] +#[derive(Debug, PartialEq)] +#[non_exhaustive] +pub enum ConflictType { + UNKNOWN = -1, + SQLITE_CHANGESET_DATA = ffi::SQLITE_CHANGESET_DATA, + SQLITE_CHANGESET_NOTFOUND = ffi::SQLITE_CHANGESET_NOTFOUND, + SQLITE_CHANGESET_CONFLICT = ffi::SQLITE_CHANGESET_CONFLICT, + SQLITE_CHANGESET_CONSTRAINT = ffi::SQLITE_CHANGESET_CONSTRAINT, + SQLITE_CHANGESET_FOREIGN_KEY = ffi::SQLITE_CHANGESET_FOREIGN_KEY, +} +impl From<i32> for ConflictType { + fn from(code: i32) -> ConflictType { + match code { + ffi::SQLITE_CHANGESET_DATA => ConflictType::SQLITE_CHANGESET_DATA, + ffi::SQLITE_CHANGESET_NOTFOUND => ConflictType::SQLITE_CHANGESET_NOTFOUND, + ffi::SQLITE_CHANGESET_CONFLICT => ConflictType::SQLITE_CHANGESET_CONFLICT, + ffi::SQLITE_CHANGESET_CONSTRAINT => ConflictType::SQLITE_CHANGESET_CONSTRAINT, + ffi::SQLITE_CHANGESET_FOREIGN_KEY => ConflictType::SQLITE_CHANGESET_FOREIGN_KEY, + _ => ConflictType::UNKNOWN, + } + } +} + +/// `feature = "session"` Constants returned by the conflict handler +/// See [here](https://sqlite.org/session.html#SQLITE_CHANGESET_ABORT) for details. +#[allow(missing_docs)] +#[repr(i32)] +#[derive(Debug, PartialEq)] +#[non_exhaustive] +pub enum ConflictAction { + SQLITE_CHANGESET_OMIT = ffi::SQLITE_CHANGESET_OMIT, + SQLITE_CHANGESET_REPLACE = ffi::SQLITE_CHANGESET_REPLACE, + SQLITE_CHANGESET_ABORT = ffi::SQLITE_CHANGESET_ABORT, +} + +unsafe extern "C" fn call_filter<F, C>(p_ctx: *mut c_void, tbl_str: *const c_char) -> c_int +where + F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, + C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, +{ + use std::str; + + let tuple: *mut (Option<F>, C) = p_ctx as *mut (Option<F>, C); + let tbl_name = { + let c_slice = CStr::from_ptr(tbl_str).to_bytes(); + str::from_utf8(c_slice) + }; + match *tuple { + (Some(ref filter), _) => { + if let Ok(true) = catch_unwind(|| filter(tbl_name.expect("illegal table name"))) { + 1 + } else { + 0 + } + } + _ => unimplemented!(), + } +} + +unsafe extern "C" fn call_conflict<F, C>( + p_ctx: *mut c_void, + e_conflict: c_int, + p: *mut ffi::sqlite3_changeset_iter, +) -> c_int +where + F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, + C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, +{ + let tuple: *mut (Option<F>, C) = p_ctx as *mut (Option<F>, C); + let conflict_type = ConflictType::from(e_conflict); + let item = ChangesetItem { it: p }; + if let Ok(action) = catch_unwind(|| (*tuple).1(conflict_type, item)) { + action as c_int + } else { + ffi::SQLITE_CHANGESET_ABORT + } +} + +unsafe extern "C" fn x_input(p_in: *mut c_void, data: *mut c_void, len: *mut c_int) -> c_int { + if p_in.is_null() { + return ffi::SQLITE_MISUSE; + } + let bytes: &mut [u8] = from_raw_parts_mut(data as *mut u8, *len as usize); + let input = p_in as *mut &mut dyn Read; + match (*input).read(bytes) { + Ok(n) => { + *len = n as i32; // TODO Validate: n = 0 may not mean the reader will always no longer be able to + // produce bytes. + ffi::SQLITE_OK + } + Err(_) => ffi::SQLITE_IOERR_READ, // TODO check if err is a (ru)sqlite Error => propagate + } +} + +unsafe extern "C" fn x_output(p_out: *mut c_void, data: *const c_void, len: c_int) -> c_int { + if p_out.is_null() { + return ffi::SQLITE_MISUSE; + } + // The sessions module never invokes an xOutput callback with the third + // parameter set to a value less than or equal to zero. + let bytes: &[u8] = from_raw_parts(data as *const u8, len as usize); + let output = p_out as *mut &mut dyn Write; + match (*output).write_all(bytes) { + Ok(_) => ffi::SQLITE_OK, + Err(_) => ffi::SQLITE_IOERR_WRITE, // TODO check if err is a (ru)sqlite Error => propagate + } +} + +#[cfg(test)] +mod test { + use fallible_streaming_iterator::FallibleStreamingIterator; + use std::io::Read; + use std::sync::atomic::{AtomicBool, Ordering}; + + use super::{Changeset, ChangesetIter, ConflictAction, ConflictType, Session}; + use crate::hooks::Action; + use crate::Connection; + + fn one_changeset() -> Changeset { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);") + .unwrap(); + + let mut session = Session::new(&db).unwrap(); + assert!(session.is_empty()); + + session.attach(None).unwrap(); + db.execute("INSERT INTO foo (t) VALUES (?);", &["bar"]) + .unwrap(); + + session.changeset().unwrap() + } + + fn one_changeset_strm() -> Vec<u8> { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);") + .unwrap(); + + let mut session = Session::new(&db).unwrap(); + assert!(session.is_empty()); + + session.attach(None).unwrap(); + db.execute("INSERT INTO foo (t) VALUES (?);", &["bar"]) + .unwrap(); + + let mut output = Vec::new(); + session.changeset_strm(&mut output).unwrap(); + output + } + + #[test] + fn test_changeset() { + let changeset = one_changeset(); + let mut iter = changeset.iter().unwrap(); + let item = iter.next().unwrap(); + assert!(item.is_some()); + + let item = item.unwrap(); + let op = item.op().unwrap(); + assert_eq!("foo", op.table_name()); + assert_eq!(1, op.number_of_columns()); + assert_eq!(Action::SQLITE_INSERT, op.code()); + assert_eq!(false, op.indirect()); + + let pk = item.pk().unwrap(); + assert_eq!(&[1], pk); + + let new_value = item.new_value(0).unwrap(); + assert_eq!(Ok("bar"), new_value.as_str()); + } + + #[test] + fn test_changeset_strm() { + let output = one_changeset_strm(); + assert!(!output.is_empty()); + assert_eq!(14, output.len()); + + let input: &mut dyn Read = &mut output.as_slice(); + let mut iter = ChangesetIter::start_strm(&input).unwrap(); + let item = iter.next().unwrap(); + assert!(item.is_some()); + } + + #[test] + fn test_changeset_apply() { + let changeset = one_changeset(); + + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);") + .unwrap(); + + lazy_static::lazy_static! { + static ref CALLED: AtomicBool = AtomicBool::new(false); + } + db.apply( + &changeset, + None::<fn(&str) -> bool>, + |_conflict_type, _item| { + CALLED.store(true, Ordering::Relaxed); + ConflictAction::SQLITE_CHANGESET_OMIT + }, + ) + .unwrap(); + + assert!(!CALLED.load(Ordering::Relaxed)); + let check = db + .query_row("SELECT 1 FROM foo WHERE t = ?", &["bar"], |row| { + row.get::<_, i32>(0) + }) + .unwrap(); + assert_eq!(1, check); + + // conflict expected when same changeset applied again on the same db + db.apply( + &changeset, + None::<fn(&str) -> bool>, + |conflict_type, item| { + CALLED.store(true, Ordering::Relaxed); + assert_eq!(ConflictType::SQLITE_CHANGESET_CONFLICT, conflict_type); + let conflict = item.conflict(0).unwrap(); + assert_eq!(Ok("bar"), conflict.as_str()); + ConflictAction::SQLITE_CHANGESET_OMIT + }, + ) + .unwrap(); + assert!(CALLED.load(Ordering::Relaxed)); + } + + #[test] + fn test_changeset_apply_strm() { + let output = one_changeset_strm(); + + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);") + .unwrap(); + + let mut input = output.as_slice(); + db.apply_strm( + &mut input, + None::<fn(&str) -> bool>, + |_conflict_type, _item| ConflictAction::SQLITE_CHANGESET_OMIT, + ) + .unwrap(); + + let check = db + .query_row("SELECT 1 FROM foo WHERE t = ?", &["bar"], |row| { + row.get::<_, i32>(0) + }) + .unwrap(); + assert_eq!(1, check); + } + + #[test] + fn test_session_empty() { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);") + .unwrap(); + + let mut session = Session::new(&db).unwrap(); + assert!(session.is_empty()); + + session.attach(None).unwrap(); + db.execute("INSERT INTO foo (t) VALUES (?);", &["bar"]) + .unwrap(); + + assert!(!session.is_empty()); + } + + #[test] + fn test_session_set_enabled() { + let db = Connection::open_in_memory().unwrap(); + + let mut session = Session::new(&db).unwrap(); + assert!(session.is_enabled()); + session.set_enabled(false); + assert!(!session.is_enabled()); + } + + #[test] + fn test_session_set_indirect() { + let db = Connection::open_in_memory().unwrap(); + + let mut session = Session::new(&db).unwrap(); + assert!(!session.is_indirect()); + session.set_indirect(true); + assert!(session.is_indirect()); + } +} diff --git a/third_party/rust/rusqlite/src/statement.rs b/third_party/rust/rusqlite/src/statement.rs new file mode 100644 index 0000000000..648a9b7777 --- /dev/null +++ b/third_party/rust/rusqlite/src/statement.rs @@ -0,0 +1,1250 @@ +use std::iter::IntoIterator; +use std::os::raw::{c_int, c_void}; +#[cfg(feature = "array")] +use std::rc::Rc; +use std::slice::from_raw_parts; +use std::{convert, fmt, mem, ptr, str}; + +use super::ffi; +use super::{len_as_c_int, str_for_sqlite}; +use super::{ + AndThenRows, Connection, Error, MappedRows, RawStatement, Result, Row, Rows, ValueRef, +}; +use crate::types::{ToSql, ToSqlOutput}; +#[cfg(feature = "array")] +use crate::vtab::array::{free_array, ARRAY_TYPE}; + +/// A prepared statement. +pub struct Statement<'conn> { + conn: &'conn Connection, + pub(crate) stmt: RawStatement, +} + +impl Statement<'_> { + /// Execute the prepared statement. + /// + /// On success, returns the number of rows that were changed or inserted or + /// deleted (via `sqlite3_changes`). + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn update_rows(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("UPDATE foo SET bar = 'baz' WHERE qux = ?")?; + /// + /// stmt.execute(&[1i32])?; + /// stmt.execute(&[2i32])?; + /// + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if binding parameters fails, the executed statement + /// returns rows (in which case `query` should be used instead), or the + /// underlying SQLite call fails. + pub fn execute<P>(&mut self, params: P) -> Result<usize> + where + P: IntoIterator, + P::Item: ToSql, + { + self.bind_parameters(params)?; + self.execute_with_bound_parameters() + } + + /// Execute the prepared statement with named parameter(s). If any + /// parameters that were in the prepared statement are not included in + /// `params`, they will continue to use the most-recently bound value + /// from a previous call to `execute_named`, or `NULL` if they have + /// never been bound. + /// + /// On success, returns the number of rows that were changed or inserted or + /// deleted (via `sqlite3_changes`). + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn insert(conn: &Connection) -> Result<usize> { + /// let mut stmt = conn.prepare("INSERT INTO test (name) VALUES (:name)")?; + /// stmt.execute_named(&[(":name", &"one")]) + /// } + /// ``` + /// + /// Note, the `named_params` macro is provided for syntactic convenience, + /// and so the above example could also be written as: + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result, named_params}; + /// fn insert(conn: &Connection) -> Result<usize> { + /// let mut stmt = conn.prepare("INSERT INTO test (name) VALUES (:name)")?; + /// stmt.execute_named(named_params!{":name": "one"}) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if binding parameters fails, the executed statement + /// returns rows (in which case `query` should be used instead), or the + /// underlying SQLite call fails. + pub fn execute_named(&mut self, params: &[(&str, &dyn ToSql)]) -> Result<usize> { + self.bind_parameters_named(params)?; + self.execute_with_bound_parameters() + } + + /// Execute an INSERT and return the ROWID. + /// + /// # Note + /// + /// This function is a convenience wrapper around `execute()` intended for + /// queries that insert a single item. It is possible to misuse this + /// function in a way that it cannot detect, such as by calling it on a + /// statement which _updates_ a single + /// item rather than inserting one. Please don't do that. + /// + /// # Failure + /// + /// Will return `Err` if no row is inserted or many rows are inserted. + pub fn insert<P>(&mut self, params: P) -> Result<i64> + where + P: IntoIterator, + P::Item: ToSql, + { + let changes = self.execute(params)?; + match changes { + 1 => Ok(self.conn.last_insert_rowid()), + _ => Err(Error::StatementChangedRows(changes)), + } + } + + /// Execute the prepared statement, returning a handle to the resulting + /// rows. + /// + /// Due to lifetime restricts, the rows handle returned by `query` does not + /// implement the `Iterator` trait. Consider using `query_map` or + /// `query_and_then` instead, which do. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result, NO_PARAMS}; + /// fn get_names(conn: &Connection) -> Result<Vec<String>> { + /// let mut stmt = conn.prepare("SELECT name FROM people")?; + /// let mut rows = stmt.query(NO_PARAMS)?; + /// + /// let mut names = Vec::new(); + /// while let Some(row) = rows.next()? { + /// names.push(row.get(0)?); + /// } + /// + /// Ok(names) + /// } + /// ``` + /// + /// ## Failure + /// + /// Will return `Err` if binding parameters fails. + pub fn query<P>(&mut self, params: P) -> Result<Rows<'_>> + where + P: IntoIterator, + P::Item: ToSql, + { + self.check_readonly()?; + self.bind_parameters(params)?; + Ok(Rows::new(self)) + } + + /// Execute the prepared statement with named parameter(s), returning a + /// handle for the resulting rows. If any parameters that were in the + /// prepared statement are not included in `params`, they will continue + /// to use the most-recently bound value from a previous + /// call to `query_named`, or `NULL` if they have never been bound. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn query(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("SELECT * FROM test where name = :name")?; + /// let mut rows = stmt.query_named(&[(":name", &"one")])?; + /// while let Some(row) = rows.next()? { + /// // ... + /// } + /// Ok(()) + /// } + /// ``` + /// + /// Note, the `named_params!` macro is provided for syntactic convenience, + /// and so the above example could also be written as: + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result, named_params}; + /// fn query(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("SELECT * FROM test where name = :name")?; + /// let mut rows = stmt.query_named(named_params!{ ":name": "one" })?; + /// while let Some(row) = rows.next()? { + /// // ... + /// } + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if binding parameters fails. + pub fn query_named(&mut self, params: &[(&str, &dyn ToSql)]) -> Result<Rows<'_>> { + self.check_readonly()?; + self.bind_parameters_named(params)?; + Ok(Rows::new(self)) + } + + /// Executes the prepared statement and maps a function over the resulting + /// rows, returning an iterator over the mapped function results. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result, NO_PARAMS}; + /// fn get_names(conn: &Connection) -> Result<Vec<String>> { + /// let mut stmt = conn.prepare("SELECT name FROM people")?; + /// let rows = stmt.query_map(NO_PARAMS, |row| row.get(0))?; + /// + /// let mut names = Vec::new(); + /// for name_result in rows { + /// names.push(name_result?); + /// } + /// + /// Ok(names) + /// } + /// ``` + /// `f` is used to tranform the _streaming_ iterator into a _standard_ + /// iterator. + /// + /// ## Failure + /// + /// Will return `Err` if binding parameters fails. + pub fn query_map<T, P, F>(&mut self, params: P, f: F) -> Result<MappedRows<'_, F>> + where + P: IntoIterator, + P::Item: ToSql, + F: FnMut(&Row<'_>) -> Result<T>, + { + let rows = self.query(params)?; + Ok(MappedRows::new(rows, f)) + } + + /// Execute the prepared statement with named parameter(s), returning an + /// iterator over the result of calling the mapping function over the + /// query's rows. If any parameters that were in the prepared statement + /// are not included in `params`, they will continue to use the + /// most-recently bound value from a previous call to `query_named`, + /// or `NULL` if they have never been bound. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn get_names(conn: &Connection) -> Result<Vec<String>> { + /// let mut stmt = conn.prepare("SELECT name FROM people WHERE id = :id")?; + /// let rows = stmt.query_map_named(&[(":id", &"one")], |row| row.get(0))?; + /// + /// let mut names = Vec::new(); + /// for name_result in rows { + /// names.push(name_result?); + /// } + /// + /// Ok(names) + /// } + /// ``` + /// `f` is used to tranform the _streaming_ iterator into a _standard_ + /// iterator. + /// + /// ## Failure + /// + /// Will return `Err` if binding parameters fails. + pub fn query_map_named<T, F>( + &mut self, + params: &[(&str, &dyn ToSql)], + f: F, + ) -> Result<MappedRows<'_, F>> + where + F: FnMut(&Row<'_>) -> Result<T>, + { + let rows = self.query_named(params)?; + Ok(MappedRows::new(rows, f)) + } + + /// Executes the prepared statement and maps a function over the resulting + /// rows, where the function returns a `Result` with `Error` type + /// implementing `std::convert::From<Error>` (so errors can be unified). + /// + /// # Failure + /// + /// Will return `Err` if binding parameters fails. + pub fn query_and_then<T, E, P, F>(&mut self, params: P, f: F) -> Result<AndThenRows<'_, F>> + where + P: IntoIterator, + P::Item: ToSql, + E: convert::From<Error>, + F: FnMut(&Row<'_>) -> Result<T, E>, + { + let rows = self.query(params)?; + Ok(AndThenRows::new(rows, f)) + } + + /// Execute the prepared statement with named parameter(s), returning an + /// iterator over the result of calling the mapping function over the + /// query's rows. If any parameters that were in the prepared statement + /// are not included in + /// `params`, they will + /// continue to use the most-recently bound value from a previous call + /// to `query_named`, or `NULL` if they have never been bound. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// struct Person { + /// name: String, + /// }; + /// + /// fn name_to_person(name: String) -> Result<Person> { + /// // ... check for valid name + /// Ok(Person { name: name }) + /// } + /// + /// fn get_names(conn: &Connection) -> Result<Vec<Person>> { + /// let mut stmt = conn.prepare("SELECT name FROM people WHERE id = :id")?; + /// let rows = + /// stmt.query_and_then_named(&[(":id", &"one")], |row| name_to_person(row.get(0)?))?; + /// + /// let mut persons = Vec::new(); + /// for person_result in rows { + /// persons.push(person_result?); + /// } + /// + /// Ok(persons) + /// } + /// ``` + /// + /// ## Failure + /// + /// Will return `Err` if binding parameters fails. + pub fn query_and_then_named<T, E, F>( + &mut self, + params: &[(&str, &dyn ToSql)], + f: F, + ) -> Result<AndThenRows<'_, F>> + where + E: convert::From<Error>, + F: FnMut(&Row<'_>) -> Result<T, E>, + { + let rows = self.query_named(params)?; + Ok(AndThenRows::new(rows, f)) + } + + /// Return `true` if a query in the SQL statement it executes returns one + /// or more rows and `false` if the SQL returns an empty set. + pub fn exists<P>(&mut self, params: P) -> Result<bool> + where + P: IntoIterator, + P::Item: ToSql, + { + let mut rows = self.query(params)?; + let exists = rows.next()?.is_some(); + Ok(exists) + } + + /// Convenience method to execute a query that is expected to return a + /// single row. + /// + /// If the query returns more than one row, all rows except the first are + /// ignored. + /// + /// Returns `Err(QueryReturnedNoRows)` if no results are returned. If the + /// query truly is optional, you can call `.optional()` on the result of + /// this to get a `Result<Option<T>>`. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + pub fn query_row<T, P, F>(&mut self, params: P, f: F) -> Result<T> + where + P: IntoIterator, + P::Item: ToSql, + F: FnOnce(&Row<'_>) -> Result<T>, + { + let mut rows = self.query(params)?; + + rows.get_expected_row().and_then(|r| f(&r)) + } + + /// Convenience method to execute a query with named parameter(s) that is + /// expected to return a single row. + /// + /// If the query returns more than one row, all rows except the first are + /// ignored. + /// + /// Returns `Err(QueryReturnedNoRows)` if no results are returned. If the + /// query truly is optional, you can call `.optional()` on the result of + /// this to get a `Result<Option<T>>`. + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string + /// or if the underlying SQLite call fails. + pub fn query_row_named<T, F>(&mut self, params: &[(&str, &dyn ToSql)], f: F) -> Result<T> + where + F: FnOnce(&Row<'_>) -> Result<T>, + { + let mut rows = self.query_named(params)?; + + rows.get_expected_row().and_then(|r| f(&r)) + } + + /// Consumes the statement. + /// + /// Functionally equivalent to the `Drop` implementation, but allows + /// callers to see any errors that occur. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + pub fn finalize(mut self) -> Result<()> { + self.finalize_() + } + + /// Return the (one-based) index of an SQL parameter given its name. + /// + /// Note that the initial ":" or "$" or "@" or "?" used to specify the + /// parameter is included as part of the name. + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn example(conn: &Connection) -> Result<()> { + /// let stmt = conn.prepare("SELECT * FROM test WHERE name = :example")?; + /// let index = stmt.parameter_index(":example")?; + /// assert_eq!(index, Some(1)); + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return Err if `name` is invalid. Will return Ok(None) if the name + /// is valid but not a bound parameter of this statement. + pub fn parameter_index(&self, name: &str) -> Result<Option<usize>> { + Ok(self.stmt.bind_parameter_index(name)) + } + + fn bind_parameters<P>(&mut self, params: P) -> Result<()> + where + P: IntoIterator, + P::Item: ToSql, + { + let expected = self.stmt.bind_parameter_count(); + let mut index = 0; + for p in params.into_iter() { + index += 1; // The leftmost SQL parameter has an index of 1. + if index > expected { + break; + } + self.bind_parameter(&p, index)?; + } + if index != expected { + Err(Error::InvalidParameterCount(index, expected)) + } else { + Ok(()) + } + } + + fn bind_parameters_named(&mut self, params: &[(&str, &dyn ToSql)]) -> Result<()> { + for &(name, value) in params { + if let Some(i) = self.parameter_index(name)? { + self.bind_parameter(value, i)?; + } else { + return Err(Error::InvalidParameterName(name.into())); + } + } + Ok(()) + } + + /// Return the number of parameters that can be bound to this statement. + pub fn parameter_count(&self) -> usize { + self.stmt.bind_parameter_count() + } + + /// Low level API to directly bind a parameter to a given index. + /// + /// Note that the index is one-based, that is, the first parameter index is + /// 1 and not 0. This is consistent with the SQLite API and the values given + /// to parameters bound as `?NNN`. + /// + /// The valid values for `one_based_col_index` begin at `1`, and end at + /// [`Statement::parameter_count`], inclusive. + /// + /// # Caveats + /// + /// This should not generally be used, but is available for special cases + /// such as: + /// + /// - binding parameters where a gap exists. + /// - binding named and positional parameters in the same query. + /// - separating parameter binding from query execution. + /// + /// Statements that have had their parameters bound this way should be + /// queried or executed by [`Statement::raw_query`] or + /// [`Statement::raw_execute`]. Other functions are not guaranteed to work. + /// + /// # Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn query(conn: &Connection) -> Result<()> { + /// let mut stmt = conn.prepare("SELECT * FROM test WHERE name = :name AND value > ?2")?; + /// let name_index = stmt.parameter_index(":name")?.expect("No such parameter"); + /// stmt.raw_bind_parameter(name_index, "foo")?; + /// stmt.raw_bind_parameter(2, 100)?; + /// let mut rows = stmt.raw_query(); + /// while let Some(row) = rows.next()? { + /// // ... + /// } + /// Ok(()) + /// } + /// ``` + pub fn raw_bind_parameter<T: ToSql>( + &mut self, + one_based_col_index: usize, + param: T, + ) -> Result<()> { + // This is the same as `bind_parameter` but slightly more ergonomic and + // correctly takes `&mut self`. + self.bind_parameter(¶m, one_based_col_index) + } + + /// Low level API to execute a statement given that all parameters were + /// bound explicitly with the [`Statement::raw_bind_parameter`] API. + /// + /// # Caveats + /// + /// Any unbound parameters will have `NULL` as their value. + /// + /// This should not generally be used outside of special cases, and + /// functions in the [`Statement::execute`] family should be preferred. + /// + /// # Failure + /// + /// Will return `Err` if the executed statement returns rows (in which case + /// `query` should be used instead), or the underlying SQLite call fails. + pub fn raw_execute(&mut self) -> Result<usize> { + self.execute_with_bound_parameters() + } + + /// Low level API to get `Rows` for this query given that all parameters + /// were bound explicitly with the [`Statement::raw_bind_parameter`] API. + /// + /// # Caveats + /// + /// Any unbound parameters will have `NULL` as their value. + /// + /// This should not generally be used outside of special cases, and + /// functions in the [`Statement::query`] family should be preferred. + /// + /// Note that if the SQL does not return results, [`Statement::raw_execute`] + /// should be used instead. + pub fn raw_query(&mut self) -> Rows<'_> { + Rows::new(self) + } + + fn bind_parameter(&self, param: &dyn ToSql, col: usize) -> Result<()> { + let value = param.to_sql()?; + + let ptr = unsafe { self.stmt.ptr() }; + let value = match value { + ToSqlOutput::Borrowed(v) => v, + ToSqlOutput::Owned(ref v) => ValueRef::from(v), + + #[cfg(feature = "blob")] + ToSqlOutput::ZeroBlob(len) => { + return self + .conn + .decode_result(unsafe { ffi::sqlite3_bind_zeroblob(ptr, col as c_int, len) }); + } + #[cfg(feature = "array")] + ToSqlOutput::Array(a) => { + return self.conn.decode_result(unsafe { + ffi::sqlite3_bind_pointer( + ptr, + col as c_int, + Rc::into_raw(a) as *mut c_void, + ARRAY_TYPE, + Some(free_array), + ) + }); + } + }; + self.conn.decode_result(match value { + ValueRef::Null => unsafe { ffi::sqlite3_bind_null(ptr, col as c_int) }, + ValueRef::Integer(i) => unsafe { ffi::sqlite3_bind_int64(ptr, col as c_int, i) }, + ValueRef::Real(r) => unsafe { ffi::sqlite3_bind_double(ptr, col as c_int, r) }, + ValueRef::Text(s) => unsafe { + let (c_str, len, destructor) = str_for_sqlite(s)?; + ffi::sqlite3_bind_text(ptr, col as c_int, c_str, len, destructor) + }, + ValueRef::Blob(b) => unsafe { + let length = len_as_c_int(b.len())?; + if length == 0 { + ffi::sqlite3_bind_zeroblob(ptr, col as c_int, 0) + } else { + ffi::sqlite3_bind_blob( + ptr, + col as c_int, + b.as_ptr() as *const c_void, + length, + ffi::SQLITE_TRANSIENT(), + ) + } + }, + }) + } + + fn execute_with_bound_parameters(&mut self) -> Result<usize> { + self.check_update()?; + let r = self.stmt.step(); + self.stmt.reset(); + match r { + ffi::SQLITE_DONE => Ok(self.conn.changes()), + ffi::SQLITE_ROW => Err(Error::ExecuteReturnedResults), + _ => Err(self.conn.decode_result(r).unwrap_err()), + } + } + + fn finalize_(&mut self) -> Result<()> { + let mut stmt = unsafe { RawStatement::new(ptr::null_mut(), 0) }; + mem::swap(&mut stmt, &mut self.stmt); + self.conn.decode_result(stmt.finalize()) + } + + #[cfg(not(feature = "modern_sqlite"))] + #[inline] + fn check_readonly(&self) -> Result<()> { + Ok(()) + } + + #[cfg(feature = "modern_sqlite")] + #[inline] + fn check_readonly(&self) -> Result<()> { + /*if !self.stmt.readonly() { does not work for PRAGMA + return Err(Error::InvalidQuery); + }*/ + Ok(()) + } + + #[cfg(all(feature = "modern_sqlite", feature = "extra_check"))] + #[inline] + fn check_update(&self) -> Result<()> { + // sqlite3_column_count works for DML but not for DDL (ie ALTER) + if self.column_count() > 0 && self.stmt.readonly() { + return Err(Error::ExecuteReturnedResults); + } + Ok(()) + } + + #[cfg(all(not(feature = "modern_sqlite"), feature = "extra_check"))] + #[inline] + fn check_update(&self) -> Result<()> { + // sqlite3_column_count works for DML but not for DDL (ie ALTER) + if self.column_count() > 0 { + return Err(Error::ExecuteReturnedResults); + } + Ok(()) + } + + #[cfg(not(feature = "extra_check"))] + #[inline] + fn check_update(&self) -> Result<()> { + Ok(()) + } + + /// Returns a string containing the SQL text of prepared statement with + /// bound parameters expanded. + #[cfg(feature = "modern_sqlite")] + pub fn expanded_sql(&self) -> Option<String> { + self.stmt + .expanded_sql() + .map(|s| s.to_string_lossy().to_string()) + } + + /// Get the value for one of the status counters for this statement. + pub fn get_status(&self, status: StatementStatus) -> i32 { + self.stmt.get_status(status, false) + } + + /// Reset the value of one of the status counters for this statement, + /// returning the value it had before resetting. + pub fn reset_status(&self, status: StatementStatus) -> i32 { + self.stmt.get_status(status, true) + } + + #[cfg(feature = "extra_check")] + pub(crate) fn check_no_tail(&self) -> Result<()> { + if self.stmt.has_tail() { + Err(Error::MultipleStatement) + } else { + Ok(()) + } + } + + #[cfg(not(feature = "extra_check"))] + #[inline] + pub(crate) fn check_no_tail(&self) -> Result<()> { + Ok(()) + } + + /// Safety: This is unsafe, because using `sqlite3_stmt` after the + /// connection has closed is illegal, but `RawStatement` does not enforce + /// this, as it loses our protective `'conn` lifetime bound. + pub(crate) unsafe fn into_raw(mut self) -> RawStatement { + let mut stmt = RawStatement::new(ptr::null_mut(), 0); + mem::swap(&mut stmt, &mut self.stmt); + stmt + } +} + +impl fmt::Debug for Statement<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let sql = if self.stmt.is_null() { + Ok("") + } else { + str::from_utf8(self.stmt.sql().unwrap().to_bytes()) + }; + f.debug_struct("Statement") + .field("conn", self.conn) + .field("stmt", &self.stmt) + .field("sql", &sql) + .finish() + } +} + +impl Drop for Statement<'_> { + #[allow(unused_must_use)] + fn drop(&mut self) { + self.finalize_(); + } +} + +impl Statement<'_> { + pub(super) fn new(conn: &Connection, stmt: RawStatement) -> Statement<'_> { + Statement { conn, stmt } + } + + pub(super) fn value_ref(&self, col: usize) -> ValueRef<'_> { + let raw = unsafe { self.stmt.ptr() }; + + match self.stmt.column_type(col) { + ffi::SQLITE_NULL => ValueRef::Null, + ffi::SQLITE_INTEGER => { + ValueRef::Integer(unsafe { ffi::sqlite3_column_int64(raw, col as c_int) }) + } + ffi::SQLITE_FLOAT => { + ValueRef::Real(unsafe { ffi::sqlite3_column_double(raw, col as c_int) }) + } + ffi::SQLITE_TEXT => { + let s = unsafe { + // Quoting from "Using SQLite" book: + // To avoid problems, an application should first extract the desired type using + // a sqlite3_column_xxx() function, and then call the + // appropriate sqlite3_column_bytes() function. + let text = ffi::sqlite3_column_text(raw, col as c_int); + let len = ffi::sqlite3_column_bytes(raw, col as c_int); + assert!( + !text.is_null(), + "unexpected SQLITE_TEXT column type with NULL data" + ); + from_raw_parts(text as *const u8, len as usize) + }; + + ValueRef::Text(s) + } + ffi::SQLITE_BLOB => { + let (blob, len) = unsafe { + ( + ffi::sqlite3_column_blob(raw, col as c_int), + ffi::sqlite3_column_bytes(raw, col as c_int), + ) + }; + + assert!( + len >= 0, + "unexpected negative return from sqlite3_column_bytes" + ); + if len > 0 { + assert!( + !blob.is_null(), + "unexpected SQLITE_BLOB column type with NULL data" + ); + ValueRef::Blob(unsafe { from_raw_parts(blob as *const u8, len as usize) }) + } else { + // The return value from sqlite3_column_blob() for a zero-length BLOB + // is a NULL pointer. + ValueRef::Blob(&[]) + } + } + _ => unreachable!("sqlite3_column_type returned invalid value"), + } + } + + pub(super) fn step(&self) -> Result<bool> { + match self.stmt.step() { + ffi::SQLITE_ROW => Ok(true), + ffi::SQLITE_DONE => Ok(false), + code => Err(self.conn.decode_result(code).unwrap_err()), + } + } + + pub(super) fn reset(&self) -> c_int { + self.stmt.reset() + } +} + +/// Prepared statement status counters. +/// +/// See https://www.sqlite.org/c3ref/c_stmtstatus_counter.html +/// for explanations of each. +/// +/// Note that depending on your version of SQLite, all of these +/// may not be available. +#[repr(i32)] +#[derive(Clone, Copy, PartialEq, Eq)] +#[non_exhaustive] +pub enum StatementStatus { + /// Equivalent to SQLITE_STMTSTATUS_FULLSCAN_STEP + FullscanStep = 1, + /// Equivalent to SQLITE_STMTSTATUS_SORT + Sort = 2, + /// Equivalent to SQLITE_STMTSTATUS_AUTOINDEX + AutoIndex = 3, + /// Equivalent to SQLITE_STMTSTATUS_VM_STEP + VmStep = 4, + /// Equivalent to SQLITE_STMTSTATUS_REPREPARE + RePrepare = 5, + /// Equivalent to SQLITE_STMTSTATUS_RUN + Run = 6, + /// Equivalent to SQLITE_STMTSTATUS_MEMUSED + MemUsed = 99, +} + +#[cfg(test)] +mod test { + use crate::types::ToSql; + use crate::{Connection, Error, Result, NO_PARAMS}; + + #[test] + fn test_execute_named() { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE foo(x INTEGER)").unwrap(); + + assert_eq!( + db.execute_named("INSERT INTO foo(x) VALUES (:x)", &[(":x", &1i32)]) + .unwrap(), + 1 + ); + assert_eq!( + db.execute_named("INSERT INTO foo(x) VALUES (:x)", &[(":x", &2i32)]) + .unwrap(), + 1 + ); + + assert_eq!( + 3i32, + db.query_row_named::<i32, _>( + "SELECT SUM(x) FROM foo WHERE x > :x", + &[(":x", &0i32)], + |r| r.get(0) + ) + .unwrap() + ); + } + + #[test] + fn test_stmt_execute_named() { + let db = Connection::open_in_memory().unwrap(); + let sql = "CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag \ + INTEGER)"; + db.execute_batch(sql).unwrap(); + + let mut stmt = db + .prepare("INSERT INTO test (name) VALUES (:name)") + .unwrap(); + stmt.execute_named(&[(":name", &"one")]).unwrap(); + + let mut stmt = db + .prepare("SELECT COUNT(*) FROM test WHERE name = :name") + .unwrap(); + assert_eq!( + 1i32, + stmt.query_row_named::<i32, _>(&[(":name", &"one")], |r| r.get(0)) + .unwrap() + ); + } + + #[test] + fn test_query_named() { + let db = Connection::open_in_memory().unwrap(); + let sql = r#" + CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag INTEGER); + INSERT INTO test(id, name) VALUES (1, "one"); + "#; + db.execute_batch(sql).unwrap(); + + let mut stmt = db + .prepare("SELECT id FROM test where name = :name") + .unwrap(); + let mut rows = stmt.query_named(&[(":name", &"one")]).unwrap(); + + let id: Result<i32> = rows.next().unwrap().unwrap().get(0); + assert_eq!(Ok(1), id); + } + + #[test] + fn test_query_map_named() { + let db = Connection::open_in_memory().unwrap(); + let sql = r#" + CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag INTEGER); + INSERT INTO test(id, name) VALUES (1, "one"); + "#; + db.execute_batch(sql).unwrap(); + + let mut stmt = db + .prepare("SELECT id FROM test where name = :name") + .unwrap(); + let mut rows = stmt + .query_map_named(&[(":name", &"one")], |row| { + let id: Result<i32> = row.get(0); + id.map(|i| 2 * i) + }) + .unwrap(); + + let doubled_id: i32 = rows.next().unwrap().unwrap(); + assert_eq!(2, doubled_id); + } + + #[test] + fn test_query_and_then_named() { + let db = Connection::open_in_memory().unwrap(); + let sql = r#" + CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag INTEGER); + INSERT INTO test(id, name) VALUES (1, "one"); + INSERT INTO test(id, name) VALUES (2, "one"); + "#; + db.execute_batch(sql).unwrap(); + + let mut stmt = db + .prepare("SELECT id FROM test where name = :name ORDER BY id ASC") + .unwrap(); + let mut rows = stmt + .query_and_then_named(&[(":name", &"one")], |row| { + let id: i32 = row.get(0)?; + if id == 1 { + Ok(id) + } else { + Err(Error::SqliteSingleThreadedMode) + } + }) + .unwrap(); + + // first row should be Ok + let doubled_id: i32 = rows.next().unwrap().unwrap(); + assert_eq!(1, doubled_id); + + // second row should be Err + #[allow(clippy::match_wild_err_arm)] + match rows.next().unwrap() { + Ok(_) => panic!("invalid Ok"), + Err(Error::SqliteSingleThreadedMode) => (), + Err(_) => panic!("invalid Err"), + } + } + + #[test] + fn test_unbound_parameters_are_null() { + let db = Connection::open_in_memory().unwrap(); + let sql = "CREATE TABLE test (x TEXT, y TEXT)"; + db.execute_batch(sql).unwrap(); + + let mut stmt = db + .prepare("INSERT INTO test (x, y) VALUES (:x, :y)") + .unwrap(); + stmt.execute_named(&[(":x", &"one")]).unwrap(); + + let result: Option<String> = db + .query_row("SELECT y FROM test WHERE x = 'one'", NO_PARAMS, |row| { + row.get(0) + }) + .unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_raw_binding() -> Result<()> { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE test (name TEXT, value INTEGER)")?; + { + let mut stmt = db.prepare("INSERT INTO test (name, value) VALUES (:name, ?3)")?; + + let name_idx = stmt.parameter_index(":name")?.unwrap(); + stmt.raw_bind_parameter(name_idx, "example")?; + stmt.raw_bind_parameter(3, 50i32)?; + let n = stmt.raw_execute()?; + assert_eq!(n, 1); + } + + { + let mut stmt = db.prepare("SELECT name, value FROM test WHERE value = ?2")?; + stmt.raw_bind_parameter(2, 50)?; + let mut rows = stmt.raw_query(); + { + let row = rows.next()?.unwrap(); + let name: String = row.get(0)?; + assert_eq!(name, "example"); + let value: i32 = row.get(1)?; + assert_eq!(value, 50); + } + assert!(rows.next()?.is_none()); + } + + Ok(()) + } + + #[test] + fn test_unbound_parameters_are_reused() { + let db = Connection::open_in_memory().unwrap(); + let sql = "CREATE TABLE test (x TEXT, y TEXT)"; + db.execute_batch(sql).unwrap(); + + let mut stmt = db + .prepare("INSERT INTO test (x, y) VALUES (:x, :y)") + .unwrap(); + stmt.execute_named(&[(":x", &"one")]).unwrap(); + stmt.execute_named(&[(":y", &"two")]).unwrap(); + + let result: String = db + .query_row("SELECT x FROM test WHERE y = 'two'", NO_PARAMS, |row| { + row.get(0) + }) + .unwrap(); + assert_eq!(result, "one"); + } + + #[test] + fn test_insert() { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE foo(x INTEGER UNIQUE)") + .unwrap(); + let mut stmt = db + .prepare("INSERT OR IGNORE INTO foo (x) VALUES (?)") + .unwrap(); + assert_eq!(stmt.insert(&[1i32]).unwrap(), 1); + assert_eq!(stmt.insert(&[2i32]).unwrap(), 2); + match stmt.insert(&[1i32]).unwrap_err() { + Error::StatementChangedRows(0) => (), + err => panic!("Unexpected error {}", err), + } + let mut multi = db + .prepare("INSERT INTO foo (x) SELECT 3 UNION ALL SELECT 4") + .unwrap(); + match multi.insert(NO_PARAMS).unwrap_err() { + Error::StatementChangedRows(2) => (), + err => panic!("Unexpected error {}", err), + } + } + + #[test] + fn test_insert_different_tables() { + // Test for https://github.com/rusqlite/rusqlite/issues/171 + let db = Connection::open_in_memory().unwrap(); + db.execute_batch( + r" + CREATE TABLE foo(x INTEGER); + CREATE TABLE bar(x INTEGER); + ", + ) + .unwrap(); + + assert_eq!( + db.prepare("INSERT INTO foo VALUES (10)") + .unwrap() + .insert(NO_PARAMS) + .unwrap(), + 1 + ); + assert_eq!( + db.prepare("INSERT INTO bar VALUES (10)") + .unwrap() + .insert(NO_PARAMS) + .unwrap(), + 1 + ); + } + + #[test] + fn test_exists() { + let db = Connection::open_in_memory().unwrap(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER); + INSERT INTO foo VALUES(1); + INSERT INTO foo VALUES(2); + END;"; + db.execute_batch(sql).unwrap(); + let mut stmt = db.prepare("SELECT 1 FROM foo WHERE x = ?").unwrap(); + assert!(stmt.exists(&[1i32]).unwrap()); + assert!(stmt.exists(&[2i32]).unwrap()); + assert!(!stmt.exists(&[0i32]).unwrap()); + } + + #[test] + fn test_query_row() { + let db = Connection::open_in_memory().unwrap(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y INTEGER); + INSERT INTO foo VALUES(1, 3); + INSERT INTO foo VALUES(2, 4); + END;"; + db.execute_batch(sql).unwrap(); + let mut stmt = db.prepare("SELECT y FROM foo WHERE x = ?").unwrap(); + let y: Result<i64> = stmt.query_row(&[1i32], |r| r.get(0)); + assert_eq!(3i64, y.unwrap()); + } + + #[test] + fn test_query_by_column_name() { + let db = Connection::open_in_memory().unwrap(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y INTEGER); + INSERT INTO foo VALUES(1, 3); + END;"; + db.execute_batch(sql).unwrap(); + let mut stmt = db.prepare("SELECT y FROM foo").unwrap(); + let y: Result<i64> = stmt.query_row(NO_PARAMS, |r| r.get("y")); + assert_eq!(3i64, y.unwrap()); + } + + #[test] + fn test_query_by_column_name_ignore_case() { + let db = Connection::open_in_memory().unwrap(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y INTEGER); + INSERT INTO foo VALUES(1, 3); + END;"; + db.execute_batch(sql).unwrap(); + let mut stmt = db.prepare("SELECT y as Y FROM foo").unwrap(); + let y: Result<i64> = stmt.query_row(NO_PARAMS, |r| r.get("y")); + assert_eq!(3i64, y.unwrap()); + } + + #[test] + #[cfg(feature = "modern_sqlite")] + fn test_expanded_sql() { + let db = Connection::open_in_memory().unwrap(); + let stmt = db.prepare("SELECT ?").unwrap(); + stmt.bind_parameter(&1, 1).unwrap(); + assert_eq!(Some("SELECT 1".to_owned()), stmt.expanded_sql()); + } + + #[test] + fn test_bind_parameters() { + let db = Connection::open_in_memory().unwrap(); + // dynamic slice: + db.query_row( + "SELECT ?1, ?2, ?3", + &[&1u8 as &dyn ToSql, &"one", &Some("one")], + |row| row.get::<_, u8>(0), + ) + .unwrap(); + // existing collection: + let data = vec![1, 2, 3]; + db.query_row("SELECT ?1, ?2, ?3", &data, |row| row.get::<_, u8>(0)) + .unwrap(); + db.query_row("SELECT ?1, ?2, ?3", data.as_slice(), |row| { + row.get::<_, u8>(0) + }) + .unwrap(); + db.query_row("SELECT ?1, ?2, ?3", data, |row| row.get::<_, u8>(0)) + .unwrap(); + + use std::collections::BTreeSet; + let data: BTreeSet<String> = ["one", "two", "three"] + .iter() + .map(|s| (*s).to_string()) + .collect(); + db.query_row("SELECT ?1, ?2, ?3", &data, |row| row.get::<_, String>(0)) + .unwrap(); + + let data = [0; 3]; + db.query_row("SELECT ?1, ?2, ?3", &data, |row| row.get::<_, u8>(0)) + .unwrap(); + db.query_row("SELECT ?1, ?2, ?3", data.iter(), |row| row.get::<_, u8>(0)) + .unwrap(); + } + + #[test] + fn test_empty_stmt() { + let conn = Connection::open_in_memory().unwrap(); + let mut stmt = conn.prepare("").unwrap(); + assert_eq!(0, stmt.column_count()); + assert!(stmt.parameter_index("test").is_ok()); + assert!(stmt.step().is_err()); + stmt.reset(); + assert!(stmt.execute(NO_PARAMS).is_err()); + } + + #[test] + fn test_comment_stmt() { + let conn = Connection::open_in_memory().unwrap(); + conn.prepare("/*SELECT 1;*/").unwrap(); + } + + #[test] + fn test_comment_and_sql_stmt() { + let conn = Connection::open_in_memory().unwrap(); + let stmt = conn.prepare("/*...*/ SELECT 1;").unwrap(); + assert_eq!(1, stmt.column_count()); + } + + #[test] + fn test_semi_colon_stmt() { + let conn = Connection::open_in_memory().unwrap(); + let stmt = conn.prepare(";").unwrap(); + assert_eq!(0, stmt.column_count()); + } + + #[test] + fn test_utf16_conversion() { + let db = Connection::open_in_memory().unwrap(); + db.pragma_update(None, "encoding", &"UTF-16le").unwrap(); + let encoding: String = db + .pragma_query_value(None, "encoding", |row| row.get(0)) + .unwrap(); + assert_eq!("UTF-16le", encoding); + db.execute_batch("CREATE TABLE foo(x TEXT)").unwrap(); + let expected = "テスト"; + db.execute("INSERT INTO foo(x) VALUES (?)", &[&expected]) + .unwrap(); + let actual: String = db + .query_row("SELECT x FROM foo", NO_PARAMS, |row| row.get(0)) + .unwrap(); + assert_eq!(expected, actual); + } + + #[test] + fn test_nul_byte() { + let db = Connection::open_in_memory().unwrap(); + let expected = "a\x00b"; + let actual: String = db + .query_row("SELECT ?", &[&expected], |row| row.get(0)) + .unwrap(); + assert_eq!(expected, actual); + } +} diff --git a/third_party/rust/rusqlite/src/trace.rs b/third_party/rust/rusqlite/src/trace.rs new file mode 100644 index 0000000000..76e0969c4f --- /dev/null +++ b/third_party/rust/rusqlite/src/trace.rs @@ -0,0 +1,180 @@ +//! `feature = "trace"` Tracing and profiling functions. Error and warning log. + +use std::ffi::{CStr, CString}; +use std::mem; +use std::os::raw::{c_char, c_int, c_void}; +use std::panic::catch_unwind; +use std::ptr; +use std::time::Duration; + +use super::ffi; +use crate::error::error_from_sqlite_code; +use crate::{Connection, Result}; + +/// `feature = "trace"` Set up the process-wide SQLite error logging callback. +/// +/// # Safety +/// +/// This function is marked unsafe for two reasons: +/// +/// * The function is not threadsafe. No other SQLite calls may be made while +/// `config_log` is running, and multiple threads may not call `config_log` +/// simultaneously. +/// * The provided `callback` itself function has two requirements: +/// * It must not invoke any SQLite calls. +/// * It must be threadsafe if SQLite is used in a multithreaded way. +/// +/// cf [The Error And Warning Log](http://sqlite.org/errlog.html). +pub unsafe fn config_log(callback: Option<fn(c_int, &str)>) -> Result<()> { + extern "C" fn log_callback(p_arg: *mut c_void, err: c_int, msg: *const c_char) { + let c_slice = unsafe { CStr::from_ptr(msg).to_bytes() }; + let callback: fn(c_int, &str) = unsafe { mem::transmute(p_arg) }; + + let s = String::from_utf8_lossy(c_slice); + let _ = catch_unwind(|| callback(err, &s)); + } + + let rc = match callback { + Some(f) => ffi::sqlite3_config( + ffi::SQLITE_CONFIG_LOG, + log_callback as extern "C" fn(_, _, _), + f as *mut c_void, + ), + None => { + let nullptr: *mut c_void = ptr::null_mut(); + ffi::sqlite3_config(ffi::SQLITE_CONFIG_LOG, nullptr, nullptr) + } + }; + + if rc == ffi::SQLITE_OK { + Ok(()) + } else { + Err(error_from_sqlite_code(rc, None)) + } +} + +/// `feature = "trace"` Write a message into the error log established by +/// `config_log`. +pub fn log(err_code: c_int, msg: &str) { + let msg = CString::new(msg).expect("SQLite log messages cannot contain embedded zeroes"); + unsafe { + ffi::sqlite3_log(err_code, b"%s\0" as *const _ as *const c_char, msg.as_ptr()); + } +} + +impl Connection { + /// `feature = "trace"` Register or clear a callback function that can be + /// used for tracing the execution of SQL statements. + /// + /// Prepared statement placeholders are replaced/logged with their assigned + /// values. There can only be a single tracer defined for each database + /// connection. Setting a new tracer clears the old one. + pub fn trace(&mut self, trace_fn: Option<fn(&str)>) { + unsafe extern "C" fn trace_callback(p_arg: *mut c_void, z_sql: *const c_char) { + let trace_fn: fn(&str) = mem::transmute(p_arg); + let c_slice = CStr::from_ptr(z_sql).to_bytes(); + let s = String::from_utf8_lossy(c_slice); + let _ = catch_unwind(|| trace_fn(&s)); + } + + let c = self.db.borrow_mut(); + match trace_fn { + Some(f) => unsafe { + ffi::sqlite3_trace(c.db(), Some(trace_callback), f as *mut c_void); + }, + None => unsafe { + ffi::sqlite3_trace(c.db(), None, ptr::null_mut()); + }, + } + } + + /// `feature = "trace"` Register or clear a callback function that can be + /// used for profiling the execution of SQL statements. + /// + /// There can only be a single profiler defined for each database + /// connection. Setting a new profiler clears the old one. + pub fn profile(&mut self, profile_fn: Option<fn(&str, Duration)>) { + unsafe extern "C" fn profile_callback( + p_arg: *mut c_void, + z_sql: *const c_char, + nanoseconds: u64, + ) { + let profile_fn: fn(&str, Duration) = mem::transmute(p_arg); + let c_slice = CStr::from_ptr(z_sql).to_bytes(); + let s = String::from_utf8_lossy(c_slice); + const NANOS_PER_SEC: u64 = 1_000_000_000; + + let duration = Duration::new( + nanoseconds / NANOS_PER_SEC, + (nanoseconds % NANOS_PER_SEC) as u32, + ); + let _ = catch_unwind(|| profile_fn(&s, duration)); + } + + let c = self.db.borrow_mut(); + match profile_fn { + Some(f) => unsafe { + ffi::sqlite3_profile(c.db(), Some(profile_callback), f as *mut c_void) + }, + None => unsafe { ffi::sqlite3_profile(c.db(), None, ptr::null_mut()) }, + }; + } +} + +#[cfg(test)] +mod test { + use lazy_static::lazy_static; + use std::sync::Mutex; + use std::time::Duration; + + use crate::Connection; + + #[test] + fn test_trace() { + lazy_static! { + static ref TRACED_STMTS: Mutex<Vec<String>> = Mutex::new(Vec::new()); + } + fn tracer(s: &str) { + let mut traced_stmts = TRACED_STMTS.lock().unwrap(); + traced_stmts.push(s.to_owned()); + } + + let mut db = Connection::open_in_memory().unwrap(); + db.trace(Some(tracer)); + { + let _ = db.query_row("SELECT ?", &[1i32], |_| Ok(())); + let _ = db.query_row("SELECT ?", &["hello"], |_| Ok(())); + } + db.trace(None); + { + let _ = db.query_row("SELECT ?", &[2i32], |_| Ok(())); + let _ = db.query_row("SELECT ?", &["goodbye"], |_| Ok(())); + } + + let traced_stmts = TRACED_STMTS.lock().unwrap(); + assert_eq!(traced_stmts.len(), 2); + assert_eq!(traced_stmts[0], "SELECT 1"); + assert_eq!(traced_stmts[1], "SELECT 'hello'"); + } + + #[test] + fn test_profile() { + lazy_static! { + static ref PROFILED: Mutex<Vec<(String, Duration)>> = Mutex::new(Vec::new()); + } + fn profiler(s: &str, d: Duration) { + let mut profiled = PROFILED.lock().unwrap(); + profiled.push((s.to_owned(), d)); + } + + let mut db = Connection::open_in_memory().unwrap(); + db.profile(Some(profiler)); + db.execute_batch("PRAGMA application_id = 1").unwrap(); + db.profile(None); + db.execute_batch("PRAGMA application_id = 2").unwrap(); + + let profiled = PROFILED.lock().unwrap(); + assert_eq!(profiled.len(), 1); + assert_eq!(profiled[0].0, "PRAGMA application_id = 1"); + } +} diff --git a/third_party/rust/rusqlite/src/transaction.rs b/third_party/rust/rusqlite/src/transaction.rs new file mode 100644 index 0000000000..5e649b755e --- /dev/null +++ b/third_party/rust/rusqlite/src/transaction.rs @@ -0,0 +1,673 @@ +use crate::{Connection, Result}; +use std::ops::Deref; + +/// Options for transaction behavior. See [BEGIN +/// TRANSACTION](http://www.sqlite.org/lang_transaction.html) for details. +#[derive(Copy, Clone)] +#[non_exhaustive] +pub enum TransactionBehavior { + /// DEFERRED means that the transaction does not actually start until the + /// database is first accessed. + Deferred, + /// IMMEDIATE cause the database connection to start a new write + /// immediately, without waiting for a writes statement. + Immediate, + /// EXCLUSIVE prevents other database connections from reading the database + /// while the transaction is underway. + Exclusive, +} + +/// Options for how a Transaction or Savepoint should behave when it is dropped. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum DropBehavior { + /// Roll back the changes. This is the default. + Rollback, + + /// Commit the changes. + Commit, + + /// Do not commit or roll back changes - this will leave the transaction or + /// savepoint open, so should be used with care. + Ignore, + + /// Panic. Used to enforce intentional behavior during development. + Panic, +} + +/// Represents a transaction on a database connection. +/// +/// ## Note +/// +/// Transactions will roll back by default. Use `commit` method to explicitly +/// commit the transaction, or use `set_drop_behavior` to change what happens +/// when the transaction is dropped. +/// +/// ## Example +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result}; +/// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) } +/// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) } +/// fn perform_queries(conn: &mut Connection) -> Result<()> { +/// let tx = conn.transaction()?; +/// +/// do_queries_part_1(&tx)?; // tx causes rollback if this fails +/// do_queries_part_2(&tx)?; // tx causes rollback if this fails +/// +/// tx.commit() +/// } +/// ``` +#[derive(Debug)] +pub struct Transaction<'conn> { + conn: &'conn Connection, + drop_behavior: DropBehavior, +} + +/// Represents a savepoint on a database connection. +/// +/// ## Note +/// +/// Savepoints will roll back by default. Use `commit` method to explicitly +/// commit the savepoint, or use `set_drop_behavior` to change what happens +/// when the savepoint is dropped. +/// +/// ## Example +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result}; +/// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) } +/// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) } +/// fn perform_queries(conn: &mut Connection) -> Result<()> { +/// let sp = conn.savepoint()?; +/// +/// do_queries_part_1(&sp)?; // sp causes rollback if this fails +/// do_queries_part_2(&sp)?; // sp causes rollback if this fails +/// +/// sp.commit() +/// } +/// ``` +pub struct Savepoint<'conn> { + conn: &'conn Connection, + name: String, + depth: u32, + drop_behavior: DropBehavior, + committed: bool, +} + +impl Transaction<'_> { + /// Begin a new transaction. Cannot be nested; see `savepoint` for nested + /// transactions. + /// + /// Even though we don't mutate the connection, we take a `&mut Connection` + /// so as to prevent nested transactions on the same connection. For cases + /// where this is unacceptable, [`Transaction::new_unchecked`] is available. + pub fn new(conn: &mut Connection, behavior: TransactionBehavior) -> Result<Transaction<'_>> { + Self::new_unchecked(conn, behavior) + } + + /// Begin a new transaction, failing if a transaction is open. + /// + /// If a transaction is already open, this will return an error. Where + /// possible, [`Transaction::new`] should be preferred, as it provides a + /// compile-time guarantee that transactions are not nested. + pub fn new_unchecked( + conn: &Connection, + behavior: TransactionBehavior, + ) -> Result<Transaction<'_>> { + let query = match behavior { + TransactionBehavior::Deferred => "BEGIN DEFERRED", + TransactionBehavior::Immediate => "BEGIN IMMEDIATE", + TransactionBehavior::Exclusive => "BEGIN EXCLUSIVE", + }; + conn.execute_batch(query).map(move |_| Transaction { + conn, + drop_behavior: DropBehavior::Rollback, + }) + } + + /// Starts a new [savepoint](http://www.sqlite.org/lang_savepoint.html), allowing nested + /// transactions. + /// + /// ## Note + /// + /// Just like outer level transactions, savepoint transactions rollback by + /// default. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// # fn perform_queries_part_1_succeeds(_conn: &Connection) -> bool { true } + /// fn perform_queries(conn: &mut Connection) -> Result<()> { + /// let mut tx = conn.transaction()?; + /// + /// { + /// let sp = tx.savepoint()?; + /// if perform_queries_part_1_succeeds(&sp) { + /// sp.commit()?; + /// } + /// // otherwise, sp will rollback + /// } + /// + /// tx.commit() + /// } + /// ``` + pub fn savepoint(&mut self) -> Result<Savepoint<'_>> { + Savepoint::with_depth(self.conn, 1) + } + + /// Create a new savepoint with a custom savepoint name. See `savepoint()`. + pub fn savepoint_with_name<T: Into<String>>(&mut self, name: T) -> Result<Savepoint<'_>> { + Savepoint::with_depth_and_name(self.conn, 1, name) + } + + /// Get the current setting for what happens to the transaction when it is + /// dropped. + pub fn drop_behavior(&self) -> DropBehavior { + self.drop_behavior + } + + /// Configure the transaction to perform the specified action when it is + /// dropped. + pub fn set_drop_behavior(&mut self, drop_behavior: DropBehavior) { + self.drop_behavior = drop_behavior + } + + /// A convenience method which consumes and commits a transaction. + pub fn commit(mut self) -> Result<()> { + self.commit_() + } + + fn commit_(&mut self) -> Result<()> { + self.conn.execute_batch("COMMIT")?; + Ok(()) + } + + /// A convenience method which consumes and rolls back a transaction. + pub fn rollback(mut self) -> Result<()> { + self.rollback_() + } + + fn rollback_(&mut self) -> Result<()> { + self.conn.execute_batch("ROLLBACK")?; + Ok(()) + } + + /// Consumes the transaction, committing or rolling back according to the + /// current setting (see `drop_behavior`). + /// + /// Functionally equivalent to the `Drop` implementation, but allows + /// callers to see any errors that occur. + pub fn finish(mut self) -> Result<()> { + self.finish_() + } + + fn finish_(&mut self) -> Result<()> { + if self.conn.is_autocommit() { + return Ok(()); + } + match self.drop_behavior() { + DropBehavior::Commit => self.commit_().or_else(|_| self.rollback_()), + DropBehavior::Rollback => self.rollback_(), + DropBehavior::Ignore => Ok(()), + DropBehavior::Panic => panic!("Transaction dropped unexpectedly."), + } + } +} + +impl Deref for Transaction<'_> { + type Target = Connection; + + fn deref(&self) -> &Connection { + self.conn + } +} + +#[allow(unused_must_use)] +impl Drop for Transaction<'_> { + fn drop(&mut self) { + self.finish_(); + } +} + +impl Savepoint<'_> { + fn with_depth_and_name<T: Into<String>>( + conn: &Connection, + depth: u32, + name: T, + ) -> Result<Savepoint<'_>> { + let name = name.into(); + conn.execute_batch(&format!("SAVEPOINT {}", name)) + .map(|_| Savepoint { + conn, + name, + depth, + drop_behavior: DropBehavior::Rollback, + committed: false, + }) + } + + fn with_depth(conn: &Connection, depth: u32) -> Result<Savepoint<'_>> { + let name = format!("_rusqlite_sp_{}", depth); + Savepoint::with_depth_and_name(conn, depth, name) + } + + /// Begin a new savepoint. Can be nested. + pub fn new(conn: &mut Connection) -> Result<Savepoint<'_>> { + Savepoint::with_depth(conn, 0) + } + + /// Begin a new savepoint with a user-provided savepoint name. + pub fn with_name<T: Into<String>>(conn: &mut Connection, name: T) -> Result<Savepoint<'_>> { + Savepoint::with_depth_and_name(conn, 0, name) + } + + /// Begin a nested savepoint. + pub fn savepoint(&mut self) -> Result<Savepoint<'_>> { + Savepoint::with_depth(self.conn, self.depth + 1) + } + + /// Begin a nested savepoint with a user-provided savepoint name. + pub fn savepoint_with_name<T: Into<String>>(&mut self, name: T) -> Result<Savepoint<'_>> { + Savepoint::with_depth_and_name(self.conn, self.depth + 1, name) + } + + /// Get the current setting for what happens to the savepoint when it is + /// dropped. + pub fn drop_behavior(&self) -> DropBehavior { + self.drop_behavior + } + + /// Configure the savepoint to perform the specified action when it is + /// dropped. + pub fn set_drop_behavior(&mut self, drop_behavior: DropBehavior) { + self.drop_behavior = drop_behavior + } + + /// A convenience method which consumes and commits a savepoint. + pub fn commit(mut self) -> Result<()> { + self.commit_() + } + + fn commit_(&mut self) -> Result<()> { + self.conn.execute_batch(&format!("RELEASE {}", self.name))?; + self.committed = true; + Ok(()) + } + + /// A convenience method which rolls back a savepoint. + /// + /// ## Note + /// + /// Unlike `Transaction`s, savepoints remain active after they have been + /// rolled back, and can be rolled back again or committed. + pub fn rollback(&mut self) -> Result<()> { + self.conn + .execute_batch(&format!("ROLLBACK TO {}", self.name)) + } + + /// Consumes the savepoint, committing or rolling back according to the + /// current setting (see `drop_behavior`). + /// + /// Functionally equivalent to the `Drop` implementation, but allows + /// callers to see any errors that occur. + pub fn finish(mut self) -> Result<()> { + self.finish_() + } + + fn finish_(&mut self) -> Result<()> { + if self.committed { + return Ok(()); + } + match self.drop_behavior() { + DropBehavior::Commit => self.commit_().or_else(|_| self.rollback()), + DropBehavior::Rollback => self.rollback(), + DropBehavior::Ignore => Ok(()), + DropBehavior::Panic => panic!("Savepoint dropped unexpectedly."), + } + } +} + +impl Deref for Savepoint<'_> { + type Target = Connection; + + fn deref(&self) -> &Connection { + self.conn + } +} + +#[allow(unused_must_use)] +impl Drop for Savepoint<'_> { + fn drop(&mut self) { + self.finish_(); + } +} + +impl Connection { + /// Begin a new transaction with the default behavior (DEFERRED). + /// + /// The transaction defaults to rolling back when it is dropped. If you + /// want the transaction to commit, you must call `commit` or + /// `set_drop_behavior(DropBehavior::Commit)`. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) } + /// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) } + /// fn perform_queries(conn: &mut Connection) -> Result<()> { + /// let tx = conn.transaction()?; + /// + /// do_queries_part_1(&tx)?; // tx causes rollback if this fails + /// do_queries_part_2(&tx)?; // tx causes rollback if this fails + /// + /// tx.commit() + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + pub fn transaction(&mut self) -> Result<Transaction<'_>> { + Transaction::new(self, TransactionBehavior::Deferred) + } + + /// Begin a new transaction with a specified behavior. + /// + /// See `transaction`. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + pub fn transaction_with_behavior( + &mut self, + behavior: TransactionBehavior, + ) -> Result<Transaction<'_>> { + Transaction::new(self, behavior) + } + + /// Begin a new transaction with the default behavior (DEFERRED). + /// + /// Attempt to open a nested transaction will result in a SQLite error. + /// `Connection::transaction` prevents this at compile time by taking `&mut + /// self`, but `Connection::unchecked_transaction()` may be used to defer + /// the checking until runtime. + /// + /// See [`Connection::transaction`] and [`Transaction::new_unchecked`] + /// (which can be used if the default transaction behavior is undesirable). + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// # use std::rc::Rc; + /// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) } + /// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) } + /// fn perform_queries(conn: Rc<Connection>) -> Result<()> { + /// let tx = conn.unchecked_transaction()?; + /// + /// do_queries_part_1(&tx)?; // tx causes rollback if this fails + /// do_queries_part_2(&tx)?; // tx causes rollback if this fails + /// + /// tx.commit() + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. The specific + /// error returned if transactions are nested is currently unspecified. + pub fn unchecked_transaction(&self) -> Result<Transaction<'_>> { + Transaction::new_unchecked(self, TransactionBehavior::Deferred) + } + + /// Begin a new savepoint with the default behavior (DEFERRED). + /// + /// The savepoint defaults to rolling back when it is dropped. If you want + /// the savepoint to commit, you must call `commit` or + /// `set_drop_behavior(DropBehavior::Commit)`. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) } + /// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) } + /// fn perform_queries(conn: &mut Connection) -> Result<()> { + /// let sp = conn.savepoint()?; + /// + /// do_queries_part_1(&sp)?; // sp causes rollback if this fails + /// do_queries_part_2(&sp)?; // sp causes rollback if this fails + /// + /// sp.commit() + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + pub fn savepoint(&mut self) -> Result<Savepoint<'_>> { + Savepoint::new(self) + } + + /// Begin a new savepoint with a specified name. + /// + /// See `savepoint`. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + pub fn savepoint_with_name<T: Into<String>>(&mut self, name: T) -> Result<Savepoint<'_>> { + Savepoint::with_name(self, name) + } +} + +#[cfg(test)] +mod test { + use super::DropBehavior; + use crate::{Connection, Error, NO_PARAMS}; + + fn checked_memory_handle() -> Connection { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE foo (x INTEGER)").unwrap(); + db + } + + #[test] + fn test_drop() { + let mut db = checked_memory_handle(); + { + let tx = db.transaction().unwrap(); + tx.execute_batch("INSERT INTO foo VALUES(1)").unwrap(); + // default: rollback + } + { + let mut tx = db.transaction().unwrap(); + tx.execute_batch("INSERT INTO foo VALUES(2)").unwrap(); + tx.set_drop_behavior(DropBehavior::Commit) + } + { + let tx = db.transaction().unwrap(); + assert_eq!( + 2i32, + tx.query_row::<i32, _, _>("SELECT SUM(x) FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap() + ); + } + } + fn assert_nested_tx_error(e: crate::Error) { + if let Error::SqliteFailure(e, Some(m)) = &e { + assert_eq!(e.extended_code, crate::ffi::SQLITE_ERROR); + // FIXME: Not ideal... + assert_eq!(e.code, crate::ErrorCode::Unknown); + assert!(m.contains("transaction")); + } else { + panic!("Unexpected error type: {:?}", e); + } + } + + #[test] + fn test_unchecked_nesting() { + let db = checked_memory_handle(); + + { + let tx = db.unchecked_transaction().unwrap(); + let e = tx.unchecked_transaction().unwrap_err(); + assert_nested_tx_error(e); + // default: rollback + } + { + let tx = db.unchecked_transaction().unwrap(); + tx.execute_batch("INSERT INTO foo VALUES(1)").unwrap(); + // Ensure this doesn't interfere with ongoing transaction + let e = tx.unchecked_transaction().unwrap_err(); + assert_nested_tx_error(e); + + tx.execute_batch("INSERT INTO foo VALUES(1)").unwrap(); + tx.commit().unwrap(); + } + + assert_eq!( + 2i32, + db.query_row::<i32, _, _>("SELECT SUM(x) FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap() + ); + } + + #[test] + fn test_explicit_rollback_commit() { + let mut db = checked_memory_handle(); + { + let mut tx = db.transaction().unwrap(); + { + let mut sp = tx.savepoint().unwrap(); + sp.execute_batch("INSERT INTO foo VALUES(1)").unwrap(); + sp.rollback().unwrap(); + sp.execute_batch("INSERT INTO foo VALUES(2)").unwrap(); + sp.commit().unwrap(); + } + tx.commit().unwrap(); + } + { + let tx = db.transaction().unwrap(); + tx.execute_batch("INSERT INTO foo VALUES(4)").unwrap(); + tx.commit().unwrap(); + } + { + let tx = db.transaction().unwrap(); + assert_eq!( + 6i32, + tx.query_row::<i32, _, _>("SELECT SUM(x) FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap() + ); + } + } + + #[test] + fn test_savepoint() { + let mut db = checked_memory_handle(); + { + let mut tx = db.transaction().unwrap(); + tx.execute_batch("INSERT INTO foo VALUES(1)").unwrap(); + assert_current_sum(1, &tx); + tx.set_drop_behavior(DropBehavior::Commit); + { + let mut sp1 = tx.savepoint().unwrap(); + sp1.execute_batch("INSERT INTO foo VALUES(2)").unwrap(); + assert_current_sum(3, &sp1); + // will rollback sp1 + { + let mut sp2 = sp1.savepoint().unwrap(); + sp2.execute_batch("INSERT INTO foo VALUES(4)").unwrap(); + assert_current_sum(7, &sp2); + // will rollback sp2 + { + let sp3 = sp2.savepoint().unwrap(); + sp3.execute_batch("INSERT INTO foo VALUES(8)").unwrap(); + assert_current_sum(15, &sp3); + sp3.commit().unwrap(); + // committed sp3, but will be erased by sp2 rollback + } + assert_current_sum(15, &sp2); + } + assert_current_sum(3, &sp1); + } + assert_current_sum(1, &tx); + } + assert_current_sum(1, &db); + } + + #[test] + fn test_ignore_drop_behavior() { + let mut db = checked_memory_handle(); + + let mut tx = db.transaction().unwrap(); + { + let mut sp1 = tx.savepoint().unwrap(); + insert(1, &sp1); + sp1.rollback().unwrap(); + insert(2, &sp1); + { + let mut sp2 = sp1.savepoint().unwrap(); + sp2.set_drop_behavior(DropBehavior::Ignore); + insert(4, &sp2); + } + assert_current_sum(6, &sp1); + sp1.commit().unwrap(); + } + assert_current_sum(6, &tx); + } + + #[test] + fn test_savepoint_names() { + let mut db = checked_memory_handle(); + + { + let mut sp1 = db.savepoint_with_name("my_sp").unwrap(); + insert(1, &sp1); + assert_current_sum(1, &sp1); + { + let mut sp2 = sp1.savepoint_with_name("my_sp").unwrap(); + sp2.set_drop_behavior(DropBehavior::Commit); + insert(2, &sp2); + assert_current_sum(3, &sp2); + sp2.rollback().unwrap(); + assert_current_sum(1, &sp2); + insert(4, &sp2); + } + assert_current_sum(5, &sp1); + sp1.rollback().unwrap(); + { + let mut sp2 = sp1.savepoint_with_name("my_sp").unwrap(); + sp2.set_drop_behavior(DropBehavior::Ignore); + insert(8, &sp2); + } + assert_current_sum(8, &sp1); + sp1.commit().unwrap(); + } + assert_current_sum(8, &db); + } + + #[test] + fn test_rc() { + use std::rc::Rc; + let mut conn = Connection::open_in_memory().unwrap(); + let rc_txn = Rc::new(conn.transaction().unwrap()); + + // This will compile only if Transaction is Debug + Rc::try_unwrap(rc_txn).unwrap(); + } + + fn insert(x: i32, conn: &Connection) { + conn.execute("INSERT INTO foo VALUES(?)", &[x]).unwrap(); + } + + fn assert_current_sum(x: i32, conn: &Connection) { + let i = conn + .query_row::<i32, _, _>("SELECT SUM(x) FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(x, i); + } +} diff --git a/third_party/rust/rusqlite/src/types/chrono.rs b/third_party/rust/rusqlite/src/types/chrono.rs new file mode 100644 index 0000000000..3cba1e9a6d --- /dev/null +++ b/third_party/rust/rusqlite/src/types/chrono.rs @@ -0,0 +1,280 @@ +//! Convert most of the [Time Strings](http://sqlite.org/lang_datefunc.html) to chrono types. + +use std::borrow::Cow; + +use chrono::{DateTime, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; + +use crate::types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef}; +use crate::Result; + +/// ISO 8601 calendar date without timezone => "YYYY-MM-DD" +impl ToSql for NaiveDate { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + let date_str = self.format("%Y-%m-%d").to_string(); + Ok(ToSqlOutput::from(date_str)) + } +} + +/// "YYYY-MM-DD" => ISO 8601 calendar date without timezone. +impl FromSql for NaiveDate { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value + .as_str() + .and_then(|s| match NaiveDate::parse_from_str(s, "%Y-%m-%d") { + Ok(dt) => Ok(dt), + Err(err) => Err(FromSqlError::Other(Box::new(err))), + }) + } +} + +/// ISO 8601 time without timezone => "HH:MM:SS.SSS" +impl ToSql for NaiveTime { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + let date_str = self.format("%H:%M:%S%.f").to_string(); + Ok(ToSqlOutput::from(date_str)) + } +} + +/// "HH:MM"/"HH:MM:SS"/"HH:MM:SS.SSS" => ISO 8601 time without timezone. +impl FromSql for NaiveTime { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().and_then(|s| { + let fmt = match s.len() { + 5 => "%H:%M", + 8 => "%H:%M:%S", + _ => "%H:%M:%S%.f", + }; + match NaiveTime::parse_from_str(s, fmt) { + Ok(dt) => Ok(dt), + Err(err) => Err(FromSqlError::Other(Box::new(err))), + } + }) + } +} + +/// ISO 8601 combined date and time without timezone => +/// "YYYY-MM-DDTHH:MM:SS.SSS" +impl ToSql for NaiveDateTime { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + let date_str = self.format("%Y-%m-%dT%H:%M:%S%.f").to_string(); + Ok(ToSqlOutput::from(date_str)) + } +} + +/// "YYYY-MM-DD HH:MM:SS"/"YYYY-MM-DD HH:MM:SS.SSS" => ISO 8601 combined date +/// and time without timezone. ("YYYY-MM-DDTHH:MM:SS"/"YYYY-MM-DDTHH:MM:SS.SSS" +/// also supported) +impl FromSql for NaiveDateTime { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().and_then(|s| { + let fmt = if s.len() >= 11 && s.as_bytes()[10] == b'T' { + "%Y-%m-%dT%H:%M:%S%.f" + } else { + "%Y-%m-%d %H:%M:%S%.f" + }; + + match NaiveDateTime::parse_from_str(s, fmt) { + Ok(dt) => Ok(dt), + Err(err) => Err(FromSqlError::Other(Box::new(err))), + } + }) + } +} + +/// Date and time with time zone => UTC RFC3339 timestamp +/// ("YYYY-MM-DDTHH:MM:SS.SSS+00:00"). +impl<Tz: TimeZone> ToSql for DateTime<Tz> { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(self.with_timezone(&Utc).to_rfc3339())) + } +} + +/// RFC3339 ("YYYY-MM-DDTHH:MM:SS.SSS[+-]HH:MM") into `DateTime<Utc>`. +impl FromSql for DateTime<Utc> { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + { + // Try to parse value as rfc3339 first. + let s = value.as_str()?; + + // If timestamp looks space-separated, make a copy and replace it with 'T'. + let s = if s.len() >= 11 && s.as_bytes()[10] == b' ' { + let mut s = s.to_string(); + unsafe { + let sbytes = s.as_mut_vec(); + sbytes[10] = b'T'; + } + Cow::Owned(s) + } else { + Cow::Borrowed(s) + }; + + if let Ok(dt) = DateTime::parse_from_rfc3339(&s) { + return Ok(dt.with_timezone(&Utc)); + } + } + + // Couldn't parse as rfc3339 - fall back to NaiveDateTime. + NaiveDateTime::column_result(value).map(|dt| Utc.from_utc_datetime(&dt)) + } +} + +/// RFC3339 ("YYYY-MM-DDTHH:MM:SS.SSS[+-]HH:MM") into `DateTime<Local>`. +impl FromSql for DateTime<Local> { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + let utc_dt = DateTime::<Utc>::column_result(value)?; + Ok(utc_dt.with_timezone(&Local)) + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, Result, NO_PARAMS}; + use chrono::{DateTime, Duration, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; + + fn checked_memory_handle() -> Connection { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE foo (t TEXT, i INTEGER, f FLOAT, b BLOB)") + .unwrap(); + db + } + + #[test] + fn test_naive_date() { + let db = checked_memory_handle(); + let date = NaiveDate::from_ymd(2016, 2, 23); + db.execute("INSERT INTO foo (t) VALUES (?)", &[&date]) + .unwrap(); + + let s: String = db + .query_row("SELECT t FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!("2016-02-23", s); + let t: NaiveDate = db + .query_row("SELECT t FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(date, t); + } + + #[test] + fn test_naive_time() { + let db = checked_memory_handle(); + let time = NaiveTime::from_hms(23, 56, 4); + db.execute("INSERT INTO foo (t) VALUES (?)", &[&time]) + .unwrap(); + + let s: String = db + .query_row("SELECT t FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!("23:56:04", s); + let v: NaiveTime = db + .query_row("SELECT t FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(time, v); + } + + #[test] + fn test_naive_date_time() { + let db = checked_memory_handle(); + let date = NaiveDate::from_ymd(2016, 2, 23); + let time = NaiveTime::from_hms(23, 56, 4); + let dt = NaiveDateTime::new(date, time); + + db.execute("INSERT INTO foo (t) VALUES (?)", &[&dt]) + .unwrap(); + + let s: String = db + .query_row("SELECT t FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!("2016-02-23T23:56:04", s); + let v: NaiveDateTime = db + .query_row("SELECT t FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(dt, v); + + db.execute("UPDATE foo set b = datetime(t)", NO_PARAMS) + .unwrap(); // "YYYY-MM-DD HH:MM:SS" + let hms: NaiveDateTime = db + .query_row("SELECT b FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(dt, hms); + } + + #[test] + fn test_date_time_utc() { + let db = checked_memory_handle(); + let date = NaiveDate::from_ymd(2016, 2, 23); + let time = NaiveTime::from_hms_milli(23, 56, 4, 789); + let dt = NaiveDateTime::new(date, time); + let utc = Utc.from_utc_datetime(&dt); + + db.execute("INSERT INTO foo (t) VALUES (?)", &[&utc]) + .unwrap(); + + let s: String = db + .query_row("SELECT t FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!("2016-02-23T23:56:04.789+00:00", s); + + let v1: DateTime<Utc> = db + .query_row("SELECT t FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(utc, v1); + + let v2: DateTime<Utc> = db + .query_row("SELECT '2016-02-23 23:56:04.789'", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(utc, v2); + + let v3: DateTime<Utc> = db + .query_row("SELECT '2016-02-23 23:56:04'", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(utc - Duration::milliseconds(789), v3); + + let v4: DateTime<Utc> = db + .query_row("SELECT '2016-02-23 23:56:04.789+00:00'", NO_PARAMS, |r| { + r.get(0) + }) + .unwrap(); + assert_eq!(utc, v4); + } + + #[test] + fn test_date_time_local() { + let db = checked_memory_handle(); + let date = NaiveDate::from_ymd(2016, 2, 23); + let time = NaiveTime::from_hms_milli(23, 56, 4, 789); + let dt = NaiveDateTime::new(date, time); + let local = Local.from_local_datetime(&dt).single().unwrap(); + + db.execute("INSERT INTO foo (t) VALUES (?)", &[&local]) + .unwrap(); + + // Stored string should be in UTC + let s: String = db + .query_row("SELECT t FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert!(s.ends_with("+00:00")); + + let v: DateTime<Local> = db + .query_row("SELECT t FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(local, v); + } + + #[test] + fn test_sqlite_functions() { + let db = checked_memory_handle(); + let result: Result<NaiveTime> = + db.query_row("SELECT CURRENT_TIME", NO_PARAMS, |r| r.get(0)); + assert!(result.is_ok()); + let result: Result<NaiveDate> = + db.query_row("SELECT CURRENT_DATE", NO_PARAMS, |r| r.get(0)); + assert!(result.is_ok()); + let result: Result<NaiveDateTime> = + db.query_row("SELECT CURRENT_TIMESTAMP", NO_PARAMS, |r| r.get(0)); + assert!(result.is_ok()); + let result: Result<DateTime<Utc>> = + db.query_row("SELECT CURRENT_TIMESTAMP", NO_PARAMS, |r| r.get(0)); + assert!(result.is_ok()); + } +} diff --git a/third_party/rust/rusqlite/src/types/from_sql.rs b/third_party/rust/rusqlite/src/types/from_sql.rs new file mode 100644 index 0000000000..3fe74b42fc --- /dev/null +++ b/third_party/rust/rusqlite/src/types/from_sql.rs @@ -0,0 +1,270 @@ +use super::{Value, ValueRef}; +use std::error::Error; +use std::fmt; + +/// Enum listing possible errors from `FromSql` trait. +#[derive(Debug)] +#[non_exhaustive] +pub enum FromSqlError { + /// Error when an SQLite value is requested, but the type of the result + /// cannot be converted to the requested Rust type. + InvalidType, + + /// Error when the i64 value returned by SQLite cannot be stored into the + /// requested type. + OutOfRange(i64), + + /// `feature = "i128_blob"` Error returned when reading an `i128` from a + /// blob with a size other than 16. Only available when the `i128_blob` + /// feature is enabled. + #[cfg(feature = "i128_blob")] + InvalidI128Size(usize), + + /// `feature = "uuid"` Error returned when reading a `uuid` from a blob with + /// a size other than 16. Only available when the `uuid` feature is enabled. + #[cfg(feature = "uuid")] + InvalidUuidSize(usize), + + /// An error case available for implementors of the `FromSql` trait. + Other(Box<dyn Error + Send + Sync + 'static>), +} + +impl PartialEq for FromSqlError { + fn eq(&self, other: &FromSqlError) -> bool { + match (self, other) { + (FromSqlError::InvalidType, FromSqlError::InvalidType) => true, + (FromSqlError::OutOfRange(n1), FromSqlError::OutOfRange(n2)) => n1 == n2, + #[cfg(feature = "i128_blob")] + (FromSqlError::InvalidI128Size(s1), FromSqlError::InvalidI128Size(s2)) => s1 == s2, + #[cfg(feature = "uuid")] + (FromSqlError::InvalidUuidSize(s1), FromSqlError::InvalidUuidSize(s2)) => s1 == s2, + (..) => false, + } + } +} + +impl fmt::Display for FromSqlError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + FromSqlError::InvalidType => write!(f, "Invalid type"), + FromSqlError::OutOfRange(i) => write!(f, "Value {} out of range", i), + #[cfg(feature = "i128_blob")] + FromSqlError::InvalidI128Size(s) => { + write!(f, "Cannot read 128bit value out of {} byte blob", s) + } + #[cfg(feature = "uuid")] + FromSqlError::InvalidUuidSize(s) => { + write!(f, "Cannot read UUID value out of {} byte blob", s) + } + FromSqlError::Other(ref err) => err.fmt(f), + } + } +} + +impl Error for FromSqlError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + if let FromSqlError::Other(ref err) = self { + Some(&**err) + } else { + None + } + } +} + +/// Result type for implementors of the `FromSql` trait. +pub type FromSqlResult<T> = Result<T, FromSqlError>; + +/// A trait for types that can be created from a SQLite value. +/// +/// Note that `FromSql` and `ToSql` are defined for most integral types, but +/// not `u64` or `usize`. This is intentional; SQLite returns integers as +/// signed 64-bit values, which cannot fully represent the range of these +/// types. Rusqlite would have to +/// decide how to handle negative values: return an error or reinterpret as a +/// very large postive numbers, neither of which +/// is guaranteed to be correct for everyone. Callers can work around this by +/// fetching values as i64 and then doing the interpretation themselves or by +/// defining a newtype and implementing `FromSql`/`ToSql` for it. +pub trait FromSql: Sized { + /// Converts SQLite value into Rust value. + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self>; +} + +impl FromSql for isize { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + i64::column_result(value).and_then(|i| { + if i < isize::min_value() as i64 || i > isize::max_value() as i64 { + Err(FromSqlError::OutOfRange(i)) + } else { + Ok(i as isize) + } + }) + } +} + +macro_rules! from_sql_integral( + ($t:ident) => ( + impl FromSql for $t { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + i64::column_result(value).and_then(|i| { + if i < i64::from($t::min_value()) || i > i64::from($t::max_value()) { + Err(FromSqlError::OutOfRange(i)) + } else { + Ok(i as $t) + } + }) + } + } + ) +); + +from_sql_integral!(i8); +from_sql_integral!(i16); +from_sql_integral!(i32); +from_sql_integral!(u8); +from_sql_integral!(u16); +from_sql_integral!(u32); + +impl FromSql for i64 { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_i64() + } +} + +impl FromSql for f64 { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + match value { + ValueRef::Integer(i) => Ok(i as f64), + ValueRef::Real(f) => Ok(f), + _ => Err(FromSqlError::InvalidType), + } + } +} + +impl FromSql for bool { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + i64::column_result(value).map(|i| !matches!(i, 0)) + } +} + +impl FromSql for String { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().map(ToString::to_string) + } +} + +impl FromSql for Box<str> { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().map(Into::into) + } +} + +impl FromSql for std::rc::Rc<str> { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().map(Into::into) + } +} + +impl FromSql for std::sync::Arc<str> { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().map(Into::into) + } +} + +impl FromSql for Vec<u8> { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_blob().map(|b| b.to_vec()) + } +} + +#[cfg(feature = "i128_blob")] +impl FromSql for i128 { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + use byteorder::{BigEndian, ByteOrder}; + + value.as_blob().and_then(|bytes| { + if bytes.len() == 16 { + Ok(BigEndian::read_i128(bytes) ^ (1i128 << 127)) + } else { + Err(FromSqlError::InvalidI128Size(bytes.len())) + } + }) + } +} + +#[cfg(feature = "uuid")] +impl FromSql for uuid::Uuid { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value + .as_blob() + .and_then(|bytes| { + uuid::Builder::from_slice(bytes) + .map_err(|_| FromSqlError::InvalidUuidSize(bytes.len())) + }) + .map(|mut builder| builder.build()) + } +} + +impl<T: FromSql> FromSql for Option<T> { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + match value { + ValueRef::Null => Ok(None), + _ => FromSql::column_result(value).map(Some), + } + } +} + +impl FromSql for Value { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + Ok(value.into()) + } +} + +#[cfg(test)] +mod test { + use super::FromSql; + use crate::{Connection, Error}; + + fn checked_memory_handle() -> Connection { + Connection::open_in_memory().unwrap() + } + + #[test] + fn test_integral_ranges() { + let db = checked_memory_handle(); + + fn check_ranges<T>(db: &Connection, out_of_range: &[i64], in_range: &[i64]) + where + T: Into<i64> + FromSql + ::std::fmt::Debug, + { + for n in out_of_range { + let err = db + .query_row("SELECT ?", &[n], |r| r.get::<_, T>(0)) + .unwrap_err(); + match err { + Error::IntegralValueOutOfRange(_, value) => assert_eq!(*n, value), + _ => panic!("unexpected error: {}", err), + } + } + for n in in_range { + assert_eq!( + *n, + db.query_row("SELECT ?", &[n], |r| r.get::<_, T>(0)) + .unwrap() + .into() + ); + } + } + + check_ranges::<i8>(&db, &[-129, 128], &[-128, 0, 1, 127]); + check_ranges::<i16>(&db, &[-32769, 32768], &[-32768, -1, 0, 1, 32767]); + check_ranges::<i32>( + &db, + &[-2_147_483_649, 2_147_483_648], + &[-2_147_483_648, -1, 0, 1, 2_147_483_647], + ); + check_ranges::<u8>(&db, &[-2, -1, 256], &[0, 1, 255]); + check_ranges::<u16>(&db, &[-2, -1, 65536], &[0, 1, 65535]); + check_ranges::<u32>(&db, &[-2, -1, 4_294_967_296], &[0, 1, 4_294_967_295]); + } +} diff --git a/third_party/rust/rusqlite/src/types/mod.rs b/third_party/rust/rusqlite/src/types/mod.rs new file mode 100644 index 0000000000..85d8ef2826 --- /dev/null +++ b/third_party/rust/rusqlite/src/types/mod.rs @@ -0,0 +1,372 @@ +//! Traits dealing with SQLite data types. +//! +//! SQLite uses a [dynamic type system](https://www.sqlite.org/datatype3.html). Implementations of +//! the `ToSql` and `FromSql` traits are provided for the basic types that +//! SQLite provides methods for: +//! +//! * Integers (`i32` and `i64`; SQLite uses `i64` internally, so getting an +//! `i32` will truncate if the value is too large or too small). +//! * Reals (`f64`) +//! * Strings (`String` and `&str`) +//! * Blobs (`Vec<u8>` and `&[u8]`) +//! +//! Additionally, if the `time` feature is enabled, implementations are +//! provided for `time::OffsetDateTime` that use the RFC 3339 date/time format, +//! `"%Y-%m-%dT%H:%M:%S.%fZ"`, to store time values as strings. These values +//! can be parsed by SQLite's builtin +//! [datetime](https://www.sqlite.org/lang_datefunc.html) functions. If you +//! want different storage for datetimes, you can use a newtype. +//! +#![cfg_attr( + feature = "time", + doc = r##" +For example, to store datetimes as `i64`s counting the number of seconds since +the Unix epoch: + +``` +use rusqlite::types::{FromSql, FromSqlResult, ToSql, ToSqlOutput, ValueRef}; +use rusqlite::Result; + +pub struct DateTimeSql(pub time::OffsetDateTime); + +impl FromSql for DateTimeSql { + fn column_result(value: ValueRef) -> FromSqlResult<Self> { + i64::column_result(value).map(|as_i64| { + DateTimeSql(time::OffsetDateTime::from_unix_timestamp(as_i64)) + }) + } +} + +impl ToSql for DateTimeSql { + fn to_sql(&self) -> Result<ToSqlOutput> { + Ok(self.0.timestamp().into()) + } +} +``` + +"## +)] +//! `ToSql` and `FromSql` are also implemented for `Option<T>` where `T` +//! implements `ToSql` or `FromSql` for the cases where you want to know if a +//! value was NULL (which gets translated to `None`). + +pub use self::from_sql::{FromSql, FromSqlError, FromSqlResult}; +pub use self::to_sql::{ToSql, ToSqlOutput}; +pub use self::value::Value; +pub use self::value_ref::ValueRef; + +use std::fmt; + +#[cfg(feature = "chrono")] +mod chrono; +mod from_sql; +#[cfg(feature = "serde_json")] +mod serde_json; +#[cfg(feature = "time")] +mod time; +mod to_sql; +#[cfg(feature = "url")] +mod url; +mod value; +mod value_ref; + +/// Empty struct that can be used to fill in a query parameter as `NULL`. +/// +/// ## Example +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result}; +/// # use rusqlite::types::{Null}; +/// +/// fn insert_null(conn: &Connection) -> Result<usize> { +/// conn.execute("INSERT INTO people (name) VALUES (?)", &[Null]) +/// } +/// ``` +#[derive(Copy, Clone)] +pub struct Null; + +/// SQLite data types. +/// See [Fundamental Datatypes](https://sqlite.org/c3ref/c_blob.html). +#[derive(Clone, Debug, PartialEq)] +pub enum Type { + /// NULL + Null, + /// 64-bit signed integer + Integer, + /// 64-bit IEEE floating point number + Real, + /// String + Text, + /// BLOB + Blob, +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Type::Null => write!(f, "Null"), + Type::Integer => write!(f, "Integer"), + Type::Real => write!(f, "Real"), + Type::Text => write!(f, "Text"), + Type::Blob => write!(f, "Blob"), + } + } +} + +#[cfg(test)] +mod test { + use super::Value; + use crate::{Connection, Error, NO_PARAMS}; + use std::f64::EPSILON; + use std::os::raw::{c_double, c_int}; + + fn checked_memory_handle() -> Connection { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE foo (b BLOB, t TEXT, i INTEGER, f FLOAT, n)") + .unwrap(); + db + } + + #[test] + fn test_blob() { + let db = checked_memory_handle(); + + let v1234 = vec![1u8, 2, 3, 4]; + db.execute("INSERT INTO foo(b) VALUES (?)", &[&v1234]) + .unwrap(); + + let v: Vec<u8> = db + .query_row("SELECT b FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(v, v1234); + } + + #[test] + fn test_empty_blob() { + let db = checked_memory_handle(); + + let empty = vec![]; + db.execute("INSERT INTO foo(b) VALUES (?)", &[&empty]) + .unwrap(); + + let v: Vec<u8> = db + .query_row("SELECT b FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(v, empty); + } + + #[test] + fn test_str() { + let db = checked_memory_handle(); + + let s = "hello, world!"; + db.execute("INSERT INTO foo(t) VALUES (?)", &[&s]).unwrap(); + + let from: String = db + .query_row("SELECT t FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(from, s); + } + + #[test] + fn test_string() { + let db = checked_memory_handle(); + + let s = "hello, world!"; + db.execute("INSERT INTO foo(t) VALUES (?)", &[s.to_owned()]) + .unwrap(); + + let from: String = db + .query_row("SELECT t FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(from, s); + } + + #[test] + fn test_value() { + let db = checked_memory_handle(); + + db.execute("INSERT INTO foo(i) VALUES (?)", &[Value::Integer(10)]) + .unwrap(); + + assert_eq!( + 10i64, + db.query_row::<i64, _, _>("SELECT i FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap() + ); + } + + #[test] + fn test_option() { + let db = checked_memory_handle(); + + let s = Some("hello, world!"); + let b = Some(vec![1u8, 2, 3, 4]); + + db.execute("INSERT INTO foo(t) VALUES (?)", &[&s]).unwrap(); + db.execute("INSERT INTO foo(b) VALUES (?)", &[&b]).unwrap(); + + let mut stmt = db + .prepare("SELECT t, b FROM foo ORDER BY ROWID ASC") + .unwrap(); + let mut rows = stmt.query(NO_PARAMS).unwrap(); + + { + let row1 = rows.next().unwrap().unwrap(); + let s1: Option<String> = row1.get_unwrap(0); + let b1: Option<Vec<u8>> = row1.get_unwrap(1); + assert_eq!(s.unwrap(), s1.unwrap()); + assert!(b1.is_none()); + } + + { + let row2 = rows.next().unwrap().unwrap(); + let s2: Option<String> = row2.get_unwrap(0); + let b2: Option<Vec<u8>> = row2.get_unwrap(1); + assert!(s2.is_none()); + assert_eq!(b, b2); + } + } + + #[test] + #[allow(clippy::cognitive_complexity)] + fn test_mismatched_types() { + fn is_invalid_column_type(err: Error) -> bool { + matches!(err, Error::InvalidColumnType(..)) + } + + let db = checked_memory_handle(); + + db.execute( + "INSERT INTO foo(b, t, i, f) VALUES (X'0102', 'text', 1, 1.5)", + NO_PARAMS, + ) + .unwrap(); + + let mut stmt = db.prepare("SELECT b, t, i, f, n FROM foo").unwrap(); + let mut rows = stmt.query(NO_PARAMS).unwrap(); + + let row = rows.next().unwrap().unwrap(); + + // check the correct types come back as expected + assert_eq!(vec![1, 2], row.get::<_, Vec<u8>>(0).unwrap()); + assert_eq!("text", row.get::<_, String>(1).unwrap()); + assert_eq!(1, row.get::<_, c_int>(2).unwrap()); + assert!((1.5 - row.get::<_, c_double>(3).unwrap()).abs() < EPSILON); + assert!(row.get::<_, Option<c_int>>(4).unwrap().is_none()); + assert!(row.get::<_, Option<c_double>>(4).unwrap().is_none()); + assert!(row.get::<_, Option<String>>(4).unwrap().is_none()); + + // check some invalid types + + // 0 is actually a blob (Vec<u8>) + assert!(is_invalid_column_type( + row.get::<_, c_int>(0).err().unwrap() + )); + assert!(is_invalid_column_type( + row.get::<_, c_int>(0).err().unwrap() + )); + assert!(is_invalid_column_type(row.get::<_, i64>(0).err().unwrap())); + assert!(is_invalid_column_type( + row.get::<_, c_double>(0).err().unwrap() + )); + assert!(is_invalid_column_type( + row.get::<_, String>(0).err().unwrap() + )); + #[cfg(feature = "time")] + assert!(is_invalid_column_type( + row.get::<_, time::OffsetDateTime>(0).err().unwrap() + )); + assert!(is_invalid_column_type( + row.get::<_, Option<c_int>>(0).err().unwrap() + )); + + // 1 is actually a text (String) + assert!(is_invalid_column_type( + row.get::<_, c_int>(1).err().unwrap() + )); + assert!(is_invalid_column_type(row.get::<_, i64>(1).err().unwrap())); + assert!(is_invalid_column_type( + row.get::<_, c_double>(1).err().unwrap() + )); + assert!(is_invalid_column_type( + row.get::<_, Vec<u8>>(1).err().unwrap() + )); + assert!(is_invalid_column_type( + row.get::<_, Option<c_int>>(1).err().unwrap() + )); + + // 2 is actually an integer + assert!(is_invalid_column_type( + row.get::<_, String>(2).err().unwrap() + )); + assert!(is_invalid_column_type( + row.get::<_, Vec<u8>>(2).err().unwrap() + )); + assert!(is_invalid_column_type( + row.get::<_, Option<String>>(2).err().unwrap() + )); + + // 3 is actually a float (c_double) + assert!(is_invalid_column_type( + row.get::<_, c_int>(3).err().unwrap() + )); + assert!(is_invalid_column_type(row.get::<_, i64>(3).err().unwrap())); + assert!(is_invalid_column_type( + row.get::<_, String>(3).err().unwrap() + )); + assert!(is_invalid_column_type( + row.get::<_, Vec<u8>>(3).err().unwrap() + )); + assert!(is_invalid_column_type( + row.get::<_, Option<c_int>>(3).err().unwrap() + )); + + // 4 is actually NULL + assert!(is_invalid_column_type( + row.get::<_, c_int>(4).err().unwrap() + )); + assert!(is_invalid_column_type(row.get::<_, i64>(4).err().unwrap())); + assert!(is_invalid_column_type( + row.get::<_, c_double>(4).err().unwrap() + )); + assert!(is_invalid_column_type( + row.get::<_, String>(4).err().unwrap() + )); + assert!(is_invalid_column_type( + row.get::<_, Vec<u8>>(4).err().unwrap() + )); + #[cfg(feature = "time")] + assert!(is_invalid_column_type( + row.get::<_, time::OffsetDateTime>(4).err().unwrap() + )); + } + + #[test] + fn test_dynamic_type() { + use super::Value; + let db = checked_memory_handle(); + + db.execute( + "INSERT INTO foo(b, t, i, f) VALUES (X'0102', 'text', 1, 1.5)", + NO_PARAMS, + ) + .unwrap(); + + let mut stmt = db.prepare("SELECT b, t, i, f, n FROM foo").unwrap(); + let mut rows = stmt.query(NO_PARAMS).unwrap(); + + let row = rows.next().unwrap().unwrap(); + assert_eq!(Value::Blob(vec![1, 2]), row.get::<_, Value>(0).unwrap()); + assert_eq!( + Value::Text(String::from("text")), + row.get::<_, Value>(1).unwrap() + ); + assert_eq!(Value::Integer(1), row.get::<_, Value>(2).unwrap()); + match row.get::<_, Value>(3).unwrap() { + Value::Real(val) => assert!((1.5 - val).abs() < EPSILON), + x => panic!("Invalid Value {:?}", x), + } + assert_eq!(Value::Null, row.get::<_, Value>(4).unwrap()); + } +} diff --git a/third_party/rust/rusqlite/src/types/serde_json.rs b/third_party/rust/rusqlite/src/types/serde_json.rs new file mode 100644 index 0000000000..abaecda076 --- /dev/null +++ b/third_party/rust/rusqlite/src/types/serde_json.rs @@ -0,0 +1,60 @@ +//! `ToSql` and `FromSql` implementation for JSON `Value`. + +use serde_json::Value; + +use crate::types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef}; +use crate::Result; + +/// Serialize JSON `Value` to text. +impl ToSql for Value { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(serde_json::to_string(self).unwrap())) + } +} + +/// Deserialize text/blob to JSON `Value`. +impl FromSql for Value { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + match value { + ValueRef::Text(s) => serde_json::from_slice(s), + ValueRef::Blob(b) => serde_json::from_slice(b), + _ => return Err(FromSqlError::InvalidType), + } + .map_err(|err| FromSqlError::Other(Box::new(err))) + } +} + +#[cfg(test)] +mod test { + use crate::types::ToSql; + use crate::{Connection, NO_PARAMS}; + + fn checked_memory_handle() -> Connection { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE foo (t TEXT, b BLOB)") + .unwrap(); + db + } + + #[test] + fn test_json_value() { + let db = checked_memory_handle(); + + let json = r#"{"foo": 13, "bar": "baz"}"#; + let data: serde_json::Value = serde_json::from_str(json).unwrap(); + db.execute( + "INSERT INTO foo (t, b) VALUES (?, ?)", + &[&data as &dyn ToSql, &json.as_bytes()], + ) + .unwrap(); + + let t: serde_json::Value = db + .query_row("SELECT t FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(data, t); + let b: serde_json::Value = db + .query_row("SELECT b FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + assert_eq!(data, b); + } +} diff --git a/third_party/rust/rusqlite/src/types/time.rs b/third_party/rust/rusqlite/src/types/time.rs new file mode 100644 index 0000000000..8589167b81 --- /dev/null +++ b/third_party/rust/rusqlite/src/types/time.rs @@ -0,0 +1,82 @@ +//! `ToSql` and `FromSql` implementation for [`time::OffsetDateTime`]. +use crate::types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef}; +use crate::Result; +use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset}; + +const CURRENT_TIMESTAMP_FMT: &str = "%Y-%m-%d %H:%M:%S"; +const SQLITE_DATETIME_FMT: &str = "%Y-%m-%dT%H:%M:%S.%NZ"; +const SQLITE_DATETIME_FMT_LEGACY: &str = "%Y-%m-%d %H:%M:%S:%N %z"; + +impl ToSql for OffsetDateTime { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + let time_string = self.to_offset(UtcOffset::UTC).format(SQLITE_DATETIME_FMT); + Ok(ToSqlOutput::from(time_string)) + } +} + +impl FromSql for OffsetDateTime { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + value.as_str().and_then(|s| { + match s.len() { + 19 => PrimitiveDateTime::parse(s, CURRENT_TIMESTAMP_FMT).map(|d| d.assume_utc()), + _ => PrimitiveDateTime::parse(s, SQLITE_DATETIME_FMT) + .map(|d| d.assume_utc()) + .or_else(|err| { + OffsetDateTime::parse(s, SQLITE_DATETIME_FMT_LEGACY).map_err(|_| err) + }), + } + .map_err(|err| FromSqlError::Other(Box::new(err))) + }) + } +} + +#[cfg(test)] +mod test { + use crate::{Connection, Result, NO_PARAMS}; + use std::time::Duration; + use time::OffsetDateTime; + + fn checked_memory_handle() -> Connection { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE foo (t TEXT, i INTEGER, f FLOAT)") + .unwrap(); + db + } + + #[test] + fn test_offset_date_time() { + let db = checked_memory_handle(); + + let mut ts_vec = vec![]; + + let make_datetime = + |secs, nanos| OffsetDateTime::from_unix_timestamp(secs) + Duration::from_nanos(nanos); + + ts_vec.push(make_datetime(10_000, 0)); //January 1, 1970 2:46:40 AM + ts_vec.push(make_datetime(10_000, 1000)); //January 1, 1970 2:46:40 AM (and one microsecond) + ts_vec.push(make_datetime(1_500_391_124, 1_000_000)); //July 18, 2017 + ts_vec.push(make_datetime(2_000_000_000, 2_000_000)); //May 18, 2033 + ts_vec.push(make_datetime(3_000_000_000, 999_999_999)); //January 24, 2065 + ts_vec.push(make_datetime(10_000_000_000, 0)); //November 20, 2286 + + for ts in ts_vec { + db.execute("INSERT INTO foo(t) VALUES (?)", &[&ts]).unwrap(); + + let from: OffsetDateTime = db + .query_row("SELECT t FROM foo", NO_PARAMS, |r| r.get(0)) + .unwrap(); + + db.execute("DELETE FROM foo", NO_PARAMS).unwrap(); + + assert_eq!(from, ts); + } + } + + #[test] + fn test_sqlite_functions() { + let db = checked_memory_handle(); + let result: Result<OffsetDateTime> = + db.query_row("SELECT CURRENT_TIMESTAMP", NO_PARAMS, |r| r.get(0)); + assert!(result.is_ok()); + } +} diff --git a/third_party/rust/rusqlite/src/types/to_sql.rs b/third_party/rust/rusqlite/src/types/to_sql.rs new file mode 100644 index 0000000000..937c0f80e8 --- /dev/null +++ b/third_party/rust/rusqlite/src/types/to_sql.rs @@ -0,0 +1,377 @@ +use super::{Null, Value, ValueRef}; +#[cfg(feature = "array")] +use crate::vtab::array::Array; +use crate::Result; +use std::borrow::Cow; + +/// `ToSqlOutput` represents the possible output types for implementors of the +/// `ToSql` trait. +#[derive(Clone, Debug, PartialEq)] +#[non_exhaustive] +pub enum ToSqlOutput<'a> { + /// A borrowed SQLite-representable value. + Borrowed(ValueRef<'a>), + + /// An owned SQLite-representable value. + Owned(Value), + + /// `feature = "blob"` A BLOB of the given length that is filled with + /// zeroes. + #[cfg(feature = "blob")] + ZeroBlob(i32), + + /// `feature = "array"` + #[cfg(feature = "array")] + Array(Array), +} + +// Generically allow any type that can be converted into a ValueRef +// to be converted into a ToSqlOutput as well. +impl<'a, T: ?Sized> From<&'a T> for ToSqlOutput<'a> +where + &'a T: Into<ValueRef<'a>>, +{ + fn from(t: &'a T) -> Self { + ToSqlOutput::Borrowed(t.into()) + } +} + +// We cannot also generically allow any type that can be converted +// into a Value to be converted into a ToSqlOutput because of +// coherence rules (https://github.com/rust-lang/rust/pull/46192), +// so we'll manually implement it for all the types we know can +// be converted into Values. +macro_rules! from_value( + ($t:ty) => ( + impl From<$t> for ToSqlOutput<'_> { + fn from(t: $t) -> Self { ToSqlOutput::Owned(t.into())} + } + ) +); +from_value!(String); +from_value!(Null); +from_value!(bool); +from_value!(i8); +from_value!(i16); +from_value!(i32); +from_value!(i64); +from_value!(isize); +from_value!(u8); +from_value!(u16); +from_value!(u32); +from_value!(f64); +from_value!(Vec<u8>); + +// It would be nice if we could avoid the heap allocation (of the `Vec`) that +// `i128` needs in `Into<Value>`, but it's probably fine for the moment, and not +// worth adding another case to Value. +#[cfg(feature = "i128_blob")] +from_value!(i128); + +#[cfg(feature = "uuid")] +from_value!(uuid::Uuid); + +impl ToSql for ToSqlOutput<'_> { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(match *self { + ToSqlOutput::Borrowed(v) => ToSqlOutput::Borrowed(v), + ToSqlOutput::Owned(ref v) => ToSqlOutput::Borrowed(ValueRef::from(v)), + + #[cfg(feature = "blob")] + ToSqlOutput::ZeroBlob(i) => ToSqlOutput::ZeroBlob(i), + #[cfg(feature = "array")] + ToSqlOutput::Array(ref a) => ToSqlOutput::Array(a.clone()), + }) + } +} + +/// A trait for types that can be converted into SQLite values. +pub trait ToSql { + /// Converts Rust value to SQLite value + fn to_sql(&self) -> Result<ToSqlOutput<'_>>; +} + +impl<T: ToSql + ToOwned + ?Sized> ToSql for Cow<'_, T> { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + self.as_ref().to_sql() + } +} + +impl<T: ToSql + ?Sized> ToSql for Box<T> { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + self.as_ref().to_sql() + } +} + +impl<T: ToSql + ?Sized> ToSql for std::rc::Rc<T> { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + self.as_ref().to_sql() + } +} + +impl<T: ToSql + ?Sized> ToSql for std::sync::Arc<T> { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + self.as_ref().to_sql() + } +} + +// We should be able to use a generic impl like this: +// +// impl<T: Copy> ToSql for T where T: Into<Value> { +// fn to_sql(&self) -> Result<ToSqlOutput> { +// Ok(ToSqlOutput::from((*self).into())) +// } +// } +// +// instead of the following macro, but this runs afoul of +// https://github.com/rust-lang/rust/issues/30191 and reports conflicting +// implementations even when there aren't any. + +macro_rules! to_sql_self( + ($t:ty) => ( + impl ToSql for $t { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(*self)) + } + } + ) +); + +to_sql_self!(Null); +to_sql_self!(bool); +to_sql_self!(i8); +to_sql_self!(i16); +to_sql_self!(i32); +to_sql_self!(i64); +to_sql_self!(isize); +to_sql_self!(u8); +to_sql_self!(u16); +to_sql_self!(u32); +to_sql_self!(f64); + +#[cfg(feature = "i128_blob")] +to_sql_self!(i128); + +#[cfg(feature = "uuid")] +to_sql_self!(uuid::Uuid); + +impl<T: ?Sized> ToSql for &'_ T +where + T: ToSql, +{ + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + (*self).to_sql() + } +} + +impl ToSql for String { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(self.as_str())) + } +} + +impl ToSql for str { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(self)) + } +} + +impl ToSql for Vec<u8> { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(self.as_slice())) + } +} + +impl ToSql for [u8] { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(self)) + } +} + +impl ToSql for Value { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(self)) + } +} + +impl<T: ToSql> ToSql for Option<T> { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + match *self { + None => Ok(ToSqlOutput::from(Null)), + Some(ref t) => t.to_sql(), + } + } +} + +#[cfg(test)] +mod test { + use super::ToSql; + + fn is_to_sql<T: ToSql>() {} + + #[test] + fn test_integral_types() { + is_to_sql::<i8>(); + is_to_sql::<i16>(); + is_to_sql::<i32>(); + is_to_sql::<i64>(); + is_to_sql::<u8>(); + is_to_sql::<u16>(); + is_to_sql::<u32>(); + } + + #[test] + fn test_cow_str() { + use std::borrow::Cow; + let s = "str"; + let cow: Cow<str> = Cow::Borrowed(s); + let r = cow.to_sql(); + assert!(r.is_ok()); + let cow: Cow<str> = Cow::Owned::<str>(String::from(s)); + let r = cow.to_sql(); + assert!(r.is_ok()); + // Ensure this compiles. + let _p: &[&dyn ToSql] = crate::params![cow]; + } + + #[test] + fn test_box_dyn() { + let s: Box<dyn ToSql> = Box::new("Hello world!"); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = ToSql::to_sql(&s); + + assert!(r.is_ok()); + } + + #[test] + fn test_box_deref() { + let s: Box<str> = "Hello world!".into(); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + + assert!(r.is_ok()); + } + + #[test] + fn test_box_direct() { + let s: Box<str> = "Hello world!".into(); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = ToSql::to_sql(&s); + + assert!(r.is_ok()); + } + + #[test] + fn test_cells() { + use std::{rc::Rc, sync::Arc}; + + let source_str: Box<str> = "Hello world!".into(); + + let s: Rc<Box<str>> = Rc::new(source_str.clone()); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + assert!(r.is_ok()); + + let s: Arc<Box<str>> = Arc::new(source_str.clone()); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + assert!(r.is_ok()); + + let s: Arc<str> = Arc::from(&*source_str); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + assert!(r.is_ok()); + + let s: Arc<dyn ToSql> = Arc::new(source_str.clone()); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + assert!(r.is_ok()); + + let s: Rc<str> = Rc::from(&*source_str); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + assert!(r.is_ok()); + + let s: Rc<dyn ToSql> = Rc::new(source_str); + let _s: &[&dyn ToSql] = crate::params![s]; + let r = s.to_sql(); + assert!(r.is_ok()); + } + + #[cfg(feature = "i128_blob")] + #[test] + fn test_i128() { + use crate::{Connection, NO_PARAMS}; + use std::i128; + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE foo (i128 BLOB, desc TEXT)") + .unwrap(); + db.execute( + " + INSERT INTO foo(i128, desc) VALUES + (?, 'zero'), + (?, 'neg one'), (?, 'neg two'), + (?, 'pos one'), (?, 'pos two'), + (?, 'min'), (?, 'max')", + &[0i128, -1i128, -2i128, 1i128, 2i128, i128::MIN, i128::MAX], + ) + .unwrap(); + + let mut stmt = db + .prepare("SELECT i128, desc FROM foo ORDER BY i128 ASC") + .unwrap(); + + let res = stmt + .query_map(NO_PARAMS, |row| { + Ok((row.get::<_, i128>(0)?, row.get::<_, String>(1)?)) + }) + .unwrap() + .collect::<Result<Vec<_>, _>>() + .unwrap(); + + assert_eq!( + res, + &[ + (i128::MIN, "min".to_owned()), + (-2, "neg two".to_owned()), + (-1, "neg one".to_owned()), + (0, "zero".to_owned()), + (1, "pos one".to_owned()), + (2, "pos two".to_owned()), + (i128::MAX, "max".to_owned()), + ] + ); + } + + #[cfg(feature = "uuid")] + #[test] + fn test_uuid() { + use crate::{params, Connection}; + use uuid::Uuid; + + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE foo (id BLOB CHECK(length(id) = 16), label TEXT);") + .unwrap(); + + let id = Uuid::new_v4(); + + db.execute( + "INSERT INTO foo (id, label) VALUES (?, ?)", + params![id, "target"], + ) + .unwrap(); + + let mut stmt = db + .prepare("SELECT id, label FROM foo WHERE id = ?") + .unwrap(); + + let mut rows = stmt.query(params![id]).unwrap(); + let row = rows.next().unwrap().unwrap(); + + let found_id: Uuid = row.get_unwrap(0); + let found_label: String = row.get_unwrap(1); + + assert_eq!(found_id, id); + assert_eq!(found_label, "target"); + } +} diff --git a/third_party/rust/rusqlite/src/types/url.rs b/third_party/rust/rusqlite/src/types/url.rs new file mode 100644 index 0000000000..1c9c63a17d --- /dev/null +++ b/third_party/rust/rusqlite/src/types/url.rs @@ -0,0 +1,81 @@ +//! `ToSql` and `FromSql` implementation for [`url::Url`]. +use crate::types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef}; +use crate::Result; +use url::Url; + +/// Serialize `Url` to text. +impl ToSql for Url { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::from(self.as_str())) + } +} + +/// Deserialize text to `Url`. +impl FromSql for Url { + fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> { + match value { + ValueRef::Text(s) => { + let s = std::str::from_utf8(s).map_err(|e| FromSqlError::Other(Box::new(e)))?; + Url::parse(s).map_err(|e| FromSqlError::Other(Box::new(e))) + } + _ => Err(FromSqlError::InvalidType), + } + } +} + +#[cfg(test)] +mod test { + use crate::{params, Connection, Error, Result}; + use url::{ParseError, Url}; + + fn checked_memory_handle() -> Connection { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("CREATE TABLE urls (i INTEGER, v TEXT)") + .unwrap(); + db + } + + fn get_url(db: &Connection, id: i64) -> Result<Url> { + db.query_row("SELECT v FROM urls WHERE i = ?", params![id], |r| r.get(0)) + } + + #[test] + fn test_sql_url() { + let db = &checked_memory_handle(); + + let url0 = Url::parse("http://www.example1.com").unwrap(); + let url1 = Url::parse("http://www.example1.com/👌").unwrap(); + let url2 = "http://www.example2.com/👌"; + + db.execute( + "INSERT INTO urls (i, v) VALUES (0, ?), (1, ?), (2, ?), (3, ?)", + // also insert a non-hex encoded url (which might be present if it was + // inserted separately) + params![url0, url1, url2, "illegal"], + ) + .unwrap(); + + assert_eq!(get_url(db, 0).unwrap(), url0); + + assert_eq!(get_url(db, 1).unwrap(), url1); + + // Should successfully read it, even though it wasn't inserted as an + // escaped url. + let out_url2: Url = get_url(db, 2).unwrap(); + assert_eq!(out_url2, Url::parse(url2).unwrap()); + + // Make sure the conversion error comes through correctly. + let err = get_url(db, 3).unwrap_err(); + match err { + Error::FromSqlConversionFailure(_, _, e) => { + assert_eq!( + *e.downcast::<ParseError>().unwrap(), + ParseError::RelativeUrlWithoutBase, + ); + } + e => { + panic!("Expected conversion failure, got {}", e); + } + } + } +} diff --git a/third_party/rust/rusqlite/src/types/value.rs b/third_party/rust/rusqlite/src/types/value.rs new file mode 100644 index 0000000000..64dc20354c --- /dev/null +++ b/third_party/rust/rusqlite/src/types/value.rs @@ -0,0 +1,122 @@ +use super::{Null, Type}; + +/// Owning [dynamic type value](http://sqlite.org/datatype3.html). Value's type is typically +/// dictated by SQLite (not by the caller). +/// +/// See [`ValueRef`](enum.ValueRef.html) for a non-owning dynamic type value. +#[derive(Clone, Debug, PartialEq)] +pub enum Value { + /// The value is a `NULL` value. + Null, + /// The value is a signed integer. + Integer(i64), + /// The value is a floating point number. + Real(f64), + /// The value is a text string. + Text(String), + /// The value is a blob of data + Blob(Vec<u8>), +} + +impl From<Null> for Value { + fn from(_: Null) -> Value { + Value::Null + } +} + +impl From<bool> for Value { + fn from(i: bool) -> Value { + Value::Integer(i as i64) + } +} + +impl From<isize> for Value { + fn from(i: isize) -> Value { + Value::Integer(i as i64) + } +} + +#[cfg(feature = "i128_blob")] +impl From<i128> for Value { + fn from(i: i128) -> Value { + use byteorder::{BigEndian, ByteOrder}; + let mut buf = vec![0u8; 16]; + // We store these biased (e.g. with the most significant bit flipped) + // so that comparisons with negative numbers work properly. + BigEndian::write_i128(&mut buf, i ^ (1i128 << 127)); + Value::Blob(buf) + } +} + +#[cfg(feature = "uuid")] +impl From<uuid::Uuid> for Value { + fn from(id: uuid::Uuid) -> Value { + Value::Blob(id.as_bytes().to_vec()) + } +} + +macro_rules! from_i64( + ($t:ty) => ( + impl From<$t> for Value { + fn from(i: $t) -> Value { + Value::Integer(i64::from(i)) + } + } + ) +); + +from_i64!(i8); +from_i64!(i16); +from_i64!(i32); +from_i64!(u8); +from_i64!(u16); +from_i64!(u32); + +impl From<i64> for Value { + fn from(i: i64) -> Value { + Value::Integer(i) + } +} + +impl From<f64> for Value { + fn from(f: f64) -> Value { + Value::Real(f) + } +} + +impl From<String> for Value { + fn from(s: String) -> Value { + Value::Text(s) + } +} + +impl From<Vec<u8>> for Value { + fn from(v: Vec<u8>) -> Value { + Value::Blob(v) + } +} + +impl<T> From<Option<T>> for Value +where + T: Into<Value>, +{ + fn from(v: Option<T>) -> Value { + match v { + Some(x) => x.into(), + None => Value::Null, + } + } +} + +impl Value { + /// Returns SQLite fundamental datatype. + pub fn data_type(&self) -> Type { + match *self { + Value::Null => Type::Null, + Value::Integer(_) => Type::Integer, + Value::Real(_) => Type::Real, + Value::Text(_) => Type::Text, + Value::Blob(_) => Type::Blob, + } + } +} diff --git a/third_party/rust/rusqlite/src/types/value_ref.rs b/third_party/rust/rusqlite/src/types/value_ref.rs new file mode 100644 index 0000000000..2f3243404e --- /dev/null +++ b/third_party/rust/rusqlite/src/types/value_ref.rs @@ -0,0 +1,171 @@ +use super::{Type, Value}; +use crate::types::{FromSqlError, FromSqlResult}; + +/// A non-owning [dynamic type value](http://sqlite.org/datatype3.html). Typically the +/// memory backing this value is owned by SQLite. +/// +/// See [`Value`](enum.Value.html) for an owning dynamic type value. +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum ValueRef<'a> { + /// The value is a `NULL` value. + Null, + /// The value is a signed integer. + Integer(i64), + /// The value is a floating point number. + Real(f64), + /// The value is a text string. + Text(&'a [u8]), + /// The value is a blob of data + Blob(&'a [u8]), +} + +impl ValueRef<'_> { + /// Returns SQLite fundamental datatype. + pub fn data_type(&self) -> Type { + match *self { + ValueRef::Null => Type::Null, + ValueRef::Integer(_) => Type::Integer, + ValueRef::Real(_) => Type::Real, + ValueRef::Text(_) => Type::Text, + ValueRef::Blob(_) => Type::Blob, + } + } +} + +impl<'a> ValueRef<'a> { + /// If `self` is case `Integer`, returns the integral value. Otherwise, + /// returns `Err(Error::InvalidColumnType)`. + pub fn as_i64(&self) -> FromSqlResult<i64> { + match *self { + ValueRef::Integer(i) => Ok(i), + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Real`, returns the floating point value. Otherwise, + /// returns `Err(Error::InvalidColumnType)`. + pub fn as_f64(&self) -> FromSqlResult<f64> { + match *self { + ValueRef::Real(f) => Ok(f), + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Text`, returns the string value. Otherwise, returns + /// `Err(Error::InvalidColumnType)`. + pub fn as_str(&self) -> FromSqlResult<&'a str> { + match *self { + ValueRef::Text(t) => { + std::str::from_utf8(t).map_err(|e| FromSqlError::Other(Box::new(e))) + } + _ => Err(FromSqlError::InvalidType), + } + } + + /// If `self` is case `Blob`, returns the byte slice. Otherwise, returns + /// `Err(Error::InvalidColumnType)`. + pub fn as_blob(&self) -> FromSqlResult<&'a [u8]> { + match *self { + ValueRef::Blob(b) => Ok(b), + _ => Err(FromSqlError::InvalidType), + } + } +} + +impl From<ValueRef<'_>> for Value { + fn from(borrowed: ValueRef<'_>) -> Value { + match borrowed { + ValueRef::Null => Value::Null, + ValueRef::Integer(i) => Value::Integer(i), + ValueRef::Real(r) => Value::Real(r), + ValueRef::Text(s) => { + let s = std::str::from_utf8(s).expect("invalid UTF-8"); + Value::Text(s.to_string()) + } + ValueRef::Blob(b) => Value::Blob(b.to_vec()), + } + } +} + +impl<'a> From<&'a str> for ValueRef<'a> { + fn from(s: &str) -> ValueRef<'_> { + ValueRef::Text(s.as_bytes()) + } +} + +impl<'a> From<&'a [u8]> for ValueRef<'a> { + fn from(s: &[u8]) -> ValueRef<'_> { + ValueRef::Blob(s) + } +} + +impl<'a> From<&'a Value> for ValueRef<'a> { + fn from(value: &'a Value) -> ValueRef<'a> { + match *value { + Value::Null => ValueRef::Null, + Value::Integer(i) => ValueRef::Integer(i), + Value::Real(r) => ValueRef::Real(r), + Value::Text(ref s) => ValueRef::Text(s.as_bytes()), + Value::Blob(ref b) => ValueRef::Blob(b), + } + } +} + +impl<'a, T> From<Option<T>> for ValueRef<'a> +where + T: Into<ValueRef<'a>>, +{ + fn from(s: Option<T>) -> ValueRef<'a> { + match s { + Some(x) => x.into(), + None => ValueRef::Null, + } + } +} + +#[cfg(any(feature = "functions", feature = "session", feature = "vtab"))] +impl<'a> ValueRef<'a> { + pub(crate) unsafe fn from_value(value: *mut crate::ffi::sqlite3_value) -> ValueRef<'a> { + use crate::ffi; + use std::slice::from_raw_parts; + + match ffi::sqlite3_value_type(value) { + ffi::SQLITE_NULL => ValueRef::Null, + ffi::SQLITE_INTEGER => ValueRef::Integer(ffi::sqlite3_value_int64(value)), + ffi::SQLITE_FLOAT => ValueRef::Real(ffi::sqlite3_value_double(value)), + ffi::SQLITE_TEXT => { + let text = ffi::sqlite3_value_text(value); + let len = ffi::sqlite3_value_bytes(value); + assert!( + !text.is_null(), + "unexpected SQLITE_TEXT value type with NULL data" + ); + let s = from_raw_parts(text as *const u8, len as usize); + ValueRef::Text(s) + } + ffi::SQLITE_BLOB => { + let (blob, len) = ( + ffi::sqlite3_value_blob(value), + ffi::sqlite3_value_bytes(value), + ); + + assert!( + len >= 0, + "unexpected negative return from sqlite3_value_bytes" + ); + if len > 0 { + assert!( + !blob.is_null(), + "unexpected SQLITE_BLOB value type with NULL data" + ); + ValueRef::Blob(from_raw_parts(blob as *const u8, len as usize)) + } else { + // The return value from sqlite3_value_blob() for a zero-length BLOB + // is a NULL pointer. + ValueRef::Blob(&[]) + } + } + _ => unreachable!("sqlite3_value_type returned invalid value"), + } + } +} diff --git a/third_party/rust/rusqlite/src/unlock_notify.rs b/third_party/rust/rusqlite/src/unlock_notify.rs new file mode 100644 index 0000000000..af2a06c607 --- /dev/null +++ b/third_party/rust/rusqlite/src/unlock_notify.rs @@ -0,0 +1,129 @@ +//! [Unlock Notification](http://sqlite.org/unlock_notify.html) + +use std::os::raw::c_int; +#[cfg(feature = "unlock_notify")] +use std::os::raw::c_void; +#[cfg(feature = "unlock_notify")] +use std::panic::catch_unwind; +#[cfg(feature = "unlock_notify")] +use std::sync::{Condvar, Mutex}; + +use crate::ffi; + +#[cfg(feature = "unlock_notify")] +struct UnlockNotification { + cond: Condvar, // Condition variable to wait on + mutex: Mutex<bool>, // Mutex to protect structure +} + +#[cfg(feature = "unlock_notify")] +#[allow(clippy::mutex_atomic)] +impl UnlockNotification { + fn new() -> UnlockNotification { + UnlockNotification { + cond: Condvar::new(), + mutex: Mutex::new(false), + } + } + + fn fired(&self) { + let mut flag = self.mutex.lock().unwrap(); + *flag = true; + self.cond.notify_one(); + } + + fn wait(&self) { + let mut fired = self.mutex.lock().unwrap(); + while !*fired { + fired = self.cond.wait(fired).unwrap(); + } + } +} + +/// This function is an unlock-notify callback +#[cfg(feature = "unlock_notify")] +unsafe extern "C" fn unlock_notify_cb(ap_arg: *mut *mut c_void, n_arg: c_int) { + use std::slice::from_raw_parts; + let args = from_raw_parts(ap_arg as *const &UnlockNotification, n_arg as usize); + for un in args { + let _ = catch_unwind(std::panic::AssertUnwindSafe(|| un.fired())); + } +} + +#[cfg(feature = "unlock_notify")] +pub unsafe fn is_locked(db: *mut ffi::sqlite3, rc: c_int) -> bool { + rc == ffi::SQLITE_LOCKED_SHAREDCACHE + || (rc & 0xFF) == ffi::SQLITE_LOCKED + && ffi::sqlite3_extended_errcode(db) == ffi::SQLITE_LOCKED_SHAREDCACHE +} + +/// This function assumes that an SQLite API call (either `sqlite3_prepare_v2()` +/// or `sqlite3_step()`) has just returned `SQLITE_LOCKED`. The argument is the +/// associated database connection. +/// +/// This function calls `sqlite3_unlock_notify()` to register for an +/// unlock-notify callback, then blocks until that callback is delivered +/// and returns `SQLITE_OK`. The caller should then retry the failed operation. +/// +/// Or, if `sqlite3_unlock_notify()` indicates that to block would deadlock +/// the system, then this function returns `SQLITE_LOCKED` immediately. In +/// this case the caller should not retry the operation and should roll +/// back the current transaction (if any). +#[cfg(feature = "unlock_notify")] +pub unsafe fn wait_for_unlock_notify(db: *mut ffi::sqlite3) -> c_int { + let un = UnlockNotification::new(); + /* Register for an unlock-notify callback. */ + let rc = ffi::sqlite3_unlock_notify( + db, + Some(unlock_notify_cb), + &un as *const UnlockNotification as *mut c_void, + ); + debug_assert!( + rc == ffi::SQLITE_LOCKED || rc == ffi::SQLITE_LOCKED_SHAREDCACHE || rc == ffi::SQLITE_OK + ); + if rc == ffi::SQLITE_OK { + un.wait(); + } + rc +} + +#[cfg(not(feature = "unlock_notify"))] +pub unsafe fn is_locked(_db: *mut ffi::sqlite3, _rc: c_int) -> bool { + unreachable!() +} + +#[cfg(not(feature = "unlock_notify"))] +pub unsafe fn wait_for_unlock_notify(_db: *mut ffi::sqlite3) -> c_int { + unreachable!() +} + +#[cfg(feature = "unlock_notify")] +#[cfg(test)] +mod test { + use crate::{Connection, OpenFlags, Result, Transaction, TransactionBehavior, NO_PARAMS}; + use std::sync::mpsc::sync_channel; + use std::thread; + use std::time; + + #[test] + fn test_unlock_notify() { + let url = "file::memory:?cache=shared"; + let flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_URI; + let db1 = Connection::open_with_flags(url, flags).unwrap(); + db1.execute_batch("CREATE TABLE foo (x)").unwrap(); + let (rx, tx) = sync_channel(0); + let child = thread::spawn(move || { + let mut db2 = Connection::open_with_flags(url, flags).unwrap(); + let tx2 = Transaction::new(&mut db2, TransactionBehavior::Immediate).unwrap(); + tx2.execute_batch("INSERT INTO foo VALUES (42)").unwrap(); + rx.send(1).unwrap(); + let ten_millis = time::Duration::from_millis(10); + thread::sleep(ten_millis); + tx2.commit().unwrap(); + }); + assert_eq!(tx.recv().unwrap(), 1); + let the_answer: Result<i64> = db1.query_row("SELECT x FROM foo", NO_PARAMS, |r| r.get(0)); + assert_eq!(42i64, the_answer.unwrap()); + child.join().unwrap(); + } +} diff --git a/third_party/rust/rusqlite/src/util/mod.rs b/third_party/rust/rusqlite/src/util/mod.rs new file mode 100644 index 0000000000..2b8dcfda1e --- /dev/null +++ b/third_party/rust/rusqlite/src/util/mod.rs @@ -0,0 +1,11 @@ +// Internal utilities +pub(crate) mod param_cache; +mod small_cstr; +pub(crate) use param_cache::ParamIndexCache; +pub(crate) use small_cstr::SmallCString; + +// Doesn't use any modern features or vtab stuff, but is only used by them. +#[cfg(any(feature = "modern_sqlite", feature = "vtab"))] +mod sqlite_string; +#[cfg(any(feature = "modern_sqlite", feature = "vtab"))] +pub(crate) use sqlite_string::SqliteMallocString; diff --git a/third_party/rust/rusqlite/src/util/param_cache.rs b/third_party/rust/rusqlite/src/util/param_cache.rs new file mode 100644 index 0000000000..6faced98af --- /dev/null +++ b/third_party/rust/rusqlite/src/util/param_cache.rs @@ -0,0 +1,60 @@ +use super::SmallCString; +use std::cell::RefCell; +use std::collections::BTreeMap; + +/// Maps parameter names to parameter indices. +#[derive(Default, Clone, Debug)] +// BTreeMap seems to do better here unless we want to pull in a custom hash +// function. +pub(crate) struct ParamIndexCache(RefCell<BTreeMap<SmallCString, usize>>); + +impl ParamIndexCache { + pub fn get_or_insert_with<F>(&self, s: &str, func: F) -> Option<usize> + where + F: FnOnce(&std::ffi::CStr) -> Option<usize>, + { + let mut cache = self.0.borrow_mut(); + // Avoid entry API, needs allocation to test membership. + if let Some(v) = cache.get(s) { + return Some(*v); + } + // If there's an internal nul in the name it couldn't have been a + // parameter, so early return here is ok. + let name = SmallCString::new(s).ok()?; + let val = func(&name)?; + cache.insert(name, val); + Some(val) + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_cache() { + let p = ParamIndexCache::default(); + let v = p.get_or_insert_with("foo", |cstr| { + assert_eq!(cstr.to_str().unwrap(), "foo"); + Some(3) + }); + assert_eq!(v, Some(3)); + let v = p.get_or_insert_with("foo", |_| { + panic!("shouldn't be called this time"); + }); + assert_eq!(v, Some(3)); + let v = p.get_or_insert_with("gar\0bage", |_| { + panic!("shouldn't be called here either"); + }); + assert_eq!(v, None); + let v = p.get_or_insert_with("bar", |cstr| { + assert_eq!(cstr.to_str().unwrap(), "bar"); + None + }); + assert_eq!(v, None); + let v = p.get_or_insert_with("bar", |cstr| { + assert_eq!(cstr.to_str().unwrap(), "bar"); + Some(30) + }); + assert_eq!(v, Some(30)); + } +} diff --git a/third_party/rust/rusqlite/src/util/small_cstr.rs b/third_party/rust/rusqlite/src/util/small_cstr.rs new file mode 100644 index 0000000000..bc4d9cbcba --- /dev/null +++ b/third_party/rust/rusqlite/src/util/small_cstr.rs @@ -0,0 +1,169 @@ +use smallvec::{smallvec, SmallVec}; +use std::ffi::{CStr, CString, NulError}; + +/// Similar to std::ffi::CString, but avoids heap allocating if the string is +/// small enough. Also guarantees it's input is UTF-8 -- used for cases where we +/// need to pass a NUL-terminated string to SQLite, and we have a `&str`. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct SmallCString(smallvec::SmallVec<[u8; 16]>); + +impl SmallCString { + #[inline] + pub fn new(s: &str) -> Result<Self, NulError> { + if s.as_bytes().contains(&0u8) { + return Err(Self::fabricate_nul_error(s)); + } + let mut buf = SmallVec::with_capacity(s.len() + 1); + buf.extend_from_slice(s.as_bytes()); + buf.push(0); + let res = Self(buf); + res.debug_checks(); + Ok(res) + } + + #[inline] + pub fn as_str(&self) -> &str { + self.debug_checks(); + // Constructor takes a &str so this is safe. + unsafe { std::str::from_utf8_unchecked(self.as_bytes_without_nul()) } + } + + /// Get the bytes not including the NUL terminator. E.g. the bytes which + /// make up our `str`: + /// - `SmallCString::new("foo").as_bytes_without_nul() == b"foo"` + /// - `SmallCString::new("foo").as_bytes_with_nul() == b"foo\0" + #[inline] + pub fn as_bytes_without_nul(&self) -> &[u8] { + self.debug_checks(); + &self.0[..self.len()] + } + + /// Get the bytes behind this str *including* the NUL terminator. This + /// should never return an empty slice. + #[inline] + pub fn as_bytes_with_nul(&self) -> &[u8] { + self.debug_checks(); + &self.0 + } + + #[inline] + #[cfg(debug_assertions)] + fn debug_checks(&self) { + debug_assert_ne!(self.0.len(), 0); + debug_assert_eq!(self.0[self.0.len() - 1], 0); + let strbytes = &self.0[..(self.0.len() - 1)]; + debug_assert!(!strbytes.contains(&0)); + debug_assert!(std::str::from_utf8(strbytes).is_ok()); + } + + #[inline] + #[cfg(not(debug_assertions))] + fn debug_checks(&self) {} + + #[inline] + pub fn len(&self) -> usize { + debug_assert_ne!(self.0.len(), 0); + self.0.len() - 1 + } + + #[inline] + #[allow(unused)] // clippy wants this function. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + #[inline] + pub fn as_cstr(&self) -> &CStr { + let bytes = self.as_bytes_with_nul(); + debug_assert!(CStr::from_bytes_with_nul(bytes).is_ok()); + unsafe { CStr::from_bytes_with_nul_unchecked(bytes) } + } + + #[cold] + fn fabricate_nul_error(b: &str) -> NulError { + CString::new(b).unwrap_err() + } +} + +impl Default for SmallCString { + #[inline] + fn default() -> Self { + Self(smallvec![0]) + } +} + +impl std::fmt::Debug for SmallCString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("SmallCString").field(&self.as_str()).finish() + } +} + +impl std::ops::Deref for SmallCString { + type Target = CStr; + #[inline] + fn deref(&self) -> &CStr { + self.as_cstr() + } +} + +impl PartialEq<SmallCString> for str { + #[inline] + fn eq(&self, s: &SmallCString) -> bool { + s.as_bytes_without_nul() == self.as_bytes() + } +} + +impl PartialEq<str> for SmallCString { + #[inline] + fn eq(&self, s: &str) -> bool { + self.as_bytes_without_nul() == s.as_bytes() + } +} + +impl std::borrow::Borrow<str> for SmallCString { + #[inline] + fn borrow(&self) -> &str { + self.as_str() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_small_cstring() { + // We don't go through the normal machinery for default, so make sure + // things work. + assert_eq!(SmallCString::default().0, SmallCString::new("").unwrap().0); + assert_eq!(SmallCString::new("foo").unwrap().len(), 3); + assert_eq!( + SmallCString::new("foo").unwrap().as_bytes_with_nul(), + b"foo\0" + ); + assert_eq!( + SmallCString::new("foo").unwrap().as_bytes_without_nul(), + b"foo", + ); + + assert_eq!(SmallCString::new("😀").unwrap().len(), 4); + assert_eq!( + SmallCString::new("😀").unwrap().0.as_slice(), + b"\xf0\x9f\x98\x80\0", + ); + assert_eq!( + SmallCString::new("😀").unwrap().as_bytes_without_nul(), + b"\xf0\x9f\x98\x80", + ); + + assert_eq!(SmallCString::new("").unwrap().len(), 0); + assert!(SmallCString::new("").unwrap().is_empty()); + + assert_eq!(SmallCString::new("").unwrap().0.as_slice(), b"\0"); + assert_eq!(SmallCString::new("").unwrap().as_bytes_without_nul(), b""); + + assert!(SmallCString::new("\0").is_err()); + assert!(SmallCString::new("\0abc").is_err()); + assert!(SmallCString::new("abc\0").is_err()); + } +} diff --git a/third_party/rust/rusqlite/src/util/sqlite_string.rs b/third_party/rust/rusqlite/src/util/sqlite_string.rs new file mode 100644 index 0000000000..18d462e498 --- /dev/null +++ b/third_party/rust/rusqlite/src/util/sqlite_string.rs @@ -0,0 +1,234 @@ +// This is used when either vtab or modern-sqlite is on. Different methods are +// used in each feature. Avoid having to track this for each function. We will +// still warn for anything that's not used by either, though. +#![cfg_attr( + not(all(feature = "vtab", feature = "modern-sqlite")), + allow(dead_code) +)] +use crate::ffi; +use std::marker::PhantomData; +use std::os::raw::{c_char, c_int}; +use std::ptr::NonNull; + +/// A string we own that's allocated on the SQLite heap. Automatically calls +/// `sqlite3_free` when dropped, unless `into_raw` (or `into_inner`) is called +/// on it. If constructed from a rust string, `sqlite3_malloc` is used. +/// +/// It has identical representation to a nonnull `*mut c_char`, so you can use +/// it transparently as one. It's nonnull, so Option<SqliteMallocString> can be +/// used for nullable ones (it's still just one pointer). +/// +/// Most strings shouldn't use this! Only places where the string needs to be +/// freed with `sqlite3_free`. This includes `sqlite3_extended_sql` results, +/// some error message pointers... Note that misuse is extremely dangerous! +/// +/// Note that this is *not* a lossless interface. Incoming strings with internal +/// NULs are modified, and outgoing strings which are non-UTF8 are modified. +/// This seems unavoidable -- it tries very hard to not panic. +#[repr(transparent)] +pub(crate) struct SqliteMallocString { + ptr: NonNull<c_char>, + _boo: PhantomData<Box<[c_char]>>, +} +// This is owned data for a primitive type, and thus it's safe to implement +// these. That said, nothing needs them, and they make things easier to misuse. + +// unsafe impl Send for SqliteMallocString {} +// unsafe impl Sync for SqliteMallocString {} + +impl SqliteMallocString { + /// SAFETY: Caller must be certain that `m` a nul-terminated c string + /// allocated by sqlite3_malloc, and that SQLite expects us to free it! + #[inline] + pub(crate) unsafe fn from_raw_nonnull(ptr: NonNull<c_char>) -> Self { + Self { + ptr, + _boo: PhantomData, + } + } + + /// SAFETY: Caller must be certain that `m` a nul-terminated c string + /// allocated by sqlite3_malloc, and that SQLite expects us to free it! + #[inline] + pub(crate) unsafe fn from_raw(ptr: *mut c_char) -> Option<Self> { + NonNull::new(ptr).map(|p| Self::from_raw_nonnull(p)) + } + + /// Get the pointer behind `self`. After this is called, we no longer manage + /// it. + #[inline] + pub(crate) fn into_inner(self) -> NonNull<c_char> { + let p = self.ptr; + std::mem::forget(self); + p + } + + /// Get the pointer behind `self`. After this is called, we no longer manage + /// it. + #[inline] + pub(crate) fn into_raw(self) -> *mut c_char { + self.into_inner().as_ptr() + } + + /// Borrow the pointer behind `self`. We still manage it when this function + /// returns. If you want to relinquish ownership, use `into_raw`. + #[inline] + pub(crate) fn as_ptr(&self) -> *const c_char { + self.ptr.as_ptr() + } + + #[inline] + pub(crate) fn as_cstr(&self) -> &std::ffi::CStr { + unsafe { std::ffi::CStr::from_ptr(self.as_ptr()) } + } + + #[inline] + pub(crate) fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { + self.as_cstr().to_string_lossy() + } + + /// Convert `s` into a SQLite string. + /// + /// This should almost never be done except for cases like error messages or + /// other strings that SQLite frees. + /// + /// If `s` contains internal NULs, we'll replace them with + /// `NUL_REPLACE_CHAR`. + /// + /// Except for debug_asserts which may trigger during testing, this function + /// never panics. If we hit integer overflow or the allocation fails, we + /// call `handle_alloc_error` which aborts the program after calling a + /// global hook. + /// + /// This means it's safe to use in extern "C" functions even outside of + /// catch_unwind. + pub(crate) fn from_str(s: &str) -> Self { + use std::convert::TryFrom; + let s = if s.as_bytes().contains(&0) { + std::borrow::Cow::Owned(make_nonnull(s)) + } else { + std::borrow::Cow::Borrowed(s) + }; + debug_assert!(!s.as_bytes().contains(&0)); + let bytes: &[u8] = s.as_ref().as_bytes(); + let src_ptr: *const c_char = bytes.as_ptr().cast(); + let src_len = bytes.len(); + let maybe_len_plus_1 = s.len().checked_add(1).and_then(|v| c_int::try_from(v).ok()); + unsafe { + let res_ptr = maybe_len_plus_1 + .and_then(|len_to_alloc| { + // `>` because we added 1. + debug_assert!(len_to_alloc > 0); + debug_assert_eq!((len_to_alloc - 1) as usize, src_len); + NonNull::new(ffi::sqlite3_malloc(len_to_alloc) as *mut c_char) + }) + .unwrap_or_else(|| { + use std::alloc::{handle_alloc_error, Layout}; + // Report via handle_alloc_error so that it can be handled with any + // other allocation errors and properly diagnosed. + // + // This is safe: + // - `align` is never 0 + // - `align` is always a power of 2. + // - `size` needs no realignment because it's guaranteed to be + // aligned (everything is aligned to 1) + // - `size` is also never zero, although this function doesn't actually require it now. + let layout = Layout::from_size_align_unchecked(s.len().saturating_add(1), 1); + // Note: This call does not return. + handle_alloc_error(layout); + }); + let buf: *mut c_char = res_ptr.as_ptr() as *mut c_char; + src_ptr.copy_to_nonoverlapping(buf, src_len); + buf.add(src_len).write(0); + debug_assert_eq!(std::ffi::CStr::from_ptr(res_ptr.as_ptr()).to_bytes(), bytes); + Self::from_raw_nonnull(res_ptr) + } + } +} + +const NUL_REPLACE: &str = "␀"; + +#[cold] +fn make_nonnull(v: &str) -> String { + v.replace('\0', NUL_REPLACE) +} + +impl Drop for SqliteMallocString { + fn drop(&mut self) { + unsafe { ffi::sqlite3_free(self.ptr.as_ptr().cast()) }; + } +} + +impl std::fmt::Debug for SqliteMallocString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.to_string_lossy().fmt(f) + } +} + +impl std::fmt::Display for SqliteMallocString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.to_string_lossy().fmt(f) + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_from_str() { + let to_check = [ + ("", ""), + ("\0", "␀"), + ("␀", "␀"), + ("\0bar", "␀bar"), + ("foo\0bar", "foo␀bar"), + ("foo\0", "foo␀"), + ("a\0b\0c\0\0d", "a␀b␀c␀␀d"), + ("foobar0123", "foobar0123"), + ]; + + for &(input, output) in &to_check { + let s = SqliteMallocString::from_str(input); + assert_eq!(s.to_string_lossy(), output); + assert_eq!(s.as_cstr().to_str().unwrap(), output); + } + } + + // This will trigger an asan error if into_raw still freed the ptr. + #[test] + fn test_lossy() { + let p = SqliteMallocString::from_str("abcd").into_raw(); + // Make invalid + let s = unsafe { + p.cast::<u8>().write(b'\xff'); + SqliteMallocString::from_raw(p).unwrap() + }; + assert_eq!(s.to_string_lossy().as_ref(), "\u{FFFD}bcd"); + } + + // This will trigger an asan error if into_raw still freed the ptr. + #[test] + fn test_into_raw() { + let mut v = vec![]; + for i in 0..1000 { + v.push(SqliteMallocString::from_str(&i.to_string()).into_raw()); + v.push(SqliteMallocString::from_str(&format!("abc {} 😀", i)).into_raw()); + } + unsafe { + for (i, s) in v.chunks_mut(2).enumerate() { + let s0 = std::mem::replace(&mut s[0], std::ptr::null_mut()); + let s1 = std::mem::replace(&mut s[1], std::ptr::null_mut()); + assert_eq!( + std::ffi::CStr::from_ptr(s0).to_str().unwrap(), + &i.to_string() + ); + assert_eq!( + std::ffi::CStr::from_ptr(s1).to_str().unwrap(), + &format!("abc {} 😀", i) + ); + let _ = SqliteMallocString::from_raw(s0).unwrap(); + let _ = SqliteMallocString::from_raw(s1).unwrap(); + } + } + } +} diff --git a/third_party/rust/rusqlite/src/version.rs b/third_party/rust/rusqlite/src/version.rs new file mode 100644 index 0000000000..215900b6f7 --- /dev/null +++ b/third_party/rust/rusqlite/src/version.rs @@ -0,0 +1,19 @@ +use crate::ffi; +use std::ffi::CStr; + +/// Returns the SQLite version as an integer; e.g., `3016002` for version +/// 3.16.2. +/// +/// See [`sqlite3_libversion_number()`](https://www.sqlite.org/c3ref/libversion.html). +pub fn version_number() -> i32 { + unsafe { ffi::sqlite3_libversion_number() } +} + +/// Returns the SQLite version as a string; e.g., `"3.16.2"` for version 3.16.2. +/// +/// See [`sqlite3_libversion()`](https://www.sqlite.org/c3ref/libversion.html). +pub fn version() -> &'static str { + let cstr = unsafe { CStr::from_ptr(ffi::sqlite3_libversion()) }; + cstr.to_str() + .expect("SQLite version string is not valid UTF8 ?!") +} diff --git a/third_party/rust/rusqlite/src/vtab/array.rs b/third_party/rust/rusqlite/src/vtab/array.rs new file mode 100644 index 0000000000..644b4687b6 --- /dev/null +++ b/third_party/rust/rusqlite/src/vtab/array.rs @@ -0,0 +1,224 @@ +//! `feature = "array"` Array Virtual Table. +//! +//! Note: `rarray`, not `carray` is the name of the table valued function we +//! define. +//! +//! Port of [carray](http://www.sqlite.org/cgi/src/finfo?name=ext/misc/carray.c) +//! C extension: https://www.sqlite.org/carray.html +//! +//! # Example +//! +//! ```rust,no_run +//! # use rusqlite::{types::Value, Connection, Result, params}; +//! # use std::rc::Rc; +//! fn example(db: &Connection) -> Result<()> { +//! // Note: This should be done once (usually when opening the DB). +//! rusqlite::vtab::array::load_module(&db)?; +//! let v = [1i64, 2, 3, 4]; +//! // Note: A `Rc<Vec<Value>>` must be used as the parameter. +//! let values = Rc::new(v.iter().copied().map(Value::from).collect::<Vec<Value>>()); +//! let mut stmt = db.prepare("SELECT value from rarray(?);")?; +//! let rows = stmt.query_map(params![values], |row| row.get::<_, i64>(0))?; +//! for value in rows { +//! println!("{}", value?); +//! } +//! Ok(()) +//! } +//! ``` + +use std::default::Default; +use std::marker::PhantomData; +use std::os::raw::{c_char, c_int, c_void}; +use std::rc::Rc; + +use crate::ffi; +use crate::types::{ToSql, ToSqlOutput, Value}; +use crate::vtab::{ + eponymous_only_module, Context, IndexConstraintOp, IndexInfo, VTab, VTabConnection, VTabCursor, + Values, +}; +use crate::{Connection, Result}; + +// http://sqlite.org/bindptr.html + +pub(crate) const ARRAY_TYPE: *const c_char = b"rarray\0" as *const u8 as *const c_char; + +pub(crate) unsafe extern "C" fn free_array(p: *mut c_void) { + let _: Array = Rc::from_raw(p as *const Vec<Value>); +} + +/// Array parameter / pointer +pub type Array = Rc<Vec<Value>>; + +impl ToSql for Array { + fn to_sql(&self) -> Result<ToSqlOutput<'_>> { + Ok(ToSqlOutput::Array(self.clone())) + } +} + +/// `feature = "array"` Register the "rarray" module. +pub fn load_module(conn: &Connection) -> Result<()> { + let aux: Option<()> = None; + conn.create_module("rarray", eponymous_only_module::<ArrayTab>(), aux) +} + +// Column numbers +// const CARRAY_COLUMN_VALUE : c_int = 0; +const CARRAY_COLUMN_POINTER: c_int = 1; + +/// An instance of the Array virtual table +#[repr(C)] +struct ArrayTab { + /// Base class. Must be first + base: ffi::sqlite3_vtab, +} + +unsafe impl<'vtab> VTab<'vtab> for ArrayTab { + type Aux = (); + type Cursor = ArrayTabCursor<'vtab>; + + fn connect( + _: &mut VTabConnection, + _aux: Option<&()>, + _args: &[&[u8]], + ) -> Result<(String, ArrayTab)> { + let vtab = ArrayTab { + base: ffi::sqlite3_vtab::default(), + }; + Ok(("CREATE TABLE x(value,pointer hidden)".to_owned(), vtab)) + } + + fn best_index(&self, info: &mut IndexInfo) -> Result<()> { + // Index of the pointer= constraint + let mut ptr_idx = None; + for (i, constraint) in info.constraints().enumerate() { + if !constraint.is_usable() { + continue; + } + if constraint.operator() != IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_EQ { + continue; + } + if let CARRAY_COLUMN_POINTER = constraint.column() { + ptr_idx = Some(i); + } + } + if let Some(ptr_idx) = ptr_idx { + { + let mut constraint_usage = info.constraint_usage(ptr_idx); + constraint_usage.set_argv_index(1); + constraint_usage.set_omit(true); + } + info.set_estimated_cost(1f64); + info.set_estimated_rows(100); + info.set_idx_num(1); + } else { + info.set_estimated_cost(2_147_483_647f64); + info.set_estimated_rows(2_147_483_647); + info.set_idx_num(0); + } + Ok(()) + } + + fn open(&self) -> Result<ArrayTabCursor<'_>> { + Ok(ArrayTabCursor::new()) + } +} + +/// A cursor for the Array virtual table +#[repr(C)] +struct ArrayTabCursor<'vtab> { + /// Base class. Must be first + base: ffi::sqlite3_vtab_cursor, + /// The rowid + row_id: i64, + /// Pointer to the array of values ("pointer") + ptr: Option<Array>, + phantom: PhantomData<&'vtab ArrayTab>, +} + +impl ArrayTabCursor<'_> { + fn new<'vtab>() -> ArrayTabCursor<'vtab> { + ArrayTabCursor { + base: ffi::sqlite3_vtab_cursor::default(), + row_id: 0, + ptr: None, + phantom: PhantomData, + } + } + + fn len(&self) -> i64 { + match self.ptr { + Some(ref a) => a.len() as i64, + _ => 0, + } + } +} +unsafe impl VTabCursor for ArrayTabCursor<'_> { + fn filter(&mut self, idx_num: c_int, _idx_str: Option<&str>, args: &Values<'_>) -> Result<()> { + if idx_num > 0 { + self.ptr = args.get_array(0)?; + } else { + self.ptr = None; + } + self.row_id = 1; + Ok(()) + } + + fn next(&mut self) -> Result<()> { + self.row_id += 1; + Ok(()) + } + + fn eof(&self) -> bool { + self.row_id > self.len() + } + + fn column(&self, ctx: &mut Context, i: c_int) -> Result<()> { + match i { + CARRAY_COLUMN_POINTER => Ok(()), + _ => { + if let Some(ref array) = self.ptr { + let value = &array[(self.row_id - 1) as usize]; + ctx.set_result(&value) + } else { + Ok(()) + } + } + } + } + + fn rowid(&self) -> Result<i64> { + Ok(self.row_id) + } +} + +#[cfg(test)] +mod test { + use crate::types::Value; + use crate::vtab::array; + use crate::Connection; + use std::rc::Rc; + + #[test] + fn test_array_module() { + let db = Connection::open_in_memory().unwrap(); + array::load_module(&db).unwrap(); + + let v = vec![1i64, 2, 3, 4]; + let values: Vec<Value> = v.into_iter().map(Value::from).collect(); + let ptr = Rc::new(values); + { + let mut stmt = db.prepare("SELECT value from rarray(?);").unwrap(); + + let rows = stmt.query_map(&[&ptr], |row| row.get::<_, i64>(0)).unwrap(); + assert_eq!(2, Rc::strong_count(&ptr)); + let mut count = 0; + for (i, value) in rows.enumerate() { + assert_eq!(i as i64, value.unwrap() - 1); + count += 1; + } + assert_eq!(4, count); + } + assert_eq!(1, Rc::strong_count(&ptr)); + } +} diff --git a/third_party/rust/rusqlite/src/vtab/csvtab.rs b/third_party/rust/rusqlite/src/vtab/csvtab.rs new file mode 100644 index 0000000000..79ec5dab34 --- /dev/null +++ b/third_party/rust/rusqlite/src/vtab/csvtab.rs @@ -0,0 +1,414 @@ +//! `feature = "csvtab"` CSV Virtual Table. +//! +//! Port of [csv](http://www.sqlite.org/cgi/src/finfo?name=ext/misc/csv.c) C +//! extension: https://www.sqlite.org/csv.html +//! +//! # Example +//! +//! ```rust,no_run +//! # use rusqlite::{Connection, Result}; +//! fn example() -> Result<()> { +//! // Note: This should be done once (usually when opening the DB). +//! let db = Connection::open_in_memory()?; +//! rusqlite::vtab::csvtab::load_module(&db)?; +//! // Assum3e my_csv.csv +//! let schema = " +//! CREATE VIRTUAL TABLE my_csv_data +//! USING csv(filename = 'my_csv.csv') +//! "; +//! db.execute_batch(schema)?; +//! // Now the `my_csv_data` (virtual) table can be queried as normal... +//! Ok(()) +//! } +//! ``` +use std::fs::File; +use std::marker::PhantomData; +use std::os::raw::c_int; +use std::path::Path; +use std::str; + +use crate::ffi; +use crate::types::Null; +use crate::vtab::{ + dequote, escape_double_quote, parse_boolean, read_only_module, Context, CreateVTab, IndexInfo, + VTab, VTabConnection, VTabCursor, Values, +}; +use crate::{Connection, Error, Result}; + +/// `feature = "csvtab"` Register the "csv" module. +/// ```sql +/// CREATE VIRTUAL TABLE vtab USING csv( +/// filename=FILENAME -- Name of file containing CSV content +/// [, schema=SCHEMA] -- Alternative CSV schema. 'CREATE TABLE x(col1 TEXT NOT NULL, col2 INT, ...);' +/// [, header=YES|NO] -- First row of CSV defines the names of columns if "yes". Default "no". +/// [, columns=N] -- Assume the CSV file contains N columns. +/// [, delimiter=C] -- CSV delimiter. Default ','. +/// [, quote=C] -- CSV quote. Default '"'. 0 means no quote. +/// ); +/// ``` +pub fn load_module(conn: &Connection) -> Result<()> { + let aux: Option<()> = None; + conn.create_module("csv", read_only_module::<CSVTab>(), aux) +} + +/// An instance of the CSV virtual table +#[repr(C)] +struct CSVTab { + /// Base class. Must be first + base: ffi::sqlite3_vtab, + /// Name of the CSV file + filename: String, + has_headers: bool, + delimiter: u8, + quote: u8, + /// Offset to start of data + offset_first_row: csv::Position, +} + +impl CSVTab { + fn reader(&self) -> Result<csv::Reader<File>, csv::Error> { + csv::ReaderBuilder::new() + .has_headers(self.has_headers) + .delimiter(self.delimiter) + .quote(self.quote) + .from_path(&self.filename) + } + + fn parameter(c_slice: &[u8]) -> Result<(&str, &str)> { + let arg = str::from_utf8(c_slice)?.trim(); + let mut split = arg.split('='); + if let Some(key) = split.next() { + if let Some(value) = split.next() { + let param = key.trim(); + let value = dequote(value); + return Ok((param, value)); + } + } + Err(Error::ModuleError(format!("illegal argument: '{}'", arg))) + } + + fn parse_byte(arg: &str) -> Option<u8> { + if arg.len() == 1 { + arg.bytes().next() + } else { + None + } + } +} + +unsafe impl<'vtab> VTab<'vtab> for CSVTab { + type Aux = (); + type Cursor = CSVTabCursor<'vtab>; + + fn connect( + _: &mut VTabConnection, + _aux: Option<&()>, + args: &[&[u8]], + ) -> Result<(String, CSVTab)> { + if args.len() < 4 { + return Err(Error::ModuleError("no CSV file specified".to_owned())); + } + + let mut vtab = CSVTab { + base: ffi::sqlite3_vtab::default(), + filename: "".to_owned(), + has_headers: false, + delimiter: b',', + quote: b'"', + offset_first_row: csv::Position::new(), + }; + let mut schema = None; + let mut n_col = None; + + let args = &args[3..]; + for c_slice in args { + let (param, value) = CSVTab::parameter(c_slice)?; + match param { + "filename" => { + if !Path::new(value).exists() { + return Err(Error::ModuleError(format!( + "file '{}' does not exist", + value + ))); + } + vtab.filename = value.to_owned(); + } + "schema" => { + schema = Some(value.to_owned()); + } + "columns" => { + if let Ok(n) = value.parse::<u16>() { + if n_col.is_some() { + return Err(Error::ModuleError( + "more than one 'columns' parameter".to_owned(), + )); + } else if n == 0 { + return Err(Error::ModuleError( + "must have at least one column".to_owned(), + )); + } + n_col = Some(n); + } else { + return Err(Error::ModuleError(format!( + "unrecognized argument to 'columns': {}", + value + ))); + } + } + "header" => { + if let Some(b) = parse_boolean(value) { + vtab.has_headers = b; + } else { + return Err(Error::ModuleError(format!( + "unrecognized argument to 'header': {}", + value + ))); + } + } + "delimiter" => { + if let Some(b) = CSVTab::parse_byte(value) { + vtab.delimiter = b; + } else { + return Err(Error::ModuleError(format!( + "unrecognized argument to 'delimiter': {}", + value + ))); + } + } + "quote" => { + if let Some(b) = CSVTab::parse_byte(value) { + if b == b'0' { + vtab.quote = 0; + } else { + vtab.quote = b; + } + } else { + return Err(Error::ModuleError(format!( + "unrecognized argument to 'quote': {}", + value + ))); + } + } + _ => { + return Err(Error::ModuleError(format!( + "unrecognized parameter '{}'", + param + ))); + } + } + } + + if vtab.filename.is_empty() { + return Err(Error::ModuleError("no CSV file specified".to_owned())); + } + + let mut cols: Vec<String> = Vec::new(); + if vtab.has_headers || (n_col.is_none() && schema.is_none()) { + let mut reader = vtab.reader()?; + if vtab.has_headers { + { + let headers = reader.headers()?; + // headers ignored if cols is not empty + if n_col.is_none() && schema.is_none() { + cols = headers + .into_iter() + .map(|header| escape_double_quote(&header).into_owned()) + .collect(); + } + } + vtab.offset_first_row = reader.position().clone(); + } else { + let mut record = csv::ByteRecord::new(); + if reader.read_byte_record(&mut record)? { + for (i, _) in record.iter().enumerate() { + cols.push(format!("c{}", i)); + } + } + } + } else if let Some(n_col) = n_col { + for i in 0..n_col { + cols.push(format!("c{}", i)); + } + } + + if cols.is_empty() && schema.is_none() { + return Err(Error::ModuleError("no column specified".to_owned())); + } + + if schema.is_none() { + let mut sql = String::from("CREATE TABLE x("); + for (i, col) in cols.iter().enumerate() { + sql.push('"'); + sql.push_str(col); + sql.push_str("\" TEXT"); + if i == cols.len() - 1 { + sql.push_str(");"); + } else { + sql.push_str(", "); + } + } + schema = Some(sql); + } + + Ok((schema.unwrap(), vtab)) + } + + // Only a forward full table scan is supported. + fn best_index(&self, info: &mut IndexInfo) -> Result<()> { + info.set_estimated_cost(1_000_000.); + Ok(()) + } + + fn open(&self) -> Result<CSVTabCursor<'_>> { + Ok(CSVTabCursor::new(self.reader()?)) + } +} + +impl CreateVTab<'_> for CSVTab {} + +/// A cursor for the CSV virtual table +#[repr(C)] +struct CSVTabCursor<'vtab> { + /// Base class. Must be first + base: ffi::sqlite3_vtab_cursor, + /// The CSV reader object + reader: csv::Reader<File>, + /// Current cursor position used as rowid + row_number: usize, + /// Values of the current row + cols: csv::StringRecord, + eof: bool, + phantom: PhantomData<&'vtab CSVTab>, +} + +impl CSVTabCursor<'_> { + fn new<'vtab>(reader: csv::Reader<File>) -> CSVTabCursor<'vtab> { + CSVTabCursor { + base: ffi::sqlite3_vtab_cursor::default(), + reader, + row_number: 0, + cols: csv::StringRecord::new(), + eof: false, + phantom: PhantomData, + } + } + + /// Accessor to the associated virtual table. + fn vtab(&self) -> &CSVTab { + unsafe { &*(self.base.pVtab as *const CSVTab) } + } +} + +unsafe impl VTabCursor for CSVTabCursor<'_> { + // Only a full table scan is supported. So `filter` simply rewinds to + // the beginning. + fn filter( + &mut self, + _idx_num: c_int, + _idx_str: Option<&str>, + _args: &Values<'_>, + ) -> Result<()> { + { + let offset_first_row = self.vtab().offset_first_row.clone(); + self.reader.seek(offset_first_row)?; + } + self.row_number = 0; + self.next() + } + + fn next(&mut self) -> Result<()> { + { + self.eof = self.reader.is_done(); + if self.eof { + return Ok(()); + } + + self.eof = !self.reader.read_record(&mut self.cols)?; + } + + self.row_number += 1; + Ok(()) + } + + fn eof(&self) -> bool { + self.eof + } + + fn column(&self, ctx: &mut Context, col: c_int) -> Result<()> { + if col < 0 || col as usize >= self.cols.len() { + return Err(Error::ModuleError(format!( + "column index out of bounds: {}", + col + ))); + } + if self.cols.is_empty() { + return ctx.set_result(&Null); + } + // TODO Affinity + ctx.set_result(&self.cols[col as usize].to_owned()) + } + + fn rowid(&self) -> Result<i64> { + Ok(self.row_number as i64) + } +} + +impl From<csv::Error> for Error { + fn from(err: csv::Error) -> Error { + Error::ModuleError(err.to_string()) + } +} + +#[cfg(test)] +mod test { + use crate::vtab::csvtab; + use crate::{Connection, Result, NO_PARAMS}; + use fallible_iterator::FallibleIterator; + + #[test] + fn test_csv_module() { + let db = Connection::open_in_memory().unwrap(); + csvtab::load_module(&db).unwrap(); + db.execute_batch("CREATE VIRTUAL TABLE vtab USING csv(filename='test.csv', header=yes)") + .unwrap(); + + { + let mut s = db.prepare("SELECT rowid, * FROM vtab").unwrap(); + { + let headers = s.column_names(); + assert_eq!(vec!["rowid", "colA", "colB", "colC"], headers); + } + + let ids: Result<Vec<i32>> = s + .query(NO_PARAMS) + .unwrap() + .map(|row| row.get::<_, i32>(0)) + .collect(); + let sum = ids.unwrap().iter().sum::<i32>(); + assert_eq!(sum, 15); + } + db.execute_batch("DROP TABLE vtab").unwrap(); + } + + #[test] + fn test_csv_cursor() { + let db = Connection::open_in_memory().unwrap(); + csvtab::load_module(&db).unwrap(); + db.execute_batch("CREATE VIRTUAL TABLE vtab USING csv(filename='test.csv', header=yes)") + .unwrap(); + + { + let mut s = db + .prepare( + "SELECT v1.rowid, v1.* FROM vtab v1 NATURAL JOIN vtab v2 WHERE \ + v1.rowid < v2.rowid", + ) + .unwrap(); + + let mut rows = s.query(NO_PARAMS).unwrap(); + let row = rows.next().unwrap().unwrap(); + assert_eq!(row.get_unwrap::<_, i32>(0), 2); + } + db.execute_batch("DROP TABLE vtab").unwrap(); + } +} diff --git a/third_party/rust/rusqlite/src/vtab/mod.rs b/third_party/rust/rusqlite/src/vtab/mod.rs new file mode 100644 index 0000000000..dc3bda6481 --- /dev/null +++ b/third_party/rust/rusqlite/src/vtab/mod.rs @@ -0,0 +1,1076 @@ +//! `feature = "vtab"` Create virtual tables. +//! +//! Follow these steps to create your own virtual table: +//! 1. Write implemenation of `VTab` and `VTabCursor` traits. +//! 2. Create an instance of the `Module` structure specialized for `VTab` impl. +//! from step 1. +//! 3. Register your `Module` structure using `Connection.create_module`. +//! 4. Run a `CREATE VIRTUAL TABLE` command that specifies the new module in the +//! `USING` clause. +//! +//! (See [SQLite doc](http://sqlite.org/vtab.html)) +use std::borrow::Cow::{self, Borrowed, Owned}; +use std::marker::PhantomData; +use std::marker::Sync; +use std::os::raw::{c_char, c_int, c_void}; +use std::ptr; +use std::slice; + +use crate::context::set_result; +use crate::error::error_from_sqlite_code; +use crate::ffi; +pub use crate::ffi::{sqlite3_vtab, sqlite3_vtab_cursor}; +use crate::types::{FromSql, FromSqlError, ToSql, ValueRef}; +use crate::{str_to_cstring, Connection, Error, InnerConnection, Result}; + +// let conn: Connection = ...; +// let mod: Module = ...; // VTab builder +// conn.create_module("module", mod); +// +// conn.execute("CREATE VIRTUAL TABLE foo USING module(...)"); +// \-> Module::xcreate +// |-> let vtab: VTab = ...; // on the heap +// \-> conn.declare_vtab("CREATE TABLE foo (...)"); +// conn = Connection::open(...); +// \-> Module::xconnect +// |-> let vtab: VTab = ...; // on the heap +// \-> conn.declare_vtab("CREATE TABLE foo (...)"); +// +// conn.close(); +// \-> vtab.xdisconnect +// conn.execute("DROP TABLE foo"); +// \-> vtab.xDestroy +// +// let stmt = conn.prepare("SELECT ... FROM foo WHERE ..."); +// \-> vtab.xbestindex +// stmt.query().next(); +// \-> vtab.xopen +// |-> let cursor: VTabCursor = ...; // on the heap +// |-> cursor.xfilter or xnext +// |-> cursor.xeof +// \-> if not eof { cursor.column or xrowid } else { cursor.xclose } +// + +// db: *mut ffi::sqlite3 => VTabConnection +// module: *const ffi::sqlite3_module => Module +// aux: *mut c_void => Module::Aux +// ffi::sqlite3_vtab => VTab +// ffi::sqlite3_vtab_cursor => VTabCursor + +/// `feature = "vtab"` Virtual table module +/// +/// (See [SQLite doc](https://sqlite.org/c3ref/module.html)) +#[repr(transparent)] +pub struct Module<'vtab, T: VTab<'vtab>> { + base: ffi::sqlite3_module, + phantom: PhantomData<&'vtab T>, +} + +unsafe impl<'vtab, T: VTab<'vtab>> Send for Module<'vtab, T> {} +unsafe impl<'vtab, T: VTab<'vtab>> Sync for Module<'vtab, T> {} + +union ModuleZeroHack { + bytes: [u8; std::mem::size_of::<ffi::sqlite3_module>()], + module: ffi::sqlite3_module, +} + +// Used as a trailing initializer for sqlite3_module -- this way we avoid having +// the build fail if buildtime_bindgen is on. This is safe, as bindgen-generated +// structs are allowed to be zeroed. +const ZERO_MODULE: ffi::sqlite3_module = unsafe { + ModuleZeroHack { + bytes: [0u8; std::mem::size_of::<ffi::sqlite3_module>()], + } + .module +}; + +/// `feature = "vtab"` Create a read-only virtual table implementation. +/// +/// Step 2 of [Creating New Virtual Table Implementations](https://sqlite.org/vtab.html#creating_new_virtual_table_implementations). +pub fn read_only_module<'vtab, T: CreateVTab<'vtab>>() -> &'static Module<'vtab, T> { + // The xConnect and xCreate methods do the same thing, but they must be + // different so that the virtual table is not an eponymous virtual table. + &Module { + base: ffi::sqlite3_module { + // We don't use V3 + iVersion: 2, // We don't use V2 or V3 features in read_only_module types + xCreate: Some(rust_create::<T>), + xConnect: Some(rust_connect::<T>), + xBestIndex: Some(rust_best_index::<T>), + xDisconnect: Some(rust_disconnect::<T>), + xDestroy: Some(rust_destroy::<T>), + xOpen: Some(rust_open::<T>), + xClose: Some(rust_close::<T::Cursor>), + xFilter: Some(rust_filter::<T::Cursor>), + xNext: Some(rust_next::<T::Cursor>), + xEof: Some(rust_eof::<T::Cursor>), + xColumn: Some(rust_column::<T::Cursor>), + xRowid: Some(rust_rowid::<T::Cursor>), + xUpdate: None, + xBegin: None, + xSync: None, + xCommit: None, + xRollback: None, + xFindFunction: None, + xRename: None, + xSavepoint: None, + xRelease: None, + xRollbackTo: None, + ..ZERO_MODULE + }, + phantom: PhantomData::<&'vtab T>, + } +} + +/// `feature = "vtab"` Create an eponymous only virtual table implementation. +/// +/// Step 2 of [Creating New Virtual Table Implementations](https://sqlite.org/vtab.html#creating_new_virtual_table_implementations). +pub fn eponymous_only_module<'vtab, T: VTab<'vtab>>() -> &'static Module<'vtab, T> { + // A virtual table is eponymous if its xCreate method is the exact same function + // as the xConnect method For eponymous-only virtual tables, the xCreate + // method is NULL + &Module { + base: ffi::sqlite3_module { + // We don't use V3 + iVersion: 2, + xCreate: None, + xConnect: Some(rust_connect::<T>), + xBestIndex: Some(rust_best_index::<T>), + xDisconnect: Some(rust_disconnect::<T>), + xDestroy: None, + xOpen: Some(rust_open::<T>), + xClose: Some(rust_close::<T::Cursor>), + xFilter: Some(rust_filter::<T::Cursor>), + xNext: Some(rust_next::<T::Cursor>), + xEof: Some(rust_eof::<T::Cursor>), + xColumn: Some(rust_column::<T::Cursor>), + xRowid: Some(rust_rowid::<T::Cursor>), + xUpdate: None, + xBegin: None, + xSync: None, + xCommit: None, + xRollback: None, + xFindFunction: None, + xRename: None, + xSavepoint: None, + xRelease: None, + xRollbackTo: None, + ..ZERO_MODULE + }, + phantom: PhantomData::<&'vtab T>, + } +} + +/// `feature = "vtab"` +pub struct VTabConnection(*mut ffi::sqlite3); + +impl VTabConnection { + // TODO sqlite3_vtab_config (http://sqlite.org/c3ref/vtab_config.html) + + // TODO sqlite3_vtab_on_conflict (http://sqlite.org/c3ref/vtab_on_conflict.html) + + /// Get access to the underlying SQLite database connection handle. + /// + /// # Warning + /// + /// You should not need to use this function. If you do need to, please + /// [open an issue on the rusqlite repository](https://github.com/rusqlite/rusqlite/issues) and describe + /// your use case. + /// + /// # Safety + /// + /// This function is unsafe because it gives you raw access + /// to the SQLite connection, and what you do with it could impact the + /// safety of this `Connection`. + pub unsafe fn handle(&mut self) -> *mut ffi::sqlite3 { + self.0 + } +} + +/// `feature = "vtab"` Virtual table instance trait. +/// +/// # Safety +/// +/// The first item in a struct implementing VTab must be +/// `rusqlite::sqlite3_vtab`, and the struct must be `#[repr(C)]`. +/// +/// ```rust,ignore +/// #[repr(C)] +/// struct MyTab { +/// /// Base class. Must be first +/// base: ffi::sqlite3_vtab, +/// /* Virtual table implementations will typically add additional fields */ +/// } +/// ``` +/// +/// (See [SQLite doc](https://sqlite.org/c3ref/vtab.html)) +pub unsafe trait VTab<'vtab>: Sized { + /// Client data passed to `Connection::create_module`. + type Aux; + /// Specific cursor implementation + type Cursor: VTabCursor; + + /// Establish a new connection to an existing virtual table. + /// + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xconnect_method)) + fn connect( + db: &mut VTabConnection, + aux: Option<&Self::Aux>, + args: &[&[u8]], + ) -> Result<(String, Self)>; + + /// Determine the best way to access the virtual table. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xbestindex_method)) + fn best_index(&self, info: &mut IndexInfo) -> Result<()>; + + /// Create a new cursor used for accessing a virtual table. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xopen_method)) + fn open(&'vtab self) -> Result<Self::Cursor>; +} + +/// `feature = "vtab"` Non-eponymous virtual table instance trait. +/// +/// (See [SQLite doc](https://sqlite.org/c3ref/vtab.html)) +pub trait CreateVTab<'vtab>: VTab<'vtab> { + /// Create a new instance of a virtual table in response to a CREATE VIRTUAL + /// TABLE statement. The `db` parameter is a pointer to the SQLite + /// database connection that is executing the CREATE VIRTUAL TABLE + /// statement. + /// + /// Call `connect` by default. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xcreate_method)) + fn create( + db: &mut VTabConnection, + aux: Option<&Self::Aux>, + args: &[&[u8]], + ) -> Result<(String, Self)> { + Self::connect(db, aux, args) + } + + /// Destroy the underlying table implementation. This method undoes the work + /// of `create`. + /// + /// Do nothing by default. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xdestroy_method)) + fn destroy(&self) -> Result<()> { + Ok(()) + } +} + +/// `feature = "vtab"` Index constraint operator. +/// See [Virtual Table Constraint Operator Codes](https://sqlite.org/c3ref/c_index_constraint_eq.html) for details. +#[derive(Debug, PartialEq)] +#[allow(non_snake_case, non_camel_case_types, missing_docs)] +#[non_exhaustive] +pub enum IndexConstraintOp { + SQLITE_INDEX_CONSTRAINT_EQ, + SQLITE_INDEX_CONSTRAINT_GT, + SQLITE_INDEX_CONSTRAINT_LE, + SQLITE_INDEX_CONSTRAINT_LT, + SQLITE_INDEX_CONSTRAINT_GE, + SQLITE_INDEX_CONSTRAINT_MATCH, + SQLITE_INDEX_CONSTRAINT_LIKE, // 3.10.0 + SQLITE_INDEX_CONSTRAINT_GLOB, // 3.10.0 + SQLITE_INDEX_CONSTRAINT_REGEXP, // 3.10.0 + SQLITE_INDEX_CONSTRAINT_NE, // 3.21.0 + SQLITE_INDEX_CONSTRAINT_ISNOT, // 3.21.0 + SQLITE_INDEX_CONSTRAINT_ISNOTNULL, // 3.21.0 + SQLITE_INDEX_CONSTRAINT_ISNULL, // 3.21.0 + SQLITE_INDEX_CONSTRAINT_IS, // 3.21.0 + SQLITE_INDEX_CONSTRAINT_FUNCTION(u8), // 3.25.0 +} + +impl From<u8> for IndexConstraintOp { + fn from(code: u8) -> IndexConstraintOp { + match code { + 2 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_EQ, + 4 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_GT, + 8 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_LE, + 16 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_LT, + 32 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_GE, + 64 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_MATCH, + 65 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_LIKE, + 66 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_GLOB, + 67 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_REGEXP, + 68 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_NE, + 69 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_ISNOT, + 70 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_ISNOTNULL, + 71 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_ISNULL, + 72 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_IS, + v => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_FUNCTION(v), + } + } +} + +/// `feature = "vtab"` Pass information into and receive the reply from the +/// `VTab.best_index` method. +/// +/// (See [SQLite doc](http://sqlite.org/c3ref/index_info.html)) +pub struct IndexInfo(*mut ffi::sqlite3_index_info); + +impl IndexInfo { + /// Record WHERE clause constraints. + pub fn constraints(&self) -> IndexConstraintIter<'_> { + let constraints = + unsafe { slice::from_raw_parts((*self.0).aConstraint, (*self.0).nConstraint as usize) }; + IndexConstraintIter { + iter: constraints.iter(), + } + } + + /// Information about the ORDER BY clause. + pub fn order_bys(&self) -> OrderByIter<'_> { + let order_bys = + unsafe { slice::from_raw_parts((*self.0).aOrderBy, (*self.0).nOrderBy as usize) }; + OrderByIter { + iter: order_bys.iter(), + } + } + + /// Number of terms in the ORDER BY clause + pub fn num_of_order_by(&self) -> usize { + unsafe { (*self.0).nOrderBy as usize } + } + + /// Information about what parameters to pass to `VTabCursor.filter`. + pub fn constraint_usage(&mut self, constraint_idx: usize) -> IndexConstraintUsage<'_> { + let constraint_usages = unsafe { + slice::from_raw_parts_mut((*self.0).aConstraintUsage, (*self.0).nConstraint as usize) + }; + IndexConstraintUsage(&mut constraint_usages[constraint_idx]) + } + + /// Number used to identify the index + pub fn set_idx_num(&mut self, idx_num: c_int) { + unsafe { + (*self.0).idxNum = idx_num; + } + } + + /// True if output is already ordered + pub fn set_order_by_consumed(&mut self, order_by_consumed: bool) { + unsafe { + (*self.0).orderByConsumed = if order_by_consumed { 1 } else { 0 }; + } + } + + /// Estimated cost of using this index + pub fn set_estimated_cost(&mut self, estimated_ost: f64) { + unsafe { + (*self.0).estimatedCost = estimated_ost; + } + } + + /// Estimated number of rows returned. + #[cfg(feature = "modern_sqlite")] // SQLite >= 3.8.2 + pub fn set_estimated_rows(&mut self, estimated_rows: i64) { + unsafe { + (*self.0).estimatedRows = estimated_rows; + } + } + + // TODO idxFlags + // TODO colUsed + + // TODO sqlite3_vtab_collation (http://sqlite.org/c3ref/vtab_collation.html) +} + +/// `feature = "vtab"` +pub struct IndexConstraintIter<'a> { + iter: slice::Iter<'a, ffi::sqlite3_index_constraint>, +} + +impl<'a> Iterator for IndexConstraintIter<'a> { + type Item = IndexConstraint<'a>; + + fn next(&mut self) -> Option<IndexConstraint<'a>> { + self.iter.next().map(|raw| IndexConstraint(raw)) + } + + fn size_hint(&self) -> (usize, Option<usize>) { + self.iter.size_hint() + } +} + +/// `feature = "vtab"` WHERE clause constraint. +pub struct IndexConstraint<'a>(&'a ffi::sqlite3_index_constraint); + +impl IndexConstraint<'_> { + /// Column constrained. -1 for ROWID + pub fn column(&self) -> c_int { + self.0.iColumn + } + + /// Constraint operator + pub fn operator(&self) -> IndexConstraintOp { + IndexConstraintOp::from(self.0.op) + } + + /// True if this constraint is usable + pub fn is_usable(&self) -> bool { + self.0.usable != 0 + } +} + +/// `feature = "vtab"` Information about what parameters to pass to +/// `VTabCursor.filter`. +pub struct IndexConstraintUsage<'a>(&'a mut ffi::sqlite3_index_constraint_usage); + +impl IndexConstraintUsage<'_> { + /// if `argv_index` > 0, constraint is part of argv to `VTabCursor.filter` + pub fn set_argv_index(&mut self, argv_index: c_int) { + self.0.argvIndex = argv_index; + } + + /// if `omit`, do not code a test for this constraint + pub fn set_omit(&mut self, omit: bool) { + self.0.omit = if omit { 1 } else { 0 }; + } +} + +/// `feature = "vtab"` +pub struct OrderByIter<'a> { + iter: slice::Iter<'a, ffi::sqlite3_index_info_sqlite3_index_orderby>, +} + +impl<'a> Iterator for OrderByIter<'a> { + type Item = OrderBy<'a>; + + fn next(&mut self) -> Option<OrderBy<'a>> { + self.iter.next().map(|raw| OrderBy(raw)) + } + + fn size_hint(&self) -> (usize, Option<usize>) { + self.iter.size_hint() + } +} + +/// `feature = "vtab"` A column of the ORDER BY clause. +pub struct OrderBy<'a>(&'a ffi::sqlite3_index_info_sqlite3_index_orderby); + +impl OrderBy<'_> { + /// Column number + pub fn column(&self) -> c_int { + self.0.iColumn + } + + /// True for DESC. False for ASC. + pub fn is_order_by_desc(&self) -> bool { + self.0.desc != 0 + } +} + +/// `feature = "vtab"` Virtual table cursor trait. +/// +/// Implementations must be like: +/// ```rust,ignore +/// #[repr(C)] +/// struct MyTabCursor { +/// /// Base class. Must be first +/// base: ffi::sqlite3_vtab_cursor, +/// /* Virtual table implementations will typically add additional fields */ +/// } +/// ``` +/// +/// (See [SQLite doc](https://sqlite.org/c3ref/vtab_cursor.html)) +pub unsafe trait VTabCursor: Sized { + /// Begin a search of a virtual table. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xfilter_method)) + fn filter(&mut self, idx_num: c_int, idx_str: Option<&str>, args: &Values<'_>) -> Result<()>; + /// Advance cursor to the next row of a result set initiated by `filter`. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xnext_method)) + fn next(&mut self) -> Result<()>; + /// Must return `false` if the cursor currently points to a valid row of + /// data, or `true` otherwise. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xeof_method)) + fn eof(&self) -> bool; + /// Find the value for the `i`-th column of the current row. + /// `i` is zero-based so the first column is numbered 0. + /// May return its result back to SQLite using one of the specified `ctx`. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xcolumn_method)) + fn column(&self, ctx: &mut Context, i: c_int) -> Result<()>; + /// Return the rowid of row that the cursor is currently pointing at. + /// (See [SQLite doc](https://sqlite.org/vtab.html#the_xrowid_method)) + fn rowid(&self) -> Result<i64>; +} + +/// `feature = "vtab"` Context is used by `VTabCursor.column` to specify the +/// cell value. +pub struct Context(*mut ffi::sqlite3_context); + +impl Context { + /// Set current cell value + pub fn set_result<T: ToSql>(&mut self, value: &T) -> Result<()> { + let t = value.to_sql()?; + unsafe { set_result(self.0, &t) }; + Ok(()) + } + + // TODO sqlite3_vtab_nochange (http://sqlite.org/c3ref/vtab_nochange.html) +} + +/// `feature = "vtab"` Wrapper to `VTabCursor.filter` arguments, the values +/// requested by `VTab.best_index`. +pub struct Values<'a> { + args: &'a [*mut ffi::sqlite3_value], +} + +impl Values<'_> { + /// Returns the number of values. + pub fn len(&self) -> usize { + self.args.len() + } + + /// Returns `true` if there is no value. + pub fn is_empty(&self) -> bool { + self.args.is_empty() + } + + /// Returns value at `idx` + pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> { + let arg = self.args[idx]; + let value = unsafe { ValueRef::from_value(arg) }; + FromSql::column_result(value).map_err(|err| match err { + FromSqlError::InvalidType => Error::InvalidFilterParameterType(idx, value.data_type()), + FromSqlError::Other(err) => { + Error::FromSqlConversionFailure(idx, value.data_type(), err) + } + FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i), + #[cfg(feature = "i128_blob")] + FromSqlError::InvalidI128Size(_) => { + Error::InvalidColumnType(idx, idx.to_string(), value.data_type()) + } + #[cfg(feature = "uuid")] + FromSqlError::InvalidUuidSize(_) => { + Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err)) + } + }) + } + + // `sqlite3_value_type` returns `SQLITE_NULL` for pointer. + // So it seems not possible to enhance `ValueRef::from_value`. + #[cfg(feature = "array")] + fn get_array(&self, idx: usize) -> Result<Option<array::Array>> { + use crate::types::Value; + let arg = self.args[idx]; + let ptr = unsafe { ffi::sqlite3_value_pointer(arg, array::ARRAY_TYPE) }; + if ptr.is_null() { + Ok(None) + } else { + Ok(Some(unsafe { + let rc = array::Array::from_raw(ptr as *const Vec<Value>); + let array = rc.clone(); + array::Array::into_raw(rc); // don't consume it + array + })) + } + } + + /// Turns `Values` into an iterator. + pub fn iter(&self) -> ValueIter<'_> { + ValueIter { + iter: self.args.iter(), + } + } +} + +impl<'a> IntoIterator for &'a Values<'a> { + type IntoIter = ValueIter<'a>; + type Item = ValueRef<'a>; + + fn into_iter(self) -> ValueIter<'a> { + self.iter() + } +} + +/// `Values` iterator. +pub struct ValueIter<'a> { + iter: slice::Iter<'a, *mut ffi::sqlite3_value>, +} + +impl<'a> Iterator for ValueIter<'a> { + type Item = ValueRef<'a>; + + fn next(&mut self) -> Option<ValueRef<'a>> { + self.iter + .next() + .map(|&raw| unsafe { ValueRef::from_value(raw) }) + } + + fn size_hint(&self) -> (usize, Option<usize>) { + self.iter.size_hint() + } +} + +impl Connection { + /// `feature = "vtab"` Register a virtual table implementation. + /// + /// Step 3 of [Creating New Virtual Table + /// Implementations](https://sqlite.org/vtab.html#creating_new_virtual_table_implementations). + pub fn create_module<'vtab, T: VTab<'vtab>>( + &self, + module_name: &str, + module: &'static Module<'vtab, T>, + aux: Option<T::Aux>, + ) -> Result<()> { + self.db.borrow_mut().create_module(module_name, module, aux) + } +} + +impl InnerConnection { + fn create_module<'vtab, T: VTab<'vtab>>( + &mut self, + module_name: &str, + module: &'static Module<'vtab, T>, + aux: Option<T::Aux>, + ) -> Result<()> { + let c_name = str_to_cstring(module_name)?; + let r = match aux { + Some(aux) => { + let boxed_aux: *mut T::Aux = Box::into_raw(Box::new(aux)); + unsafe { + ffi::sqlite3_create_module_v2( + self.db(), + c_name.as_ptr(), + &module.base, + boxed_aux as *mut c_void, + Some(free_boxed_value::<T::Aux>), + ) + } + } + None => unsafe { + ffi::sqlite3_create_module_v2( + self.db(), + c_name.as_ptr(), + &module.base, + ptr::null_mut(), + None, + ) + }, + }; + self.decode_result(r) + } +} + +/// `feature = "vtab"` Escape double-quote (`"`) character occurences by +/// doubling them (`""`). +pub fn escape_double_quote(identifier: &str) -> Cow<'_, str> { + if identifier.contains('"') { + // escape quote by doubling them + Owned(identifier.replace("\"", "\"\"")) + } else { + Borrowed(identifier) + } +} +/// `feature = "vtab"` Dequote string +pub fn dequote(s: &str) -> &str { + if s.len() < 2 { + return s; + } + match s.bytes().next() { + Some(b) if b == b'"' || b == b'\'' => match s.bytes().rev().next() { + Some(e) if e == b => &s[1..s.len() - 1], + _ => s, + }, + _ => s, + } +} +/// `feature = "vtab"` The boolean can be one of: +/// ```text +/// 1 yes true on +/// 0 no false off +/// ``` +pub fn parse_boolean(s: &str) -> Option<bool> { + if s.eq_ignore_ascii_case("yes") + || s.eq_ignore_ascii_case("on") + || s.eq_ignore_ascii_case("true") + || s.eq("1") + { + Some(true) + } else if s.eq_ignore_ascii_case("no") + || s.eq_ignore_ascii_case("off") + || s.eq_ignore_ascii_case("false") + || s.eq("0") + { + Some(false) + } else { + None + } +} + +// FIXME copy/paste from function.rs +unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) { + let _: Box<T> = Box::from_raw(p as *mut T); +} + +unsafe extern "C" fn rust_create<'vtab, T>( + db: *mut ffi::sqlite3, + aux: *mut c_void, + argc: c_int, + argv: *const *const c_char, + pp_vtab: *mut *mut ffi::sqlite3_vtab, + err_msg: *mut *mut c_char, +) -> c_int +where + T: CreateVTab<'vtab>, +{ + use std::ffi::CStr; + + let mut conn = VTabConnection(db); + let aux = aux as *mut T::Aux; + let args = slice::from_raw_parts(argv, argc as usize); + let vec = args + .iter() + .map(|&cs| CStr::from_ptr(cs).to_bytes()) // FIXME .to_str() -> Result<&str, Utf8Error> + .collect::<Vec<_>>(); + match T::create(&mut conn, aux.as_ref(), &vec[..]) { + Ok((sql, vtab)) => match ::std::ffi::CString::new(sql) { + Ok(c_sql) => { + let rc = ffi::sqlite3_declare_vtab(db, c_sql.as_ptr()); + if rc == ffi::SQLITE_OK { + let boxed_vtab: *mut T = Box::into_raw(Box::new(vtab)); + *pp_vtab = boxed_vtab as *mut ffi::sqlite3_vtab; + ffi::SQLITE_OK + } else { + let err = error_from_sqlite_code(rc, None); + *err_msg = alloc(&err.to_string()); + rc + } + } + Err(err) => { + *err_msg = alloc(&err.to_string()); + ffi::SQLITE_ERROR + } + }, + Err(Error::SqliteFailure(err, s)) => { + if let Some(s) = s { + *err_msg = alloc(&s); + } + err.extended_code + } + Err(err) => { + *err_msg = alloc(&err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +unsafe extern "C" fn rust_connect<'vtab, T>( + db: *mut ffi::sqlite3, + aux: *mut c_void, + argc: c_int, + argv: *const *const c_char, + pp_vtab: *mut *mut ffi::sqlite3_vtab, + err_msg: *mut *mut c_char, +) -> c_int +where + T: VTab<'vtab>, +{ + use std::ffi::CStr; + + let mut conn = VTabConnection(db); + let aux = aux as *mut T::Aux; + let args = slice::from_raw_parts(argv, argc as usize); + let vec = args + .iter() + .map(|&cs| CStr::from_ptr(cs).to_bytes()) // FIXME .to_str() -> Result<&str, Utf8Error> + .collect::<Vec<_>>(); + match T::connect(&mut conn, aux.as_ref(), &vec[..]) { + Ok((sql, vtab)) => match ::std::ffi::CString::new(sql) { + Ok(c_sql) => { + let rc = ffi::sqlite3_declare_vtab(db, c_sql.as_ptr()); + if rc == ffi::SQLITE_OK { + let boxed_vtab: *mut T = Box::into_raw(Box::new(vtab)); + *pp_vtab = boxed_vtab as *mut ffi::sqlite3_vtab; + ffi::SQLITE_OK + } else { + let err = error_from_sqlite_code(rc, None); + *err_msg = alloc(&err.to_string()); + rc + } + } + Err(err) => { + *err_msg = alloc(&err.to_string()); + ffi::SQLITE_ERROR + } + }, + Err(Error::SqliteFailure(err, s)) => { + if let Some(s) = s { + *err_msg = alloc(&s); + } + err.extended_code + } + Err(err) => { + *err_msg = alloc(&err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +unsafe extern "C" fn rust_best_index<'vtab, T>( + vtab: *mut ffi::sqlite3_vtab, + info: *mut ffi::sqlite3_index_info, +) -> c_int +where + T: VTab<'vtab>, +{ + let vt = vtab as *mut T; + let mut idx_info = IndexInfo(info); + match (*vt).best_index(&mut idx_info) { + Ok(_) => ffi::SQLITE_OK, + Err(Error::SqliteFailure(err, s)) => { + if let Some(err_msg) = s { + set_err_msg(vtab, &err_msg); + } + err.extended_code + } + Err(err) => { + set_err_msg(vtab, &err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +unsafe extern "C" fn rust_disconnect<'vtab, T>(vtab: *mut ffi::sqlite3_vtab) -> c_int +where + T: VTab<'vtab>, +{ + if vtab.is_null() { + return ffi::SQLITE_OK; + } + let vtab = vtab as *mut T; + let _: Box<T> = Box::from_raw(vtab); + ffi::SQLITE_OK +} + +unsafe extern "C" fn rust_destroy<'vtab, T>(vtab: *mut ffi::sqlite3_vtab) -> c_int +where + T: CreateVTab<'vtab>, +{ + if vtab.is_null() { + return ffi::SQLITE_OK; + } + let vt = vtab as *mut T; + match (*vt).destroy() { + Ok(_) => { + let _: Box<T> = Box::from_raw(vt); + ffi::SQLITE_OK + } + Err(Error::SqliteFailure(err, s)) => { + if let Some(err_msg) = s { + set_err_msg(vtab, &err_msg); + } + err.extended_code + } + Err(err) => { + set_err_msg(vtab, &err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +unsafe extern "C" fn rust_open<'vtab, T: 'vtab>( + vtab: *mut ffi::sqlite3_vtab, + pp_cursor: *mut *mut ffi::sqlite3_vtab_cursor, +) -> c_int +where + T: VTab<'vtab>, +{ + let vt = vtab as *mut T; + match (*vt).open() { + Ok(cursor) => { + let boxed_cursor: *mut T::Cursor = Box::into_raw(Box::new(cursor)); + *pp_cursor = boxed_cursor as *mut ffi::sqlite3_vtab_cursor; + ffi::SQLITE_OK + } + Err(Error::SqliteFailure(err, s)) => { + if let Some(err_msg) = s { + set_err_msg(vtab, &err_msg); + } + err.extended_code + } + Err(err) => { + set_err_msg(vtab, &err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +unsafe extern "C" fn rust_close<C>(cursor: *mut ffi::sqlite3_vtab_cursor) -> c_int +where + C: VTabCursor, +{ + let cr = cursor as *mut C; + let _: Box<C> = Box::from_raw(cr); + ffi::SQLITE_OK +} + +unsafe extern "C" fn rust_filter<C>( + cursor: *mut ffi::sqlite3_vtab_cursor, + idx_num: c_int, + idx_str: *const c_char, + argc: c_int, + argv: *mut *mut ffi::sqlite3_value, +) -> c_int +where + C: VTabCursor, +{ + use std::ffi::CStr; + use std::str; + let idx_name = if idx_str.is_null() { + None + } else { + let c_slice = CStr::from_ptr(idx_str).to_bytes(); + Some(str::from_utf8_unchecked(c_slice)) + }; + let args = slice::from_raw_parts_mut(argv, argc as usize); + let values = Values { args }; + let cr = cursor as *mut C; + cursor_error(cursor, (*cr).filter(idx_num, idx_name, &values)) +} + +unsafe extern "C" fn rust_next<C>(cursor: *mut ffi::sqlite3_vtab_cursor) -> c_int +where + C: VTabCursor, +{ + let cr = cursor as *mut C; + cursor_error(cursor, (*cr).next()) +} + +unsafe extern "C" fn rust_eof<C>(cursor: *mut ffi::sqlite3_vtab_cursor) -> c_int +where + C: VTabCursor, +{ + let cr = cursor as *mut C; + (*cr).eof() as c_int +} + +unsafe extern "C" fn rust_column<C>( + cursor: *mut ffi::sqlite3_vtab_cursor, + ctx: *mut ffi::sqlite3_context, + i: c_int, +) -> c_int +where + C: VTabCursor, +{ + let cr = cursor as *mut C; + let mut ctxt = Context(ctx); + result_error(ctx, (*cr).column(&mut ctxt, i)) +} + +unsafe extern "C" fn rust_rowid<C>( + cursor: *mut ffi::sqlite3_vtab_cursor, + p_rowid: *mut ffi::sqlite3_int64, +) -> c_int +where + C: VTabCursor, +{ + let cr = cursor as *mut C; + match (*cr).rowid() { + Ok(rowid) => { + *p_rowid = rowid; + ffi::SQLITE_OK + } + err => cursor_error(cursor, err), + } +} + +/// Virtual table cursors can set an error message by assigning a string to +/// `zErrMsg`. +unsafe fn cursor_error<T>(cursor: *mut ffi::sqlite3_vtab_cursor, result: Result<T>) -> c_int { + match result { + Ok(_) => ffi::SQLITE_OK, + Err(Error::SqliteFailure(err, s)) => { + if let Some(err_msg) = s { + set_err_msg((*cursor).pVtab, &err_msg); + } + err.extended_code + } + Err(err) => { + set_err_msg((*cursor).pVtab, &err.to_string()); + ffi::SQLITE_ERROR + } + } +} + +/// Virtual tables methods can set an error message by assigning a string to +/// `zErrMsg`. +unsafe fn set_err_msg(vtab: *mut ffi::sqlite3_vtab, err_msg: &str) { + if !(*vtab).zErrMsg.is_null() { + ffi::sqlite3_free((*vtab).zErrMsg as *mut c_void); + } + (*vtab).zErrMsg = alloc(err_msg); +} + +/// To raise an error, the `column` method should use this method to set the +/// error message and return the error code. +unsafe fn result_error<T>(ctx: *mut ffi::sqlite3_context, result: Result<T>) -> c_int { + match result { + Ok(_) => ffi::SQLITE_OK, + Err(Error::SqliteFailure(err, s)) => { + match err.extended_code { + ffi::SQLITE_TOOBIG => { + ffi::sqlite3_result_error_toobig(ctx); + } + ffi::SQLITE_NOMEM => { + ffi::sqlite3_result_error_nomem(ctx); + } + code => { + ffi::sqlite3_result_error_code(ctx, code); + if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } + }; + err.extended_code + } + Err(err) => { + ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_ERROR); + if let Ok(cstr) = str_to_cstring(&err.to_string()) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + ffi::SQLITE_ERROR + } + } +} + +// Space to hold this string must be obtained +// from an SQLite memory allocation function +fn alloc(s: &str) -> *mut c_char { + crate::util::SqliteMallocString::from_str(s).into_raw() +} + +#[cfg(feature = "array")] +pub mod array; +#[cfg(feature = "csvtab")] +pub mod csvtab; +#[cfg(feature = "series")] +pub mod series; // SQLite >= 3.9.0 + +#[cfg(test)] +mod test { + #[test] + fn test_dequote() { + assert_eq!("", super::dequote("")); + assert_eq!("'", super::dequote("'")); + assert_eq!("\"", super::dequote("\"")); + assert_eq!("'\"", super::dequote("'\"")); + assert_eq!("", super::dequote("''")); + assert_eq!("", super::dequote("\"\"")); + assert_eq!("x", super::dequote("'x'")); + assert_eq!("x", super::dequote("\"x\"")); + assert_eq!("x", super::dequote("x")); + } + #[test] + fn test_parse_boolean() { + assert_eq!(None, super::parse_boolean("")); + assert_eq!(Some(true), super::parse_boolean("1")); + assert_eq!(Some(true), super::parse_boolean("yes")); + assert_eq!(Some(true), super::parse_boolean("on")); + assert_eq!(Some(true), super::parse_boolean("true")); + assert_eq!(Some(false), super::parse_boolean("0")); + assert_eq!(Some(false), super::parse_boolean("no")); + assert_eq!(Some(false), super::parse_boolean("off")); + assert_eq!(Some(false), super::parse_boolean("false")); + } +} diff --git a/third_party/rust/rusqlite/src/vtab/series.rs b/third_party/rust/rusqlite/src/vtab/series.rs new file mode 100644 index 0000000000..ed67f16597 --- /dev/null +++ b/third_party/rust/rusqlite/src/vtab/series.rs @@ -0,0 +1,298 @@ +//! `feature = "series"` Generate series virtual table. +//! +//! Port of C [generate series +//! "function"](http://www.sqlite.org/cgi/src/finfo?name=ext/misc/series.c): +//! https://www.sqlite.org/series.html +use std::default::Default; +use std::marker::PhantomData; +use std::os::raw::c_int; + +use crate::ffi; +use crate::types::Type; +use crate::vtab::{ + eponymous_only_module, Context, IndexConstraintOp, IndexInfo, VTab, VTabConnection, VTabCursor, + Values, +}; +use crate::{Connection, Result}; + +/// `feature = "series"` Register the "generate_series" module. +pub fn load_module(conn: &Connection) -> Result<()> { + let aux: Option<()> = None; + conn.create_module("generate_series", eponymous_only_module::<SeriesTab>(), aux) +} + +// Column numbers +// const SERIES_COLUMN_VALUE : c_int = 0; +const SERIES_COLUMN_START: c_int = 1; +const SERIES_COLUMN_STOP: c_int = 2; +const SERIES_COLUMN_STEP: c_int = 3; + +bitflags::bitflags! { + #[repr(C)] + struct QueryPlanFlags: ::std::os::raw::c_int { + // start = $value -- constraint exists + const START = 1; + // stop = $value -- constraint exists + const STOP = 2; + // step = $value -- constraint exists + const STEP = 4; + // output in descending order + const DESC = 8; + // Both start and stop + const BOTH = QueryPlanFlags::START.bits | QueryPlanFlags::STOP.bits; + } +} + +/// An instance of the Series virtual table +#[repr(C)] +struct SeriesTab { + /// Base class. Must be first + base: ffi::sqlite3_vtab, +} + +unsafe impl<'vtab> VTab<'vtab> for SeriesTab { + type Aux = (); + type Cursor = SeriesTabCursor<'vtab>; + + fn connect( + _: &mut VTabConnection, + _aux: Option<&()>, + _args: &[&[u8]], + ) -> Result<(String, SeriesTab)> { + let vtab = SeriesTab { + base: ffi::sqlite3_vtab::default(), + }; + Ok(( + "CREATE TABLE x(value,start hidden,stop hidden,step hidden)".to_owned(), + vtab, + )) + } + + fn best_index(&self, info: &mut IndexInfo) -> Result<()> { + // The query plan bitmask + let mut idx_num: QueryPlanFlags = QueryPlanFlags::empty(); + // Index of the start= constraint + let mut start_idx = None; + // Index of the stop= constraint + let mut stop_idx = None; + // Index of the step= constraint + let mut step_idx = None; + for (i, constraint) in info.constraints().enumerate() { + if !constraint.is_usable() { + continue; + } + if constraint.operator() != IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_EQ { + continue; + } + match constraint.column() { + SERIES_COLUMN_START => { + start_idx = Some(i); + idx_num |= QueryPlanFlags::START; + } + SERIES_COLUMN_STOP => { + stop_idx = Some(i); + idx_num |= QueryPlanFlags::STOP; + } + SERIES_COLUMN_STEP => { + step_idx = Some(i); + idx_num |= QueryPlanFlags::STEP; + } + _ => {} + }; + } + + let mut num_of_arg = 0; + if let Some(start_idx) = start_idx { + num_of_arg += 1; + let mut constraint_usage = info.constraint_usage(start_idx); + constraint_usage.set_argv_index(num_of_arg); + constraint_usage.set_omit(true); + } + if let Some(stop_idx) = stop_idx { + num_of_arg += 1; + let mut constraint_usage = info.constraint_usage(stop_idx); + constraint_usage.set_argv_index(num_of_arg); + constraint_usage.set_omit(true); + } + if let Some(step_idx) = step_idx { + num_of_arg += 1; + let mut constraint_usage = info.constraint_usage(step_idx); + constraint_usage.set_argv_index(num_of_arg); + constraint_usage.set_omit(true); + } + if idx_num.contains(QueryPlanFlags::BOTH) { + // Both start= and stop= boundaries are available. + info.set_estimated_cost(f64::from( + 2 - if idx_num.contains(QueryPlanFlags::STEP) { + 1 + } else { + 0 + }, + )); + info.set_estimated_rows(1000); + let order_by_consumed = { + let mut order_bys = info.order_bys(); + if let Some(order_by) = order_bys.next() { + if order_by.is_order_by_desc() { + idx_num |= QueryPlanFlags::DESC; + } + true + } else { + false + } + }; + if order_by_consumed { + info.set_order_by_consumed(true); + } + } else { + info.set_estimated_cost(2_147_483_647f64); + info.set_estimated_rows(2_147_483_647); + } + info.set_idx_num(idx_num.bits()); + Ok(()) + } + + fn open(&self) -> Result<SeriesTabCursor<'_>> { + Ok(SeriesTabCursor::new()) + } +} + +/// A cursor for the Series virtual table +#[repr(C)] +struct SeriesTabCursor<'vtab> { + /// Base class. Must be first + base: ffi::sqlite3_vtab_cursor, + /// True to count down rather than up + is_desc: bool, + /// The rowid + row_id: i64, + /// Current value ("value") + value: i64, + /// Mimimum value ("start") + min_value: i64, + /// Maximum value ("stop") + max_value: i64, + /// Increment ("step") + step: i64, + phantom: PhantomData<&'vtab SeriesTab>, +} + +impl SeriesTabCursor<'_> { + fn new<'vtab>() -> SeriesTabCursor<'vtab> { + SeriesTabCursor { + base: ffi::sqlite3_vtab_cursor::default(), + is_desc: false, + row_id: 0, + value: 0, + min_value: 0, + max_value: 0, + step: 0, + phantom: PhantomData, + } + } +} +unsafe impl VTabCursor for SeriesTabCursor<'_> { + fn filter(&mut self, idx_num: c_int, _idx_str: Option<&str>, args: &Values<'_>) -> Result<()> { + let idx_num = QueryPlanFlags::from_bits_truncate(idx_num); + let mut i = 0; + if idx_num.contains(QueryPlanFlags::START) { + self.min_value = args.get(i)?; + i += 1; + } else { + self.min_value = 0; + } + if idx_num.contains(QueryPlanFlags::STOP) { + self.max_value = args.get(i)?; + i += 1; + } else { + self.max_value = 0xffff_ffff; + } + if idx_num.contains(QueryPlanFlags::STEP) { + self.step = args.get(i)?; + if self.step < 1 { + self.step = 1; + } + } else { + self.step = 1; + }; + for arg in args.iter() { + if arg.data_type() == Type::Null { + // If any of the constraints have a NULL value, then return no rows. + self.min_value = 1; + self.max_value = 0; + break; + } + } + self.is_desc = idx_num.contains(QueryPlanFlags::DESC); + if self.is_desc { + self.value = self.max_value; + if self.step > 0 { + self.value -= (self.max_value - self.min_value) % self.step; + } + } else { + self.value = self.min_value; + } + self.row_id = 1; + Ok(()) + } + + fn next(&mut self) -> Result<()> { + if self.is_desc { + self.value -= self.step; + } else { + self.value += self.step; + } + self.row_id += 1; + Ok(()) + } + + fn eof(&self) -> bool { + if self.is_desc { + self.value < self.min_value + } else { + self.value > self.max_value + } + } + + fn column(&self, ctx: &mut Context, i: c_int) -> Result<()> { + let x = match i { + SERIES_COLUMN_START => self.min_value, + SERIES_COLUMN_STOP => self.max_value, + SERIES_COLUMN_STEP => self.step, + _ => self.value, + }; + ctx.set_result(&x) + } + + fn rowid(&self) -> Result<i64> { + Ok(self.row_id) + } +} + +#[cfg(test)] +mod test { + use crate::ffi; + use crate::vtab::series; + use crate::{Connection, NO_PARAMS}; + + #[test] + fn test_series_module() { + let version = unsafe { ffi::sqlite3_libversion_number() }; + if version < 3_008_012 { + return; + } + + let db = Connection::open_in_memory().unwrap(); + series::load_module(&db).unwrap(); + + let mut s = db.prepare("SELECT * FROM generate_series(0,20,5)").unwrap(); + + let series = s.query_map(NO_PARAMS, |row| row.get::<_, i32>(0)).unwrap(); + + let mut expected = 0; + for value in series { + assert_eq!(expected, value.unwrap()); + expected += 5; + } + } +} diff --git a/third_party/rust/rusqlite/test.csv b/third_party/rust/rusqlite/test.csv new file mode 100644 index 0000000000..708f93f195 --- /dev/null +++ b/third_party/rust/rusqlite/test.csv @@ -0,0 +1,6 @@ +"colA","colB","colC" +1,2,3 +a,b,c +a,"b",c +"a","b","c .. z" +"a","b","c,d" diff --git a/third_party/rust/rusqlite/tests/config_log.rs b/third_party/rust/rusqlite/tests/config_log.rs new file mode 100644 index 0000000000..0c28bdf1df --- /dev/null +++ b/third_party/rust/rusqlite/tests/config_log.rs @@ -0,0 +1,34 @@ +//! This file contains unit tests for `rusqlite::trace::config_log`. This +//! function affects SQLite process-wide and so is not safe to run as a normal +//! #[test] in the library. + +#[cfg(feature = "trace")] +fn main() { + use lazy_static::lazy_static; + use std::os::raw::c_int; + use std::sync::Mutex; + + lazy_static! { + static ref LOGS_RECEIVED: Mutex<Vec<(c_int, String)>> = Mutex::new(Vec::new()); + } + + fn log_handler(err: c_int, message: &str) { + let mut logs_received = LOGS_RECEIVED.lock().unwrap(); + logs_received.push((err, message.to_owned())); + } + + use rusqlite::trace; + + unsafe { trace::config_log(Some(log_handler)) }.unwrap(); + trace::log(10, "First message from rusqlite"); + unsafe { trace::config_log(None) }.unwrap(); + trace::log(11, "Second message from rusqlite"); + + let logs_received = LOGS_RECEIVED.lock().unwrap(); + assert_eq!(logs_received.len(), 1); + assert_eq!(logs_received[0].0, 10); + assert_eq!(logs_received[0].1, "First message from rusqlite"); +} + +#[cfg(not(feature = "trace"))] +fn main() {} diff --git a/third_party/rust/rusqlite/tests/deny_single_threaded_sqlite_config.rs b/third_party/rust/rusqlite/tests/deny_single_threaded_sqlite_config.rs new file mode 100644 index 0000000000..f6afdd51c4 --- /dev/null +++ b/third_party/rust/rusqlite/tests/deny_single_threaded_sqlite_config.rs @@ -0,0 +1,21 @@ +//! Ensure we reject connections when SQLite is in single-threaded mode, as it +//! would violate safety if multiple Rust threads tried to use connections. + +use rusqlite::ffi; +use rusqlite::Connection; + +#[test] +#[should_panic] +fn test_error_when_singlethread_mode() { + // put SQLite into single-threaded mode + unsafe { + if ffi::sqlite3_config(ffi::SQLITE_CONFIG_SINGLETHREAD) != ffi::SQLITE_OK { + return; + } + if ffi::sqlite3_initialize() != ffi::SQLITE_OK { + return; + } + } + + let _ = Connection::open_in_memory().unwrap(); +} diff --git a/third_party/rust/rusqlite/tests/vtab.rs b/third_party/rust/rusqlite/tests/vtab.rs new file mode 100644 index 0000000000..4b31574732 --- /dev/null +++ b/third_party/rust/rusqlite/tests/vtab.rs @@ -0,0 +1,103 @@ +//! Ensure Virtual tables can be declared outside `rusqlite` crate. + +#[cfg(feature = "vtab")] +#[test] +fn test_dummy_module() { + use rusqlite::types::ToSql; + use rusqlite::vtab::{ + eponymous_only_module, sqlite3_vtab, sqlite3_vtab_cursor, Context, IndexInfo, VTab, + VTabConnection, VTabCursor, Values, + }; + use rusqlite::{version_number, Connection, Result}; + use std::marker::PhantomData; + use std::os::raw::c_int; + + let module = eponymous_only_module::<DummyTab>(); + + #[repr(C)] + struct DummyTab { + /// Base class. Must be first + base: sqlite3_vtab, + } + + unsafe impl<'vtab> VTab<'vtab> for DummyTab { + type Aux = (); + type Cursor = DummyTabCursor<'vtab>; + + fn connect( + _: &mut VTabConnection, + _aux: Option<&()>, + _args: &[&[u8]], + ) -> Result<(String, DummyTab)> { + let vtab = DummyTab { + base: sqlite3_vtab::default(), + }; + Ok(("CREATE TABLE x(value)".to_owned(), vtab)) + } + + fn best_index(&self, info: &mut IndexInfo) -> Result<()> { + info.set_estimated_cost(1.); + Ok(()) + } + + fn open(&'vtab self) -> Result<DummyTabCursor<'vtab>> { + Ok(DummyTabCursor::default()) + } + } + + #[derive(Default)] + #[repr(C)] + struct DummyTabCursor<'vtab> { + /// Base class. Must be first + base: sqlite3_vtab_cursor, + /// The rowid + row_id: i64, + phantom: PhantomData<&'vtab DummyTab>, + } + + unsafe impl VTabCursor for DummyTabCursor<'_> { + fn filter( + &mut self, + _idx_num: c_int, + _idx_str: Option<&str>, + _args: &Values<'_>, + ) -> Result<()> { + self.row_id = 1; + Ok(()) + } + + fn next(&mut self) -> Result<()> { + self.row_id += 1; + Ok(()) + } + + fn eof(&self) -> bool { + self.row_id > 1 + } + + fn column(&self, ctx: &mut Context, _: c_int) -> Result<()> { + ctx.set_result(&self.row_id) + } + + fn rowid(&self) -> Result<i64> { + Ok(self.row_id) + } + } + + let db = Connection::open_in_memory().unwrap(); + + db.create_module::<DummyTab>("dummy", &module, None) + .unwrap(); + + let version = version_number(); + if version < 3_008_012 { + return; + } + + let mut s = db.prepare("SELECT * FROM dummy()").unwrap(); + + let dummy = s + .query_row(&[] as &[&dyn ToSql], |row| row.get::<_, i32>(0)) + .unwrap(); + assert_eq!(1, dummy); +} |