//! Generic scalar type with primitive functionality. use crate::{ bigint::{prelude::*, Limb, NonZero}, scalar::FromUintUnchecked, scalar::IsHigh, Curve, Error, FieldBytes, FieldBytesEncoding, Result, }; use base16ct::HexDisplay; use core::{ cmp::Ordering, fmt, ops::{Add, AddAssign, Neg, ShrAssign, Sub, SubAssign}, str, }; use generic_array::{typenum::Unsigned, GenericArray}; use rand_core::CryptoRngCore; use subtle::{ Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess, CtOption, }; use zeroize::DefaultIsZeroes; #[cfg(feature = "arithmetic")] use super::{CurveArithmetic, Scalar}; #[cfg(feature = "serde")] use serdect::serde::{de, ser, Deserialize, Serialize}; /// Generic scalar type with primitive functionality. /// /// This type provides a baseline level of scalar arithmetic functionality /// which is always available for all curves, regardless of if they implement /// any arithmetic traits. /// /// # `serde` support /// /// When the optional `serde` feature of this create is enabled, [`Serialize`] /// and [`Deserialize`] impls are provided for this type. /// /// The serialization is a fixed-width big endian encoding. When used with /// textual formats, the binary data is encoded as hexadecimal. // TODO(tarcieri): use `crypto-bigint`'s `Residue` type, expose more functionality? #[derive(Copy, Clone, Debug, Default)] pub struct ScalarPrimitive { /// Inner unsigned integer type. inner: C::Uint, } impl ScalarPrimitive where C: Curve, { /// Zero scalar. pub const ZERO: Self = Self { inner: C::Uint::ZERO, }; /// Multiplicative identity. pub const ONE: Self = Self { inner: C::Uint::ONE, }; /// Scalar modulus. pub const MODULUS: C::Uint = C::ORDER; /// Generate a random [`ScalarPrimitive`]. pub fn random(rng: &mut impl CryptoRngCore) -> Self { Self { inner: C::Uint::random_mod(rng, &NonZero::new(Self::MODULUS).unwrap()), } } /// Create a new scalar from [`Curve::Uint`]. pub fn new(uint: C::Uint) -> CtOption { CtOption::new(Self { inner: uint }, uint.ct_lt(&Self::MODULUS)) } /// Decode [`ScalarPrimitive`] from a serialized field element pub fn from_bytes(bytes: &FieldBytes) -> CtOption { Self::new(C::Uint::decode_field_bytes(bytes)) } /// Decode [`ScalarPrimitive`] from a big endian byte slice. pub fn from_slice(slice: &[u8]) -> Result { if slice.len() == C::FieldBytesSize::USIZE { Option::from(Self::from_bytes(GenericArray::from_slice(slice))).ok_or(Error) } else { Err(Error) } } /// Borrow the inner `C::Uint`. pub fn as_uint(&self) -> &C::Uint { &self.inner } /// Borrow the inner limbs as a slice. pub fn as_limbs(&self) -> &[Limb] { self.inner.as_ref() } /// Is this [`ScalarPrimitive`] value equal to zero? pub fn is_zero(&self) -> Choice { self.inner.is_zero() } /// Is this [`ScalarPrimitive`] value even? pub fn is_even(&self) -> Choice { self.inner.is_even() } /// Is this [`ScalarPrimitive`] value odd? pub fn is_odd(&self) -> Choice { self.inner.is_odd() } /// Encode [`ScalarPrimitive`] as a serialized field element. pub fn to_bytes(&self) -> FieldBytes { self.inner.encode_field_bytes() } /// Convert to a `C::Uint`. pub fn to_uint(&self) -> C::Uint { self.inner } } impl FromUintUnchecked for ScalarPrimitive where C: Curve, { type Uint = C::Uint; fn from_uint_unchecked(uint: C::Uint) -> Self { Self { inner: uint } } } #[cfg(feature = "arithmetic")] impl ScalarPrimitive where C: CurveArithmetic, { /// Convert [`ScalarPrimitive`] into a given curve's scalar type. pub(super) fn to_scalar(self) -> Scalar { Scalar::::from_uint_unchecked(self.inner) } } // TODO(tarcieri): better encapsulate this? impl AsRef<[Limb]> for ScalarPrimitive where C: Curve, { fn as_ref(&self) -> &[Limb] { self.as_limbs() } } impl ConditionallySelectable for ScalarPrimitive where C: Curve, { fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { Self { inner: C::Uint::conditional_select(&a.inner, &b.inner, choice), } } } impl ConstantTimeEq for ScalarPrimitive where C: Curve, { fn ct_eq(&self, other: &Self) -> Choice { self.inner.ct_eq(&other.inner) } } impl ConstantTimeLess for ScalarPrimitive where C: Curve, { fn ct_lt(&self, other: &Self) -> Choice { self.inner.ct_lt(&other.inner) } } impl ConstantTimeGreater for ScalarPrimitive where C: Curve, { fn ct_gt(&self, other: &Self) -> Choice { self.inner.ct_gt(&other.inner) } } impl DefaultIsZeroes for ScalarPrimitive {} impl Eq for ScalarPrimitive {} impl PartialEq for ScalarPrimitive where C: Curve, { fn eq(&self, other: &Self) -> bool { self.ct_eq(other).into() } } impl PartialOrd for ScalarPrimitive where C: Curve, { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for ScalarPrimitive where C: Curve, { fn cmp(&self, other: &Self) -> Ordering { self.inner.cmp(&other.inner) } } impl From for ScalarPrimitive where C: Curve, { fn from(n: u64) -> Self { Self { inner: C::Uint::from(n), } } } impl Add> for ScalarPrimitive where C: Curve, { type Output = Self; fn add(self, other: Self) -> Self { self.add(&other) } } impl Add<&ScalarPrimitive> for ScalarPrimitive where C: Curve, { type Output = Self; fn add(self, other: &Self) -> Self { Self { inner: self.inner.add_mod(&other.inner, &Self::MODULUS), } } } impl AddAssign> for ScalarPrimitive where C: Curve, { fn add_assign(&mut self, other: Self) { *self = *self + other; } } impl AddAssign<&ScalarPrimitive> for ScalarPrimitive where C: Curve, { fn add_assign(&mut self, other: &Self) { *self = *self + other; } } impl Sub> for ScalarPrimitive where C: Curve, { type Output = Self; fn sub(self, other: Self) -> Self { self.sub(&other) } } impl Sub<&ScalarPrimitive> for ScalarPrimitive where C: Curve, { type Output = Self; fn sub(self, other: &Self) -> Self { Self { inner: self.inner.sub_mod(&other.inner, &Self::MODULUS), } } } impl SubAssign> for ScalarPrimitive where C: Curve, { fn sub_assign(&mut self, other: Self) { *self = *self - other; } } impl SubAssign<&ScalarPrimitive> for ScalarPrimitive where C: Curve, { fn sub_assign(&mut self, other: &Self) { *self = *self - other; } } impl Neg for ScalarPrimitive where C: Curve, { type Output = Self; fn neg(self) -> Self { Self { inner: self.inner.neg_mod(&Self::MODULUS), } } } impl Neg for &ScalarPrimitive where C: Curve, { type Output = ScalarPrimitive; fn neg(self) -> ScalarPrimitive { -*self } } impl ShrAssign for ScalarPrimitive where C: Curve, { fn shr_assign(&mut self, rhs: usize) { self.inner >>= rhs; } } impl IsHigh for ScalarPrimitive where C: Curve, { fn is_high(&self) -> Choice { let n_2 = C::ORDER >> 1; self.inner.ct_gt(&n_2) } } impl fmt::Display for ScalarPrimitive where C: Curve, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self:X}") } } impl fmt::LowerHex for ScalarPrimitive where C: Curve, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{:x}", HexDisplay(&self.to_bytes())) } } impl fmt::UpperHex for ScalarPrimitive where C: Curve, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{:X}", HexDisplay(&self.to_bytes())) } } impl str::FromStr for ScalarPrimitive where C: Curve, { type Err = Error; fn from_str(hex: &str) -> Result { let mut bytes = FieldBytes::::default(); base16ct::lower::decode(hex, &mut bytes)?; Self::from_slice(&bytes) } } #[cfg(feature = "serde")] impl Serialize for ScalarPrimitive where C: Curve, { fn serialize(&self, serializer: S) -> core::result::Result where S: ser::Serializer, { serdect::array::serialize_hex_upper_or_bin(&self.to_bytes(), serializer) } } #[cfg(feature = "serde")] impl<'de, C> Deserialize<'de> for ScalarPrimitive where C: Curve, { fn deserialize(deserializer: D) -> core::result::Result where D: de::Deserializer<'de>, { let mut bytes = FieldBytes::::default(); serdect::array::deserialize_hex_or_bin(&mut bytes, deserializer)?; Self::from_slice(&bytes).map_err(|_| de::Error::custom("scalar out of range")) } }