summaryrefslogtreecommitdiffstats
path: root/third_party/rust/viaduct/src/headers.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/viaduct/src/headers.rs')
-rw-r--r--third_party/rust/viaduct/src/headers.rs414
1 files changed, 414 insertions, 0 deletions
diff --git a/third_party/rust/viaduct/src/headers.rs b/third_party/rust/viaduct/src/headers.rs
new file mode 100644
index 0000000000..d7b7f7aa11
--- /dev/null
+++ b/third_party/rust/viaduct/src/headers.rs
@@ -0,0 +1,414 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+pub use name::{HeaderName, InvalidHeaderName};
+use std::collections::HashMap;
+use std::iter::FromIterator;
+use std::str::FromStr;
+mod name;
+
+/// A single header. Headers have a name (case insensitive) and a value. The
+/// character set for header and values are both restrictive.
+/// - Names must only contain a-zA-Z0-9 and and ('!' | '#' | '$' | '%' | '&' |
+/// '\'' | '*' | '+' | '-' | '.' | '^' | '_' | '`' | '|' | '~') characters
+/// (the field-name token production defined at
+/// https://tools.ietf.org/html/rfc7230#section-3.2).
+/// For request headers, we expect these to all be specified statically,
+/// and so we panic if you provide an invalid one. (For response headers, we
+/// ignore headers with invalid names, but emit a warning).
+///
+/// Header names are case insensitive, and we have several pre-defined ones in
+/// the [`header_names`] module.
+///
+/// - Values may only contain printable ascii characters, and may not contain
+/// \r or \n. Strictly speaking, HTTP is more flexible for header values,
+/// however we don't need to support binary header values, and so we do not.
+///
+/// Note that typically you should not interact with this directly, and instead
+/// use the methods on [`Request`] or [`Headers`] to manipulate these.
+#[derive(Clone, Debug, PartialEq, PartialOrd, Hash, Eq, Ord)]
+pub struct Header {
+ pub(crate) name: HeaderName,
+ pub(crate) value: String,
+}
+
+// Trim `s` without copying if it can be avoided.
+fn trim_string<S: AsRef<str> + Into<String>>(s: S) -> String {
+ let sr = s.as_ref();
+ let trimmed = sr.trim();
+ if sr.len() != trimmed.len() {
+ trimmed.into()
+ } else {
+ s.into()
+ }
+}
+
+fn is_valid_header_value(value: &str) -> bool {
+ value.bytes().all(|b| (32..127).contains(&b) || b == b'\t')
+}
+
+impl Header {
+ pub fn new<Name, Value>(name: Name, value: Value) -> Result<Self, crate::Error>
+ where
+ Name: Into<HeaderName>,
+ Value: AsRef<str> + Into<String>,
+ {
+ let name = name.into();
+ let value = trim_string(value);
+ if !is_valid_header_value(&value) {
+ return Err(crate::Error::RequestHeaderError(name));
+ }
+ Ok(Self { name, value })
+ }
+
+ pub fn new_unchecked<Value>(name: HeaderName, value: Value) -> Self
+ where
+ Value: AsRef<str> + Into<String>,
+ {
+ Self {
+ name,
+ value: value.into(),
+ }
+ }
+
+ #[inline]
+ pub fn name(&self) -> &HeaderName {
+ &self.name
+ }
+
+ #[inline]
+ pub fn value(&self) -> &str {
+ &self.value
+ }
+
+ #[inline]
+ fn set_value<V: AsRef<str>>(&mut self, s: V) -> Result<(), crate::Error> {
+ let value = s.as_ref();
+ if !is_valid_header_value(value) {
+ Err(crate::Error::RequestHeaderError(self.name.clone()))
+ } else {
+ self.value.clear();
+ self.value.push_str(s.as_ref().trim());
+ Ok(())
+ }
+ }
+}
+
+impl std::fmt::Display for Header {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}: {}", self.name, self.value)
+ }
+}
+
+/// A list of headers.
+#[derive(Clone, Debug, PartialEq, Eq, Default)]
+pub struct Headers {
+ headers: Vec<Header>,
+}
+
+impl Headers {
+ /// Initialize an empty list of headers.
+ #[inline]
+ pub fn new() -> Self {
+ Default::default()
+ }
+
+ /// Initialize an empty list of headers backed by a vector with the provided
+ /// capacity.
+ pub fn with_capacity(c: usize) -> Self {
+ Self {
+ headers: Vec::with_capacity(c),
+ }
+ }
+
+ /// Convert this list of headers to a Vec<Header>
+ #[inline]
+ pub fn into_vec(self) -> Vec<Header> {
+ self.headers
+ }
+
+ /// Returns the number of headers.
+ #[inline]
+ pub fn len(&self) -> usize {
+ self.headers.len()
+ }
+
+ /// Returns true if `len()` is zero.
+ #[inline]
+ pub fn is_empty(&self) -> bool {
+ self.headers.is_empty()
+ }
+ /// Clear this set of headers.
+ #[inline]
+ pub fn clear(&mut self) {
+ self.headers.clear();
+ }
+
+ /// Insert or update a new header.
+ ///
+ /// This returns an error if you attempt to specify a header with an
+ /// invalid value (values must be printable ASCII and may not contain
+ /// \r or \n)
+ ///
+ /// ## Example
+ /// ```
+ /// # use viaduct::Headers;
+ /// # fn main() -> Result<(), viaduct::Error> {
+ /// let mut h = Headers::new();
+ /// h.insert("My-Cool-Header", "example")?;
+ /// assert_eq!(h.get("My-Cool-Header"), Some("example"));
+ ///
+ /// // Note: names are sensitive
+ /// assert_eq!(h.get("my-cool-header"), Some("example"));
+ ///
+ /// // Also note, constants for headers are in `viaduct::header_names`, and
+ /// // you can chain the result of this function.
+ /// h.insert(viaduct::header_names::CONTENT_TYPE, "something...")?
+ /// .insert("Something-Else", "etc")?;
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn insert<N, V>(&mut self, name: N, value: V) -> Result<&mut Self, crate::Error>
+ where
+ N: Into<HeaderName> + PartialEq<HeaderName>,
+ V: Into<String> + AsRef<str>,
+ {
+ if let Some(entry) = self.headers.iter_mut().find(|h| name == h.name) {
+ entry.set_value(value)?;
+ } else {
+ self.headers.push(Header::new(name, value)?);
+ }
+ Ok(self)
+ }
+
+ /// Insert the provided header unless a header is already specified.
+ /// Mostly used internally, e.g. to set "Content-Type: application/json"
+ /// in `Request::json()` unless it has been set specifically.
+ pub fn insert_if_missing<N, V>(&mut self, name: N, value: V) -> Result<&mut Self, crate::Error>
+ where
+ N: Into<HeaderName> + PartialEq<HeaderName>,
+ V: Into<String> + AsRef<str>,
+ {
+ if !self.headers.iter_mut().any(|h| name == h.name) {
+ self.headers.push(Header::new(name, value)?);
+ }
+ Ok(self)
+ }
+
+ /// Insert or update a header directly. Typically you will want to use
+ /// `insert` over this, as it performs less work if the header needs
+ /// updating instead of insertion.
+ pub fn insert_header(&mut self, new: Header) -> &mut Self {
+ if let Some(entry) = self.headers.iter_mut().find(|h| h.name == new.name) {
+ entry.value = new.value;
+ } else {
+ self.headers.push(new);
+ }
+ self
+ }
+
+ /// Add all the headers in the provided iterator to this list of headers.
+ pub fn extend<I>(&mut self, iter: I) -> &mut Self
+ where
+ I: IntoIterator<Item = Header>,
+ {
+ let it = iter.into_iter();
+ self.headers.reserve(it.size_hint().0);
+ for h in it {
+ self.insert_header(h);
+ }
+ self
+ }
+
+ /// Add all the headers in the provided iterator, unless any of them are Err.
+ pub fn try_extend<I, E>(&mut self, iter: I) -> Result<&mut Self, E>
+ where
+ I: IntoIterator<Item = Result<Header, E>>,
+ {
+ // Not the most efficient but avoids leaving us in an unspecified state
+ // if one returns Err.
+ self.extend(iter.into_iter().collect::<Result<Vec<_>, E>>()?);
+ Ok(self)
+ }
+
+ /// Get the header object with the requested name. Usually, you will
+ /// want to use `get()` or `get_as::<T>()` instead.
+ pub fn get_header<S>(&self, name: S) -> Option<&Header>
+ where
+ S: PartialEq<HeaderName>,
+ {
+ self.headers.iter().find(|h| name == h.name)
+ }
+
+ /// Get the value of the header with the provided name.
+ ///
+ /// See also `get_as`.
+ ///
+ /// ## Example
+ /// ```
+ /// # use viaduct::{Headers, header_names::CONTENT_TYPE};
+ /// # fn main() -> Result<(), viaduct::Error> {
+ /// let mut h = Headers::new();
+ /// h.insert(CONTENT_TYPE, "application/json")?;
+ /// assert_eq!(h.get(CONTENT_TYPE), Some("application/json"));
+ /// assert_eq!(h.get("Something-Else"), None);
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn get<S>(&self, name: S) -> Option<&str>
+ where
+ S: PartialEq<HeaderName>,
+ {
+ self.get_header(name).map(|h| h.value.as_str())
+ }
+
+ /// Get the value of the header with the provided name, and
+ /// attempt to parse it using [`std::str::FromStr`].
+ ///
+ /// - If the header is missing, it returns None.
+ /// - If the header is present but parsing failed, returns
+ /// `Some(Err(<error returned by parsing>))`.
+ /// - Otherwise, returns `Some(Ok(result))`.
+ ///
+ /// Note that if `Option<Result<T, E>>` is inconvenient for you,
+ /// and you wish this returned `Result<Option<T>, E>`, you may use
+ /// the built-in `transpose()` method to convert between them.
+ ///
+ /// ```
+ /// # use viaduct::Headers;
+ /// # fn main() -> Result<(), viaduct::Error> {
+ /// let mut h = Headers::new();
+ /// h.insert("Example", "1234")?.insert("Illegal", "abcd")?;
+ /// let v: Option<Result<i64, _>> = h.get_as("Example");
+ /// assert_eq!(v, Some(Ok(1234)));
+ /// assert_eq!(h.get_as::<i64, _>("Example"), Some(Ok(1234)));
+ /// assert_eq!(h.get_as::<i64, _>("Illegal"), Some("abcd".parse::<i64>()));
+ /// assert_eq!(h.get_as::<i64, _>("Something-Else"), None);
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn get_as<T, S>(&self, name: S) -> Option<Result<T, <T as FromStr>::Err>>
+ where
+ T: FromStr,
+ S: PartialEq<HeaderName>,
+ {
+ self.get(name).map(str::parse)
+ }
+ /// Get the value of the header with the provided name, and
+ /// attempt to parse it using [`std::str::FromStr`].
+ ///
+ /// This is a variant of `get_as` that returns None on error,
+ /// intended to be used for cases where missing and invalid
+ /// headers should be treated the same. (With `get_as` this
+ /// requires `h.get_as(...).and_then(|r| r.ok())`, which is
+ /// somewhat opaque.
+ pub fn try_get<T, S>(&self, name: S) -> Option<T>
+ where
+ T: FromStr,
+ S: PartialEq<HeaderName>,
+ {
+ self.get(name).and_then(|val| val.parse::<T>().ok())
+ }
+
+ /// Get an iterator over the headers in no particular order.
+ ///
+ /// Note that we also implement IntoIterator.
+ pub fn iter(&self) -> <&Headers as IntoIterator>::IntoIter {
+ self.into_iter()
+ }
+}
+
+impl std::iter::IntoIterator for Headers {
+ type IntoIter = <Vec<Header> as IntoIterator>::IntoIter;
+ type Item = Header;
+ fn into_iter(self) -> Self::IntoIter {
+ self.headers.into_iter()
+ }
+}
+
+impl<'a> std::iter::IntoIterator for &'a Headers {
+ type IntoIter = <&'a [Header] as IntoIterator>::IntoIter;
+ type Item = &'a Header;
+ fn into_iter(self) -> Self::IntoIter {
+ self.headers[..].iter()
+ }
+}
+
+impl FromIterator<Header> for Headers {
+ fn from_iter<T>(iter: T) -> Self
+ where
+ T: IntoIterator<Item = Header>,
+ {
+ let mut v = iter.into_iter().collect::<Vec<Header>>();
+ v.sort_by(|a, b| a.name.cmp(&b.name));
+ v.reverse();
+ v.dedup_by(|a, b| a.name == b.name);
+ Headers { headers: v }
+ }
+}
+
+#[allow(clippy::implicit_hasher)] // https://github.com/rust-lang/rust-clippy/issues/3899
+impl From<Headers> for HashMap<String, String> {
+ fn from(headers: Headers) -> HashMap<String, String> {
+ headers
+ .into_iter()
+ .map(|h| (String::from(h.name), h.value))
+ .collect()
+ }
+}
+
+pub mod consts {
+ use super::name::HeaderName;
+ macro_rules! def_header_consts {
+ ($(($NAME:ident, $string:literal)),* $(,)?) => {
+ $(pub const $NAME: HeaderName = HeaderName(std::borrow::Cow::Borrowed($string));)*
+ };
+ }
+
+ macro_rules! headers {
+ ($(($NAME:ident, $string:literal)),* $(,)?) => {
+ def_header_consts!($(($NAME, $string)),*);
+ // Unused except for tests.
+ const _ALL: &[&str] = &[$($string),*];
+ };
+ }
+
+ // Predefined header names, for convenience.
+ // Feel free to add to these.
+ headers!(
+ (ACCEPT_ENCODING, "accept-encoding"),
+ (ACCEPT, "accept"),
+ (AUTHORIZATION, "authorization"),
+ (CONTENT_TYPE, "content-type"),
+ (ETAG, "etag"),
+ (IF_NONE_MATCH, "if-none-match"),
+ (USER_AGENT, "user-agent"),
+ // non-standard, but it's convenient to have these.
+ (RETRY_AFTER, "retry-after"),
+ (X_IF_UNMODIFIED_SINCE, "x-if-unmodified-since"),
+ (X_KEYID, "x-keyid"),
+ (X_LAST_MODIFIED, "x-last-modified"),
+ (X_TIMESTAMP, "x-timestamp"),
+ (X_WEAVE_NEXT_OFFSET, "x-weave-next-offset"),
+ (X_WEAVE_RECORDS, "x-weave-records"),
+ (X_WEAVE_TIMESTAMP, "x-weave-timestamp"),
+ (X_WEAVE_BACKOFF, "x-weave-backoff"),
+ );
+
+ #[test]
+ fn test_predefined() {
+ for &name in _ALL {
+ assert!(
+ HeaderName::new(name).is_ok(),
+ "Invalid header name in predefined header constants: {}",
+ name
+ );
+ assert_eq!(
+ name.to_ascii_lowercase(),
+ name,
+ "Non-lowercase name in predefined header constants: {}",
+ name
+ );
+ }
+ }
+}