diff options
Diffstat (limited to 'third_party/rust/rust_decimal/src/postgres/driver.rs')
-rw-r--r-- | third_party/rust/rust_decimal/src/postgres/driver.rs | 369 |
1 files changed, 369 insertions, 0 deletions
diff --git a/third_party/rust/rust_decimal/src/postgres/driver.rs b/third_party/rust/rust_decimal/src/postgres/driver.rs new file mode 100644 index 0000000000..15574ee864 --- /dev/null +++ b/third_party/rust/rust_decimal/src/postgres/driver.rs @@ -0,0 +1,369 @@ +use crate::postgres::common::*; +use crate::Decimal; +use byteorder::{BigEndian, ReadBytesExt}; +use bytes::{BufMut, BytesMut}; +use postgres::types::{to_sql_checked, FromSql, IsNull, ToSql, Type}; +use std::io::Cursor; + +impl<'a> FromSql<'a> for Decimal { + // Decimals are represented as follows: + // Header: + // u16 numGroups + // i16 weightFirstGroup (10000^weight) + // u16 sign (0x0000 = positive, 0x4000 = negative, 0xC000 = NaN) + // i16 dscale. Number of digits (in base 10) to print after decimal separator + // + // Pseudo code : + // const Decimals [ + // 0.0000000000000000000000000001, + // 0.000000000000000000000001, + // 0.00000000000000000001, + // 0.0000000000000001, + // 0.000000000001, + // 0.00000001, + // 0.0001, + // 1, + // 10000, + // 100000000, + // 1000000000000, + // 10000000000000000, + // 100000000000000000000, + // 1000000000000000000000000, + // 10000000000000000000000000000 + // ] + // overflow = false + // result = 0 + // for i = 0, weight = weightFirstGroup + 7; i < numGroups; i++, weight-- + // group = read.u16 + // if weight < 0 or weight > MaxNum + // overflow = true + // else + // result += Decimals[weight] * group + // sign == 0x4000 ? -result : result + + // So if we were to take the number: 3950.123456 + // + // Stored on Disk: + // 00 03 00 00 00 00 00 06 0F 6E 04 D2 15 E0 + // + // Number of groups: 00 03 + // Weight of first group: 00 00 + // Sign: 00 00 + // DScale: 00 06 + // + // 0F 6E = 3950 + // result = result + 3950 * 1; + // 04 D2 = 1234 + // result = result + 1234 * 0.0001; + // 15 E0 = 5600 + // result = result + 5600 * 0.00000001; + // + + fn from_sql(_: &Type, raw: &[u8]) -> Result<Decimal, Box<dyn std::error::Error + 'static + Sync + Send>> { + let mut raw = Cursor::new(raw); + let num_groups = raw.read_u16::<BigEndian>()?; + let weight = raw.read_i16::<BigEndian>()?; // 10000^weight + // Sign: 0x0000 = positive, 0x4000 = negative, 0xC000 = NaN + let sign = raw.read_u16::<BigEndian>()?; + // Number of digits (in base 10) to print after decimal separator + let scale = raw.read_u16::<BigEndian>()?; + + // Read all of the groups + let mut groups = Vec::new(); + for _ in 0..num_groups as usize { + groups.push(raw.read_u16::<BigEndian>()?); + } + + Ok(Self::from_postgres(PostgresDecimal { + neg: sign == 0x4000, + weight, + scale, + digits: groups.into_iter(), + })) + } + + fn accepts(ty: &Type) -> bool { + matches!(*ty, Type::NUMERIC) + } +} + +impl ToSql for Decimal { + fn to_sql( + &self, + _: &Type, + out: &mut BytesMut, + ) -> Result<IsNull, Box<dyn std::error::Error + 'static + Sync + Send>> { + let PostgresDecimal { + neg, + weight, + scale, + digits, + } = self.to_postgres(); + + let num_digits = digits.len(); + + // Reserve bytes + out.reserve(8 + num_digits * 2); + + // Number of groups + out.put_u16(num_digits.try_into().unwrap()); + // Weight of first group + out.put_i16(weight); + // Sign + out.put_u16(if neg { 0x4000 } else { 0x0000 }); + // DScale + out.put_u16(scale); + // Now process the number + for digit in digits[0..num_digits].iter() { + out.put_i16(*digit); + } + + Ok(IsNull::No) + } + + fn accepts(ty: &Type) -> bool { + matches!(*ty, Type::NUMERIC) + } + + to_sql_checked!(); +} + +#[cfg(test)] +mod test { + use super::*; + use ::postgres::{Client, NoTls}; + use core::str::FromStr; + + /// Gets the URL for connecting to PostgreSQL for testing. Set the POSTGRES_URL + /// environment variable to change from the default of "postgres://postgres@localhost". + fn get_postgres_url() -> String { + if let Ok(url) = std::env::var("POSTGRES_URL") { + return url; + } + "postgres://postgres@localhost".to_string() + } + + pub static TEST_DECIMALS: &[(u32, u32, &str, &str)] = &[ + // precision, scale, sent, expected + (35, 6, "3950.123456", "3950.123456"), + (35, 2, "3950.123456", "3950.12"), + (35, 2, "3950.1256", "3950.13"), + (10, 2, "3950.123456", "3950.12"), + (35, 6, "3950", "3950.000000"), + (4, 0, "3950", "3950"), + (35, 6, "0.1", "0.100000"), + (35, 6, "0.01", "0.010000"), + (35, 6, "0.001", "0.001000"), + (35, 6, "0.0001", "0.000100"), + (35, 6, "0.00001", "0.000010"), + (35, 6, "0.000001", "0.000001"), + (35, 6, "1", "1.000000"), + (35, 6, "-100", "-100.000000"), + (35, 6, "-123.456", "-123.456000"), + (35, 6, "119996.25", "119996.250000"), + (35, 6, "1000000", "1000000.000000"), + (35, 6, "9999999.99999", "9999999.999990"), + (35, 6, "12340.56789", "12340.567890"), + // Scale is only 28 since that is the maximum we can represent. + (65, 30, "1.2", "1.2000000000000000000000000000"), + // Pi - rounded at scale 28 + ( + 65, + 30, + "3.141592653589793238462643383279", + "3.1415926535897932384626433833", + ), + ( + 65, + 34, + "3.1415926535897932384626433832795028", + "3.1415926535897932384626433833", + ), + // Unrounded number + ( + 65, + 34, + "1.234567890123456789012345678950000", + "1.2345678901234567890123456790", + ), + ( + 65, + 34, // No rounding due to 49999 after significant digits + "1.234567890123456789012345678949999", + "1.2345678901234567890123456789", + ), + // 0xFFFF_FFFF_FFFF_FFFF_FFFF_FFFF (96 bit) + (35, 0, "79228162514264337593543950335", "79228162514264337593543950335"), + // 0x0FFF_FFFF_FFFF_FFFF_FFFF_FFFF (95 bit) + (35, 1, "4951760157141521099596496895", "4951760157141521099596496895.0"), + // 0x1000_0000_0000_0000_0000_0000 + (35, 1, "4951760157141521099596496896", "4951760157141521099596496896.0"), + (35, 6, "18446744073709551615", "18446744073709551615.000000"), + (35, 6, "-18446744073709551615", "-18446744073709551615.000000"), + (35, 6, "0.10001", "0.100010"), + (35, 6, "0.12345", "0.123450"), + ]; + + #[test] + fn test_null() { + let mut client = match Client::connect(&get_postgres_url(), NoTls) { + Ok(x) => x, + Err(err) => panic!("{:#?}", err), + }; + + // Test NULL + let result: Option<Decimal> = match client.query("SELECT NULL::numeric", &[]) { + Ok(x) => x.iter().next().unwrap().get(0), + Err(err) => panic!("{:#?}", err), + }; + assert_eq!(None, result); + } + + #[tokio::test] + #[cfg(feature = "tokio-pg")] + async fn async_test_null() { + use futures::future::FutureExt; + use tokio_postgres::connect; + + let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap(); + let connection = connection.map(|e| e.unwrap()); + tokio::spawn(connection); + + let statement = client.prepare(&"SELECT NULL::numeric").await.unwrap(); + let rows = client.query(&statement, &[]).await.unwrap(); + let result: Option<Decimal> = rows.iter().next().unwrap().get(0); + + assert_eq!(None, result); + } + + #[test] + fn read_numeric_type() { + let mut client = match Client::connect(&get_postgres_url(), NoTls) { + Ok(x) => x, + Err(err) => panic!("{:#?}", err), + }; + for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() { + let result: Decimal = + match client.query(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale), &[]) { + Ok(x) => x.iter().next().unwrap().get(0), + Err(err) => panic!("SELECT {}::NUMERIC({}, {}), error - {:#?}", sent, precision, scale, err), + }; + assert_eq!( + expected, + result.to_string(), + "NUMERIC({}, {}) sent: {}", + precision, + scale, + sent + ); + } + } + + #[tokio::test] + #[cfg(feature = "tokio-pg")] + async fn async_read_numeric_type() { + use futures::future::FutureExt; + use tokio_postgres::connect; + + let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap(); + let connection = connection.map(|e| e.unwrap()); + tokio::spawn(connection); + for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() { + let statement = client + .prepare(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale)) + .await + .unwrap(); + let rows = client.query(&statement, &[]).await.unwrap(); + let result: Decimal = rows.iter().next().unwrap().get(0); + + assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale); + } + } + + #[test] + fn write_numeric_type() { + let mut client = match Client::connect(&get_postgres_url(), NoTls) { + Ok(x) => x, + Err(err) => panic!("{:#?}", err), + }; + for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() { + let number = Decimal::from_str(sent).unwrap(); + let result: Decimal = + match client.query(&*format!("SELECT $1::NUMERIC({}, {})", precision, scale), &[&number]) { + Ok(x) => x.iter().next().unwrap().get(0), + Err(err) => panic!("{:#?}", err), + }; + assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale); + } + } + + #[tokio::test] + #[cfg(feature = "tokio-pg")] + async fn async_write_numeric_type() { + use futures::future::FutureExt; + use tokio_postgres::connect; + + let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap(); + let connection = connection.map(|e| e.unwrap()); + tokio::spawn(connection); + + for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() { + let statement = client + .prepare(&*format!("SELECT $1::NUMERIC({}, {})", precision, scale)) + .await + .unwrap(); + let number = Decimal::from_str(sent).unwrap(); + let rows = client.query(&statement, &[&number]).await.unwrap(); + let result: Decimal = rows.iter().next().unwrap().get(0); + + assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale); + } + } + + #[test] + fn numeric_overflow() { + let tests = [(4, 4, "3950.1234")]; + let mut client = match Client::connect(&get_postgres_url(), NoTls) { + Ok(x) => x, + Err(err) => panic!("{:#?}", err), + }; + for &(precision, scale, sent) in tests.iter() { + match client.query(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale), &[]) { + Ok(_) => panic!( + "Expected numeric overflow for {}::NUMERIC({}, {})", + sent, precision, scale + ), + Err(err) => { + assert_eq!("22003", err.code().unwrap().code(), "Unexpected error code"); + } + }; + } + } + + #[tokio::test] + #[cfg(feature = "tokio-pg")] + async fn async_numeric_overflow() { + use futures::future::FutureExt; + use tokio_postgres::connect; + + let tests = [(4, 4, "3950.1234")]; + let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap(); + let connection = connection.map(|e| e.unwrap()); + tokio::spawn(connection); + + for &(precision, scale, sent) in tests.iter() { + let statement = client + .prepare(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale)) + .await + .unwrap(); + + match client.query(&statement, &[]).await { + Ok(_) => panic!( + "Expected numeric overflow for {}::NUMERIC({}, {})", + sent, precision, scale + ), + Err(err) => assert_eq!("22003", err.code().unwrap().code(), "Unexpected error code"), + } + } + } +} |