diff options
Diffstat (limited to 'third_party/rust/rust_decimal/src/str.rs')
-rw-r--r-- | third_party/rust/rust_decimal/src/str.rs | 993 |
1 files changed, 993 insertions, 0 deletions
diff --git a/third_party/rust/rust_decimal/src/str.rs b/third_party/rust/rust_decimal/src/str.rs new file mode 100644 index 0000000000..f3b89d31e0 --- /dev/null +++ b/third_party/rust/rust_decimal/src/str.rs @@ -0,0 +1,993 @@ +use crate::{ + constants::{BYTES_TO_OVERFLOW_U64, MAX_PRECISION, MAX_STR_BUFFER_SIZE, OVERFLOW_U96, WILL_OVERFLOW_U64}, + error::{tail_error, Error}, + ops::array::{add_by_internal_flattened, add_one_internal, div_by_u32, is_all_zero, mul_by_u32}, + Decimal, +}; + +use arrayvec::{ArrayString, ArrayVec}; + +use alloc::{string::String, vec::Vec}; +use core::fmt; + +// impl that doesn't allocate for serialization purposes. +pub(crate) fn to_str_internal( + value: &Decimal, + append_sign: bool, + precision: Option<usize>, +) -> (ArrayString<MAX_STR_BUFFER_SIZE>, Option<usize>) { + // Get the scale - where we need to put the decimal point + let scale = value.scale() as usize; + + // Convert to a string and manipulate that (neg at front, inject decimal) + let mut chars = ArrayVec::<_, MAX_STR_BUFFER_SIZE>::new(); + let mut working = value.mantissa_array3(); + while !is_all_zero(&working) { + let remainder = div_by_u32(&mut working, 10u32); + chars.push(char::from(b'0' + remainder as u8)); + } + while scale > chars.len() { + chars.push('0'); + } + + let (prec, additional) = match precision { + Some(prec) => { + let max: usize = MAX_PRECISION.into(); + if prec > max { + (max, Some(prec - max)) + } else { + (prec, None) + } + } + None => (scale, None), + }; + + let len = chars.len(); + let whole_len = len - scale; + let mut rep = ArrayString::new(); + // Append the negative sign if necessary while also keeping track of the length of an "empty" string representation + let empty_len = if append_sign && value.is_sign_negative() { + rep.push('-'); + 1 + } else { + 0 + }; + for i in 0..whole_len + prec { + if i == len - scale { + if i == 0 { + rep.push('0'); + } + rep.push('.'); + } + + if i >= len { + rep.push('0'); + } else { + let c = chars[len - i - 1]; + rep.push(c); + } + } + + // corner case for when we truncated everything in a low fractional + if rep.len() == empty_len { + rep.push('0'); + } + + (rep, additional) +} + +pub(crate) fn fmt_scientific_notation( + value: &Decimal, + exponent_symbol: &str, + f: &mut fmt::Formatter<'_>, +) -> fmt::Result { + #[cfg(not(feature = "std"))] + use alloc::string::ToString; + + // Get the scale - this is the e value. With multiples of 10 this may get bigger. + let mut exponent = -(value.scale() as isize); + + // Convert the integral to a string + let mut chars = Vec::new(); + let mut working = value.mantissa_array3(); + while !is_all_zero(&working) { + let remainder = div_by_u32(&mut working, 10u32); + chars.push(char::from(b'0' + remainder as u8)); + } + + // First of all, apply scientific notation rules. That is: + // 1. If non-zero digit comes first, move decimal point left so that e is a positive integer + // 2. If decimal point comes first, move decimal point right until after the first non-zero digit + // Since decimal notation naturally lends itself this way, we just need to inject the decimal + // point in the right place and adjust the exponent accordingly. + + let len = chars.len(); + let mut rep; + // We either are operating with a precision specified, or on defaults. Defaults will perform "smart" + // reduction of precision. + if let Some(precision) = f.precision() { + if len > 1 { + // If we're zero precision AND it's trailing zeros then strip them + if precision == 0 && chars.iter().take(len - 1).all(|c| *c == '0') { + rep = chars.iter().skip(len - 1).collect::<String>(); + } else { + // We may still be zero precision, however we aren't trailing zeros + if precision > 0 { + chars.insert(len - 1, '.'); + } + rep = chars + .iter() + .rev() + // Add on extra zeros according to the precision. At least one, since we added a decimal place. + .chain(core::iter::repeat(&'0')) + .take(if precision == 0 { 1 } else { 2 + precision }) + .collect::<String>(); + } + exponent += (len - 1) as isize; + } else if precision > 0 { + // We have precision that we want to add + chars.push('.'); + rep = chars + .iter() + .chain(core::iter::repeat(&'0')) + .take(2 + precision) + .collect::<String>(); + } else { + rep = chars.iter().collect::<String>(); + } + } else if len > 1 { + // If the number is just trailing zeros then we treat it like 0 precision + if chars.iter().take(len - 1).all(|c| *c == '0') { + rep = chars.iter().skip(len - 1).collect::<String>(); + } else { + // Otherwise, we need to insert a decimal place and make it a scientific number + chars.insert(len - 1, '.'); + rep = chars.iter().rev().collect::<String>(); + } + exponent += (len - 1) as isize; + } else { + rep = chars.iter().collect::<String>(); + } + + rep.push_str(exponent_symbol); + rep.push_str(&exponent.to_string()); + f.pad_integral(value.is_sign_positive(), "", &rep) +} + +// dedicated implementation for the most common case. +#[inline] +pub(crate) fn parse_str_radix_10(str: &str) -> Result<Decimal, Error> { + let bytes = str.as_bytes(); + if bytes.len() < BYTES_TO_OVERFLOW_U64 { + parse_str_radix_10_dispatch::<false, true>(bytes) + } else { + parse_str_radix_10_dispatch::<true, true>(bytes) + } +} + +#[inline] +pub(crate) fn parse_str_radix_10_exact(str: &str) -> Result<Decimal, Error> { + let bytes = str.as_bytes(); + if bytes.len() < BYTES_TO_OVERFLOW_U64 { + parse_str_radix_10_dispatch::<false, false>(bytes) + } else { + parse_str_radix_10_dispatch::<true, false>(bytes) + } +} + +#[inline] +fn parse_str_radix_10_dispatch<const BIG: bool, const ROUND: bool>(bytes: &[u8]) -> Result<Decimal, Error> { + match bytes { + [b, rest @ ..] => byte_dispatch_u64::<false, false, false, BIG, true, ROUND>(rest, 0, 0, *b), + [] => tail_error("Invalid decimal: empty"), + } +} + +#[inline] +fn overflow_64(val: u64) -> bool { + val >= WILL_OVERFLOW_U64 +} + +#[inline] +pub fn overflow_128(val: u128) -> bool { + val >= OVERFLOW_U96 +} + +/// Dispatch the next byte: +/// +/// * POINT - a decimal point has been seen +/// * NEG - we've encountered a `-` and the number is negative +/// * HAS - a digit has been encountered (when HAS is false it's invalid) +/// * BIG - a number that uses 96 bits instead of only 64 bits +/// * FIRST - true if it is the first byte in the string +#[inline] +fn dispatch_next<const POINT: bool, const NEG: bool, const HAS: bool, const BIG: bool, const ROUND: bool>( + bytes: &[u8], + data64: u64, + scale: u8, +) -> Result<Decimal, Error> { + if let Some((next, bytes)) = bytes.split_first() { + byte_dispatch_u64::<POINT, NEG, HAS, BIG, false, ROUND>(bytes, data64, scale, *next) + } else { + handle_data::<NEG, HAS>(data64 as u128, scale) + } +} + +#[inline(never)] +fn non_digit_dispatch_u64< + const POINT: bool, + const NEG: bool, + const HAS: bool, + const BIG: bool, + const FIRST: bool, + const ROUND: bool, +>( + bytes: &[u8], + data64: u64, + scale: u8, + b: u8, +) -> Result<Decimal, Error> { + match b { + b'-' if FIRST && !HAS => dispatch_next::<false, true, false, BIG, ROUND>(bytes, data64, scale), + b'+' if FIRST && !HAS => dispatch_next::<false, false, false, BIG, ROUND>(bytes, data64, scale), + b'_' if HAS => handle_separator::<POINT, NEG, BIG, ROUND>(bytes, data64, scale), + b => tail_invalid_digit(b), + } +} + +#[inline] +fn byte_dispatch_u64< + const POINT: bool, + const NEG: bool, + const HAS: bool, + const BIG: bool, + const FIRST: bool, + const ROUND: bool, +>( + bytes: &[u8], + data64: u64, + scale: u8, + b: u8, +) -> Result<Decimal, Error> { + match b { + b'0'..=b'9' => handle_digit_64::<POINT, NEG, BIG, ROUND>(bytes, data64, scale, b - b'0'), + b'.' if !POINT => handle_point::<NEG, HAS, BIG, ROUND>(bytes, data64, scale), + b => non_digit_dispatch_u64::<POINT, NEG, HAS, BIG, FIRST, ROUND>(bytes, data64, scale, b), + } +} + +#[inline(never)] +fn handle_digit_64<const POINT: bool, const NEG: bool, const BIG: bool, const ROUND: bool>( + bytes: &[u8], + data64: u64, + scale: u8, + digit: u8, +) -> Result<Decimal, Error> { + // we have already validated that we cannot overflow + let data64 = data64 * 10 + digit as u64; + let scale = if POINT { scale + 1 } else { 0 }; + + if let Some((next, bytes)) = bytes.split_first() { + let next = *next; + if POINT && BIG && scale >= 28 { + if ROUND { + maybe_round(data64 as u128, next, scale, POINT, NEG) + } else { + Err(Error::Underflow) + } + } else if BIG && overflow_64(data64) { + handle_full_128::<POINT, NEG, ROUND>(data64 as u128, bytes, scale, next) + } else { + byte_dispatch_u64::<POINT, NEG, true, BIG, false, ROUND>(bytes, data64, scale, next) + } + } else { + let data: u128 = data64 as u128; + + handle_data::<NEG, true>(data, scale) + } +} + +#[inline(never)] +fn handle_point<const NEG: bool, const HAS: bool, const BIG: bool, const ROUND: bool>( + bytes: &[u8], + data64: u64, + scale: u8, +) -> Result<Decimal, Error> { + dispatch_next::<true, NEG, HAS, BIG, ROUND>(bytes, data64, scale) +} + +#[inline(never)] +fn handle_separator<const POINT: bool, const NEG: bool, const BIG: bool, const ROUND: bool>( + bytes: &[u8], + data64: u64, + scale: u8, +) -> Result<Decimal, Error> { + dispatch_next::<POINT, NEG, true, BIG, ROUND>(bytes, data64, scale) +} + +#[inline(never)] +#[cold] +fn tail_invalid_digit(digit: u8) -> Result<Decimal, Error> { + match digit { + b'.' => tail_error("Invalid decimal: two decimal points"), + b'_' => tail_error("Invalid decimal: must start lead with a number"), + _ => tail_error("Invalid decimal: unknown character"), + } +} + +#[inline(never)] +#[cold] +fn handle_full_128<const POINT: bool, const NEG: bool, const ROUND: bool>( + mut data: u128, + bytes: &[u8], + scale: u8, + next_byte: u8, +) -> Result<Decimal, Error> { + let b = next_byte; + match b { + b'0'..=b'9' => { + let digit = u32::from(b - b'0'); + + // If the data is going to overflow then we should go into recovery mode + let next = (data * 10) + digit as u128; + if overflow_128(next) { + if !POINT { + return tail_error("Invalid decimal: overflow from too many digits"); + } + + if ROUND { + maybe_round(data, next_byte, scale, POINT, NEG) + } else { + Err(Error::Underflow) + } + } else { + data = next; + let scale = scale + POINT as u8; + if let Some((next, bytes)) = bytes.split_first() { + let next = *next; + if POINT && scale >= 28 { + if ROUND { + maybe_round(data, next, scale, POINT, NEG) + } else { + Err(Error::Underflow) + } + } else { + handle_full_128::<POINT, NEG, ROUND>(data, bytes, scale, next) + } + } else { + handle_data::<NEG, true>(data, scale) + } + } + } + b'.' if !POINT => { + // This call won't tail? + if let Some((next, bytes)) = bytes.split_first() { + handle_full_128::<true, NEG, ROUND>(data, bytes, scale, *next) + } else { + handle_data::<NEG, true>(data, scale) + } + } + b'_' => { + if let Some((next, bytes)) = bytes.split_first() { + handle_full_128::<POINT, NEG, ROUND>(data, bytes, scale, *next) + } else { + handle_data::<NEG, true>(data, scale) + } + } + b => tail_invalid_digit(b), + } +} + +#[inline(never)] +#[cold] +fn maybe_round( + mut data: u128, + next_byte: u8, + mut scale: u8, + point: bool, + negative: bool, +) -> Result<Decimal, crate::Error> { + let digit = match next_byte { + b'0'..=b'9' => u32::from(next_byte - b'0'), + b'_' => 0, // this should be an invalid string? + b'.' if point => 0, + b => return tail_invalid_digit(b), + }; + + // Round at midpoint + if digit >= 5 { + data += 1; + + // If the mantissa is now overflowing, round to the next + // next least significant digit and discard precision + if overflow_128(data) { + if scale == 0 { + return tail_error("Invalid decimal: overflow from mantissa after rounding"); + } + data += 4; + data /= 10; + scale -= 1; + } + } + + if negative { + handle_data::<true, true>(data, scale) + } else { + handle_data::<false, true>(data, scale) + } +} + +#[inline(never)] +fn tail_no_has() -> Result<Decimal, Error> { + tail_error("Invalid decimal: no digits found") +} + +#[inline] +fn handle_data<const NEG: bool, const HAS: bool>(data: u128, scale: u8) -> Result<Decimal, Error> { + debug_assert_eq!(data >> 96, 0); + if !HAS { + tail_no_has() + } else { + Ok(Decimal::from_parts( + data as u32, + (data >> 32) as u32, + (data >> 64) as u32, + NEG, + scale as u32, + )) + } +} + +pub(crate) fn parse_str_radix_n(str: &str, radix: u32) -> Result<Decimal, Error> { + if str.is_empty() { + return Err(Error::from("Invalid decimal: empty")); + } + if radix < 2 { + return Err(Error::from("Unsupported radix < 2")); + } + if radix > 36 { + // As per trait documentation + return Err(Error::from("Unsupported radix > 36")); + } + + let mut offset = 0; + let mut len = str.len(); + let bytes = str.as_bytes(); + let mut negative = false; // assume positive + + // handle the sign + if bytes[offset] == b'-' { + negative = true; // leading minus means negative + offset += 1; + len -= 1; + } else if bytes[offset] == b'+' { + // leading + allowed + offset += 1; + len -= 1; + } + + // should now be at numeric part of the significand + let mut digits_before_dot: i32 = -1; // digits before '.', -1 if no '.' + let mut coeff = ArrayVec::<_, 96>::new(); // integer significand array + + // Supporting different radix + let (max_n, max_alpha_lower, max_alpha_upper) = if radix <= 10 { + (b'0' + (radix - 1) as u8, 0, 0) + } else { + let adj = (radix - 11) as u8; + (b'9', adj + b'a', adj + b'A') + }; + + // Estimate the max precision. All in all, it needs to fit into 96 bits. + // Rather than try to estimate, I've included the constants directly in here. We could, + // perhaps, replace this with a formula if it's faster - though it does appear to be log2. + let estimated_max_precision = match radix { + 2 => 96, + 3 => 61, + 4 => 48, + 5 => 42, + 6 => 38, + 7 => 35, + 8 => 32, + 9 => 31, + 10 => 28, + 11 => 28, + 12 => 27, + 13 => 26, + 14 => 26, + 15 => 25, + 16 => 24, + 17 => 24, + 18 => 24, + 19 => 23, + 20 => 23, + 21 => 22, + 22 => 22, + 23 => 22, + 24 => 21, + 25 => 21, + 26 => 21, + 27 => 21, + 28 => 20, + 29 => 20, + 30 => 20, + 31 => 20, + 32 => 20, + 33 => 20, + 34 => 19, + 35 => 19, + 36 => 19, + _ => return Err(Error::from("Unsupported radix")), + }; + + let mut maybe_round = false; + while len > 0 { + let b = bytes[offset]; + match b { + b'0'..=b'9' => { + if b > max_n { + return Err(Error::from("Invalid decimal: invalid character")); + } + coeff.push(u32::from(b - b'0')); + offset += 1; + len -= 1; + + // If the coefficient is longer than the max, exit early + if coeff.len() as u32 > estimated_max_precision { + maybe_round = true; + break; + } + } + b'a'..=b'z' => { + if b > max_alpha_lower { + return Err(Error::from("Invalid decimal: invalid character")); + } + coeff.push(u32::from(b - b'a') + 10); + offset += 1; + len -= 1; + + if coeff.len() as u32 > estimated_max_precision { + maybe_round = true; + break; + } + } + b'A'..=b'Z' => { + if b > max_alpha_upper { + return Err(Error::from("Invalid decimal: invalid character")); + } + coeff.push(u32::from(b - b'A') + 10); + offset += 1; + len -= 1; + + if coeff.len() as u32 > estimated_max_precision { + maybe_round = true; + break; + } + } + b'.' => { + if digits_before_dot >= 0 { + return Err(Error::from("Invalid decimal: two decimal points")); + } + digits_before_dot = coeff.len() as i32; + offset += 1; + len -= 1; + } + b'_' => { + // Must start with a number... + if coeff.is_empty() { + return Err(Error::from("Invalid decimal: must start lead with a number")); + } + offset += 1; + len -= 1; + } + _ => return Err(Error::from("Invalid decimal: unknown character")), + } + } + + // If we exited before the end of the string then do some rounding if necessary + if maybe_round && offset < bytes.len() { + let next_byte = bytes[offset]; + let digit = match next_byte { + b'0'..=b'9' => { + if next_byte > max_n { + return Err(Error::from("Invalid decimal: invalid character")); + } + u32::from(next_byte - b'0') + } + b'a'..=b'z' => { + if next_byte > max_alpha_lower { + return Err(Error::from("Invalid decimal: invalid character")); + } + u32::from(next_byte - b'a') + 10 + } + b'A'..=b'Z' => { + if next_byte > max_alpha_upper { + return Err(Error::from("Invalid decimal: invalid character")); + } + u32::from(next_byte - b'A') + 10 + } + b'_' => 0, + b'.' => { + // Still an error if we have a second dp + if digits_before_dot >= 0 { + return Err(Error::from("Invalid decimal: two decimal points")); + } + 0 + } + _ => return Err(Error::from("Invalid decimal: unknown character")), + }; + + // Round at midpoint + let midpoint = if radix & 0x1 == 1 { radix / 2 } else { (radix + 1) / 2 }; + if digit >= midpoint { + let mut index = coeff.len() - 1; + loop { + let new_digit = coeff[index] + 1; + if new_digit <= 9 { + coeff[index] = new_digit; + break; + } else { + coeff[index] = 0; + if index == 0 { + coeff.insert(0, 1u32); + digits_before_dot += 1; + coeff.pop(); + break; + } + } + index -= 1; + } + } + } + + // here when no characters left + if coeff.is_empty() { + return Err(Error::from("Invalid decimal: no digits found")); + } + + let mut scale = if digits_before_dot >= 0 { + // we had a decimal place so set the scale + (coeff.len() as u32) - (digits_before_dot as u32) + } else { + 0 + }; + + // Parse this using specified radix + let mut data = [0u32, 0u32, 0u32]; + let mut tmp = [0u32, 0u32, 0u32]; + let len = coeff.len(); + for (i, digit) in coeff.iter().enumerate() { + // If the data is going to overflow then we should go into recovery mode + tmp[0] = data[0]; + tmp[1] = data[1]; + tmp[2] = data[2]; + let overflow = mul_by_u32(&mut tmp, radix); + if overflow > 0 { + // This means that we have more data to process, that we're not sure what to do with. + // This may or may not be an issue - depending on whether we're past a decimal point + // or not. + if (i as i32) < digits_before_dot && i + 1 < len { + return Err(Error::from("Invalid decimal: overflow from too many digits")); + } + + if *digit >= 5 { + let carry = add_one_internal(&mut data); + if carry > 0 { + // Highly unlikely scenario which is more indicative of a bug + return Err(Error::from("Invalid decimal: overflow when rounding")); + } + } + // We're also one less digit so reduce the scale + let diff = (len - i) as u32; + if diff > scale { + return Err(Error::from("Invalid decimal: overflow from scale mismatch")); + } + scale -= diff; + break; + } else { + data[0] = tmp[0]; + data[1] = tmp[1]; + data[2] = tmp[2]; + let carry = add_by_internal_flattened(&mut data, *digit); + if carry > 0 { + // Highly unlikely scenario which is more indicative of a bug + return Err(Error::from("Invalid decimal: overflow from carry")); + } + } + } + + Ok(Decimal::from_parts(data[0], data[1], data[2], negative, scale)) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::Decimal; + use arrayvec::ArrayString; + use core::{fmt::Write, str::FromStr}; + + #[test] + fn display_does_not_overflow_max_capacity() { + let num = Decimal::from_str("1.2").unwrap(); + let mut buffer = ArrayString::<64>::new(); + let _ = buffer.write_fmt(format_args!("{:.31}", num)).unwrap(); + assert_eq!("1.2000000000000000000000000000000", buffer.as_str()); + } + + #[test] + fn from_str_rounding_0() { + assert_eq!( + parse_str_radix_10("1.234").unwrap().unpack(), + Decimal::new(1234, 3).unpack() + ); + } + + #[test] + fn from_str_rounding_1() { + assert_eq!( + parse_str_radix_10("11111_11111_11111.11111_11111_11111") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(11_111_111_111_111_111_111_111_111_111, 14).unpack() + ); + } + + #[test] + fn from_str_rounding_2() { + assert_eq!( + parse_str_radix_10("11111_11111_11111.11111_11111_11115") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(11_111_111_111_111_111_111_111_111_112, 14).unpack() + ); + } + + #[test] + fn from_str_rounding_3() { + assert_eq!( + parse_str_radix_10("11111_11111_11111.11111_11111_11195") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(1_111_111_111_111_111_111_111_111_1120, 14).unpack() // was Decimal::from_i128_with_scale(1_111_111_111_111_111_111_111_111_112, 13) + ); + } + + #[test] + fn from_str_rounding_4() { + assert_eq!( + parse_str_radix_10("99999_99999_99999.99999_99999_99995") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(10_000_000_000_000_000_000_000_000_000, 13).unpack() // was Decimal::from_i128_with_scale(1_000_000_000_000_000_000_000_000_000, 12) + ); + } + + #[test] + fn from_str_no_rounding_0() { + assert_eq!( + parse_str_radix_10_exact("1.234").unwrap().unpack(), + Decimal::new(1234, 3).unpack() + ); + } + + #[test] + fn from_str_no_rounding_1() { + assert_eq!( + parse_str_radix_10_exact("11111_11111_11111.11111_11111_11111"), + Err(Error::Underflow) + ); + } + + #[test] + fn from_str_no_rounding_2() { + assert_eq!( + parse_str_radix_10_exact("11111_11111_11111.11111_11111_11115"), + Err(Error::Underflow) + ); + } + + #[test] + fn from_str_no_rounding_3() { + assert_eq!( + parse_str_radix_10_exact("11111_11111_11111.11111_11111_11195"), + Err(Error::Underflow) + ); + } + + #[test] + fn from_str_no_rounding_4() { + assert_eq!( + parse_str_radix_10_exact("99999_99999_99999.99999_99999_99995"), + Err(Error::Underflow) + ); + } + + #[test] + fn from_str_many_pointless_chars() { + assert_eq!( + parse_str_radix_10("00________________________________________________________________001.1") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(11, 1).unpack() + ); + } + + #[test] + fn from_str_leading_0s_1() { + assert_eq!( + parse_str_radix_10("00001.1").unwrap().unpack(), + Decimal::from_i128_with_scale(11, 1).unpack() + ); + } + + #[test] + fn from_str_leading_0s_2() { + assert_eq!( + parse_str_radix_10("00000_00000_00000_00000_00001.00001") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(100001, 5).unpack() + ); + } + + #[test] + fn from_str_leading_0s_3() { + assert_eq!( + parse_str_radix_10("0.00000_00000_00000_00000_00000_00100") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(1, 28).unpack() + ); + } + + #[test] + fn from_str_trailing_0s_1() { + assert_eq!( + parse_str_radix_10("0.00001_00000_00000").unwrap().unpack(), + Decimal::from_i128_with_scale(10_000_000_000, 15).unpack() + ); + } + + #[test] + fn from_str_trailing_0s_2() { + assert_eq!( + parse_str_radix_10("0.00001_00000_00000_00000_00000_00000") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(100_000_000_000_000_000_000_000, 28).unpack() + ); + } + + #[test] + fn from_str_overflow_1() { + assert_eq!( + parse_str_radix_10("99999_99999_99999_99999_99999_99999.99999"), + // The original implementation returned + // Ok(10000_00000_00000_00000_00000_0000) + // Which is a bug! + Err(Error::from("Invalid decimal: overflow from too many digits")) + ); + } + + #[test] + fn from_str_overflow_2() { + assert!( + parse_str_radix_10("99999_99999_99999_99999_99999_11111.11111").is_err(), + // The original implementation is 'overflow from scale mismatch' + // but we got rid of that now + ); + } + + #[test] + fn from_str_overflow_3() { + assert!( + parse_str_radix_10("99999_99999_99999_99999_99999_99994").is_err() // We could not get into 'overflow when rounding' or 'overflow from carry' + // in the original implementation because the rounding logic before prevented it + ); + } + + #[test] + fn from_str_overflow_4() { + assert_eq!( + // This does not overflow, moving the decimal point 1 more step would result in + // 'overflow from too many digits' + parse_str_radix_10("99999_99999_99999_99999_99999_999.99") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(10_000_000_000_000_000_000_000_000_000, 0).unpack() + ); + } + + #[test] + fn from_str_mantissa_overflow_1() { + // reminder: + assert_eq!(OVERFLOW_U96, 79_228_162_514_264_337_593_543_950_336); + assert_eq!( + parse_str_radix_10("79_228_162_514_264_337_593_543_950_33.56") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(79_228_162_514_264_337_593_543_950_34, 0).unpack() + ); + // This is a mantissa of OVERFLOW_U96 - 1 just before reaching the last digit. + // Previously, this would return Err("overflow from mantissa after rounding") + // instead of successfully rounding. + } + + #[test] + fn from_str_mantissa_overflow_2() { + assert_eq!( + parse_str_radix_10("79_228_162_514_264_337_593_543_950_335.6"), + Err(Error::from("Invalid decimal: overflow from mantissa after rounding")) + ); + // this case wants to round to 79_228_162_514_264_337_593_543_950_340. + // (79_228_162_514_264_337_593_543_950_336 is OVERFLOW_U96 and too large + // to fit in 96 bits) which is also too large for the mantissa so fails. + } + + #[test] + fn from_str_mantissa_overflow_3() { + // this hits the other avoidable overflow case in maybe_round + assert_eq!( + parse_str_radix_10("7.92281625142643375935439503356").unwrap().unpack(), + Decimal::from_i128_with_scale(79_228_162_514_264_337_593_543_950_34, 27).unpack() + ); + } + + #[ignore] + #[test] + fn from_str_mantissa_overflow_4() { + // Same test as above, however with underscores. This causes issues. + assert_eq!( + parse_str_radix_10("7.9_228_162_514_264_337_593_543_950_335_6") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(79_228_162_514_264_337_593_543_950_34, 27).unpack() + ); + } + + #[test] + fn from_str_edge_cases_1() { + assert_eq!(parse_str_radix_10(""), Err(Error::from("Invalid decimal: empty"))); + } + + #[test] + fn from_str_edge_cases_2() { + assert_eq!( + parse_str_radix_10("0.1."), + Err(Error::from("Invalid decimal: two decimal points")) + ); + } + + #[test] + fn from_str_edge_cases_3() { + assert_eq!( + parse_str_radix_10("_"), + Err(Error::from("Invalid decimal: must start lead with a number")) + ); + } + + #[test] + fn from_str_edge_cases_4() { + assert_eq!( + parse_str_radix_10("1?2"), + Err(Error::from("Invalid decimal: unknown character")) + ); + } + + #[test] + fn from_str_edge_cases_5() { + assert_eq!( + parse_str_radix_10("."), + Err(Error::from("Invalid decimal: no digits found")) + ); + } + + #[test] + fn from_str_edge_cases_6() { + // Decimal::MAX + 0.99999 + assert_eq!( + parse_str_radix_10("79_228_162_514_264_337_593_543_950_335.99999"), + Err(Error::from("Invalid decimal: overflow from mantissa after rounding")) + ); + } +} |