//! OID encoder with `const` support. use crate::{ arcs::{ARC_MAX_FIRST, ARC_MAX_SECOND}, Arc, Error, ObjectIdentifier, Result, }; /// BER/DER encoder #[derive(Debug)] pub(crate) struct Encoder { /// Current state state: State, /// Bytes of the OID being encoded in-progress bytes: [u8; ObjectIdentifier::MAX_SIZE], /// Current position within the byte buffer cursor: usize, } /// Current state of the encoder #[derive(Debug)] enum State { /// Initial state - no arcs yet encoded Initial, /// First arc parsed FirstArc(Arc), /// Encoding base 128 body of the OID Body, } impl Encoder { /// Create a new encoder initialized to an empty default state. pub(crate) const fn new() -> Self { Self { state: State::Initial, bytes: [0u8; ObjectIdentifier::MAX_SIZE], cursor: 0, } } /// Extend an existing OID. pub(crate) const fn extend(oid: ObjectIdentifier) -> Self { Self { state: State::Body, bytes: oid.bytes, cursor: oid.length as usize, } } /// Encode an [`Arc`] as base 128 into the internal buffer. pub(crate) const fn arc(mut self, arc: Arc) -> Result { match self.state { State::Initial => { if arc > ARC_MAX_FIRST { return Err(Error::ArcInvalid { arc }); } self.state = State::FirstArc(arc); Ok(self) } // Ensured not to overflow by `ARC_MAX_SECOND` check #[allow(clippy::integer_arithmetic)] State::FirstArc(first_arc) => { if arc > ARC_MAX_SECOND { return Err(Error::ArcInvalid { arc }); } self.state = State::Body; self.bytes[0] = (first_arc * (ARC_MAX_SECOND + 1)) as u8 + arc as u8; self.cursor = 1; Ok(self) } // TODO(tarcieri): finer-grained overflow safety / checked arithmetic #[allow(clippy::integer_arithmetic)] State::Body => { // Total number of bytes in encoded arc - 1 let nbytes = base128_len(arc); // Shouldn't overflow on any 16-bit+ architectures if self.cursor + nbytes + 1 >= ObjectIdentifier::MAX_SIZE { return Err(Error::Length); } let new_cursor = self.cursor + nbytes + 1; // TODO(tarcieri): use `?` when stable in `const fn` match self.encode_base128_byte(arc, nbytes, false) { Ok(mut encoder) => { encoder.cursor = new_cursor; Ok(encoder) } Err(err) => Err(err), } } } } /// Finish encoding an OID. pub(crate) const fn finish(self) -> Result { if self.cursor >= 2 { Ok(ObjectIdentifier { bytes: self.bytes, length: self.cursor as u8, }) } else { Err(Error::NotEnoughArcs) } } /// Encode a single byte of a Base 128 value. const fn encode_base128_byte(mut self, mut n: u32, i: usize, continued: bool) -> Result { let mask = if continued { 0b10000000 } else { 0 }; // Underflow checked by branch #[allow(clippy::integer_arithmetic)] if n > 0x80 { self.bytes[checked_add!(self.cursor, i)] = (n & 0b1111111) as u8 | mask; n >>= 7; if i > 0 { self.encode_base128_byte(n, i.saturating_sub(1), true) } else { Err(Error::Base128) } } else { self.bytes[self.cursor] = n as u8 | mask; Ok(self) } } } /// Compute the length - 1 of an arc when encoded in base 128. const fn base128_len(arc: Arc) -> usize { match arc { 0..=0x7f => 0, 0x80..=0x3fff => 1, 0x4000..=0x1fffff => 2, 0x200000..=0x1fffffff => 3, _ => 4, } } #[cfg(test)] mod tests { use super::Encoder; use hex_literal::hex; /// OID `1.2.840.10045.2.1` encoded as ASN.1 BER/DER const EXAMPLE_OID_BER: &[u8] = &hex!("2A8648CE3D0201"); #[test] fn encode() { let encoder = Encoder::new(); let encoder = encoder.arc(1).unwrap(); let encoder = encoder.arc(2).unwrap(); let encoder = encoder.arc(840).unwrap(); let encoder = encoder.arc(10045).unwrap(); let encoder = encoder.arc(2).unwrap(); let encoder = encoder.arc(1).unwrap(); assert_eq!(&encoder.bytes[..encoder.cursor], EXAMPLE_OID_BER); } }