summaryrefslogtreecommitdiffstats
path: root/vendor/crypto-bigint/src/non_zero.rs
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/crypto-bigint/src/non_zero.rs')
-rw-r--r--vendor/crypto-bigint/src/non_zero.rs393
1 files changed, 393 insertions, 0 deletions
diff --git a/vendor/crypto-bigint/src/non_zero.rs b/vendor/crypto-bigint/src/non_zero.rs
new file mode 100644
index 0000000..dd4294e
--- /dev/null
+++ b/vendor/crypto-bigint/src/non_zero.rs
@@ -0,0 +1,393 @@
+//! Wrapper type for non-zero integers.
+
+use crate::{CtChoice, Encoding, Integer, Limb, Uint, Zero};
+use core::{
+ fmt,
+ num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8},
+ ops::Deref,
+};
+use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
+
+#[cfg(feature = "generic-array")]
+use crate::{ArrayEncoding, ByteArray};
+
+#[cfg(feature = "rand_core")]
+use {crate::Random, rand_core::CryptoRngCore};
+
+#[cfg(feature = "serde")]
+use serdect::serde::{
+ de::{Error, Unexpected},
+ Deserialize, Deserializer, Serialize, Serializer,
+};
+
+/// Wrapper type for non-zero integers.
+#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)]
+pub struct NonZero<T: Zero>(T);
+
+impl NonZero<Limb> {
+ /// Creates a new non-zero limb in a const context.
+ /// The second return value is `FALSE` if `n` is zero, `TRUE` otherwise.
+ pub const fn const_new(n: Limb) -> (Self, CtChoice) {
+ (Self(n), n.ct_is_nonzero())
+ }
+}
+
+impl<const LIMBS: usize> NonZero<Uint<LIMBS>> {
+ /// Creates a new non-zero integer in a const context.
+ /// The second return value is `FALSE` if `n` is zero, `TRUE` otherwise.
+ pub const fn const_new(n: Uint<LIMBS>) -> (Self, CtChoice) {
+ (Self(n), n.ct_is_nonzero())
+ }
+}
+
+impl<T> NonZero<T>
+where
+ T: Zero,
+{
+ /// Create a new non-zero integer.
+ pub fn new(n: T) -> CtOption<Self> {
+ let is_zero = n.is_zero();
+ CtOption::new(Self(n), !is_zero)
+ }
+}
+
+impl<T> NonZero<T>
+where
+ T: Integer,
+{
+ /// The value `1`.
+ pub const ONE: Self = Self(T::ONE);
+
+ /// Maximum value this integer can express.
+ pub const MAX: Self = Self(T::MAX);
+}
+
+impl<T> NonZero<T>
+where
+ T: Encoding + Zero,
+{
+ /// Decode from big endian bytes.
+ pub fn from_be_bytes(bytes: T::Repr) -> CtOption<Self> {
+ Self::new(T::from_be_bytes(bytes))
+ }
+
+ /// Decode from little endian bytes.
+ pub fn from_le_bytes(bytes: T::Repr) -> CtOption<Self> {
+ Self::new(T::from_le_bytes(bytes))
+ }
+}
+
+#[cfg(feature = "generic-array")]
+impl<T> NonZero<T>
+where
+ T: ArrayEncoding + Zero,
+{
+ /// Decode a non-zero integer from big endian bytes.
+ pub fn from_be_byte_array(bytes: ByteArray<T>) -> CtOption<Self> {
+ Self::new(T::from_be_byte_array(bytes))
+ }
+
+ /// Decode a non-zero integer from big endian bytes.
+ pub fn from_le_byte_array(bytes: ByteArray<T>) -> CtOption<Self> {
+ Self::new(T::from_be_byte_array(bytes))
+ }
+}
+
+impl<T> AsRef<T> for NonZero<T>
+where
+ T: Zero,
+{
+ fn as_ref(&self) -> &T {
+ &self.0
+ }
+}
+
+impl<T> ConditionallySelectable for NonZero<T>
+where
+ T: ConditionallySelectable + Zero,
+{
+ fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
+ Self(T::conditional_select(&a.0, &b.0, choice))
+ }
+}
+
+impl<T> ConstantTimeEq for NonZero<T>
+where
+ T: Zero,
+{
+ fn ct_eq(&self, other: &Self) -> Choice {
+ self.0.ct_eq(&other.0)
+ }
+}
+
+impl<T> Deref for NonZero<T>
+where
+ T: Zero,
+{
+ type Target = T;
+
+ fn deref(&self) -> &T {
+ &self.0
+ }
+}
+
+#[cfg(feature = "rand_core")]
+impl<T> Random for NonZero<T>
+where
+ T: Random + Zero,
+{
+ /// Generate a random `NonZero<T>`.
+ fn random(mut rng: &mut impl CryptoRngCore) -> Self {
+ // Use rejection sampling to eliminate zero values.
+ // While this method isn't constant-time, the attacker shouldn't learn
+ // anything about unrelated outputs so long as `rng` is a CSRNG.
+ loop {
+ if let Some(result) = Self::new(T::random(&mut rng)).into() {
+ break result;
+ }
+ }
+ }
+}
+
+impl<T> fmt::Display for NonZero<T>
+where
+ T: fmt::Display + Zero,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt::Display::fmt(&self.0, f)
+ }
+}
+
+impl<T> fmt::Binary for NonZero<T>
+where
+ T: fmt::Binary + Zero,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt::Binary::fmt(&self.0, f)
+ }
+}
+
+impl<T> fmt::Octal for NonZero<T>
+where
+ T: fmt::Octal + Zero,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt::Octal::fmt(&self.0, f)
+ }
+}
+
+impl<T> fmt::LowerHex for NonZero<T>
+where
+ T: fmt::LowerHex + Zero,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt::LowerHex::fmt(&self.0, f)
+ }
+}
+
+impl<T> fmt::UpperHex for NonZero<T>
+where
+ T: fmt::UpperHex + Zero,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt::UpperHex::fmt(&self.0, f)
+ }
+}
+
+impl NonZero<Limb> {
+ /// Create a [`NonZero<Limb>`] from a [`NonZeroU8`] (const-friendly)
+ // TODO(tarcieri): replace with `const impl From<NonZeroU8>` when stable
+ pub const fn from_u8(n: NonZeroU8) -> Self {
+ Self(Limb::from_u8(n.get()))
+ }
+
+ /// Create a [`NonZero<Limb>`] from a [`NonZeroU16`] (const-friendly)
+ // TODO(tarcieri): replace with `const impl From<NonZeroU16>` when stable
+ pub const fn from_u16(n: NonZeroU16) -> Self {
+ Self(Limb::from_u16(n.get()))
+ }
+
+ /// Create a [`NonZero<Limb>`] from a [`NonZeroU32`] (const-friendly)
+ // TODO(tarcieri): replace with `const impl From<NonZeroU32>` when stable
+ pub const fn from_u32(n: NonZeroU32) -> Self {
+ Self(Limb::from_u32(n.get()))
+ }
+
+ /// Create a [`NonZero<Limb>`] from a [`NonZeroU64`] (const-friendly)
+ // TODO(tarcieri): replace with `const impl From<NonZeroU64>` when stable
+ #[cfg(target_pointer_width = "64")]
+ pub const fn from_u64(n: NonZeroU64) -> Self {
+ Self(Limb::from_u64(n.get()))
+ }
+}
+
+impl From<NonZeroU8> for NonZero<Limb> {
+ fn from(integer: NonZeroU8) -> Self {
+ Self::from_u8(integer)
+ }
+}
+
+impl From<NonZeroU16> for NonZero<Limb> {
+ fn from(integer: NonZeroU16) -> Self {
+ Self::from_u16(integer)
+ }
+}
+
+impl From<NonZeroU32> for NonZero<Limb> {
+ fn from(integer: NonZeroU32) -> Self {
+ Self::from_u32(integer)
+ }
+}
+
+#[cfg(target_pointer_width = "64")]
+impl From<NonZeroU64> for NonZero<Limb> {
+ fn from(integer: NonZeroU64) -> Self {
+ Self::from_u64(integer)
+ }
+}
+
+impl<const LIMBS: usize> NonZero<Uint<LIMBS>> {
+ /// Create a [`NonZero<Uint>`] from a [`Uint`] (const-friendly)
+ pub const fn from_uint(n: Uint<LIMBS>) -> Self {
+ let mut i = 0;
+ let mut found_non_zero = false;
+ while i < LIMBS {
+ if n.as_limbs()[i].0 != 0 {
+ found_non_zero = true;
+ }
+ i += 1;
+ }
+ assert!(found_non_zero, "found zero");
+ Self(n)
+ }
+
+ /// Create a [`NonZero<Uint>`] from a [`NonZeroU8`] (const-friendly)
+ // TODO(tarcieri): replace with `const impl From<NonZeroU8>` when stable
+ pub const fn from_u8(n: NonZeroU8) -> Self {
+ Self(Uint::from_u8(n.get()))
+ }
+
+ /// Create a [`NonZero<Uint>`] from a [`NonZeroU16`] (const-friendly)
+ // TODO(tarcieri): replace with `const impl From<NonZeroU16>` when stable
+ pub const fn from_u16(n: NonZeroU16) -> Self {
+ Self(Uint::from_u16(n.get()))
+ }
+
+ /// Create a [`NonZero<Uint>`] from a [`NonZeroU32`] (const-friendly)
+ // TODO(tarcieri): replace with `const impl From<NonZeroU32>` when stable
+ pub const fn from_u32(n: NonZeroU32) -> Self {
+ Self(Uint::from_u32(n.get()))
+ }
+
+ /// Create a [`NonZero<Uint>`] from a [`NonZeroU64`] (const-friendly)
+ // TODO(tarcieri): replace with `const impl From<NonZeroU64>` when stable
+ pub const fn from_u64(n: NonZeroU64) -> Self {
+ Self(Uint::from_u64(n.get()))
+ }
+
+ /// Create a [`NonZero<Uint>`] from a [`NonZeroU128`] (const-friendly)
+ // TODO(tarcieri): replace with `const impl From<NonZeroU128>` when stable
+ pub const fn from_u128(n: NonZeroU128) -> Self {
+ Self(Uint::from_u128(n.get()))
+ }
+}
+
+impl<const LIMBS: usize> From<NonZeroU8> for NonZero<Uint<LIMBS>> {
+ fn from(integer: NonZeroU8) -> Self {
+ Self::from_u8(integer)
+ }
+}
+
+impl<const LIMBS: usize> From<NonZeroU16> for NonZero<Uint<LIMBS>> {
+ fn from(integer: NonZeroU16) -> Self {
+ Self::from_u16(integer)
+ }
+}
+
+impl<const LIMBS: usize> From<NonZeroU32> for NonZero<Uint<LIMBS>> {
+ fn from(integer: NonZeroU32) -> Self {
+ Self::from_u32(integer)
+ }
+}
+
+impl<const LIMBS: usize> From<NonZeroU64> for NonZero<Uint<LIMBS>> {
+ fn from(integer: NonZeroU64) -> Self {
+ Self::from_u64(integer)
+ }
+}
+
+impl<const LIMBS: usize> From<NonZeroU128> for NonZero<Uint<LIMBS>> {
+ fn from(integer: NonZeroU128) -> Self {
+ Self::from_u128(integer)
+ }
+}
+
+#[cfg(feature = "serde")]
+impl<'de, T: Deserialize<'de> + Zero> Deserialize<'de> for NonZero<T> {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ let value: T = T::deserialize(deserializer)?;
+
+ if bool::from(value.is_zero()) {
+ Err(D::Error::invalid_value(
+ Unexpected::Other("zero"),
+ &"a non-zero value",
+ ))
+ } else {
+ Ok(Self(value))
+ }
+ }
+}
+
+#[cfg(feature = "serde")]
+impl<T: Serialize + Zero> Serialize for NonZero<T> {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: Serializer,
+ {
+ self.0.serialize(serializer)
+ }
+}
+
+#[cfg(all(test, feature = "serde"))]
+#[allow(clippy::unwrap_used)]
+mod tests {
+ use crate::{NonZero, U64};
+ use bincode::ErrorKind;
+
+ #[test]
+ fn serde() {
+ let test =
+ Option::<NonZero<U64>>::from(NonZero::new(U64::from_u64(0x0011223344556677))).unwrap();
+
+ let serialized = bincode::serialize(&test).unwrap();
+ let deserialized: NonZero<U64> = bincode::deserialize(&serialized).unwrap();
+
+ assert_eq!(test, deserialized);
+
+ let serialized = bincode::serialize(&U64::ZERO).unwrap();
+ assert!(matches!(
+ *bincode::deserialize::<NonZero<U64>>(&serialized).unwrap_err(),
+ ErrorKind::Custom(message) if message == "invalid value: zero, expected a non-zero value"
+ ));
+ }
+
+ #[test]
+ fn serde_owned() {
+ let test =
+ Option::<NonZero<U64>>::from(NonZero::new(U64::from_u64(0x0011223344556677))).unwrap();
+
+ let serialized = bincode::serialize(&test).unwrap();
+ let deserialized: NonZero<U64> = bincode::deserialize_from(serialized.as_slice()).unwrap();
+
+ assert_eq!(test, deserialized);
+
+ let serialized = bincode::serialize(&U64::ZERO).unwrap();
+ assert!(matches!(
+ *bincode::deserialize_from::<_, NonZero<U64>>(serialized.as_slice()).unwrap_err(),
+ ErrorKind::Custom(message) if message == "invalid value: zero, expected a non-zero value"
+ ));
+ }
+}