summaryrefslogtreecommitdiffstats
path: root/third_party/rust/gfx-auxil/src/lib.rs
blob: 89a0109931a477f111e9961d3578afa1c277cc11 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
use std::{io, slice};
#[cfg(feature = "spirv_cross")]
use {
    hal::{device::ShaderError, pso},
    spirv_cross::spirv,
};

/// Fast hash map used internally.
pub type FastHashMap<K, V> =
    std::collections::HashMap<K, V, std::hash::BuildHasherDefault<fxhash::FxHasher>>;
pub type FastHashSet<K> =
    std::collections::HashSet<K, std::hash::BuildHasherDefault<fxhash::FxHasher>>;

#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[repr(u8)]
pub enum ShaderStage {
    Vertex,
    Hull,
    Domain,
    Geometry,
    Fragment,
    Compute,
    Task,
    Mesh,
}

impl ShaderStage {
    pub fn to_flag(self) -> hal::pso::ShaderStageFlags {
        use hal::pso::ShaderStageFlags as Ssf;
        match self {
            ShaderStage::Vertex => Ssf::VERTEX,
            ShaderStage::Hull => Ssf::HULL,
            ShaderStage::Domain => Ssf::DOMAIN,
            ShaderStage::Geometry => Ssf::GEOMETRY,
            ShaderStage::Fragment => Ssf::FRAGMENT,
            ShaderStage::Compute => Ssf::COMPUTE,
            ShaderStage::Task => Ssf::TASK,
            ShaderStage::Mesh => Ssf::MESH,
        }
    }
}

/// Safely read SPIR-V
///
/// Converts to native endianness and returns correctly aligned storage without unnecessary
/// copying. Returns an `InvalidData` error if the input is trivially not SPIR-V.
///
/// This function can also be used to convert an already in-memory `&[u8]` to a valid `Vec<u32>`,
/// but prefer working with `&[u32]` from the start whenever possible.
///
/// # Examples
/// ```no_run
/// let mut file = std::fs::File::open("/path/to/shader.spv").unwrap();
/// let words = gfx_auxil::read_spirv(&mut file).unwrap();
/// ```
/// ```
/// const SPIRV: &[u8] = &[
///     0x03, 0x02, 0x23, 0x07, // ...
/// ];
/// let words = gfx_auxil::read_spirv(std::io::Cursor::new(&SPIRV[..])).unwrap();
/// ```
pub fn read_spirv<R: io::Read + io::Seek>(mut x: R) -> io::Result<Vec<u32>> {
    let size = x.seek(io::SeekFrom::End(0))?;
    if size % 4 != 0 {
        return Err(io::Error::new(
            io::ErrorKind::InvalidData,
            "input length not divisible by 4",
        ));
    }
    if size > usize::max_value() as u64 {
        return Err(io::Error::new(io::ErrorKind::InvalidData, "input too long"));
    }
    let words = (size / 4) as usize;
    let mut result = Vec::<u32>::with_capacity(words);
    x.seek(io::SeekFrom::Start(0))?;
    unsafe {
        // Writing all bytes through a pointer with less strict alignment when our type has no
        // invalid bitpatterns is safe.
        x.read_exact(slice::from_raw_parts_mut(
            result.as_mut_ptr() as *mut u8,
            words * 4,
        ))?;
        result.set_len(words);
    }
    const MAGIC_NUMBER: u32 = 0x07230203;
    if result.len() > 0 && result[0] == MAGIC_NUMBER.swap_bytes() {
        for word in &mut result {
            *word = word.swap_bytes();
        }
    }
    if result.len() == 0 || result[0] != MAGIC_NUMBER {
        return Err(io::Error::new(
            io::ErrorKind::InvalidData,
            "input missing SPIR-V magic number",
        ));
    }
    Ok(result)
}

#[cfg(feature = "spirv_cross")]
pub fn spirv_cross_specialize_ast<T>(
    ast: &mut spirv::Ast<T>,
    specialization: &pso::Specialization,
) -> Result<(), ShaderError>
where
    T: spirv::Target,
    spirv::Ast<T>: spirv::Compile<T> + spirv::Parse<T>,
{
    let spec_constants = ast.get_specialization_constants().map_err(|err| {
        ShaderError::CompilationFailed(match err {
            spirv_cross::ErrorCode::CompilationError(msg) => msg,
            spirv_cross::ErrorCode::Unhandled => "Unexpected specialization constant error".into(),
        })
    })?;

    for spec_constant in spec_constants {
        if let Some(constant) = specialization
            .constants
            .iter()
            .find(|c| c.id == spec_constant.constant_id)
        {
            // Override specialization constant values
            let value = specialization.data
                [constant.range.start as usize..constant.range.end as usize]
                .iter()
                .rev()
                .fold(0u64, |u, &b| (u << 8) + b as u64);

            ast.set_scalar_constant(spec_constant.id, value)
                .map_err(|err| {
                    ShaderError::CompilationFailed(match err {
                        spirv_cross::ErrorCode::CompilationError(msg) => msg,
                        spirv_cross::ErrorCode::Unhandled => {
                            "Unexpected specialization constant error".into()
                        }
                    })
                })?;
        }
    }

    Ok(())
}