diff --git a/const-oid/src/arcs.rs b/const-oid/src/arcs.rs index ab3923ef7..59d340c43 100644 --- a/const-oid/src/arcs.rs +++ b/const-oid/src/arcs.rs @@ -58,7 +58,8 @@ impl<'a> Arcs<'a> { match self.cursor { // Indicates we're on the root arc None => { - let root = RootArcs::try_from(self.bytes[0])?; + let root_byte = *self.bytes.first().ok_or(Error::Empty)?; + let root = RootArcs::try_from(root_byte)?; self.cursor = Some(0); Ok(Some(root.first_arc())) } diff --git a/const-oid/src/db.rs b/const-oid/src/db.rs index b451db722..d2d2c0e62 100644 --- a/const-oid/src/db.rs +++ b/const-oid/src/db.rs @@ -58,7 +58,7 @@ impl<'a> Database<'a> { while i < self.0.len() { let lhs = self.0[i].0; - if lhs.buffer.eq(&oid.buffer) { + if lhs.ber.eq(&oid.ber) { return Some(self.0[i].1); } @@ -110,7 +110,7 @@ impl<'a> Iterator for Names<'a> { while i < self.database.0.len() { let lhs = self.database.0[i].0; - if lhs.buffer.eq(&self.oid.buffer) { + if lhs.ber.eq(&self.oid.ber) { self.position = i + 1; return Some(self.database.0[i].1); } diff --git a/const-oid/src/encoder.rs b/const-oid/src/encoder.rs index e4e8a4bff..5f9401aa6 100644 --- a/const-oid/src/encoder.rs +++ b/const-oid/src/encoder.rs @@ -5,29 +5,29 @@ use crate::{ Arc, Buffer, Error, ObjectIdentifier, Result, }; -/// BER/DER encoder +/// BER/DER encoder. #[derive(Debug)] pub(crate) struct Encoder { - /// Current state + /// Current state. state: State, - /// Bytes of the OID being encoded in-progress + /// Bytes of the OID being BER-encoded in-progress. bytes: [u8; MAX_SIZE], - /// Current position within the byte buffer + /// Current position within the byte buffer. cursor: usize, } -/// Current state of the encoder +/// Current state of the encoder. #[derive(Debug)] enum State { - /// Initial state - no arcs yet encoded + /// Initial state - no arcs yet encoded. Initial, - /// First arc parsed + /// First arc parsed. FirstArc(Arc), - /// Encoding base 128 body of the OID + /// Encoding base 128 body of the OID. Body, } @@ -45,8 +45,8 @@ impl Encoder { pub(crate) const fn extend(oid: ObjectIdentifier) -> Self { Self { state: State::Body, - bytes: oid.buffer.bytes, - cursor: oid.buffer.length as usize, + bytes: oid.ber.bytes, + cursor: oid.ber.length as usize, } } @@ -100,16 +100,16 @@ impl Encoder { /// Finish encoding an OID. pub(crate) const fn finish(self) -> Result> { - if self.cursor >= 2 { - let bytes = Buffer { - bytes: self.bytes, - length: self.cursor as u8, - }; - - Ok(ObjectIdentifier { buffer: bytes }) - } else { - Err(Error::NotEnoughArcs) + if self.cursor == 0 { + return Err(Error::Empty); } + + let ber = Buffer { + bytes: self.bytes, + length: self.cursor as u8, + }; + + Ok(ObjectIdentifier { ber }) } /// Encode a single byte of a Base 128 value. diff --git a/const-oid/src/error.rs b/const-oid/src/error.rs index 528ce785c..a6fa56a5e 100644 --- a/const-oid/src/error.rs +++ b/const-oid/src/error.rs @@ -37,9 +37,6 @@ pub enum Error { /// OID length is invalid (too short or too long). Length, - /// Minimum 3 arcs required. - NotEnoughArcs, - /// Trailing `.` character at end of input. TrailingDot, } @@ -56,7 +53,6 @@ impl Error { Error::DigitExpected { .. } => panic!("OID expected to start with digit"), Error::Empty => panic!("OID value is empty"), Error::Length => panic!("OID length invalid"), - Error::NotEnoughArcs => panic!("OID requires minimum of 3 arcs"), Error::TrailingDot => panic!("OID ends with invalid trailing '.'"), } } @@ -73,7 +69,6 @@ impl fmt::Display for Error { } Error::Empty => f.write_str("OID value is empty"), Error::Length => f.write_str("OID length invalid"), - Error::NotEnoughArcs => f.write_str("OID requires minimum of 3 arcs"), Error::TrailingDot => f.write_str("OID ends with invalid trailing '.'"), } } diff --git a/const-oid/src/lib.rs b/const-oid/src/lib.rs index d85d60165..0dcb557c1 100644 --- a/const-oid/src/lib.rs +++ b/const-oid/src/lib.rs @@ -1,4 +1,4 @@ -#![no_std] +//#![no_std] #![cfg_attr(docsrs, feature(doc_auto_cfg))] #![doc = include_str!("../README.md")] #![doc( @@ -6,7 +6,7 @@ html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/media/6ee8e381/logo.svg" )] #![allow(clippy::len_without_is_empty)] -#![forbid(unsafe_code)] +#![deny(unsafe_code)] #![warn( clippy::arithmetic_side_effects, clippy::mod_module_files, @@ -43,7 +43,7 @@ pub use crate::{ }; use crate::encoder::Encoder; -use core::{fmt, str::FromStr}; +use core::{borrow::Borrow, fmt, ops::Deref, str::FromStr}; /// Default maximum size. /// @@ -68,7 +68,7 @@ const DEFAULT_MAX_SIZE: usize = 39; #[derive(Clone, Copy, Eq, Hash, PartialEq, PartialOrd, Ord)] pub struct ObjectIdentifier { /// Buffer containing BER/DER-serialized bytes (sans ASN.1 tag/length) - buffer: Buffer, + ber: Buffer, } impl ObjectIdentifier { @@ -122,56 +122,22 @@ impl ObjectIdentifier { /// Parse an OID from from its BER/DER encoding. pub fn from_bytes(ber_bytes: &[u8]) -> Result { - let len = ber_bytes.len(); - - match len { - 0 => return Err(Error::Empty), - 3..=Self::MAX_SIZE => (), - _ => return Err(Error::NotEnoughArcs), - } - - let mut bytes = [0u8; Self::MAX_SIZE]; - bytes[..len].copy_from_slice(ber_bytes); - - let bytes = Buffer { - bytes, - length: len as u8, - }; - - let oid = Self { buffer: bytes }; - - // Ensure arcs are well-formed - let mut arcs = oid.arcs(); - while arcs.try_next()?.is_some() {} - - Ok(oid) + ObjectIdentifierRef::from_bytes(ber_bytes)?.try_into() } } impl ObjectIdentifier { /// Get the BER/DER serialization of this OID as bytes. /// - /// Note that this encoding omits the tag/length, and only contains the value portion of the - /// encoded OID. + /// Note that this encoding omits the ASN.1 tag/length, and only contains the value portion of + /// the encoded OID. pub const fn as_bytes(&self) -> &[u8] { - self.buffer.as_bytes() + self.ber.as_bytes() } - /// Return the arc with the given index, if it exists. - pub fn arc(&self, index: usize) -> Option { - self.arcs().nth(index) - } - - /// Iterate over the arcs (a.k.a. nodes) of an [`ObjectIdentifier`]. - /// - /// Returns [`Arcs`], an iterator over [`Arc`] values. - pub fn arcs(&self) -> Arcs<'_> { - Arcs::new(self.buffer.as_ref()) - } - - /// Get the length of this [`ObjectIdentifier`] in arcs. - pub fn len(&self) -> usize { - self.arcs().count() + /// Borrow an [`ObjectIdentifierRef`] which corresponds to this [`ObjectIdentifier`]. + pub const fn as_oid_ref(&self) -> &ObjectIdentifierRef { + ObjectIdentifierRef::from_bytes_unchecked(self.as_bytes()) } /// Get the parent OID of this one (if applicable). @@ -196,7 +162,7 @@ impl ObjectIdentifier { } /// Does this OID start with the other OID? - pub const fn starts_with(&self, other: ObjectIdentifier) -> bool { + pub const fn starts_with(&self, other: ObjectIdentifier) -> bool { let len = other.as_bytes().len(); if self.as_bytes().len() < len { @@ -221,7 +187,21 @@ impl ObjectIdentifier { impl AsRef<[u8]> for ObjectIdentifier { fn as_ref(&self) -> &[u8] { - self.buffer.as_bytes() + self.as_bytes() + } +} + +impl Borrow for ObjectIdentifier { + fn borrow(&self) -> &ObjectIdentifierRef { + self.as_oid_ref() + } +} + +impl Deref for ObjectIdentifier { + type Target = ObjectIdentifierRef; + + fn deref(&self) -> &ObjectIdentifierRef { + self.as_oid_ref() } } @@ -241,6 +221,28 @@ impl TryFrom<&[u8]> for ObjectIdentifier { } } +impl TryFrom<&ObjectIdentifierRef> for ObjectIdentifier { + type Error = Error; + + fn try_from(oid_ref: &ObjectIdentifierRef) -> Result { + let len = oid_ref.as_bytes().len(); + + if len > MAX_SIZE { + return Err(Error::Length); + } + + let mut bytes = [0u8; MAX_SIZE]; + bytes[..len].copy_from_slice(oid_ref.as_bytes()); + + let ber = Buffer { + bytes, + length: len as u8, + }; + + Ok(Self { ber }) + } +} + impl fmt::Debug for ObjectIdentifier { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "ObjectIdentifier({})", self) @@ -249,19 +251,7 @@ impl fmt::Debug for ObjectIdentifier { impl fmt::Display for ObjectIdentifier { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let len = self.arcs().count(); - - for (i, arc) in self.arcs().enumerate() { - write!(f, "{}", arc)?; - - if let Some(j) = i.checked_add(1) { - if j < len { - write!(f, ".")?; - } - } - } - - Ok(()) + write!(f, "{}", self.as_oid_ref()) } } @@ -290,3 +280,102 @@ impl<'a> arbitrary::Arbitrary<'a> for ObjectIdentifier { (Arc::size_hint(depth).0.saturating_mul(3), None) } } + +/// OID reference type: wrapper for the BER serialization. +#[derive(Eq, Hash, PartialEq, PartialOrd, Ord)] +#[repr(transparent)] +pub struct ObjectIdentifierRef { + /// BER/DER-serialized bytes (sans ASN.1 tag/length). + ber: [u8], +} + +impl ObjectIdentifierRef { + /// Create an [`ObjectIdentifierRef`], validating that the provided byte slice contains a valid + /// BER/DER encoding. + // TODO(tarcieri): `const fn` support + pub fn from_bytes(ber: &[u8]) -> Result<&Self> { + // Ensure arcs are well-formed + let mut arcs = Arcs::new(ber); + while arcs.try_next()?.is_some() {} + Ok(Self::from_bytes_unchecked(ber)) + } + + /// Create an [`ObjectIdentifierRef`] from the given byte slice without first checking that it + /// contains valid BER/DER. + pub(crate) const fn from_bytes_unchecked(ber: &[u8]) -> &Self { + debug_assert!(!ber.is_empty()); + + // SAFETY: `ObjectIdentifierRef` is a `repr(transparent)` newtype for `[u8]`. + #[allow(unsafe_code)] + unsafe { + &*(ber as *const [u8] as *const ObjectIdentifierRef) + } + } + + /// Get the BER/DER serialization of this OID as bytes. + /// + /// Note that this encoding omits the ASN.1 tag/length, and only contains the value portion of + /// the encoded OID. + pub const fn as_bytes(&self) -> &[u8] { + &self.ber + } + + /// Return the arc with the given index, if it exists. + pub fn arc(&self, index: usize) -> Option { + self.arcs().nth(index) + } + + /// Iterate over the arcs (a.k.a. nodes) of an [`ObjectIdentifier`]. + /// + /// Returns [`Arcs`], an iterator over [`Arc`] values. + pub fn arcs(&self) -> Arcs<'_> { + Arcs::new(self.ber.as_ref()) + } + + /// Get the length of this [`ObjectIdentifier`] in arcs. + pub fn len(&self) -> usize { + self.arcs().count() + } +} + +impl AsRef<[u8]> for ObjectIdentifierRef { + fn as_ref(&self) -> &[u8] { + self.as_bytes() + } +} + +impl<'a, const MAX_SIZE: usize> From<&'a ObjectIdentifier> for &'a ObjectIdentifierRef { + fn from(oid: &'a ObjectIdentifier) -> &'a ObjectIdentifierRef { + oid.as_oid_ref() + } +} + +impl fmt::Debug for ObjectIdentifierRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ObjectIdentifierRef({})", self) + } +} + +impl fmt::Display for ObjectIdentifierRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let len = self.arcs().count(); + + for (i, arc) in self.arcs().enumerate() { + write!(f, "{}", arc)?; + + if let Some(j) = i.checked_add(1) { + if j < len { + write!(f, ".")?; + } + } + } + + Ok(()) + } +} + +impl PartialEq> for ObjectIdentifierRef { + fn eq(&self, other: &ObjectIdentifier) -> bool { + self.as_bytes().eq(other.as_bytes()) + } +} diff --git a/const-oid/tests/lib.rs b/const-oid/tests/oid.rs similarity index 92% rename from const-oid/tests/lib.rs rename to const-oid/tests/oid.rs index 7efb0f060..7172fc160 100644 --- a/const-oid/tests/lib.rs +++ b/const-oid/tests/oid.rs @@ -1,4 +1,4 @@ -//! `const-oid` crate tests +//! Tests for `ObjectIdentifier`. // TODO(tarcieri): test full set of OID encoding constraints specified here: // @@ -62,16 +62,6 @@ fn from_bytes() { // Empty assert_eq!(ObjectIdentifier::from_bytes(&[]), Err(Error::Empty)); - - // Truncated - assert_eq!( - ObjectIdentifier::from_bytes(&[42]), - Err(Error::NotEnoughArcs) - ); - assert_eq!( - ObjectIdentifier::from_bytes(&[42, 134]), - Err(Error::NotEnoughArcs) - ); } #[test] @@ -103,9 +93,6 @@ fn from_str() { assert_eq!(oid3.arc(6).unwrap(), 1); assert_eq!(oid3, EXAMPLE_OID_LARGE_ARC); - // Too short - assert_eq!("1.2".parse::(), Err(Error::NotEnoughArcs)); - // Truncated assert_eq!( "1.2.840.10045.2.".parse::(), @@ -145,12 +132,6 @@ fn try_from_u32_slice() { assert_eq!(oid2.arc(1).unwrap(), 16); assert_eq!(EXAMPLE_OID_2, oid2); - // Too short - assert_eq!( - ObjectIdentifier::from_arcs([1, 2]), - Err(Error::NotEnoughArcs) - ); - // Invalid first arc assert_eq!( ObjectIdentifier::from_arcs([3, 2, 840, 10045, 3, 1, 7]), @@ -171,13 +152,16 @@ fn as_bytes() { } #[test] -fn parse_empty() { - assert_eq!(ObjectIdentifier::new(""), Err(Error::Empty)); +fn as_oid_ref() { + assert_eq!( + EXAMPLE_OID_0.as_bytes(), + EXAMPLE_OID_0.as_oid_ref().as_bytes() + ); } #[test] -fn parse_not_enough_arcs() { - assert_eq!(ObjectIdentifier::new("1.2"), Err(Error::NotEnoughArcs)); +fn parse_empty() { + assert_eq!(ObjectIdentifier::new(""), Err(Error::Empty)); } #[test] @@ -201,6 +185,9 @@ fn parent() { let child = oid("1.2.3.4"); let parent = child.parent().unwrap(); assert_eq!(parent, oid("1.2.3")); + + let parent = parent.parent().unwrap(); + assert_eq!(parent, oid("1.2")); assert_eq!(parent.parent(), None); } diff --git a/const-oid/tests/oid_ref.rs b/const-oid/tests/oid_ref.rs new file mode 100644 index 000000000..6f7b782cb --- /dev/null +++ b/const-oid/tests/oid_ref.rs @@ -0,0 +1,71 @@ +//! Tests for `ObjectIdentifierRef`. + +use const_oid::{Error, ObjectIdentifier, ObjectIdentifierRef}; +use hex_literal::hex; + +/// Example OID value with a root arc of `0` (and large arc). +const EXAMPLE_OID_0_STR: &str = "0.9.2342.19200300.100.1.1"; +const EXAMPLE_OID_0_BER: &[u8] = &hex!("0992268993F22C640101"); +const EXAMPLE_OID_0: ObjectIdentifier = ObjectIdentifier::new_unwrap(EXAMPLE_OID_0_STR); + +/// Example OID value with a root arc of `1`. +const EXAMPLE_OID_1_STR: &str = "1.2.840.10045.2.1"; +const EXAMPLE_OID_1_BER: &[u8] = &hex!("2A8648CE3D0201"); +const EXAMPLE_OID_1: ObjectIdentifier = ObjectIdentifier::new_unwrap(EXAMPLE_OID_1_STR); + +/// Example OID value with a root arc of `2`. +const EXAMPLE_OID_2_STR: &str = "2.16.840.1.101.3.4.1.42"; +const EXAMPLE_OID_2_BER: &[u8] = &hex!("60864801650304012A"); +const EXAMPLE_OID_2: ObjectIdentifier = ObjectIdentifier::new_unwrap(EXAMPLE_OID_2_STR); + +/// Example OID value with a large arc +const EXAMPLE_OID_LARGE_ARC_STR: &str = "0.9.2342.19200300.100.1.1"; +const EXAMPLE_OID_LARGE_ARC_BER: &[u8] = &hex!("0992268993F22C640101"); +const EXAMPLE_OID_LARGE_ARC: ObjectIdentifier = + ObjectIdentifier::new_unwrap("0.9.2342.19200300.100.1.1"); + +#[test] +fn from_bytes() { + let oid0 = ObjectIdentifierRef::from_bytes(EXAMPLE_OID_0_BER).unwrap(); + assert_eq!(oid0.arc(0).unwrap(), 0); + assert_eq!(oid0.arc(1).unwrap(), 9); + assert_eq!(oid0, &EXAMPLE_OID_0); + + let oid1 = ObjectIdentifierRef::from_bytes(EXAMPLE_OID_1_BER).unwrap(); + assert_eq!(oid1.arc(0).unwrap(), 1); + assert_eq!(oid1.arc(1).unwrap(), 2); + assert_eq!(oid1, &EXAMPLE_OID_1); + + let oid2 = ObjectIdentifierRef::from_bytes(EXAMPLE_OID_2_BER).unwrap(); + assert_eq!(oid2.arc(0).unwrap(), 2); + assert_eq!(oid2.arc(1).unwrap(), 16); + assert_eq!(oid2, &EXAMPLE_OID_2); + + let oid3 = ObjectIdentifierRef::from_bytes(EXAMPLE_OID_LARGE_ARC_BER).unwrap(); + assert_eq!(oid3.arc(0).unwrap(), 0); + assert_eq!(oid3.arc(1).unwrap(), 9); + assert_eq!(oid3.arc(2).unwrap(), 2342); + assert_eq!(oid3.arc(3).unwrap(), 19200300); + assert_eq!(oid3.arc(4).unwrap(), 100); + assert_eq!(oid3.arc(5).unwrap(), 1); + assert_eq!(oid3.arc(6).unwrap(), 1); + assert_eq!(oid3, &EXAMPLE_OID_LARGE_ARC); + + // Empty + assert_eq!(ObjectIdentifierRef::from_bytes(&[]), Err(Error::Empty)); +} + +#[test] +fn display() { + let oid0 = ObjectIdentifierRef::from_bytes(EXAMPLE_OID_0_BER).unwrap(); + assert_eq!(oid0.to_string(), EXAMPLE_OID_0_STR); + + let oid1 = ObjectIdentifierRef::from_bytes(EXAMPLE_OID_1_BER).unwrap(); + assert_eq!(oid1.to_string(), EXAMPLE_OID_1_STR); + + let oid2 = ObjectIdentifierRef::from_bytes(EXAMPLE_OID_2_BER).unwrap(); + assert_eq!(oid2.to_string(), EXAMPLE_OID_2_STR); + + let oid3 = ObjectIdentifierRef::from_bytes(EXAMPLE_OID_LARGE_ARC_BER).unwrap(); + assert_eq!(oid3.to_string(), EXAMPLE_OID_LARGE_ARC_STR); +}