diff options
Diffstat (limited to 'third_party/rust/gfx-auxil/src/lib.rs')
-rw-r--r-- | third_party/rust/gfx-auxil/src/lib.rs | 143 |
1 files changed, 143 insertions, 0 deletions
diff --git a/third_party/rust/gfx-auxil/src/lib.rs b/third_party/rust/gfx-auxil/src/lib.rs new file mode 100644 index 0000000000..89a0109931 --- /dev/null +++ b/third_party/rust/gfx-auxil/src/lib.rs @@ -0,0 +1,143 @@ +use std::{io, slice}; +#[cfg(feature = "spirv_cross")] +use { + hal::{device::ShaderError, pso}, + spirv_cross::spirv, +}; + +/// Fast hash map used internally. +pub type FastHashMap<K, V> = + std::collections::HashMap<K, V, std::hash::BuildHasherDefault<fxhash::FxHasher>>; +pub type FastHashSet<K> = + std::collections::HashSet<K, std::hash::BuildHasherDefault<fxhash::FxHasher>>; + +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[repr(u8)] +pub enum ShaderStage { + Vertex, + Hull, + Domain, + Geometry, + Fragment, + Compute, + Task, + Mesh, +} + +impl ShaderStage { + pub fn to_flag(self) -> hal::pso::ShaderStageFlags { + use hal::pso::ShaderStageFlags as Ssf; + match self { + ShaderStage::Vertex => Ssf::VERTEX, + ShaderStage::Hull => Ssf::HULL, + ShaderStage::Domain => Ssf::DOMAIN, + ShaderStage::Geometry => Ssf::GEOMETRY, + ShaderStage::Fragment => Ssf::FRAGMENT, + ShaderStage::Compute => Ssf::COMPUTE, + ShaderStage::Task => Ssf::TASK, + ShaderStage::Mesh => Ssf::MESH, + } + } +} + +/// Safely read SPIR-V +/// +/// Converts to native endianness and returns correctly aligned storage without unnecessary +/// copying. Returns an `InvalidData` error if the input is trivially not SPIR-V. +/// +/// This function can also be used to convert an already in-memory `&[u8]` to a valid `Vec<u32>`, +/// but prefer working with `&[u32]` from the start whenever possible. +/// +/// # Examples +/// ```no_run +/// let mut file = std::fs::File::open("/path/to/shader.spv").unwrap(); +/// let words = gfx_auxil::read_spirv(&mut file).unwrap(); +/// ``` +/// ``` +/// const SPIRV: &[u8] = &[ +/// 0x03, 0x02, 0x23, 0x07, // ... +/// ]; +/// let words = gfx_auxil::read_spirv(std::io::Cursor::new(&SPIRV[..])).unwrap(); +/// ``` +pub fn read_spirv<R: io::Read + io::Seek>(mut x: R) -> io::Result<Vec<u32>> { + let size = x.seek(io::SeekFrom::End(0))?; + if size % 4 != 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "input length not divisible by 4", + )); + } + if size > usize::max_value() as u64 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "input too long")); + } + let words = (size / 4) as usize; + let mut result = Vec::<u32>::with_capacity(words); + x.seek(io::SeekFrom::Start(0))?; + unsafe { + // Writing all bytes through a pointer with less strict alignment when our type has no + // invalid bitpatterns is safe. + x.read_exact(slice::from_raw_parts_mut( + result.as_mut_ptr() as *mut u8, + words * 4, + ))?; + result.set_len(words); + } + const MAGIC_NUMBER: u32 = 0x07230203; + if result.len() > 0 && result[0] == MAGIC_NUMBER.swap_bytes() { + for word in &mut result { + *word = word.swap_bytes(); + } + } + if result.len() == 0 || result[0] != MAGIC_NUMBER { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "input missing SPIR-V magic number", + )); + } + Ok(result) +} + +#[cfg(feature = "spirv_cross")] +pub fn spirv_cross_specialize_ast<T>( + ast: &mut spirv::Ast<T>, + specialization: &pso::Specialization, +) -> Result<(), ShaderError> +where + T: spirv::Target, + spirv::Ast<T>: spirv::Compile<T> + spirv::Parse<T>, +{ + let spec_constants = ast.get_specialization_constants().map_err(|err| { + ShaderError::CompilationFailed(match err { + spirv_cross::ErrorCode::CompilationError(msg) => msg, + spirv_cross::ErrorCode::Unhandled => "Unexpected specialization constant error".into(), + }) + })?; + + for spec_constant in spec_constants { + if let Some(constant) = specialization + .constants + .iter() + .find(|c| c.id == spec_constant.constant_id) + { + // Override specialization constant values + let value = specialization.data + [constant.range.start as usize..constant.range.end as usize] + .iter() + .rev() + .fold(0u64, |u, &b| (u << 8) + b as u64); + + ast.set_scalar_constant(spec_constant.id, value) + .map_err(|err| { + ShaderError::CompilationFailed(match err { + spirv_cross::ErrorCode::CompilationError(msg) => msg, + spirv_cross::ErrorCode::Unhandled => { + "Unexpected specialization constant error".into() + } + }) + })?; + } + } + + Ok(()) +} |