summaryrefslogtreecommitdiffstats
path: root/third_party/rust/half/src/vec.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/half/src/vec.rs')
-rw-r--r--third_party/rust/half/src/vec.rs286
1 files changed, 286 insertions, 0 deletions
diff --git a/third_party/rust/half/src/vec.rs b/third_party/rust/half/src/vec.rs
new file mode 100644
index 0000000000..6967656e4d
--- /dev/null
+++ b/third_party/rust/half/src/vec.rs
@@ -0,0 +1,286 @@
+//! Contains utility functions and traits to convert between vectors of [`u16`] bits and [`f16`] or
+//! [`bf16`] vectors.
+//!
+//! The utility [`HalfBitsVecExt`] sealed extension trait is implemented for [`Vec<u16>`] vectors,
+//! while the utility [`HalfFloatVecExt`] sealed extension trait is implemented for both
+//! [`Vec<f16>`] and [`Vec<bf16>`] vectors. These traits provide efficient conversions and
+//! reinterpret casting of larger buffers of floating point values, and are automatically included
+//! in the [`prelude`][crate::prelude] module.
+//!
+//! This module is only available with the `std` or `alloc` feature.
+
+use super::{bf16, f16, slice::HalfFloatSliceExt};
+#[cfg(feature = "alloc")]
+use alloc::vec::Vec;
+use core::mem;
+
+/// Extensions to [`Vec<f16>`] and [`Vec<bf16>`] to support reinterpret operations.
+///
+/// This trait is sealed and cannot be implemented outside of this crate.
+pub trait HalfFloatVecExt: private::SealedHalfFloatVec {
+ /// Reinterprets a vector of [`f16`]or [`bf16`] numbers as a vector of [`u16`] bits.
+ ///
+ /// This is a zero-copy operation. The reinterpreted vector has the same memory location as
+ /// `self`.
+ ///
+ /// # Examples
+ ///
+ /// ```rust
+ /// # use half::prelude::*;
+ /// let float_buffer = vec![f16::from_f32(1.), f16::from_f32(2.), f16::from_f32(3.)];
+ /// let int_buffer = float_buffer.reinterpret_into();
+ ///
+ /// assert_eq!(int_buffer, [f16::from_f32(1.).to_bits(), f16::from_f32(2.).to_bits(), f16::from_f32(3.).to_bits()]);
+ /// ```
+ fn reinterpret_into(self) -> Vec<u16>;
+
+ /// Converts all of the elements of a `[f32]` slice into a new [`f16`] or [`bf16`] vector.
+ ///
+ /// The conversion operation is vectorized over the slice, meaning the conversion may be more
+ /// efficient than converting individual elements on some hardware that supports SIMD
+ /// conversions. See [crate documentation][crate] for more information on hardware conversion
+ /// support.
+ ///
+ /// # Examples
+ /// ```rust
+ /// # use half::prelude::*;
+ /// let float_values = [1., 2., 3., 4.];
+ /// let vec: Vec<f16> = Vec::from_f32_slice(&float_values);
+ ///
+ /// assert_eq!(vec, vec![f16::from_f32(1.), f16::from_f32(2.), f16::from_f32(3.), f16::from_f32(4.)]);
+ /// ```
+ fn from_f32_slice(slice: &[f32]) -> Self;
+
+ /// Converts all of the elements of a `[f64]` slice into a new [`f16`] or [`bf16`] vector.
+ ///
+ /// The conversion operation is vectorized over the slice, meaning the conversion may be more
+ /// efficient than converting individual elements on some hardware that supports SIMD
+ /// conversions. See [crate documentation][crate] for more information on hardware conversion
+ /// support.
+ ///
+ /// # Examples
+ /// ```rust
+ /// # use half::prelude::*;
+ /// let float_values = [1., 2., 3., 4.];
+ /// let vec: Vec<f16> = Vec::from_f64_slice(&float_values);
+ ///
+ /// assert_eq!(vec, vec![f16::from_f64(1.), f16::from_f64(2.), f16::from_f64(3.), f16::from_f64(4.)]);
+ /// ```
+ fn from_f64_slice(slice: &[f64]) -> Self;
+}
+
+/// Extensions to [`Vec<u16>`] to support reinterpret operations.
+///
+/// This trait is sealed and cannot be implemented outside of this crate.
+pub trait HalfBitsVecExt: private::SealedHalfBitsVec {
+ /// Reinterprets a vector of [`u16`] bits as a vector of [`f16`] or [`bf16`] numbers.
+ ///
+ /// `H` is the type to cast to, and must be either the [`f16`] or [`bf16`] type.
+ ///
+ /// This is a zero-copy operation. The reinterpreted vector has the same memory location as
+ /// `self`.
+ ///
+ /// # Examples
+ ///
+ /// ```rust
+ /// # use half::prelude::*;
+ /// let int_buffer = vec![f16::from_f32(1.).to_bits(), f16::from_f32(2.).to_bits(), f16::from_f32(3.).to_bits()];
+ /// let float_buffer = int_buffer.reinterpret_into::<f16>();
+ ///
+ /// assert_eq!(float_buffer, [f16::from_f32(1.), f16::from_f32(2.), f16::from_f32(3.)]);
+ /// ```
+ fn reinterpret_into<H>(self) -> Vec<H>
+ where
+ H: crate::private::SealedHalf;
+}
+
+mod private {
+ use crate::{bf16, f16};
+ #[cfg(feature = "alloc")]
+ use alloc::vec::Vec;
+
+ pub trait SealedHalfFloatVec {}
+ impl SealedHalfFloatVec for Vec<f16> {}
+ impl SealedHalfFloatVec for Vec<bf16> {}
+
+ pub trait SealedHalfBitsVec {}
+ impl SealedHalfBitsVec for Vec<u16> {}
+}
+
+impl HalfFloatVecExt for Vec<f16> {
+ #[inline]
+ fn reinterpret_into(mut self) -> Vec<u16> {
+ // An f16 array has same length and capacity as u16 array
+ let length = self.len();
+ let capacity = self.capacity();
+
+ // Actually reinterpret the contents of the Vec<f16> as u16,
+ // knowing that structs are represented as only their members in memory,
+ // which is the u16 part of `f16(u16)`
+ let pointer = self.as_mut_ptr() as *mut u16;
+
+ // Prevent running a destructor on the old Vec<u16>, so the pointer won't be deleted
+ mem::forget(self);
+
+ // Finally construct a new Vec<f16> from the raw pointer
+ // SAFETY: We are reconstructing full length and capacity of original vector,
+ // using its original pointer, and the size of elements are identical.
+ unsafe { Vec::from_raw_parts(pointer, length, capacity) }
+ }
+
+ fn from_f32_slice(slice: &[f32]) -> Self {
+ let mut vec = Vec::with_capacity(slice.len());
+ // SAFETY: convert will initialize every value in the vector without reading them,
+ // so this is safe to do instead of double initialize from resize, and we're setting it to
+ // same value as capacity.
+ unsafe { vec.set_len(slice.len()) };
+ vec.convert_from_f32_slice(slice);
+ vec
+ }
+
+ fn from_f64_slice(slice: &[f64]) -> Self {
+ let mut vec = Vec::with_capacity(slice.len());
+ // SAFETY: convert will initialize every value in the vector without reading them,
+ // so this is safe to do instead of double initialize from resize, and we're setting it to
+ // same value as capacity.
+ unsafe { vec.set_len(slice.len()) };
+ vec.convert_from_f64_slice(slice);
+ vec
+ }
+}
+
+impl HalfFloatVecExt for Vec<bf16> {
+ #[inline]
+ fn reinterpret_into(mut self) -> Vec<u16> {
+ // An f16 array has same length and capacity as u16 array
+ let length = self.len();
+ let capacity = self.capacity();
+
+ // Actually reinterpret the contents of the Vec<f16> as u16,
+ // knowing that structs are represented as only their members in memory,
+ // which is the u16 part of `f16(u16)`
+ let pointer = self.as_mut_ptr() as *mut u16;
+
+ // Prevent running a destructor on the old Vec<u16>, so the pointer won't be deleted
+ mem::forget(self);
+
+ // Finally construct a new Vec<f16> from the raw pointer
+ // SAFETY: We are reconstructing full length and capacity of original vector,
+ // using its original pointer, and the size of elements are identical.
+ unsafe { Vec::from_raw_parts(pointer, length, capacity) }
+ }
+
+ fn from_f32_slice(slice: &[f32]) -> Self {
+ let mut vec = Vec::with_capacity(slice.len());
+ // SAFETY: convert will initialize every value in the vector without reading them,
+ // so this is safe to do instead of double initialize from resize, and we're setting it to
+ // same value as capacity.
+ unsafe { vec.set_len(slice.len()) };
+ vec.convert_from_f32_slice(slice);
+ vec
+ }
+
+ fn from_f64_slice(slice: &[f64]) -> Self {
+ let mut vec = Vec::with_capacity(slice.len());
+ // SAFETY: convert will initialize every value in the vector without reading them,
+ // so this is safe to do instead of double initialize from resize, and we're setting it to
+ // same value as capacity.
+ unsafe { vec.set_len(slice.len()) };
+ vec.convert_from_f64_slice(slice);
+ vec
+ }
+}
+
+impl HalfBitsVecExt for Vec<u16> {
+ // This is safe because all traits are sealed
+ #[inline]
+ fn reinterpret_into<H>(mut self) -> Vec<H>
+ where
+ H: crate::private::SealedHalf,
+ {
+ // An f16 array has same length and capacity as u16 array
+ let length = self.len();
+ let capacity = self.capacity();
+
+ // Actually reinterpret the contents of the Vec<u16> as f16,
+ // knowing that structs are represented as only their members in memory,
+ // which is the u16 part of `f16(u16)`
+ let pointer = self.as_mut_ptr() as *mut H;
+
+ // Prevent running a destructor on the old Vec<u16>, so the pointer won't be deleted
+ mem::forget(self);
+
+ // Finally construct a new Vec<f16> from the raw pointer
+ // SAFETY: We are reconstructing full length and capacity of original vector,
+ // using its original pointer, and the size of elements are identical.
+ unsafe { Vec::from_raw_parts(pointer, length, capacity) }
+ }
+}
+
+#[doc(hidden)]
+#[deprecated(
+ since = "1.4.0",
+ note = "use `HalfBitsVecExt::reinterpret_into` instead"
+)]
+#[inline]
+pub fn from_bits(bits: Vec<u16>) -> Vec<f16> {
+ bits.reinterpret_into()
+}
+
+#[doc(hidden)]
+#[deprecated(
+ since = "1.4.0",
+ note = "use `HalfFloatVecExt::reinterpret_into` instead"
+)]
+#[inline]
+pub fn to_bits(numbers: Vec<f16>) -> Vec<u16> {
+ numbers.reinterpret_into()
+}
+
+#[cfg(test)]
+mod test {
+ use super::{HalfBitsVecExt, HalfFloatVecExt};
+ use crate::{bf16, f16};
+ #[cfg(all(feature = "alloc", not(feature = "std")))]
+ use alloc::vec;
+
+ #[test]
+ fn test_vec_conversions_f16() {
+ let numbers = vec![f16::E, f16::PI, f16::EPSILON, f16::FRAC_1_SQRT_2];
+ let bits = vec![
+ f16::E.to_bits(),
+ f16::PI.to_bits(),
+ f16::EPSILON.to_bits(),
+ f16::FRAC_1_SQRT_2.to_bits(),
+ ];
+ let bits_cloned = bits.clone();
+
+ // Convert from bits to numbers
+ let from_bits = bits.reinterpret_into::<f16>();
+ assert_eq!(&from_bits[..], &numbers[..]);
+
+ // Convert from numbers back to bits
+ let to_bits = from_bits.reinterpret_into();
+ assert_eq!(&to_bits[..], &bits_cloned[..]);
+ }
+
+ #[test]
+ fn test_vec_conversions_bf16() {
+ let numbers = vec![bf16::E, bf16::PI, bf16::EPSILON, bf16::FRAC_1_SQRT_2];
+ let bits = vec![
+ bf16::E.to_bits(),
+ bf16::PI.to_bits(),
+ bf16::EPSILON.to_bits(),
+ bf16::FRAC_1_SQRT_2.to_bits(),
+ ];
+ let bits_cloned = bits.clone();
+
+ // Convert from bits to numbers
+ let from_bits = bits.reinterpret_into::<bf16>();
+ assert_eq!(&from_bits[..], &numbers[..]);
+
+ // Convert from numbers back to bits
+ let to_bits = from_bits.reinterpret_into();
+ assert_eq!(&to_bits[..], &bits_cloned[..]);
+ }
+}