summaryrefslogtreecommitdiffstats
path: root/third_party/rust/rusqlite/src/collation.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/rusqlite/src/collation.rs')
-rw-r--r--third_party/rust/rusqlite/src/collation.rs215
1 files changed, 215 insertions, 0 deletions
diff --git a/third_party/rust/rusqlite/src/collation.rs b/third_party/rust/rusqlite/src/collation.rs
new file mode 100644
index 0000000000..c1fe3f7837
--- /dev/null
+++ b/third_party/rust/rusqlite/src/collation.rs
@@ -0,0 +1,215 @@
+//! Add, remove, or modify a collation
+use std::cmp::Ordering;
+use std::os::raw::{c_char, c_int, c_void};
+use std::panic::{catch_unwind, UnwindSafe};
+use std::ptr;
+use std::slice;
+
+use crate::ffi;
+use crate::{str_to_cstring, Connection, InnerConnection, Result};
+
+// FIXME copy/paste from function.rs
+unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
+ drop(Box::from_raw(p.cast::<T>()));
+}
+
+impl Connection {
+ /// Add or modify a collation.
+ #[inline]
+ pub fn create_collation<C>(&self, collation_name: &str, x_compare: C) -> Result<()>
+ where
+ C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
+ {
+ self.db
+ .borrow_mut()
+ .create_collation(collation_name, x_compare)
+ }
+
+ /// Collation needed callback
+ #[inline]
+ pub fn collation_needed(
+ &self,
+ x_coll_needed: fn(&Connection, &str) -> Result<()>,
+ ) -> Result<()> {
+ self.db.borrow_mut().collation_needed(x_coll_needed)
+ }
+
+ /// Remove collation.
+ #[inline]
+ pub fn remove_collation(&self, collation_name: &str) -> Result<()> {
+ self.db.borrow_mut().remove_collation(collation_name)
+ }
+}
+
+impl InnerConnection {
+ fn create_collation<C>(&mut self, collation_name: &str, x_compare: C) -> Result<()>
+ where
+ C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
+ {
+ unsafe extern "C" fn call_boxed_closure<C>(
+ arg1: *mut c_void,
+ arg2: c_int,
+ arg3: *const c_void,
+ arg4: c_int,
+ arg5: *const c_void,
+ ) -> c_int
+ where
+ C: Fn(&str, &str) -> Ordering,
+ {
+ let r = catch_unwind(|| {
+ let boxed_f: *mut C = arg1.cast::<C>();
+ assert!(!boxed_f.is_null(), "Internal error - null function pointer");
+ let s1 = {
+ let c_slice = slice::from_raw_parts(arg3.cast::<u8>(), arg2 as usize);
+ String::from_utf8_lossy(c_slice)
+ };
+ let s2 = {
+ let c_slice = slice::from_raw_parts(arg5.cast::<u8>(), arg4 as usize);
+ String::from_utf8_lossy(c_slice)
+ };
+ (*boxed_f)(s1.as_ref(), s2.as_ref())
+ });
+ let t = match r {
+ Err(_) => {
+ return -1; // FIXME How ?
+ }
+ Ok(r) => r,
+ };
+
+ match t {
+ Ordering::Less => -1,
+ Ordering::Equal => 0,
+ Ordering::Greater => 1,
+ }
+ }
+
+ let boxed_f: *mut C = Box::into_raw(Box::new(x_compare));
+ let c_name = str_to_cstring(collation_name)?;
+ let flags = ffi::SQLITE_UTF8;
+ let r = unsafe {
+ ffi::sqlite3_create_collation_v2(
+ self.db(),
+ c_name.as_ptr(),
+ flags,
+ boxed_f.cast::<c_void>(),
+ Some(call_boxed_closure::<C>),
+ Some(free_boxed_value::<C>),
+ )
+ };
+ let res = self.decode_result(r);
+ // The xDestroy callback is not called if the sqlite3_create_collation_v2()
+ // function fails.
+ if res.is_err() {
+ drop(unsafe { Box::from_raw(boxed_f) });
+ }
+ res
+ }
+
+ fn collation_needed(
+ &mut self,
+ x_coll_needed: fn(&Connection, &str) -> Result<()>,
+ ) -> Result<()> {
+ use std::mem;
+ #[allow(clippy::needless_return)]
+ unsafe extern "C" fn collation_needed_callback(
+ arg1: *mut c_void,
+ arg2: *mut ffi::sqlite3,
+ e_text_rep: c_int,
+ arg3: *const c_char,
+ ) {
+ use std::ffi::CStr;
+ use std::str;
+
+ if e_text_rep != ffi::SQLITE_UTF8 {
+ // TODO: validate
+ return;
+ }
+
+ let callback: fn(&Connection, &str) -> Result<()> = mem::transmute(arg1);
+ let res = catch_unwind(|| {
+ let conn = Connection::from_handle(arg2).unwrap();
+ let collation_name = {
+ let c_slice = CStr::from_ptr(arg3).to_bytes();
+ str::from_utf8(c_slice).expect("illegal collation sequence name")
+ };
+ callback(&conn, collation_name)
+ });
+ if res.is_err() {
+ return; // FIXME How ?
+ }
+ }
+
+ let r = unsafe {
+ ffi::sqlite3_collation_needed(
+ self.db(),
+ x_coll_needed as *mut c_void,
+ Some(collation_needed_callback),
+ )
+ };
+ self.decode_result(r)
+ }
+
+ #[inline]
+ fn remove_collation(&mut self, collation_name: &str) -> Result<()> {
+ let c_name = str_to_cstring(collation_name)?;
+ let r = unsafe {
+ ffi::sqlite3_create_collation_v2(
+ self.db(),
+ c_name.as_ptr(),
+ ffi::SQLITE_UTF8,
+ ptr::null_mut(),
+ None,
+ None,
+ )
+ };
+ self.decode_result(r)
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use crate::{Connection, Result};
+ use fallible_streaming_iterator::FallibleStreamingIterator;
+ use std::cmp::Ordering;
+ use unicase::UniCase;
+
+ fn unicase_compare(s1: &str, s2: &str) -> Ordering {
+ UniCase::new(s1).cmp(&UniCase::new(s2))
+ }
+
+ #[test]
+ fn test_unicase() -> Result<()> {
+ let db = Connection::open_in_memory()?;
+
+ db.create_collation("unicase", unicase_compare)?;
+
+ collate(db)
+ }
+
+ fn collate(db: Connection) -> Result<()> {
+ db.execute_batch(
+ "CREATE TABLE foo (bar);
+ INSERT INTO foo (bar) VALUES ('Maße');
+ INSERT INTO foo (bar) VALUES ('MASSE');",
+ )?;
+ let mut stmt = db.prepare("SELECT DISTINCT bar COLLATE unicase FROM foo ORDER BY 1")?;
+ let rows = stmt.query([])?;
+ assert_eq!(rows.count()?, 1);
+ Ok(())
+ }
+
+ fn collation_needed(db: &Connection, collation_name: &str) -> Result<()> {
+ if "unicase" == collation_name {
+ db.create_collation(collation_name, unicase_compare)
+ } else {
+ Ok(())
+ }
+ }
+
+ #[test]
+ fn test_collation_needed() -> Result<()> {
+ let db = Connection::open_in_memory()?;
+ db.collation_needed(collation_needed)?;
+ collate(db)
+ }
+}