summaryrefslogtreecommitdiffstats
path: root/src/tools/rust-analyzer/crates/hir-ty/src/mir/eval/shim/simd.rs
blob: ec74631048797c71db1134f02ffe6c52948d158a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
//! Shim implementation for simd intrinsics

use std::cmp::Ordering;

use crate::TyKind;

use super::*;

macro_rules! from_bytes {
    ($ty:tt, $value:expr) => {
        ($ty::from_le_bytes(match ($value).try_into() {
            Ok(it) => it,
            Err(_) => return Err(MirEvalError::TypeError("mismatched size")),
        }))
    };
}

macro_rules! not_supported {
    ($it: expr) => {
        return Err(MirEvalError::NotSupported(format!($it)))
    };
}

impl Evaluator<'_> {
    fn detect_simd_ty(&self, ty: &Ty) -> Result<(usize, Ty)> {
        match ty.kind(Interner) {
            TyKind::Adt(id, subst) => {
                let len = match subst.as_slice(Interner).get(1).and_then(|it| it.constant(Interner))
                {
                    Some(len) => len,
                    _ => {
                        if let AdtId::StructId(id) = id.0 {
                            let struct_data = self.db.struct_data(id);
                            let fields = struct_data.variant_data.fields();
                            let Some((first_field, _)) = fields.iter().next() else {
                                not_supported!("simd type with no field");
                            };
                            let field_ty = self.db.field_types(id.into())[first_field]
                                .clone()
                                .substitute(Interner, subst);
                            return Ok((fields.len(), field_ty));
                        }
                        return Err(MirEvalError::TypeError("simd type with no len param"));
                    }
                };
                match try_const_usize(self.db, len) {
                    Some(len) => {
                        let Some(ty) = subst.as_slice(Interner).get(0).and_then(|it| it.ty(Interner)) else {
                            return Err(MirEvalError::TypeError("simd type with no ty param"));
                        };
                        Ok((len as usize, ty.clone()))
                    }
                    None => Err(MirEvalError::TypeError("simd type with unevaluatable len param")),
                }
            }
            _ => Err(MirEvalError::TypeError("simd type which is not a struct")),
        }
    }

    pub(super) fn exec_simd_intrinsic(
        &mut self,
        name: &str,
        args: &[IntervalAndTy],
        _generic_args: &Substitution,
        destination: Interval,
        _locals: &Locals,
        _span: MirSpan,
    ) -> Result<()> {
        match name {
            "and" | "or" | "xor" => {
                let [left, right] = args else {
                    return Err(MirEvalError::TypeError("simd bit op args are not provided"));
                };
                let result = left
                    .get(self)?
                    .iter()
                    .zip(right.get(self)?)
                    .map(|(&it, &y)| match name {
                        "and" => it & y,
                        "or" => it | y,
                        "xor" => it ^ y,
                        _ => unreachable!(),
                    })
                    .collect::<Vec<_>>();
                destination.write_from_bytes(self, &result)
            }
            "eq" | "ne" | "lt" | "le" | "gt" | "ge" => {
                let [left, right] = args else {
                    return Err(MirEvalError::TypeError("simd args are not provided"));
                };
                let (len, ty) = self.detect_simd_ty(&left.ty)?;
                let is_signed = matches!(ty.as_builtin(), Some(BuiltinType::Int(_)));
                let size = left.interval.size / len;
                let dest_size = destination.size / len;
                let mut destination_bytes = vec![];
                let vector = left.get(self)?.chunks(size).zip(right.get(self)?.chunks(size));
                for (l, r) in vector {
                    let mut result = Ordering::Equal;
                    for (l, r) in l.iter().zip(r).rev() {
                        let it = l.cmp(r);
                        if it != Ordering::Equal {
                            result = it;
                            break;
                        }
                    }
                    if is_signed {
                        if let Some((&l, &r)) = l.iter().zip(r).rev().next() {
                            if l != r {
                                result = (l as i8).cmp(&(r as i8));
                            }
                        }
                    }
                    let result = match result {
                        Ordering::Less => ["lt", "le", "ne"].contains(&name),
                        Ordering::Equal => ["ge", "le", "eq"].contains(&name),
                        Ordering::Greater => ["ge", "gt", "ne"].contains(&name),
                    };
                    let result = if result { 255 } else { 0 };
                    destination_bytes.extend(std::iter::repeat(result).take(dest_size));
                }

                destination.write_from_bytes(self, &destination_bytes)
            }
            "bitmask" => {
                let [op] = args else {
                    return Err(MirEvalError::TypeError("simd_bitmask args are not provided"));
                };
                let (op_len, _) = self.detect_simd_ty(&op.ty)?;
                let op_count = op.interval.size / op_len;
                let mut result: u64 = 0;
                for (i, val) in op.get(self)?.chunks(op_count).enumerate() {
                    if !val.iter().all(|&it| it == 0) {
                        result |= 1 << i;
                    }
                }
                destination.write_from_bytes(self, &result.to_le_bytes()[0..destination.size])
            }
            "shuffle" => {
                let [left, right, index] = args else {
                    return Err(MirEvalError::TypeError("simd_shuffle args are not provided"));
                };
                let TyKind::Array(_, index_len) = index.ty.kind(Interner) else {
                    return Err(MirEvalError::TypeError(
                        "simd_shuffle index argument has non-array type",
                    ));
                };
                let index_len = match try_const_usize(self.db, index_len) {
                    Some(it) => it as usize,
                    None => {
                        return Err(MirEvalError::TypeError(
                            "simd type with unevaluatable len param",
                        ))
                    }
                };
                let (left_len, _) = self.detect_simd_ty(&left.ty)?;
                let left_size = left.interval.size / left_len;
                let vector =
                    left.get(self)?.chunks(left_size).chain(right.get(self)?.chunks(left_size));
                let mut result = vec![];
                for index in index.get(self)?.chunks(index.interval.size / index_len) {
                    let index = from_bytes!(u32, index) as usize;
                    let val = match vector.clone().nth(index) {
                        Some(it) => it,
                        None => {
                            return Err(MirEvalError::TypeError(
                                "out of bound access in simd shuffle",
                            ))
                        }
                    };
                    result.extend(val);
                }
                destination.write_from_bytes(self, &result)
            }
            _ => not_supported!("unknown simd intrinsic {name}"),
        }
    }
}