diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 14:29:10 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 14:29:10 +0000 |
commit | 2aa4a82499d4becd2284cdb482213d541b8804dd (patch) | |
tree | b80bf8bf13c3766139fbacc530efd0dd9d54394c /third_party/rust/cranelift-codegen-meta/src/cdsl | |
parent | Initial commit. (diff) | |
download | firefox-upstream.tar.xz firefox-upstream.zip |
Adding upstream version 86.0.1.upstream/86.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/cranelift-codegen-meta/src/cdsl')
15 files changed, 7069 insertions, 0 deletions
diff --git a/third_party/rust/cranelift-codegen-meta/src/cdsl/ast.rs b/third_party/rust/cranelift-codegen-meta/src/cdsl/ast.rs new file mode 100644 index 0000000000..82cdbad762 --- /dev/null +++ b/third_party/rust/cranelift-codegen-meta/src/cdsl/ast.rs @@ -0,0 +1,753 @@ +use crate::cdsl::instructions::{InstSpec, Instruction, InstructionPredicate}; +use crate::cdsl::operands::{OperandKind, OperandKindFields}; +use crate::cdsl::types::ValueType; +use crate::cdsl::typevar::{TypeSetBuilder, TypeVar}; + +use cranelift_entity::{entity_impl, PrimaryMap, SparseMap, SparseMapValue}; + +use std::fmt; +use std::iter::IntoIterator; + +pub(crate) enum Expr { + Var(VarIndex), + Literal(Literal), +} + +impl Expr { + pub fn maybe_literal(&self) -> Option<&Literal> { + match &self { + Expr::Literal(lit) => Some(lit), + _ => None, + } + } + + pub fn maybe_var(&self) -> Option<VarIndex> { + if let Expr::Var(var) = &self { + Some(*var) + } else { + None + } + } + + pub fn unwrap_var(&self) -> VarIndex { + self.maybe_var() + .expect("tried to unwrap a non-Var content in Expr::unwrap_var") + } + + pub fn to_rust_code(&self, var_pool: &VarPool) -> String { + match self { + Expr::Var(var_index) => var_pool.get(*var_index).to_rust_code(), + Expr::Literal(literal) => literal.to_rust_code(), + } + } +} + +/// An AST definition associates a set of variables with the values produced by an expression. +pub(crate) struct Def { + pub apply: Apply, + pub defined_vars: Vec<VarIndex>, +} + +impl Def { + pub fn to_comment_string(&self, var_pool: &VarPool) -> String { + let results = self + .defined_vars + .iter() + .map(|&x| var_pool.get(x).name.as_str()) + .collect::<Vec<_>>(); + + let results = if results.len() == 1 { + results[0].to_string() + } else { + format!("({})", results.join(", ")) + }; + + format!("{} := {}", results, self.apply.to_comment_string(var_pool)) + } +} + +pub(crate) struct DefPool { + pool: PrimaryMap<DefIndex, Def>, +} + +impl DefPool { + pub fn new() -> Self { + Self { + pool: PrimaryMap::new(), + } + } + pub fn get(&self, index: DefIndex) -> &Def { + self.pool.get(index).unwrap() + } + pub fn next_index(&self) -> DefIndex { + self.pool.next_key() + } + pub fn create_inst(&mut self, apply: Apply, defined_vars: Vec<VarIndex>) -> DefIndex { + self.pool.push(Def { + apply, + defined_vars, + }) + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub(crate) struct DefIndex(u32); +entity_impl!(DefIndex); + +/// A definition which would lead to generate a block creation. +#[derive(Clone)] +pub(crate) struct Block { + /// Instruction index after which the block entry is set. + pub location: DefIndex, + /// Variable holding the new created block. + pub name: VarIndex, +} + +pub(crate) struct BlockPool { + pool: SparseMap<DefIndex, Block>, +} + +impl SparseMapValue<DefIndex> for Block { + fn key(&self) -> DefIndex { + self.location + } +} + +impl BlockPool { + pub fn new() -> Self { + Self { + pool: SparseMap::new(), + } + } + pub fn get(&self, index: DefIndex) -> Option<&Block> { + self.pool.get(index) + } + pub fn create_block(&mut self, name: VarIndex, location: DefIndex) { + if self.pool.contains_key(location) { + panic!("Attempt to insert 2 blocks after the same instruction") + } + self.pool.insert(Block { location, name }); + } + pub fn is_empty(&self) -> bool { + self.pool.is_empty() + } +} + +// Implement IntoIterator such that we can iterate over blocks which are in the block pool. +impl<'a> IntoIterator for &'a BlockPool { + type Item = <&'a SparseMap<DefIndex, Block> as IntoIterator>::Item; + type IntoIter = <&'a SparseMap<DefIndex, Block> as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.pool.into_iter() + } +} + +#[derive(Clone, Debug)] +pub(crate) enum Literal { + /// A value of an enumerated immediate operand. + /// + /// Some immediate operand kinds like `intcc` and `floatcc` have an enumerated range of values + /// corresponding to a Rust enum type. An `Enumerator` object is an AST leaf node representing one + /// of the values. + Enumerator { + rust_type: &'static str, + value: &'static str, + }, + + /// A bitwise value of an immediate operand, used for bitwise exact floating point constants. + Bits { rust_type: &'static str, value: u64 }, + + /// A value of an integer immediate operand. + Int(i64), + + /// A empty list of variable set of arguments. + EmptyVarArgs, +} + +impl Literal { + pub fn enumerator_for(kind: &OperandKind, value: &'static str) -> Self { + let value = match &kind.fields { + OperandKindFields::ImmEnum(values) => values.get(value).unwrap_or_else(|| { + panic!( + "nonexistent value '{}' in enumeration '{}'", + value, kind.rust_type + ) + }), + _ => panic!("enumerator is for enum values"), + }; + Literal::Enumerator { + rust_type: kind.rust_type, + value, + } + } + + pub fn bits(kind: &OperandKind, bits: u64) -> Self { + match kind.fields { + OperandKindFields::ImmValue => {} + _ => panic!("bits_of is for immediate scalar types"), + } + Literal::Bits { + rust_type: kind.rust_type, + value: bits, + } + } + + pub fn constant(kind: &OperandKind, value: i64) -> Self { + match kind.fields { + OperandKindFields::ImmValue => {} + _ => panic!("constant is for immediate scalar types"), + } + Literal::Int(value) + } + + pub fn empty_vararg() -> Self { + Literal::EmptyVarArgs + } + + pub fn to_rust_code(&self) -> String { + match self { + Literal::Enumerator { rust_type, value } => format!("{}::{}", rust_type, value), + Literal::Bits { rust_type, value } => format!("{}::with_bits({:#x})", rust_type, value), + Literal::Int(val) => val.to_string(), + Literal::EmptyVarArgs => "&[]".into(), + } + } +} + +#[derive(Clone, Copy, Debug)] +pub(crate) enum PatternPosition { + Source, + Destination, +} + +/// A free variable. +/// +/// When variables are used in `XForms` with source and destination patterns, they are classified +/// as follows: +/// +/// Input values: Uses in the source pattern with no preceding def. These may appear as inputs in +/// the destination pattern too, but no new inputs can be introduced. +/// +/// Output values: Variables that are defined in both the source and destination pattern. These +/// values may have uses outside the source pattern, and the destination pattern must compute the +/// same value. +/// +/// Intermediate values: Values that are defined in the source pattern, but not in the destination +/// pattern. These may have uses outside the source pattern, so the defining instruction can't be +/// deleted immediately. +/// +/// Temporary values are defined only in the destination pattern. +pub(crate) struct Var { + pub name: String, + + /// The `Def` defining this variable in a source pattern. + pub src_def: Option<DefIndex>, + + /// The `Def` defining this variable in a destination pattern. + pub dst_def: Option<DefIndex>, + + /// TypeVar representing the type of this variable. + type_var: Option<TypeVar>, + + /// Is this the original type variable, or has it be redefined with set_typevar? + is_original_type_var: bool, +} + +impl Var { + fn new(name: String) -> Self { + Self { + name, + src_def: None, + dst_def: None, + type_var: None, + is_original_type_var: false, + } + } + + /// Is this an input value to the src pattern? + pub fn is_input(&self) -> bool { + self.src_def.is_none() && self.dst_def.is_none() + } + + /// Is this an output value, defined in both src and dst patterns? + pub fn is_output(&self) -> bool { + self.src_def.is_some() && self.dst_def.is_some() + } + + /// Is this an intermediate value, defined only in the src pattern? + pub fn is_intermediate(&self) -> bool { + self.src_def.is_some() && self.dst_def.is_none() + } + + /// Is this a temp value, defined only in the dst pattern? + pub fn is_temp(&self) -> bool { + self.src_def.is_none() && self.dst_def.is_some() + } + + /// Get the def of this variable according to the position. + pub fn get_def(&self, position: PatternPosition) -> Option<DefIndex> { + match position { + PatternPosition::Source => self.src_def, + PatternPosition::Destination => self.dst_def, + } + } + + pub fn set_def(&mut self, position: PatternPosition, def: DefIndex) { + assert!( + self.get_def(position).is_none(), + format!("redefinition of variable {}", self.name) + ); + match position { + PatternPosition::Source => { + self.src_def = Some(def); + } + PatternPosition::Destination => { + self.dst_def = Some(def); + } + } + } + + /// Get the type variable representing the type of this variable. + pub fn get_or_create_typevar(&mut self) -> TypeVar { + match &self.type_var { + Some(tv) => tv.clone(), + None => { + // Create a new type var in which we allow all types. + let tv = TypeVar::new( + format!("typeof_{}", self.name), + format!("Type of the pattern variable {:?}", self), + TypeSetBuilder::all(), + ); + self.type_var = Some(tv.clone()); + self.is_original_type_var = true; + tv + } + } + } + pub fn get_typevar(&self) -> Option<TypeVar> { + self.type_var.clone() + } + pub fn set_typevar(&mut self, tv: TypeVar) { + self.is_original_type_var = if let Some(previous_tv) = &self.type_var { + *previous_tv == tv + } else { + false + }; + self.type_var = Some(tv); + } + + /// Check if this variable has a free type variable. If not, the type of this variable is + /// computed from the type of another variable. + pub fn has_free_typevar(&self) -> bool { + match &self.type_var { + Some(tv) => tv.base.is_none() && self.is_original_type_var, + None => false, + } + } + + pub fn to_rust_code(&self) -> String { + self.name.clone() + } + fn rust_type(&self) -> String { + self.type_var.as_ref().unwrap().to_rust_code() + } +} + +impl fmt::Debug for Var { + fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { + fmt.write_fmt(format_args!( + "Var({}{}{})", + self.name, + if self.src_def.is_some() { ", src" } else { "" }, + if self.dst_def.is_some() { ", dst" } else { "" } + )) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub(crate) struct VarIndex(u32); +entity_impl!(VarIndex); + +pub(crate) struct VarPool { + pool: PrimaryMap<VarIndex, Var>, +} + +impl VarPool { + pub fn new() -> Self { + Self { + pool: PrimaryMap::new(), + } + } + pub fn get(&self, index: VarIndex) -> &Var { + self.pool.get(index).unwrap() + } + pub fn get_mut(&mut self, index: VarIndex) -> &mut Var { + self.pool.get_mut(index).unwrap() + } + pub fn create(&mut self, name: impl Into<String>) -> VarIndex { + self.pool.push(Var::new(name.into())) + } +} + +/// Contains constants created in the AST that must be inserted into the true [ConstantPool] when +/// the legalizer code is generated. The constant data is named in the order it is inserted; +/// inserting data using [insert] will avoid duplicates. +/// +/// [ConstantPool]: ../../../cranelift_codegen/ir/constant/struct.ConstantPool.html +/// [insert]: ConstPool::insert +pub(crate) struct ConstPool { + pool: Vec<Vec<u8>>, +} + +impl ConstPool { + /// Create an empty constant pool. + pub fn new() -> Self { + Self { pool: vec![] } + } + + /// Create a name for a constant from its position in the pool. + fn create_name(position: usize) -> String { + format!("const{}", position) + } + + /// Insert constant data into the pool, returning the name of the variable used to reference it. + /// This method will search for data that matches the new data and return the existing constant + /// name to avoid duplicates. + pub fn insert(&mut self, data: Vec<u8>) -> String { + let possible_position = self.pool.iter().position(|d| d == &data); + let position = if let Some(found_position) = possible_position { + found_position + } else { + let new_position = self.pool.len(); + self.pool.push(data); + new_position + }; + ConstPool::create_name(position) + } + + /// Iterate over the name/value pairs in the pool. + pub fn iter(&self) -> impl Iterator<Item = (String, &Vec<u8>)> { + self.pool + .iter() + .enumerate() + .map(|(i, v)| (ConstPool::create_name(i), v)) + } +} + +/// Apply an instruction to arguments. +/// +/// An `Apply` AST expression is created by using function call syntax on instructions. This +/// applies to both bound and unbound polymorphic instructions. +pub(crate) struct Apply { + pub inst: Instruction, + pub args: Vec<Expr>, + pub value_types: Vec<ValueType>, +} + +impl Apply { + pub fn new(target: InstSpec, args: Vec<Expr>) -> Self { + let (inst, value_types) = match target { + InstSpec::Inst(inst) => (inst, Vec::new()), + InstSpec::Bound(bound_inst) => (bound_inst.inst, bound_inst.value_types), + }; + + // Apply should only operate on concrete value types, not "any". + let value_types = value_types + .into_iter() + .map(|vt| vt.expect("shouldn't be Any")) + .collect(); + + // Basic check on number of arguments. + assert!( + inst.operands_in.len() == args.len(), + format!("incorrect number of arguments in instruction {}", inst.name) + ); + + // Check that the kinds of Literals arguments match the expected operand. + for &imm_index in &inst.imm_opnums { + let arg = &args[imm_index]; + if let Some(literal) = arg.maybe_literal() { + let op = &inst.operands_in[imm_index]; + match &op.kind.fields { + OperandKindFields::ImmEnum(values) => { + if let Literal::Enumerator { value, .. } = literal { + assert!( + values.iter().any(|(_key, v)| v == value), + "Nonexistent enum value '{}' passed to field of kind '{}' -- \ + did you use the right enum?", + value, + op.kind.rust_type + ); + } else { + panic!( + "Passed non-enum field value {:?} to field of kind {}", + literal, op.kind.rust_type + ); + } + } + OperandKindFields::ImmValue => match &literal { + Literal::Enumerator { value, .. } => panic!( + "Expected immediate value in immediate field of kind '{}', \ + obtained enum value '{}'", + op.kind.rust_type, value + ), + Literal::Bits { .. } | Literal::Int(_) | Literal::EmptyVarArgs => {} + }, + _ => { + panic!( + "Literal passed to non-literal field of kind {}", + op.kind.rust_type + ); + } + } + } + } + + Self { + inst, + args, + value_types, + } + } + + fn to_comment_string(&self, var_pool: &VarPool) -> String { + let args = self + .args + .iter() + .map(|arg| arg.to_rust_code(var_pool)) + .collect::<Vec<_>>() + .join(", "); + + let mut inst_and_bound_types = vec![self.inst.name.to_string()]; + inst_and_bound_types.extend(self.value_types.iter().map(|vt| vt.to_string())); + let inst_name = inst_and_bound_types.join("."); + + format!("{}({})", inst_name, args) + } + + pub fn inst_predicate(&self, var_pool: &VarPool) -> InstructionPredicate { + let mut pred = InstructionPredicate::new(); + for (format_field, &op_num) in self + .inst + .format + .imm_fields + .iter() + .zip(self.inst.imm_opnums.iter()) + { + let arg = &self.args[op_num]; + if arg.maybe_var().is_some() { + // Ignore free variables for now. + continue; + } + pred = pred.and(InstructionPredicate::new_is_field_equal_ast( + &*self.inst.format, + format_field, + arg.to_rust_code(var_pool), + )); + } + + // Add checks for any bound secondary type variables. We can't check the controlling type + // variable this way since it may not appear as the type of an operand. + if self.value_types.len() > 1 { + let poly = self + .inst + .polymorphic_info + .as_ref() + .expect("must have polymorphic info if it has bounded types"); + for (bound_type, type_var) in + self.value_types[1..].iter().zip(poly.other_typevars.iter()) + { + pred = pred.and(InstructionPredicate::new_typevar_check( + &self.inst, type_var, bound_type, + )); + } + } + + pred + } + + /// Same as `inst_predicate()`, but also check the controlling type variable. + pub fn inst_predicate_with_ctrl_typevar(&self, var_pool: &VarPool) -> InstructionPredicate { + let mut pred = self.inst_predicate(var_pool); + + if !self.value_types.is_empty() { + let bound_type = &self.value_types[0]; + let poly = self.inst.polymorphic_info.as_ref().unwrap(); + let type_check = if poly.use_typevar_operand { + InstructionPredicate::new_typevar_check(&self.inst, &poly.ctrl_typevar, bound_type) + } else { + InstructionPredicate::new_ctrl_typevar_check(&bound_type) + }; + pred = pred.and(type_check); + } + + pred + } + + pub fn rust_builder(&self, defined_vars: &[VarIndex], var_pool: &VarPool) -> String { + let mut args = self + .args + .iter() + .map(|expr| expr.to_rust_code(var_pool)) + .collect::<Vec<_>>() + .join(", "); + + // Do we need to pass an explicit type argument? + if let Some(poly) = &self.inst.polymorphic_info { + if !poly.use_typevar_operand { + args = format!("{}, {}", var_pool.get(defined_vars[0]).rust_type(), args); + } + } + + format!("{}({})", self.inst.snake_name(), args) + } +} + +// Simple helpers for legalize actions construction. + +pub(crate) enum DummyExpr { + Var(DummyVar), + Literal(Literal), + Constant(DummyConstant), + Apply(InstSpec, Vec<DummyExpr>), + Block(DummyVar), +} + +#[derive(Clone)] +pub(crate) struct DummyVar { + pub name: String, +} + +impl Into<DummyExpr> for DummyVar { + fn into(self) -> DummyExpr { + DummyExpr::Var(self) + } +} +impl Into<DummyExpr> for Literal { + fn into(self) -> DummyExpr { + DummyExpr::Literal(self) + } +} + +#[derive(Clone)] +pub(crate) struct DummyConstant(pub(crate) Vec<u8>); + +pub(crate) fn constant(data: Vec<u8>) -> DummyConstant { + DummyConstant(data) +} + +impl Into<DummyExpr> for DummyConstant { + fn into(self) -> DummyExpr { + DummyExpr::Constant(self) + } +} + +pub(crate) fn var(name: &str) -> DummyVar { + DummyVar { + name: name.to_owned(), + } +} + +pub(crate) struct DummyDef { + pub expr: DummyExpr, + pub defined_vars: Vec<DummyVar>, +} + +pub(crate) struct ExprBuilder { + expr: DummyExpr, +} + +impl ExprBuilder { + pub fn apply(inst: InstSpec, args: Vec<DummyExpr>) -> Self { + let expr = DummyExpr::Apply(inst, args); + Self { expr } + } + + pub fn assign_to(self, defined_vars: Vec<DummyVar>) -> DummyDef { + DummyDef { + expr: self.expr, + defined_vars, + } + } + + pub fn block(name: DummyVar) -> Self { + let expr = DummyExpr::Block(name); + Self { expr } + } +} + +macro_rules! def_rhs { + // inst(a, b, c) + ($inst:ident($($src:expr),*)) => { + ExprBuilder::apply($inst.into(), vec![$($src.clone().into()),*]) + }; + + // inst.type(a, b, c) + ($inst:ident.$type:ident($($src:expr),*)) => { + ExprBuilder::apply($inst.bind($type).into(), vec![$($src.clone().into()),*]) + }; +} + +// Helper macro to define legalization recipes. +macro_rules! def { + // x = ... + ($dest:ident = $($tt:tt)*) => { + def_rhs!($($tt)*).assign_to(vec![$dest.clone()]) + }; + + // (x, y, ...) = ... + (($($dest:ident),*) = $($tt:tt)*) => { + def_rhs!($($tt)*).assign_to(vec![$($dest.clone()),*]) + }; + + // An instruction with no results. + ($($tt:tt)*) => { + def_rhs!($($tt)*).assign_to(Vec::new()) + } +} + +// Helper macro to define legalization recipes. +macro_rules! block { + // a basic block definition, splitting the current block in 2. + ($block: ident) => { + ExprBuilder::block($block).assign_to(Vec::new()) + }; +} + +#[cfg(test)] +mod tests { + use crate::cdsl::ast::ConstPool; + + #[test] + fn const_pool_returns_var_names() { + let mut c = ConstPool::new(); + assert_eq!(c.insert([0, 1, 2].to_vec()), "const0"); + assert_eq!(c.insert([1, 2, 3].to_vec()), "const1"); + } + + #[test] + fn const_pool_avoids_duplicates() { + let data = [0, 1, 2].to_vec(); + let mut c = ConstPool::new(); + assert_eq!(c.pool.len(), 0); + + assert_eq!(c.insert(data.clone()), "const0"); + assert_eq!(c.pool.len(), 1); + + assert_eq!(c.insert(data), "const0"); + assert_eq!(c.pool.len(), 1); + } + + #[test] + fn const_pool_iterates() { + let mut c = ConstPool::new(); + c.insert([0, 1, 2].to_vec()); + c.insert([3, 4, 5].to_vec()); + + let mut iter = c.iter(); + assert_eq!(iter.next(), Some(("const0".to_owned(), &vec![0, 1, 2]))); + assert_eq!(iter.next(), Some(("const1".to_owned(), &vec![3, 4, 5]))); + assert_eq!(iter.next(), None); + } +} diff --git a/third_party/rust/cranelift-codegen-meta/src/cdsl/cpu_modes.rs b/third_party/rust/cranelift-codegen-meta/src/cdsl/cpu_modes.rs new file mode 100644 index 0000000000..7d119b00ce --- /dev/null +++ b/third_party/rust/cranelift-codegen-meta/src/cdsl/cpu_modes.rs @@ -0,0 +1,88 @@ +use std::collections::{hash_map, HashMap, HashSet}; +use std::iter::FromIterator; + +use crate::cdsl::encodings::Encoding; +use crate::cdsl::types::{LaneType, ValueType}; +use crate::cdsl::xform::{TransformGroup, TransformGroupIndex}; + +pub(crate) struct CpuMode { + pub name: &'static str, + default_legalize: Option<TransformGroupIndex>, + monomorphic_legalize: Option<TransformGroupIndex>, + typed_legalize: HashMap<ValueType, TransformGroupIndex>, + pub encodings: Vec<Encoding>, +} + +impl CpuMode { + pub fn new(name: &'static str) -> Self { + Self { + name, + default_legalize: None, + monomorphic_legalize: None, + typed_legalize: HashMap::new(), + encodings: Vec::new(), + } + } + + pub fn set_encodings(&mut self, encodings: Vec<Encoding>) { + assert!(self.encodings.is_empty(), "clobbering encodings"); + self.encodings = encodings; + } + + pub fn legalize_monomorphic(&mut self, group: &TransformGroup) { + assert!(self.monomorphic_legalize.is_none()); + self.monomorphic_legalize = Some(group.id); + } + pub fn legalize_default(&mut self, group: &TransformGroup) { + assert!(self.default_legalize.is_none()); + self.default_legalize = Some(group.id); + } + pub fn legalize_value_type(&mut self, lane_type: impl Into<ValueType>, group: &TransformGroup) { + assert!(self + .typed_legalize + .insert(lane_type.into(), group.id) + .is_none()); + } + pub fn legalize_type(&mut self, lane_type: impl Into<LaneType>, group: &TransformGroup) { + assert!(self + .typed_legalize + .insert(lane_type.into().into(), group.id) + .is_none()); + } + + pub fn get_default_legalize_code(&self) -> TransformGroupIndex { + self.default_legalize + .expect("a finished CpuMode must have a default legalize code") + } + pub fn get_legalize_code_for(&self, typ: &Option<ValueType>) -> TransformGroupIndex { + match typ { + Some(typ) => self + .typed_legalize + .get(typ) + .copied() + .unwrap_or_else(|| self.get_default_legalize_code()), + None => self + .monomorphic_legalize + .unwrap_or_else(|| self.get_default_legalize_code()), + } + } + pub fn get_legalized_types(&self) -> hash_map::Keys<ValueType, TransformGroupIndex> { + self.typed_legalize.keys() + } + + /// Returns a deterministically ordered, deduplicated list of TransformGroupIndex for the directly + /// reachable set of TransformGroup this TargetIsa uses. + pub fn direct_transform_groups(&self) -> Vec<TransformGroupIndex> { + let mut set = HashSet::new(); + if let Some(i) = &self.default_legalize { + set.insert(*i); + } + if let Some(i) = &self.monomorphic_legalize { + set.insert(*i); + } + set.extend(self.typed_legalize.values().cloned()); + let mut ret = Vec::from_iter(set); + ret.sort(); + ret + } +} diff --git a/third_party/rust/cranelift-codegen-meta/src/cdsl/encodings.rs b/third_party/rust/cranelift-codegen-meta/src/cdsl/encodings.rs new file mode 100644 index 0000000000..f66746f92f --- /dev/null +++ b/third_party/rust/cranelift-codegen-meta/src/cdsl/encodings.rs @@ -0,0 +1,179 @@ +use crate::cdsl::instructions::{ + InstSpec, Instruction, InstructionPredicate, InstructionPredicateNode, + InstructionPredicateNumber, InstructionPredicateRegistry, ValueTypeOrAny, +}; +use crate::cdsl::recipes::{EncodingRecipeNumber, Recipes}; +use crate::cdsl::settings::SettingPredicateNumber; +use crate::cdsl::types::ValueType; +use std::rc::Rc; +use std::string::ToString; + +/// Encoding for a concrete instruction. +/// +/// An `Encoding` object ties an instruction opcode with concrete type variables together with an +/// encoding recipe and encoding encbits. +/// +/// The concrete instruction can be in three different forms: +/// +/// 1. A naked opcode: `trap` for non-polymorphic instructions. +/// 2. With bound type variables: `iadd.i32` for polymorphic instructions. +/// 3. With operands providing constraints: `icmp.i32(intcc.eq, x, y)`. +/// +/// If the instruction is polymorphic, all type variables must be provided. +pub(crate) struct EncodingContent { + /// The `Instruction` or `BoundInstruction` being encoded. + inst: InstSpec, + + /// The `EncodingRecipe` to use. + pub recipe: EncodingRecipeNumber, + + /// Additional encoding bits to be interpreted by `recipe`. + pub encbits: u16, + + /// An instruction predicate that must be true to allow selecting this encoding. + pub inst_predicate: Option<InstructionPredicateNumber>, + + /// An ISA predicate that must be true to allow selecting this encoding. + pub isa_predicate: Option<SettingPredicateNumber>, + + /// The value type this encoding has been bound to, for encodings of polymorphic instructions. + pub bound_type: Option<ValueType>, +} + +impl EncodingContent { + pub fn inst(&self) -> &Instruction { + self.inst.inst() + } + pub fn to_rust_comment(&self, recipes: &Recipes) -> String { + format!("[{}#{:02x}]", recipes[self.recipe].name, self.encbits) + } +} + +pub(crate) type Encoding = Rc<EncodingContent>; + +pub(crate) struct EncodingBuilder { + inst: InstSpec, + recipe: EncodingRecipeNumber, + encbits: u16, + inst_predicate: Option<InstructionPredicate>, + isa_predicate: Option<SettingPredicateNumber>, + bound_type: Option<ValueType>, +} + +impl EncodingBuilder { + pub fn new(inst: InstSpec, recipe: EncodingRecipeNumber, encbits: u16) -> Self { + let (inst_predicate, bound_type) = match &inst { + InstSpec::Bound(inst) => { + let other_typevars = &inst.inst.polymorphic_info.as_ref().unwrap().other_typevars; + + assert_eq!( + inst.value_types.len(), + other_typevars.len() + 1, + "partially bound polymorphic instruction" + ); + + // Add secondary type variables to the instruction predicate. + let value_types = &inst.value_types; + let mut inst_predicate: Option<InstructionPredicate> = None; + for (typevar, value_type) in other_typevars.iter().zip(value_types.iter().skip(1)) { + let value_type = match value_type { + ValueTypeOrAny::Any => continue, + ValueTypeOrAny::ValueType(vt) => vt, + }; + let type_predicate = + InstructionPredicate::new_typevar_check(&inst.inst, typevar, value_type); + inst_predicate = Some(type_predicate.into()); + } + + // Add immediate value predicates + for (immediate_value, immediate_operand) in inst + .immediate_values + .iter() + .zip(inst.inst.operands_in.iter().filter(|o| o.is_immediate())) + { + let immediate_predicate = InstructionPredicate::new_is_field_equal( + &inst.inst.format, + immediate_operand.kind.rust_field_name, + immediate_value.to_string(), + ); + inst_predicate = if let Some(type_predicate) = inst_predicate { + Some(type_predicate.and(immediate_predicate)) + } else { + Some(immediate_predicate.into()) + } + } + + let ctrl_type = value_types[0] + .clone() + .expect("Controlling type shouldn't be Any"); + (inst_predicate, Some(ctrl_type)) + } + + InstSpec::Inst(inst) => { + assert!( + inst.polymorphic_info.is_none(), + "unbound polymorphic instruction" + ); + (None, None) + } + }; + + Self { + inst, + recipe, + encbits, + inst_predicate, + isa_predicate: None, + bound_type, + } + } + + pub fn inst_predicate(mut self, inst_predicate: InstructionPredicateNode) -> Self { + let inst_predicate = Some(match self.inst_predicate { + Some(node) => node.and(inst_predicate), + None => inst_predicate.into(), + }); + self.inst_predicate = inst_predicate; + self + } + + pub fn isa_predicate(mut self, isa_predicate: SettingPredicateNumber) -> Self { + assert!(self.isa_predicate.is_none()); + self.isa_predicate = Some(isa_predicate); + self + } + + pub fn build( + self, + recipes: &Recipes, + inst_pred_reg: &mut InstructionPredicateRegistry, + ) -> Encoding { + let inst_predicate = self.inst_predicate.map(|pred| inst_pred_reg.insert(pred)); + + let inst = self.inst.inst(); + assert!( + Rc::ptr_eq(&inst.format, &recipes[self.recipe].format), + format!( + "Inst {} and recipe {} must have the same format!", + inst.name, recipes[self.recipe].name + ) + ); + + assert_eq!( + inst.is_branch && !inst.is_indirect_branch, + recipes[self.recipe].branch_range.is_some(), + "Inst {}'s is_branch contradicts recipe {} branch_range!", + inst.name, + recipes[self.recipe].name + ); + + Rc::new(EncodingContent { + inst: self.inst, + recipe: self.recipe, + encbits: self.encbits, + inst_predicate, + isa_predicate: self.isa_predicate, + bound_type: self.bound_type, + }) + } +} diff --git a/third_party/rust/cranelift-codegen-meta/src/cdsl/formats.rs b/third_party/rust/cranelift-codegen-meta/src/cdsl/formats.rs new file mode 100644 index 0000000000..e713a8bccb --- /dev/null +++ b/third_party/rust/cranelift-codegen-meta/src/cdsl/formats.rs @@ -0,0 +1,171 @@ +use crate::cdsl::operands::OperandKind; +use std::fmt; +use std::rc::Rc; + +/// An immediate field in an instruction format. +/// +/// This corresponds to a single member of a variant of the `InstructionData` +/// data type. +#[derive(Debug)] +pub(crate) struct FormatField { + /// Immediate operand kind. + pub kind: OperandKind, + + /// Member name in InstructionData variant. + pub member: &'static str, +} + +/// Every instruction opcode has a corresponding instruction format which determines the number of +/// operands and their kinds. Instruction formats are identified structurally, i.e., the format of +/// an instruction is derived from the kinds of operands used in its declaration. +/// +/// The instruction format stores two separate lists of operands: Immediates and values. Immediate +/// operands (including entity references) are represented as explicit members in the +/// `InstructionData` variants. The value operands are stored differently, depending on how many +/// there are. Beyond a certain point, instruction formats switch to an external value list for +/// storing value arguments. Value lists can hold an arbitrary number of values. +/// +/// All instruction formats must be predefined in the meta shared/formats.rs module. +#[derive(Debug)] +pub(crate) struct InstructionFormat { + /// Instruction format name in CamelCase. This is used as a Rust variant name in both the + /// `InstructionData` and `InstructionFormat` enums. + pub name: &'static str, + + pub num_value_operands: usize, + + pub has_value_list: bool, + + pub imm_fields: Vec<FormatField>, + + /// Index of the value input operand that is used to infer the controlling type variable. By + /// default, this is `0`, the first `value` operand. The index is relative to the values only, + /// ignoring immediate operands. + pub typevar_operand: Option<usize>, +} + +/// A tuple serving as a key to deduplicate InstructionFormat. +#[derive(Hash, PartialEq, Eq)] +pub(crate) struct FormatStructure { + pub num_value_operands: usize, + pub has_value_list: bool, + /// Tuples of (Rust field name / Rust type) for each immediate field. + pub imm_field_names: Vec<(&'static str, &'static str)>, +} + +impl fmt::Display for InstructionFormat { + fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { + let imm_args = self + .imm_fields + .iter() + .map(|field| format!("{}: {}", field.member, field.kind.rust_type)) + .collect::<Vec<_>>() + .join(", "); + fmt.write_fmt(format_args!( + "{}(imms=({}), vals={})", + self.name, imm_args, self.num_value_operands + ))?; + Ok(()) + } +} + +impl InstructionFormat { + pub fn imm_by_name(&self, name: &'static str) -> &FormatField { + self.imm_fields + .iter() + .find(|&field| field.member == name) + .unwrap_or_else(|| { + panic!( + "unexpected immediate field named {} in instruction format {}", + name, self.name + ) + }) + } + + /// Returns a tuple that uniquely identifies the structure. + pub fn structure(&self) -> FormatStructure { + FormatStructure { + num_value_operands: self.num_value_operands, + has_value_list: self.has_value_list, + imm_field_names: self + .imm_fields + .iter() + .map(|field| (field.kind.rust_field_name, field.kind.rust_type)) + .collect::<Vec<_>>(), + } + } +} + +pub(crate) struct InstructionFormatBuilder { + name: &'static str, + num_value_operands: usize, + has_value_list: bool, + imm_fields: Vec<FormatField>, + typevar_operand: Option<usize>, +} + +impl InstructionFormatBuilder { + pub fn new(name: &'static str) -> Self { + Self { + name, + num_value_operands: 0, + has_value_list: false, + imm_fields: Vec::new(), + typevar_operand: None, + } + } + + pub fn value(mut self) -> Self { + self.num_value_operands += 1; + self + } + + pub fn varargs(mut self) -> Self { + self.has_value_list = true; + self + } + + pub fn imm(mut self, operand_kind: &OperandKind) -> Self { + let field = FormatField { + kind: operand_kind.clone(), + member: operand_kind.rust_field_name, + }; + self.imm_fields.push(field); + self + } + + pub fn imm_with_name(mut self, member: &'static str, operand_kind: &OperandKind) -> Self { + let field = FormatField { + kind: operand_kind.clone(), + member, + }; + self.imm_fields.push(field); + self + } + + pub fn typevar_operand(mut self, operand_index: usize) -> Self { + assert!(self.typevar_operand.is_none()); + assert!(self.has_value_list || operand_index < self.num_value_operands); + self.typevar_operand = Some(operand_index); + self + } + + pub fn build(self) -> Rc<InstructionFormat> { + let typevar_operand = if self.typevar_operand.is_some() { + self.typevar_operand + } else if self.has_value_list || self.num_value_operands > 0 { + // Default to the first value operand, if there's one. + Some(0) + } else { + None + }; + + Rc::new(InstructionFormat { + name: self.name, + num_value_operands: self.num_value_operands, + has_value_list: self.has_value_list, + imm_fields: self.imm_fields, + typevar_operand, + }) + } +} diff --git a/third_party/rust/cranelift-codegen-meta/src/cdsl/instructions.rs b/third_party/rust/cranelift-codegen-meta/src/cdsl/instructions.rs new file mode 100644 index 0000000000..88a15c6038 --- /dev/null +++ b/third_party/rust/cranelift-codegen-meta/src/cdsl/instructions.rs @@ -0,0 +1,1395 @@ +use cranelift_codegen_shared::condcodes::IntCC; +use cranelift_entity::{entity_impl, PrimaryMap}; + +use std::collections::HashMap; +use std::fmt; +use std::fmt::{Display, Error, Formatter}; +use std::rc::Rc; + +use crate::cdsl::camel_case; +use crate::cdsl::formats::{FormatField, InstructionFormat}; +use crate::cdsl::operands::Operand; +use crate::cdsl::type_inference::Constraint; +use crate::cdsl::types::{LaneType, ReferenceType, ValueType, VectorType}; +use crate::cdsl::typevar::TypeVar; + +use crate::shared::formats::Formats; +use crate::shared::types::{Bool, Float, Int, Reference}; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub(crate) struct OpcodeNumber(u32); +entity_impl!(OpcodeNumber); + +pub(crate) type AllInstructions = PrimaryMap<OpcodeNumber, Instruction>; + +pub(crate) struct InstructionGroupBuilder<'all_inst> { + all_instructions: &'all_inst mut AllInstructions, + own_instructions: Vec<Instruction>, +} + +impl<'all_inst> InstructionGroupBuilder<'all_inst> { + pub fn new(all_instructions: &'all_inst mut AllInstructions) -> Self { + Self { + all_instructions, + own_instructions: Vec::new(), + } + } + + pub fn push(&mut self, builder: InstructionBuilder) { + let opcode_number = OpcodeNumber(self.all_instructions.next_key().as_u32()); + let inst = builder.build(opcode_number); + // Note this clone is cheap, since Instruction is a Rc<> wrapper for InstructionContent. + self.own_instructions.push(inst.clone()); + self.all_instructions.push(inst); + } + + pub fn build(self) -> InstructionGroup { + InstructionGroup { + instructions: self.own_instructions, + } + } +} + +/// Every instruction must belong to exactly one instruction group. A given +/// target architecture can support instructions from multiple groups, and it +/// does not necessarily support all instructions in a group. +pub(crate) struct InstructionGroup { + instructions: Vec<Instruction>, +} + +impl InstructionGroup { + pub fn by_name(&self, name: &'static str) -> &Instruction { + self.instructions + .iter() + .find(|inst| inst.name == name) + .unwrap_or_else(|| panic!("instruction with name '{}' does not exist", name)) + } +} + +/// Instructions can have parameters bound to them to specialize them for more specific encodings +/// (e.g. the encoding for adding two float types may be different than that of adding two +/// integer types) +pub(crate) trait Bindable { + /// Bind a parameter to an instruction + fn bind(&self, parameter: impl Into<BindParameter>) -> BoundInstruction; +} + +#[derive(Debug)] +pub(crate) struct PolymorphicInfo { + pub use_typevar_operand: bool, + pub ctrl_typevar: TypeVar, + pub other_typevars: Vec<TypeVar>, +} + +#[derive(Debug)] +pub(crate) struct InstructionContent { + /// Instruction mnemonic, also becomes opcode name. + pub name: String, + pub camel_name: String, + pub opcode_number: OpcodeNumber, + + /// Documentation string. + pub doc: String, + + /// Input operands. This can be a mix of SSA value operands and other operand kinds. + pub operands_in: Vec<Operand>, + /// Output operands. The output operands must be SSA values or `variable_args`. + pub operands_out: Vec<Operand>, + /// Instruction-specific TypeConstraints. + pub constraints: Vec<Constraint>, + + /// Instruction format, automatically derived from the input operands. + pub format: Rc<InstructionFormat>, + + /// One of the input or output operands is a free type variable. None if the instruction is not + /// polymorphic, set otherwise. + pub polymorphic_info: Option<PolymorphicInfo>, + + /// Indices in operands_in of input operands that are values. + pub value_opnums: Vec<usize>, + /// Indices in operands_in of input operands that are immediates or entities. + pub imm_opnums: Vec<usize>, + /// Indices in operands_out of output operands that are values. + pub value_results: Vec<usize>, + + /// True for instructions that terminate the block. + pub is_terminator: bool, + /// True for all branch or jump instructions. + pub is_branch: bool, + /// True for all indirect branch or jump instructions.', + pub is_indirect_branch: bool, + /// Is this a call instruction? + pub is_call: bool, + /// Is this a return instruction? + pub is_return: bool, + /// Is this a ghost instruction? + pub is_ghost: bool, + /// Can this instruction read from memory? + pub can_load: bool, + /// Can this instruction write to memory? + pub can_store: bool, + /// Can this instruction cause a trap? + pub can_trap: bool, + /// Does this instruction have other side effects besides can_* flags? + pub other_side_effects: bool, + /// Does this instruction write to CPU flags? + pub writes_cpu_flags: bool, + /// Should this opcode be considered to clobber all live registers, during regalloc? + pub clobbers_all_regs: bool, +} + +impl InstructionContent { + pub fn snake_name(&self) -> &str { + if &self.name == "return" { + "return_" + } else { + &self.name + } + } + + pub fn all_typevars(&self) -> Vec<&TypeVar> { + match &self.polymorphic_info { + Some(poly) => { + let mut result = vec![&poly.ctrl_typevar]; + result.extend(&poly.other_typevars); + result + } + None => Vec::new(), + } + } +} + +pub(crate) type Instruction = Rc<InstructionContent>; + +impl Bindable for Instruction { + fn bind(&self, parameter: impl Into<BindParameter>) -> BoundInstruction { + BoundInstruction::new(self).bind(parameter) + } +} + +impl fmt::Display for InstructionContent { + fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { + if !self.operands_out.is_empty() { + let operands_out = self + .operands_out + .iter() + .map(|op| op.name) + .collect::<Vec<_>>() + .join(", "); + fmt.write_str(&operands_out)?; + fmt.write_str(" = ")?; + } + + fmt.write_str(&self.name)?; + + if !self.operands_in.is_empty() { + let operands_in = self + .operands_in + .iter() + .map(|op| op.name) + .collect::<Vec<_>>() + .join(", "); + fmt.write_str(" ")?; + fmt.write_str(&operands_in)?; + } + + Ok(()) + } +} + +pub(crate) struct InstructionBuilder { + name: String, + doc: String, + format: Rc<InstructionFormat>, + operands_in: Option<Vec<Operand>>, + operands_out: Option<Vec<Operand>>, + constraints: Option<Vec<Constraint>>, + + // See Instruction comments for the meaning of these fields. + is_terminator: bool, + is_branch: bool, + is_indirect_branch: bool, + is_call: bool, + is_return: bool, + is_ghost: bool, + can_load: bool, + can_store: bool, + can_trap: bool, + other_side_effects: bool, + clobbers_all_regs: bool, +} + +impl InstructionBuilder { + pub fn new<S: Into<String>>(name: S, doc: S, format: &Rc<InstructionFormat>) -> Self { + Self { + name: name.into(), + doc: doc.into(), + format: format.clone(), + operands_in: None, + operands_out: None, + constraints: None, + + is_terminator: false, + is_branch: false, + is_indirect_branch: false, + is_call: false, + is_return: false, + is_ghost: false, + can_load: false, + can_store: false, + can_trap: false, + other_side_effects: false, + clobbers_all_regs: false, + } + } + + pub fn operands_in(mut self, operands: Vec<&Operand>) -> Self { + assert!(self.operands_in.is_none()); + self.operands_in = Some(operands.iter().map(|x| (*x).clone()).collect()); + self + } + + pub fn operands_out(mut self, operands: Vec<&Operand>) -> Self { + assert!(self.operands_out.is_none()); + self.operands_out = Some(operands.iter().map(|x| (*x).clone()).collect()); + self + } + + pub fn constraints(mut self, constraints: Vec<Constraint>) -> Self { + assert!(self.constraints.is_none()); + self.constraints = Some(constraints); + self + } + + #[allow(clippy::wrong_self_convention)] + pub fn is_terminator(mut self, val: bool) -> Self { + self.is_terminator = val; + self + } + + #[allow(clippy::wrong_self_convention)] + pub fn is_branch(mut self, val: bool) -> Self { + self.is_branch = val; + self + } + + #[allow(clippy::wrong_self_convention)] + pub fn is_indirect_branch(mut self, val: bool) -> Self { + self.is_indirect_branch = val; + self + } + + #[allow(clippy::wrong_self_convention)] + pub fn is_call(mut self, val: bool) -> Self { + self.is_call = val; + self + } + + #[allow(clippy::wrong_self_convention)] + pub fn is_return(mut self, val: bool) -> Self { + self.is_return = val; + self + } + + #[allow(clippy::wrong_self_convention)] + pub fn is_ghost(mut self, val: bool) -> Self { + self.is_ghost = val; + self + } + + pub fn can_load(mut self, val: bool) -> Self { + self.can_load = val; + self + } + + pub fn can_store(mut self, val: bool) -> Self { + self.can_store = val; + self + } + + pub fn can_trap(mut self, val: bool) -> Self { + self.can_trap = val; + self + } + + pub fn other_side_effects(mut self, val: bool) -> Self { + self.other_side_effects = val; + self + } + + pub fn clobbers_all_regs(mut self, val: bool) -> Self { + self.clobbers_all_regs = val; + self + } + + fn build(self, opcode_number: OpcodeNumber) -> Instruction { + let operands_in = self.operands_in.unwrap_or_else(Vec::new); + let operands_out = self.operands_out.unwrap_or_else(Vec::new); + + let mut value_opnums = Vec::new(); + let mut imm_opnums = Vec::new(); + for (i, op) in operands_in.iter().enumerate() { + if op.is_value() { + value_opnums.push(i); + } else if op.is_immediate_or_entityref() { + imm_opnums.push(i); + } else { + assert!(op.is_varargs()); + } + } + + let value_results = operands_out + .iter() + .enumerate() + .filter_map(|(i, op)| if op.is_value() { Some(i) } else { None }) + .collect(); + + verify_format(&self.name, &operands_in, &self.format); + + let polymorphic_info = + verify_polymorphic(&operands_in, &operands_out, &self.format, &value_opnums); + + // Infer from output operands whether an instruction clobbers CPU flags or not. + let writes_cpu_flags = operands_out.iter().any(|op| op.is_cpu_flags()); + + let camel_name = camel_case(&self.name); + + Rc::new(InstructionContent { + name: self.name, + camel_name, + opcode_number, + doc: self.doc, + operands_in, + operands_out, + constraints: self.constraints.unwrap_or_else(Vec::new), + format: self.format, + polymorphic_info, + value_opnums, + value_results, + imm_opnums, + is_terminator: self.is_terminator, + is_branch: self.is_branch, + is_indirect_branch: self.is_indirect_branch, + is_call: self.is_call, + is_return: self.is_return, + is_ghost: self.is_ghost, + can_load: self.can_load, + can_store: self.can_store, + can_trap: self.can_trap, + other_side_effects: self.other_side_effects, + writes_cpu_flags, + clobbers_all_regs: self.clobbers_all_regs, + }) + } +} + +/// A thin wrapper like Option<ValueType>, but with more precise semantics. +#[derive(Clone)] +pub(crate) enum ValueTypeOrAny { + ValueType(ValueType), + Any, +} + +impl ValueTypeOrAny { + pub fn expect(self, msg: &str) -> ValueType { + match self { + ValueTypeOrAny::ValueType(vt) => vt, + ValueTypeOrAny::Any => panic!(format!("Unexpected Any: {}", msg)), + } + } +} + +/// The number of bits in the vector +type VectorBitWidth = u64; + +/// An parameter used for binding instructions to specific types or values +pub(crate) enum BindParameter { + Any, + Lane(LaneType), + Vector(LaneType, VectorBitWidth), + Reference(ReferenceType), + Immediate(Immediate), +} + +/// Constructor for more easily building vector parameters from any lane type +pub(crate) fn vector(parameter: impl Into<LaneType>, vector_size: VectorBitWidth) -> BindParameter { + BindParameter::Vector(parameter.into(), vector_size) +} + +impl From<Int> for BindParameter { + fn from(ty: Int) -> Self { + BindParameter::Lane(ty.into()) + } +} + +impl From<Bool> for BindParameter { + fn from(ty: Bool) -> Self { + BindParameter::Lane(ty.into()) + } +} + +impl From<Float> for BindParameter { + fn from(ty: Float) -> Self { + BindParameter::Lane(ty.into()) + } +} + +impl From<LaneType> for BindParameter { + fn from(ty: LaneType) -> Self { + BindParameter::Lane(ty) + } +} + +impl From<Reference> for BindParameter { + fn from(ty: Reference) -> Self { + BindParameter::Reference(ty.into()) + } +} + +impl From<Immediate> for BindParameter { + fn from(imm: Immediate) -> Self { + BindParameter::Immediate(imm) + } +} + +#[derive(Clone)] +pub(crate) enum Immediate { + // When needed, this enum should be expanded to include other immediate types (e.g. u8, u128). + IntCC(IntCC), +} + +impl Display for Immediate { + fn fmt(&self, f: &mut Formatter) -> Result<(), Error> { + match self { + Immediate::IntCC(x) => write!(f, "IntCC::{:?}", x), + } + } +} + +#[derive(Clone)] +pub(crate) struct BoundInstruction { + pub inst: Instruction, + pub value_types: Vec<ValueTypeOrAny>, + pub immediate_values: Vec<Immediate>, +} + +impl BoundInstruction { + /// Construct a new bound instruction (with nothing bound yet) from an instruction + fn new(inst: &Instruction) -> Self { + BoundInstruction { + inst: inst.clone(), + value_types: vec![], + immediate_values: vec![], + } + } + + /// Verify that the bindings for a BoundInstruction are correct. + fn verify_bindings(&self) -> Result<(), String> { + // Verify that binding types to the instruction does not violate the polymorphic rules. + if !self.value_types.is_empty() { + match &self.inst.polymorphic_info { + Some(poly) => { + if self.value_types.len() > 1 + poly.other_typevars.len() { + return Err(format!( + "trying to bind too many types for {}", + self.inst.name + )); + } + } + None => { + return Err(format!( + "trying to bind a type for {} which is not a polymorphic instruction", + self.inst.name + )); + } + } + } + + // Verify that only the right number of immediates are bound. + let immediate_count = self + .inst + .operands_in + .iter() + .filter(|o| o.is_immediate_or_entityref()) + .count(); + if self.immediate_values.len() > immediate_count { + return Err(format!( + "trying to bind too many immediates ({}) to instruction {} which only expects {} \ + immediates", + self.immediate_values.len(), + self.inst.name, + immediate_count + )); + } + + Ok(()) + } +} + +impl Bindable for BoundInstruction { + fn bind(&self, parameter: impl Into<BindParameter>) -> BoundInstruction { + let mut modified = self.clone(); + match parameter.into() { + BindParameter::Any => modified.value_types.push(ValueTypeOrAny::Any), + BindParameter::Lane(lane_type) => modified + .value_types + .push(ValueTypeOrAny::ValueType(lane_type.into())), + BindParameter::Vector(lane_type, vector_size_in_bits) => { + let num_lanes = vector_size_in_bits / lane_type.lane_bits(); + assert!( + num_lanes >= 2, + "Minimum lane number for bind_vector is 2, found {}.", + num_lanes, + ); + let vector_type = ValueType::Vector(VectorType::new(lane_type, num_lanes)); + modified + .value_types + .push(ValueTypeOrAny::ValueType(vector_type)); + } + BindParameter::Reference(reference_type) => { + modified + .value_types + .push(ValueTypeOrAny::ValueType(reference_type.into())); + } + BindParameter::Immediate(immediate) => modified.immediate_values.push(immediate), + } + modified.verify_bindings().unwrap(); + modified + } +} + +/// Checks that the input operands actually match the given format. +fn verify_format(inst_name: &str, operands_in: &[Operand], format: &InstructionFormat) { + // A format is defined by: + // - its number of input value operands, + // - its number and names of input immediate operands, + // - whether it has a value list or not. + let mut num_values = 0; + let mut num_immediates = 0; + + for operand in operands_in.iter() { + if operand.is_varargs() { + assert!( + format.has_value_list, + "instruction {} has varargs, but its format {} doesn't have a value list; you may \ + need to use a different format.", + inst_name, format.name + ); + } + if operand.is_value() { + num_values += 1; + } + if operand.is_immediate_or_entityref() { + if let Some(format_field) = format.imm_fields.get(num_immediates) { + assert_eq!( + format_field.kind.rust_field_name, + operand.kind.rust_field_name, + "{}th operand of {} should be {} (according to format), not {} (according to \ + inst definition). You may need to use a different format.", + num_immediates, + inst_name, + format_field.kind.rust_field_name, + operand.kind.rust_field_name + ); + num_immediates += 1; + } + } + } + + assert_eq!( + num_values, format.num_value_operands, + "inst {} doesn't have as many value input operands as its format {} declares; you may need \ + to use a different format.", + inst_name, format.name + ); + + assert_eq!( + num_immediates, + format.imm_fields.len(), + "inst {} doesn't have as many immediate input \ + operands as its format {} declares; you may need to use a different format.", + inst_name, + format.name + ); +} + +/// Check if this instruction is polymorphic, and verify its use of type variables. +fn verify_polymorphic( + operands_in: &[Operand], + operands_out: &[Operand], + format: &InstructionFormat, + value_opnums: &[usize], +) -> Option<PolymorphicInfo> { + // The instruction is polymorphic if it has one free input or output operand. + let is_polymorphic = operands_in + .iter() + .any(|op| op.is_value() && op.type_var().unwrap().free_typevar().is_some()) + || operands_out + .iter() + .any(|op| op.is_value() && op.type_var().unwrap().free_typevar().is_some()); + + if !is_polymorphic { + return None; + } + + // Verify the use of type variables. + let tv_op = format.typevar_operand; + let mut maybe_error_message = None; + if let Some(tv_op) = tv_op { + if tv_op < value_opnums.len() { + let op_num = value_opnums[tv_op]; + let tv = operands_in[op_num].type_var().unwrap(); + let free_typevar = tv.free_typevar(); + if (free_typevar.is_some() && tv == &free_typevar.unwrap()) + || tv.singleton_type().is_some() + { + match is_ctrl_typevar_candidate(tv, &operands_in, &operands_out) { + Ok(other_typevars) => { + return Some(PolymorphicInfo { + use_typevar_operand: true, + ctrl_typevar: tv.clone(), + other_typevars, + }); + } + Err(error_message) => { + maybe_error_message = Some(error_message); + } + } + } + } + }; + + // If we reached here, it means the type variable indicated as the typevar operand couldn't + // control every other input and output type variable. We need to look at the result type + // variables. + if operands_out.is_empty() { + // No result means no other possible type variable, so it's a type inference failure. + match maybe_error_message { + Some(msg) => panic!(msg), + None => panic!("typevar_operand must be a free type variable"), + } + } + + // Otherwise, try to infer the controlling type variable by looking at the first result. + let tv = operands_out[0].type_var().unwrap(); + let free_typevar = tv.free_typevar(); + if free_typevar.is_some() && tv != &free_typevar.unwrap() { + panic!("first result must be a free type variable"); + } + + // At this point, if the next unwrap() fails, it means the output type couldn't be used as a + // controlling type variable either; panicking is the right behavior. + let other_typevars = is_ctrl_typevar_candidate(tv, &operands_in, &operands_out).unwrap(); + + Some(PolymorphicInfo { + use_typevar_operand: false, + ctrl_typevar: tv.clone(), + other_typevars, + }) +} + +/// Verify that the use of TypeVars is consistent with `ctrl_typevar` as the controlling type +/// variable. +/// +/// All polymorhic inputs must either be derived from `ctrl_typevar` or be independent free type +/// variables only used once. +/// +/// All polymorphic results must be derived from `ctrl_typevar`. +/// +/// Return a vector of other type variables used, or a string explaining what went wrong. +fn is_ctrl_typevar_candidate( + ctrl_typevar: &TypeVar, + operands_in: &[Operand], + operands_out: &[Operand], +) -> Result<Vec<TypeVar>, String> { + let mut other_typevars = Vec::new(); + + // Check value inputs. + for input in operands_in { + if !input.is_value() { + continue; + } + + let typ = input.type_var().unwrap(); + let free_typevar = typ.free_typevar(); + + // Non-polymorphic or derived from ctrl_typevar is OK. + if free_typevar.is_none() { + continue; + } + let free_typevar = free_typevar.unwrap(); + if &free_typevar == ctrl_typevar { + continue; + } + + // No other derived typevars allowed. + if typ != &free_typevar { + return Err(format!( + "{:?}: type variable {} must be derived from {:?} while it is derived from {:?}", + input, typ.name, ctrl_typevar, free_typevar + )); + } + + // Other free type variables can only be used once each. + for other_tv in &other_typevars { + if &free_typevar == other_tv { + return Err(format!( + "non-controlling type variable {} can't be used more than once", + free_typevar.name + )); + } + } + + other_typevars.push(free_typevar); + } + + // Check outputs. + for result in operands_out { + if !result.is_value() { + continue; + } + + let typ = result.type_var().unwrap(); + let free_typevar = typ.free_typevar(); + + // Non-polymorphic or derived from ctrl_typevar is OK. + if free_typevar.is_none() || &free_typevar.unwrap() == ctrl_typevar { + continue; + } + + return Err("type variable in output not derived from ctrl_typevar".into()); + } + + Ok(other_typevars) +} + +#[derive(Clone, Hash, PartialEq, Eq)] +pub(crate) enum FormatPredicateKind { + /// Is the field member equal to the expected value (stored here)? + IsEqual(String), + + /// Is the immediate instruction format field representable as an n-bit two's complement + /// integer? (with width: first member, scale: second member). + /// The predicate is true if the field is in the range: `-2^(width-1) -- 2^(width-1)-1` and a + /// multiple of `2^scale`. + IsSignedInt(usize, usize), + + /// Is the immediate instruction format field representable as an n-bit unsigned integer? (with + /// width: first member, scale: second member). + /// The predicate is true if the field is in the range: `0 -- 2^width - 1` and a multiple of + /// `2^scale`. + IsUnsignedInt(usize, usize), + + /// Is the immediate format field member an integer equal to zero? + IsZeroInt, + /// Is the immediate format field member equal to zero? (float32 version) + IsZero32BitFloat, + + /// Is the immediate format field member equal to zero? (float64 version) + IsZero64BitFloat, + + /// Is the immediate format field member equal zero in all lanes? + IsAllZeroes, + + /// Does the immediate format field member have ones in all bits of all lanes? + IsAllOnes, + + /// Has the value list (in member_name) the size specified in parameter? + LengthEquals(usize), + + /// Is the referenced function colocated? + IsColocatedFunc, + + /// Is the referenced data object colocated? + IsColocatedData, +} + +#[derive(Clone, Hash, PartialEq, Eq)] +pub(crate) struct FormatPredicateNode { + format_name: &'static str, + member_name: &'static str, + kind: FormatPredicateKind, +} + +impl FormatPredicateNode { + fn new( + format: &InstructionFormat, + field_name: &'static str, + kind: FormatPredicateKind, + ) -> Self { + let member_name = format.imm_by_name(field_name).member; + Self { + format_name: format.name, + member_name, + kind, + } + } + + fn new_raw( + format: &InstructionFormat, + member_name: &'static str, + kind: FormatPredicateKind, + ) -> Self { + Self { + format_name: format.name, + member_name, + kind, + } + } + + fn destructuring_member_name(&self) -> &'static str { + match &self.kind { + FormatPredicateKind::LengthEquals(_) => { + // Length operates on the argument value list. + assert!(self.member_name == "args"); + "ref args" + } + _ => self.member_name, + } + } + + fn rust_predicate(&self) -> String { + match &self.kind { + FormatPredicateKind::IsEqual(arg) => { + format!("predicates::is_equal({}, {})", self.member_name, arg) + } + FormatPredicateKind::IsSignedInt(width, scale) => format!( + "predicates::is_signed_int({}, {}, {})", + self.member_name, width, scale + ), + FormatPredicateKind::IsUnsignedInt(width, scale) => format!( + "predicates::is_unsigned_int({}, {}, {})", + self.member_name, width, scale + ), + FormatPredicateKind::IsZeroInt => { + format!("predicates::is_zero_int({})", self.member_name) + } + FormatPredicateKind::IsZero32BitFloat => { + format!("predicates::is_zero_32_bit_float({})", self.member_name) + } + FormatPredicateKind::IsZero64BitFloat => { + format!("predicates::is_zero_64_bit_float({})", self.member_name) + } + FormatPredicateKind::IsAllZeroes => format!( + "predicates::is_all_zeroes(func.dfg.constants.get({}))", + self.member_name + ), + FormatPredicateKind::IsAllOnes => format!( + "predicates::is_all_ones(func.dfg.constants.get({}))", + self.member_name + ), + FormatPredicateKind::LengthEquals(num) => format!( + "predicates::has_length_of({}, {}, func)", + self.member_name, num + ), + FormatPredicateKind::IsColocatedFunc => { + format!("predicates::is_colocated_func({}, func)", self.member_name,) + } + FormatPredicateKind::IsColocatedData => { + format!("predicates::is_colocated_data({}, func)", self.member_name) + } + } + } +} + +#[derive(Clone, Hash, PartialEq, Eq)] +pub(crate) enum TypePredicateNode { + /// Is the value argument (at the index designated by the first member) the same type as the + /// type name (second member)? + TypeVarCheck(usize, String), + + /// Is the controlling type variable the same type as the one designated by the type name + /// (only member)? + CtrlTypeVarCheck(String), +} + +impl TypePredicateNode { + fn rust_predicate(&self, func_str: &str) -> String { + match self { + TypePredicateNode::TypeVarCheck(index, value_type_name) => format!( + "{}.dfg.value_type(args[{}]) == {}", + func_str, index, value_type_name + ), + TypePredicateNode::CtrlTypeVarCheck(value_type_name) => { + format!("{}.dfg.ctrl_typevar(inst) == {}", func_str, value_type_name) + } + } + } +} + +/// A basic node in an instruction predicate: either an atom, or an AND of two conditions. +#[derive(Clone, Hash, PartialEq, Eq)] +pub(crate) enum InstructionPredicateNode { + FormatPredicate(FormatPredicateNode), + + TypePredicate(TypePredicateNode), + + /// An AND-combination of two or more other predicates. + And(Vec<InstructionPredicateNode>), + + /// An OR-combination of two or more other predicates. + Or(Vec<InstructionPredicateNode>), +} + +impl InstructionPredicateNode { + fn rust_predicate(&self, func_str: &str) -> String { + match self { + InstructionPredicateNode::FormatPredicate(node) => node.rust_predicate(), + InstructionPredicateNode::TypePredicate(node) => node.rust_predicate(func_str), + InstructionPredicateNode::And(nodes) => nodes + .iter() + .map(|x| x.rust_predicate(func_str)) + .collect::<Vec<_>>() + .join(" && "), + InstructionPredicateNode::Or(nodes) => nodes + .iter() + .map(|x| x.rust_predicate(func_str)) + .collect::<Vec<_>>() + .join(" || "), + } + } + + pub fn format_destructuring_member_name(&self) -> &str { + match self { + InstructionPredicateNode::FormatPredicate(format_pred) => { + format_pred.destructuring_member_name() + } + _ => panic!("Only for leaf format predicates"), + } + } + + pub fn format_name(&self) -> &str { + match self { + InstructionPredicateNode::FormatPredicate(format_pred) => format_pred.format_name, + _ => panic!("Only for leaf format predicates"), + } + } + + pub fn is_type_predicate(&self) -> bool { + match self { + InstructionPredicateNode::FormatPredicate(_) + | InstructionPredicateNode::And(_) + | InstructionPredicateNode::Or(_) => false, + InstructionPredicateNode::TypePredicate(_) => true, + } + } + + fn collect_leaves(&self) -> Vec<&InstructionPredicateNode> { + let mut ret = Vec::new(); + match self { + InstructionPredicateNode::And(nodes) | InstructionPredicateNode::Or(nodes) => { + for node in nodes { + ret.extend(node.collect_leaves()); + } + } + _ => ret.push(self), + } + ret + } +} + +#[derive(Clone, Hash, PartialEq, Eq)] +pub(crate) struct InstructionPredicate { + node: Option<InstructionPredicateNode>, +} + +impl Into<InstructionPredicate> for InstructionPredicateNode { + fn into(self) -> InstructionPredicate { + InstructionPredicate { node: Some(self) } + } +} + +impl InstructionPredicate { + pub fn new() -> Self { + Self { node: None } + } + + pub fn unwrap(self) -> InstructionPredicateNode { + self.node.unwrap() + } + + pub fn new_typevar_check( + inst: &Instruction, + type_var: &TypeVar, + value_type: &ValueType, + ) -> InstructionPredicateNode { + let index = inst + .value_opnums + .iter() + .enumerate() + .find(|(_, &op_num)| inst.operands_in[op_num].type_var().unwrap() == type_var) + .unwrap() + .0; + InstructionPredicateNode::TypePredicate(TypePredicateNode::TypeVarCheck( + index, + value_type.rust_name(), + )) + } + + pub fn new_ctrl_typevar_check(value_type: &ValueType) -> InstructionPredicateNode { + InstructionPredicateNode::TypePredicate(TypePredicateNode::CtrlTypeVarCheck( + value_type.rust_name(), + )) + } + + pub fn new_is_field_equal( + format: &InstructionFormat, + field_name: &'static str, + imm_value: String, + ) -> InstructionPredicateNode { + InstructionPredicateNode::FormatPredicate(FormatPredicateNode::new( + format, + field_name, + FormatPredicateKind::IsEqual(imm_value), + )) + } + + /// Used only for the AST module, which directly passes in the format field. + pub fn new_is_field_equal_ast( + format: &InstructionFormat, + field: &FormatField, + imm_value: String, + ) -> InstructionPredicateNode { + InstructionPredicateNode::FormatPredicate(FormatPredicateNode::new_raw( + format, + field.member, + FormatPredicateKind::IsEqual(imm_value), + )) + } + + pub fn new_is_signed_int( + format: &InstructionFormat, + field_name: &'static str, + width: usize, + scale: usize, + ) -> InstructionPredicateNode { + InstructionPredicateNode::FormatPredicate(FormatPredicateNode::new( + format, + field_name, + FormatPredicateKind::IsSignedInt(width, scale), + )) + } + + pub fn new_is_unsigned_int( + format: &InstructionFormat, + field_name: &'static str, + width: usize, + scale: usize, + ) -> InstructionPredicateNode { + InstructionPredicateNode::FormatPredicate(FormatPredicateNode::new( + format, + field_name, + FormatPredicateKind::IsUnsignedInt(width, scale), + )) + } + + pub fn new_is_zero_int( + format: &InstructionFormat, + field_name: &'static str, + ) -> InstructionPredicateNode { + InstructionPredicateNode::FormatPredicate(FormatPredicateNode::new( + format, + field_name, + FormatPredicateKind::IsZeroInt, + )) + } + + pub fn new_is_zero_32bit_float( + format: &InstructionFormat, + field_name: &'static str, + ) -> InstructionPredicateNode { + InstructionPredicateNode::FormatPredicate(FormatPredicateNode::new( + format, + field_name, + FormatPredicateKind::IsZero32BitFloat, + )) + } + + pub fn new_is_zero_64bit_float( + format: &InstructionFormat, + field_name: &'static str, + ) -> InstructionPredicateNode { + InstructionPredicateNode::FormatPredicate(FormatPredicateNode::new( + format, + field_name, + FormatPredicateKind::IsZero64BitFloat, + )) + } + + pub fn new_is_all_zeroes( + format: &InstructionFormat, + field_name: &'static str, + ) -> InstructionPredicateNode { + InstructionPredicateNode::FormatPredicate(FormatPredicateNode::new( + format, + field_name, + FormatPredicateKind::IsAllZeroes, + )) + } + + pub fn new_is_all_ones( + format: &InstructionFormat, + field_name: &'static str, + ) -> InstructionPredicateNode { + InstructionPredicateNode::FormatPredicate(FormatPredicateNode::new( + format, + field_name, + FormatPredicateKind::IsAllOnes, + )) + } + + pub fn new_length_equals(format: &InstructionFormat, size: usize) -> InstructionPredicateNode { + assert!( + format.has_value_list, + "the format must be variadic in number of arguments" + ); + InstructionPredicateNode::FormatPredicate(FormatPredicateNode::new_raw( + format, + "args", + FormatPredicateKind::LengthEquals(size), + )) + } + + pub fn new_is_colocated_func( + format: &InstructionFormat, + field_name: &'static str, + ) -> InstructionPredicateNode { + InstructionPredicateNode::FormatPredicate(FormatPredicateNode::new( + format, + field_name, + FormatPredicateKind::IsColocatedFunc, + )) + } + + pub fn new_is_colocated_data(formats: &Formats) -> InstructionPredicateNode { + let format = &formats.unary_global_value; + InstructionPredicateNode::FormatPredicate(FormatPredicateNode::new( + &*format, + "global_value", + FormatPredicateKind::IsColocatedData, + )) + } + + pub fn and(mut self, new_node: InstructionPredicateNode) -> Self { + let node = self.node; + let mut and_nodes = match node { + Some(node) => match node { + InstructionPredicateNode::And(nodes) => nodes, + InstructionPredicateNode::Or(_) => { + panic!("Can't mix and/or without implementing operator precedence!") + } + _ => vec![node], + }, + _ => Vec::new(), + }; + and_nodes.push(new_node); + self.node = Some(InstructionPredicateNode::And(and_nodes)); + self + } + + pub fn or(mut self, new_node: InstructionPredicateNode) -> Self { + let node = self.node; + let mut or_nodes = match node { + Some(node) => match node { + InstructionPredicateNode::Or(nodes) => nodes, + InstructionPredicateNode::And(_) => { + panic!("Can't mix and/or without implementing operator precedence!") + } + _ => vec![node], + }, + _ => Vec::new(), + }; + or_nodes.push(new_node); + self.node = Some(InstructionPredicateNode::Or(or_nodes)); + self + } + + pub fn rust_predicate(&self, func_str: &str) -> Option<String> { + self.node.as_ref().map(|root| root.rust_predicate(func_str)) + } + + /// Returns the type predicate if this is one, or None otherwise. + pub fn type_predicate(&self, func_str: &str) -> Option<String> { + let node = self.node.as_ref().unwrap(); + if node.is_type_predicate() { + Some(node.rust_predicate(func_str)) + } else { + None + } + } + + /// Returns references to all the nodes that are leaves in the condition (i.e. by flattening + /// AND/OR). + pub fn collect_leaves(&self) -> Vec<&InstructionPredicateNode> { + self.node.as_ref().unwrap().collect_leaves() + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub(crate) struct InstructionPredicateNumber(u32); +entity_impl!(InstructionPredicateNumber); + +pub(crate) type InstructionPredicateMap = + PrimaryMap<InstructionPredicateNumber, InstructionPredicate>; + +/// A registry of predicates to help deduplicating them, during Encodings construction. When the +/// construction process is over, it needs to be extracted with `extract` and associated to the +/// TargetIsa. +pub(crate) struct InstructionPredicateRegistry { + /// Maps a predicate number to its actual predicate. + map: InstructionPredicateMap, + + /// Inverse map: maps a predicate to its predicate number. This is used before inserting a + /// predicate, to check whether it already exists. + inverted_map: HashMap<InstructionPredicate, InstructionPredicateNumber>, +} + +impl InstructionPredicateRegistry { + pub fn new() -> Self { + Self { + map: PrimaryMap::new(), + inverted_map: HashMap::new(), + } + } + pub fn insert(&mut self, predicate: InstructionPredicate) -> InstructionPredicateNumber { + match self.inverted_map.get(&predicate) { + Some(&found) => found, + None => { + let key = self.map.push(predicate.clone()); + self.inverted_map.insert(predicate, key); + key + } + } + } + pub fn extract(self) -> InstructionPredicateMap { + self.map + } +} + +/// An instruction specification, containing an instruction that has bound types or not. +pub(crate) enum InstSpec { + Inst(Instruction), + Bound(BoundInstruction), +} + +impl InstSpec { + pub fn inst(&self) -> &Instruction { + match &self { + InstSpec::Inst(inst) => inst, + InstSpec::Bound(bound_inst) => &bound_inst.inst, + } + } +} + +impl Bindable for InstSpec { + fn bind(&self, parameter: impl Into<BindParameter>) -> BoundInstruction { + match self { + InstSpec::Inst(inst) => inst.bind(parameter.into()), + InstSpec::Bound(inst) => inst.bind(parameter.into()), + } + } +} + +impl Into<InstSpec> for &Instruction { + fn into(self) -> InstSpec { + InstSpec::Inst(self.clone()) + } +} + +impl Into<InstSpec> for BoundInstruction { + fn into(self) -> InstSpec { + InstSpec::Bound(self) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::cdsl::formats::InstructionFormatBuilder; + use crate::cdsl::operands::{OperandKind, OperandKindFields}; + use crate::cdsl::typevar::TypeSetBuilder; + use crate::shared::types::Int::{I32, I64}; + + fn field_to_operand(index: usize, field: OperandKindFields) -> Operand { + // Pretend the index string is &'static. + let name = Box::leak(index.to_string().into_boxed_str()); + // Format's name / rust_type don't matter here. + let kind = OperandKind::new(name, name, field); + let operand = Operand::new(name, kind); + operand + } + + fn field_to_operands(types: Vec<OperandKindFields>) -> Vec<Operand> { + types + .iter() + .enumerate() + .map(|(i, f)| field_to_operand(i, f.clone())) + .collect() + } + + fn build_fake_instruction( + inputs: Vec<OperandKindFields>, + outputs: Vec<OperandKindFields>, + ) -> Instruction { + // Setup a format from the input operands. + let mut format = InstructionFormatBuilder::new("fake"); + for (i, f) in inputs.iter().enumerate() { + match f { + OperandKindFields::TypeVar(_) => format = format.value(), + OperandKindFields::ImmValue => { + format = format.imm(&field_to_operand(i, f.clone()).kind) + } + _ => {} + }; + } + let format = format.build(); + + // Create the fake instruction. + InstructionBuilder::new("fake", "A fake instruction for testing.", &format) + .operands_in(field_to_operands(inputs).iter().collect()) + .operands_out(field_to_operands(outputs).iter().collect()) + .build(OpcodeNumber(42)) + } + + #[test] + fn ensure_bound_instructions_can_bind_lane_types() { + let type1 = TypeSetBuilder::new().ints(8..64).build(); + let in1 = OperandKindFields::TypeVar(TypeVar::new("a", "...", type1)); + let inst = build_fake_instruction(vec![in1], vec![]); + inst.bind(LaneType::Int(I32)); + } + + #[test] + fn ensure_bound_instructions_can_bind_immediates() { + let inst = build_fake_instruction(vec![OperandKindFields::ImmValue], vec![]); + let bound_inst = inst.bind(Immediate::IntCC(IntCC::Equal)); + assert!(bound_inst.verify_bindings().is_ok()); + } + + #[test] + #[should_panic] + fn ensure_instructions_fail_to_bind() { + let inst = build_fake_instruction(vec![], vec![]); + inst.bind(BindParameter::Lane(LaneType::Int(I32))); + // Trying to bind to an instruction with no inputs should fail. + } + + #[test] + #[should_panic] + fn ensure_bound_instructions_fail_to_bind_too_many_types() { + let type1 = TypeSetBuilder::new().ints(8..64).build(); + let in1 = OperandKindFields::TypeVar(TypeVar::new("a", "...", type1)); + let inst = build_fake_instruction(vec![in1], vec![]); + inst.bind(LaneType::Int(I32)).bind(LaneType::Int(I64)); + } + + #[test] + #[should_panic] + fn ensure_instructions_fail_to_bind_too_many_immediates() { + let inst = build_fake_instruction(vec![OperandKindFields::ImmValue], vec![]); + inst.bind(BindParameter::Immediate(Immediate::IntCC(IntCC::Equal))) + .bind(BindParameter::Immediate(Immediate::IntCC(IntCC::Equal))); + // Trying to bind too many immediates to an instruction should fail; note that the immediate + // values are nonsensical but irrelevant to the purpose of this test. + } +} diff --git a/third_party/rust/cranelift-codegen-meta/src/cdsl/isa.rs b/third_party/rust/cranelift-codegen-meta/src/cdsl/isa.rs new file mode 100644 index 0000000000..512105d09a --- /dev/null +++ b/third_party/rust/cranelift-codegen-meta/src/cdsl/isa.rs @@ -0,0 +1,99 @@ +use std::collections::HashSet; +use std::iter::FromIterator; + +use crate::cdsl::cpu_modes::CpuMode; +use crate::cdsl::instructions::{InstructionGroup, InstructionPredicateMap}; +use crate::cdsl::recipes::Recipes; +use crate::cdsl::regs::IsaRegs; +use crate::cdsl::settings::SettingGroup; +use crate::cdsl::xform::{TransformGroupIndex, TransformGroups}; + +pub(crate) struct TargetIsa { + pub name: &'static str, + pub instructions: InstructionGroup, + pub settings: SettingGroup, + pub regs: IsaRegs, + pub recipes: Recipes, + pub cpu_modes: Vec<CpuMode>, + pub encodings_predicates: InstructionPredicateMap, + + /// TransformGroupIndex are global to all the ISAs, while we want to have indices into the + /// local array of transform groups that are directly used. We use this map to get this + /// information. + pub local_transform_groups: Vec<TransformGroupIndex>, +} + +impl TargetIsa { + pub fn new( + name: &'static str, + instructions: InstructionGroup, + settings: SettingGroup, + regs: IsaRegs, + recipes: Recipes, + cpu_modes: Vec<CpuMode>, + encodings_predicates: InstructionPredicateMap, + ) -> Self { + // Compute the local TransformGroup index. + let mut local_transform_groups = Vec::new(); + for cpu_mode in &cpu_modes { + let transform_groups = cpu_mode.direct_transform_groups(); + for group_index in transform_groups { + // find() is fine here: the number of transform group is < 5 as of June 2019. + if local_transform_groups + .iter() + .find(|&val| group_index == *val) + .is_none() + { + local_transform_groups.push(group_index); + } + } + } + + Self { + name, + instructions, + settings, + regs, + recipes, + cpu_modes, + encodings_predicates, + local_transform_groups, + } + } + + /// Returns a deterministically ordered, deduplicated list of TransformGroupIndex for the + /// transitive set of TransformGroup this TargetIsa uses. + pub fn transitive_transform_groups( + &self, + all_groups: &TransformGroups, + ) -> Vec<TransformGroupIndex> { + let mut set = HashSet::new(); + + for &root in self.local_transform_groups.iter() { + set.insert(root); + let mut base = root; + // Follow the chain of chain_with. + while let Some(chain_with) = &all_groups.get(base).chain_with { + set.insert(*chain_with); + base = *chain_with; + } + } + + let mut vec = Vec::from_iter(set); + vec.sort(); + vec + } + + /// Returns a deterministically ordered, deduplicated list of TransformGroupIndex for the directly + /// reachable set of TransformGroup this TargetIsa uses. + pub fn direct_transform_groups(&self) -> &Vec<TransformGroupIndex> { + &self.local_transform_groups + } + + pub fn translate_group_index(&self, group_index: TransformGroupIndex) -> usize { + self.local_transform_groups + .iter() + .position(|&val| val == group_index) + .expect("TransformGroup unused by this TargetIsa!") + } +} diff --git a/third_party/rust/cranelift-codegen-meta/src/cdsl/mod.rs b/third_party/rust/cranelift-codegen-meta/src/cdsl/mod.rs new file mode 100644 index 0000000000..698b64dff3 --- /dev/null +++ b/third_party/rust/cranelift-codegen-meta/src/cdsl/mod.rs @@ -0,0 +1,89 @@ +//! Cranelift DSL classes. +//! +//! This module defines the classes that are used to define Cranelift +//! instructions and other entities. + +#[macro_use] +pub mod ast; +pub mod cpu_modes; +pub mod encodings; +pub mod formats; +pub mod instructions; +pub mod isa; +pub mod operands; +pub mod recipes; +pub mod regs; +pub mod settings; +pub mod type_inference; +pub mod types; +pub mod typevar; +pub mod xform; + +/// A macro that converts boolean settings into predicates to look more natural. +#[macro_export] +macro_rules! predicate { + ($a:ident && $($b:tt)*) => { + PredicateNode::And(Box::new($a.into()), Box::new(predicate!($($b)*))) + }; + (!$a:ident && $($b:tt)*) => { + PredicateNode::And( + Box::new(PredicateNode::Not(Box::new($a.into()))), + Box::new(predicate!($($b)*)) + ) + }; + (!$a:ident) => { + PredicateNode::Not(Box::new($a.into())) + }; + ($a:ident) => { + $a.into() + }; +} + +/// A macro that joins boolean settings into a list (e.g. `preset!(feature_a && feature_b)`). +#[macro_export] +macro_rules! preset { + () => { + vec![] + }; + ($($x:ident)&&*) => { + { + let mut v = Vec::new(); + $( + v.push($x.into()); + )* + v + } + }; +} + +/// Convert the string `s` to CamelCase. +pub fn camel_case(s: &str) -> String { + let mut output_chars = String::with_capacity(s.len()); + + let mut capitalize = true; + for curr_char in s.chars() { + if curr_char == '_' { + capitalize = true; + } else { + if capitalize { + output_chars.extend(curr_char.to_uppercase()); + } else { + output_chars.push(curr_char); + } + capitalize = false; + } + } + + output_chars +} + +#[cfg(test)] +mod tests { + use super::camel_case; + + #[test] + fn camel_case_works() { + assert_eq!(camel_case("x"), "X"); + assert_eq!(camel_case("camel_case"), "CamelCase"); + } +} diff --git a/third_party/rust/cranelift-codegen-meta/src/cdsl/operands.rs b/third_party/rust/cranelift-codegen-meta/src/cdsl/operands.rs new file mode 100644 index 0000000000..605df24862 --- /dev/null +++ b/third_party/rust/cranelift-codegen-meta/src/cdsl/operands.rs @@ -0,0 +1,173 @@ +use std::collections::HashMap; + +use crate::cdsl::typevar::TypeVar; + +/// An instruction operand can be an *immediate*, an *SSA value*, or an *entity reference*. The +/// type of the operand is one of: +/// +/// 1. A `ValueType` instance indicates an SSA value operand with a concrete type. +/// +/// 2. A `TypeVar` instance indicates an SSA value operand, and the instruction is polymorphic over +/// the possible concrete types that the type variable can assume. +/// +/// 3. An `ImmediateKind` instance indicates an immediate operand whose value is encoded in the +/// instruction itself rather than being passed as an SSA value. +/// +/// 4. An `EntityRefKind` instance indicates an operand that references another entity in the +/// function, typically something declared in the function preamble. +#[derive(Clone, Debug)] +pub(crate) struct Operand { + /// Name of the operand variable, as it appears in function parameters, legalizations, etc. + pub name: &'static str, + + /// Type of the operand. + pub kind: OperandKind, + + doc: Option<&'static str>, +} + +impl Operand { + pub fn new(name: &'static str, kind: impl Into<OperandKind>) -> Self { + Self { + name, + doc: None, + kind: kind.into(), + } + } + pub fn with_doc(mut self, doc: &'static str) -> Self { + self.doc = Some(doc); + self + } + + pub fn doc(&self) -> Option<&str> { + if let Some(doc) = &self.doc { + return Some(doc); + } + match &self.kind.fields { + OperandKindFields::TypeVar(tvar) => Some(&tvar.doc), + _ => self.kind.doc(), + } + } + + pub fn is_value(&self) -> bool { + match self.kind.fields { + OperandKindFields::TypeVar(_) => true, + _ => false, + } + } + + pub fn type_var(&self) -> Option<&TypeVar> { + match &self.kind.fields { + OperandKindFields::TypeVar(typevar) => Some(typevar), + _ => None, + } + } + + pub fn is_varargs(&self) -> bool { + match self.kind.fields { + OperandKindFields::VariableArgs => true, + _ => false, + } + } + + /// Returns true if the operand has an immediate kind or is an EntityRef. + pub fn is_immediate_or_entityref(&self) -> bool { + match self.kind.fields { + OperandKindFields::ImmEnum(_) + | OperandKindFields::ImmValue + | OperandKindFields::EntityRef => true, + _ => false, + } + } + + /// Returns true if the operand has an immediate kind. + pub fn is_immediate(&self) -> bool { + match self.kind.fields { + OperandKindFields::ImmEnum(_) | OperandKindFields::ImmValue => true, + _ => false, + } + } + + pub fn is_cpu_flags(&self) -> bool { + match &self.kind.fields { + OperandKindFields::TypeVar(type_var) + if type_var.name == "iflags" || type_var.name == "fflags" => + { + true + } + _ => false, + } + } +} + +pub type EnumValues = HashMap<&'static str, &'static str>; + +#[derive(Clone, Debug)] +pub(crate) enum OperandKindFields { + EntityRef, + VariableArgs, + ImmValue, + ImmEnum(EnumValues), + TypeVar(TypeVar), +} + +#[derive(Clone, Debug)] +pub(crate) struct OperandKind { + /// String representation of the Rust type mapping to this OperandKind. + pub rust_type: &'static str, + + /// Name of this OperandKind in the format's member field. + pub rust_field_name: &'static str, + + /// Type-specific fields for this OperandKind. + pub fields: OperandKindFields, + + doc: Option<&'static str>, +} + +impl OperandKind { + pub fn new( + rust_field_name: &'static str, + rust_type: &'static str, + fields: OperandKindFields, + ) -> Self { + Self { + rust_field_name, + rust_type, + fields, + doc: None, + } + } + pub fn with_doc(mut self, doc: &'static str) -> Self { + assert!(self.doc.is_none()); + self.doc = Some(doc); + self + } + fn doc(&self) -> Option<&str> { + if let Some(doc) = &self.doc { + return Some(doc); + } + match &self.fields { + OperandKindFields::TypeVar(type_var) => Some(&type_var.doc), + OperandKindFields::ImmEnum(_) + | OperandKindFields::ImmValue + | OperandKindFields::EntityRef + | OperandKindFields::VariableArgs => None, + } + } +} + +impl Into<OperandKind> for &TypeVar { + fn into(self) -> OperandKind { + OperandKind::new( + "value", + "ir::Value", + OperandKindFields::TypeVar(self.into()), + ) + } +} +impl Into<OperandKind> for &OperandKind { + fn into(self) -> OperandKind { + self.clone() + } +} diff --git a/third_party/rust/cranelift-codegen-meta/src/cdsl/recipes.rs b/third_party/rust/cranelift-codegen-meta/src/cdsl/recipes.rs new file mode 100644 index 0000000000..dfe4cd67a5 --- /dev/null +++ b/third_party/rust/cranelift-codegen-meta/src/cdsl/recipes.rs @@ -0,0 +1,298 @@ +use std::rc::Rc; + +use cranelift_entity::{entity_impl, PrimaryMap}; + +use crate::cdsl::formats::InstructionFormat; +use crate::cdsl::instructions::InstructionPredicate; +use crate::cdsl::regs::RegClassIndex; +use crate::cdsl::settings::SettingPredicateNumber; + +/// A specific register in a register class. +/// +/// A register is identified by the top-level register class it belongs to and +/// its first register unit. +/// +/// Specific registers are used to describe constraints on instructions where +/// some operands must use a fixed register. +/// +/// Register instances can be created with the constructor, or accessed as +/// attributes on the register class: `GPR.rcx`. +#[derive(Copy, Clone, Hash, PartialEq, Eq)] +pub(crate) struct Register { + pub regclass: RegClassIndex, + pub unit: u8, +} + +impl Register { + pub fn new(regclass: RegClassIndex, unit: u8) -> Self { + Self { regclass, unit } + } +} + +/// An operand that must be in a stack slot. +/// +/// A `Stack` object can be used to indicate an operand constraint for a value +/// operand that must live in a stack slot. +#[derive(Copy, Clone, Hash, PartialEq)] +pub(crate) struct Stack { + pub regclass: RegClassIndex, +} + +impl Stack { + pub fn new(regclass: RegClassIndex) -> Self { + Self { regclass } + } + pub fn stack_base_mask(self) -> &'static str { + // TODO: Make this configurable instead of just using the SP. + "StackBaseMask(1)" + } +} + +#[derive(Clone, Hash, PartialEq)] +pub(crate) struct BranchRange { + pub inst_size: u64, + pub range: u64, +} + +#[derive(Copy, Clone, Hash, PartialEq)] +pub(crate) enum OperandConstraint { + RegClass(RegClassIndex), + FixedReg(Register), + TiedInput(usize), + Stack(Stack), +} + +impl Into<OperandConstraint> for RegClassIndex { + fn into(self) -> OperandConstraint { + OperandConstraint::RegClass(self) + } +} + +impl Into<OperandConstraint> for Register { + fn into(self) -> OperandConstraint { + OperandConstraint::FixedReg(self) + } +} + +impl Into<OperandConstraint> for usize { + fn into(self) -> OperandConstraint { + OperandConstraint::TiedInput(self) + } +} + +impl Into<OperandConstraint> for Stack { + fn into(self) -> OperandConstraint { + OperandConstraint::Stack(self) + } +} + +/// A recipe for encoding instructions with a given format. +/// +/// Many different instructions can be encoded by the same recipe, but they +/// must all have the same instruction format. +/// +/// The `operands_in` and `operands_out` arguments are tuples specifying the register +/// allocation constraints for the value operands and results respectively. The +/// possible constraints for an operand are: +/// +/// - A `RegClass` specifying the set of allowed registers. +/// - A `Register` specifying a fixed-register operand. +/// - An integer indicating that this result is tied to a value operand, so +/// they must use the same register. +/// - A `Stack` specifying a value in a stack slot. +/// +/// The `branch_range` argument must be provided for recipes that can encode +/// branch instructions. It is an `(origin, bits)` tuple describing the exact +/// range that can be encoded in a branch instruction. +#[derive(Clone)] +pub(crate) struct EncodingRecipe { + /// Short mnemonic name for this recipe. + pub name: String, + + /// Associated instruction format. + pub format: Rc<InstructionFormat>, + + /// Base number of bytes in the binary encoded instruction. + pub base_size: u64, + + /// Tuple of register constraints for value operands. + pub operands_in: Vec<OperandConstraint>, + + /// Tuple of register constraints for results. + pub operands_out: Vec<OperandConstraint>, + + /// Function name to use when computing actual size. + pub compute_size: &'static str, + + /// `(origin, bits)` range for branches. + pub branch_range: Option<BranchRange>, + + /// This instruction clobbers `iflags` and `fflags`; true by default. + pub clobbers_flags: bool, + + /// Instruction predicate. + pub inst_predicate: Option<InstructionPredicate>, + + /// ISA predicate. + pub isa_predicate: Option<SettingPredicateNumber>, + + /// Rust code for binary emission. + pub emit: Option<String>, +} + +// Implement PartialEq ourselves: take all the fields into account but the name. +impl PartialEq for EncodingRecipe { + fn eq(&self, other: &Self) -> bool { + Rc::ptr_eq(&self.format, &other.format) + && self.base_size == other.base_size + && self.operands_in == other.operands_in + && self.operands_out == other.operands_out + && self.compute_size == other.compute_size + && self.branch_range == other.branch_range + && self.clobbers_flags == other.clobbers_flags + && self.inst_predicate == other.inst_predicate + && self.isa_predicate == other.isa_predicate + && self.emit == other.emit + } +} + +// To allow using it in a hashmap. +impl Eq for EncodingRecipe {} + +#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub(crate) struct EncodingRecipeNumber(u32); +entity_impl!(EncodingRecipeNumber); + +pub(crate) type Recipes = PrimaryMap<EncodingRecipeNumber, EncodingRecipe>; + +#[derive(Clone)] +pub(crate) struct EncodingRecipeBuilder { + pub name: String, + format: Rc<InstructionFormat>, + pub base_size: u64, + pub operands_in: Option<Vec<OperandConstraint>>, + pub operands_out: Option<Vec<OperandConstraint>>, + pub compute_size: Option<&'static str>, + pub branch_range: Option<BranchRange>, + pub emit: Option<String>, + clobbers_flags: Option<bool>, + inst_predicate: Option<InstructionPredicate>, + isa_predicate: Option<SettingPredicateNumber>, +} + +impl EncodingRecipeBuilder { + pub fn new(name: impl Into<String>, format: &Rc<InstructionFormat>, base_size: u64) -> Self { + Self { + name: name.into(), + format: format.clone(), + base_size, + operands_in: None, + operands_out: None, + compute_size: None, + branch_range: None, + emit: None, + clobbers_flags: None, + inst_predicate: None, + isa_predicate: None, + } + } + + // Setters. + pub fn operands_in(mut self, constraints: Vec<impl Into<OperandConstraint>>) -> Self { + assert!(self.operands_in.is_none()); + self.operands_in = Some( + constraints + .into_iter() + .map(|constr| constr.into()) + .collect(), + ); + self + } + pub fn operands_out(mut self, constraints: Vec<impl Into<OperandConstraint>>) -> Self { + assert!(self.operands_out.is_none()); + self.operands_out = Some( + constraints + .into_iter() + .map(|constr| constr.into()) + .collect(), + ); + self + } + pub fn clobbers_flags(mut self, flag: bool) -> Self { + assert!(self.clobbers_flags.is_none()); + self.clobbers_flags = Some(flag); + self + } + pub fn emit(mut self, code: impl Into<String>) -> Self { + assert!(self.emit.is_none()); + self.emit = Some(code.into()); + self + } + pub fn branch_range(mut self, range: (u64, u64)) -> Self { + assert!(self.branch_range.is_none()); + self.branch_range = Some(BranchRange { + inst_size: range.0, + range: range.1, + }); + self + } + pub fn isa_predicate(mut self, pred: SettingPredicateNumber) -> Self { + assert!(self.isa_predicate.is_none()); + self.isa_predicate = Some(pred); + self + } + pub fn inst_predicate(mut self, inst_predicate: impl Into<InstructionPredicate>) -> Self { + assert!(self.inst_predicate.is_none()); + self.inst_predicate = Some(inst_predicate.into()); + self + } + pub fn compute_size(mut self, compute_size: &'static str) -> Self { + assert!(self.compute_size.is_none()); + self.compute_size = Some(compute_size); + self + } + + pub fn build(self) -> EncodingRecipe { + let operands_in = self.operands_in.unwrap_or_default(); + let operands_out = self.operands_out.unwrap_or_default(); + + // The number of input constraints must match the number of format input operands. + if !self.format.has_value_list { + assert!( + operands_in.len() == self.format.num_value_operands, + format!( + "missing operand constraints for recipe {} (format {})", + self.name, self.format.name + ) + ); + } + + // Ensure tied inputs actually refer to existing inputs. + for constraint in operands_in.iter().chain(operands_out.iter()) { + if let OperandConstraint::TiedInput(n) = *constraint { + assert!(n < operands_in.len()); + } + } + + let compute_size = match self.compute_size { + Some(compute_size) => compute_size, + None => "base_size", + }; + + let clobbers_flags = self.clobbers_flags.unwrap_or(true); + + EncodingRecipe { + name: self.name, + format: self.format, + base_size: self.base_size, + operands_in, + operands_out, + compute_size, + branch_range: self.branch_range, + clobbers_flags, + inst_predicate: self.inst_predicate, + isa_predicate: self.isa_predicate, + emit: self.emit, + } + } +} diff --git a/third_party/rust/cranelift-codegen-meta/src/cdsl/regs.rs b/third_party/rust/cranelift-codegen-meta/src/cdsl/regs.rs new file mode 100644 index 0000000000..864826ee43 --- /dev/null +++ b/third_party/rust/cranelift-codegen-meta/src/cdsl/regs.rs @@ -0,0 +1,412 @@ +use cranelift_codegen_shared::constants; +use cranelift_entity::{entity_impl, EntityRef, PrimaryMap}; + +#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub(crate) struct RegBankIndex(u32); +entity_impl!(RegBankIndex); + +pub(crate) struct RegBank { + pub name: &'static str, + pub first_unit: u8, + pub units: u8, + pub names: Vec<&'static str>, + pub prefix: &'static str, + pub pressure_tracking: bool, + pub pinned_reg: Option<u16>, + pub toprcs: Vec<RegClassIndex>, + pub classes: Vec<RegClassIndex>, +} + +impl RegBank { + pub fn new( + name: &'static str, + first_unit: u8, + units: u8, + names: Vec<&'static str>, + prefix: &'static str, + pressure_tracking: bool, + pinned_reg: Option<u16>, + ) -> Self { + RegBank { + name, + first_unit, + units, + names, + prefix, + pressure_tracking, + pinned_reg, + toprcs: Vec::new(), + classes: Vec::new(), + } + } + + fn unit_by_name(&self, name: &'static str) -> u8 { + let unit = if let Some(found) = self.names.iter().position(|®_name| reg_name == name) { + found + } else { + // Try to match without the bank prefix. + assert!(name.starts_with(self.prefix)); + let name_without_prefix = &name[self.prefix.len()..]; + if let Some(found) = self + .names + .iter() + .position(|®_name| reg_name == name_without_prefix) + { + found + } else { + // Ultimate try: try to parse a number and use this in the array, eg r15 on x86. + if let Ok(as_num) = name_without_prefix.parse::<u8>() { + assert!( + as_num < self.units, + "trying to get {}, but bank only has {} registers!", + name, + self.units + ); + as_num as usize + } else { + panic!("invalid register name {}", name); + } + } + }; + self.first_unit + (unit as u8) + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)] +pub(crate) struct RegClassIndex(u32); +entity_impl!(RegClassIndex); + +pub(crate) struct RegClass { + pub name: &'static str, + pub index: RegClassIndex, + pub width: u8, + pub bank: RegBankIndex, + pub toprc: RegClassIndex, + pub count: u8, + pub start: u8, + pub subclasses: Vec<RegClassIndex>, +} + +impl RegClass { + pub fn new( + name: &'static str, + index: RegClassIndex, + width: u8, + bank: RegBankIndex, + toprc: RegClassIndex, + count: u8, + start: u8, + ) -> Self { + Self { + name, + index, + width, + bank, + toprc, + count, + start, + subclasses: Vec::new(), + } + } + + /// Compute a bit-mask of subclasses, including self. + pub fn subclass_mask(&self) -> u64 { + let mut m = 1 << self.index.index(); + for rc in self.subclasses.iter() { + m |= 1 << rc.index(); + } + m + } + + /// Compute a bit-mask of the register units allocated by this register class. + pub fn mask(&self, bank_first_unit: u8) -> Vec<u32> { + let mut u = (self.start + bank_first_unit) as usize; + let mut out_mask = vec![0, 0, 0]; + for _ in 0..self.count { + out_mask[u / 32] |= 1 << (u % 32); + u += self.width as usize; + } + out_mask + } +} + +pub(crate) enum RegClassProto { + TopLevel(RegBankIndex), + SubClass(RegClassIndex), +} + +pub(crate) struct RegClassBuilder { + pub name: &'static str, + pub width: u8, + pub count: u8, + pub start: u8, + pub proto: RegClassProto, +} + +impl RegClassBuilder { + pub fn new_toplevel(name: &'static str, bank: RegBankIndex) -> Self { + Self { + name, + width: 1, + count: 0, + start: 0, + proto: RegClassProto::TopLevel(bank), + } + } + pub fn subclass_of( + name: &'static str, + parent_index: RegClassIndex, + start: u8, + stop: u8, + ) -> Self { + assert!(stop >= start); + Self { + name, + width: 0, + count: stop - start, + start, + proto: RegClassProto::SubClass(parent_index), + } + } + pub fn count(mut self, count: u8) -> Self { + self.count = count; + self + } + pub fn width(mut self, width: u8) -> Self { + match self.proto { + RegClassProto::TopLevel(_) => self.width = width, + RegClassProto::SubClass(_) => panic!("Subclasses inherit their parent's width."), + } + self + } +} + +pub(crate) struct RegBankBuilder { + pub name: &'static str, + pub units: u8, + pub names: Vec<&'static str>, + pub prefix: &'static str, + pub pressure_tracking: Option<bool>, + pub pinned_reg: Option<u16>, +} + +impl RegBankBuilder { + pub fn new(name: &'static str, prefix: &'static str) -> Self { + Self { + name, + units: 0, + names: vec![], + prefix, + pressure_tracking: None, + pinned_reg: None, + } + } + pub fn units(mut self, units: u8) -> Self { + self.units = units; + self + } + pub fn names(mut self, names: Vec<&'static str>) -> Self { + self.names = names; + self + } + pub fn track_pressure(mut self, track: bool) -> Self { + self.pressure_tracking = Some(track); + self + } + pub fn pinned_reg(mut self, unit: u16) -> Self { + assert!(unit < u16::from(self.units)); + self.pinned_reg = Some(unit); + self + } +} + +pub(crate) struct IsaRegsBuilder { + pub banks: PrimaryMap<RegBankIndex, RegBank>, + pub classes: PrimaryMap<RegClassIndex, RegClass>, +} + +impl IsaRegsBuilder { + pub fn new() -> Self { + Self { + banks: PrimaryMap::new(), + classes: PrimaryMap::new(), + } + } + + pub fn add_bank(&mut self, builder: RegBankBuilder) -> RegBankIndex { + let first_unit = if self.banks.is_empty() { + 0 + } else { + let last = &self.banks.last().unwrap(); + let first_available_unit = (last.first_unit + last.units) as i8; + let units = builder.units; + let align = if units.is_power_of_two() { + units + } else { + units.next_power_of_two() + } as i8; + (first_available_unit + align - 1) & -align + } as u8; + + self.banks.push(RegBank::new( + builder.name, + first_unit, + builder.units, + builder.names, + builder.prefix, + builder + .pressure_tracking + .expect("Pressure tracking must be explicitly set"), + builder.pinned_reg, + )) + } + + pub fn add_class(&mut self, builder: RegClassBuilder) -> RegClassIndex { + let class_index = self.classes.next_key(); + + // Finish delayed construction of RegClass. + let (bank, toprc, start, width) = match builder.proto { + RegClassProto::TopLevel(bank_index) => { + self.banks + .get_mut(bank_index) + .unwrap() + .toprcs + .push(class_index); + (bank_index, class_index, builder.start, builder.width) + } + RegClassProto::SubClass(parent_class_index) => { + assert!(builder.width == 0); + let (bank, toprc, start, width) = { + let parent = self.classes.get(parent_class_index).unwrap(); + (parent.bank, parent.toprc, parent.start, parent.width) + }; + for reg_class in self.classes.values_mut() { + if reg_class.toprc == toprc { + reg_class.subclasses.push(class_index); + } + } + let subclass_start = start + builder.start * width; + (bank, toprc, subclass_start, width) + } + }; + + let reg_bank_units = self.banks.get(bank).unwrap().units; + assert!(start < reg_bank_units); + + let count = if builder.count != 0 { + builder.count + } else { + reg_bank_units / width + }; + + let reg_class = RegClass::new(builder.name, class_index, width, bank, toprc, count, start); + self.classes.push(reg_class); + + let reg_bank = self.banks.get_mut(bank).unwrap(); + reg_bank.classes.push(class_index); + + class_index + } + + /// Checks that the set of register classes satisfies: + /// + /// 1. Closed under intersection: The intersection of any two register + /// classes in the set is either empty or identical to a member of the + /// set. + /// 2. There are no identical classes under different names. + /// 3. Classes are sorted topologically such that all subclasses have a + /// higher index that the superclass. + pub fn build(self) -> IsaRegs { + for reg_bank in self.banks.values() { + for i1 in reg_bank.classes.iter() { + for i2 in reg_bank.classes.iter() { + if i1 >= i2 { + continue; + } + + let rc1 = self.classes.get(*i1).unwrap(); + let rc2 = self.classes.get(*i2).unwrap(); + + let rc1_mask = rc1.mask(0); + let rc2_mask = rc2.mask(0); + + assert!( + rc1.width != rc2.width || rc1_mask != rc2_mask, + "no duplicates" + ); + if rc1.width != rc2.width { + continue; + } + + let mut intersect = Vec::new(); + for (a, b) in rc1_mask.iter().zip(rc2_mask.iter()) { + intersect.push(a & b); + } + if intersect == vec![0; intersect.len()] { + continue; + } + + // Classes must be topologically ordered, so the intersection can't be the + // superclass. + assert!(intersect != rc1_mask); + + // If the intersection is the second one, then it must be a subclass. + if intersect == rc2_mask { + assert!(self + .classes + .get(*i1) + .unwrap() + .subclasses + .iter() + .any(|x| *x == *i2)); + } + } + } + } + + assert!( + self.classes.len() <= constants::MAX_NUM_REG_CLASSES, + "Too many register classes" + ); + + let num_toplevel = self + .classes + .values() + .filter(|x| x.toprc == x.index && self.banks.get(x.bank).unwrap().pressure_tracking) + .count(); + + assert!( + num_toplevel <= constants::MAX_TRACKED_TOP_RCS, + "Too many top-level register classes" + ); + + IsaRegs::new(self.banks, self.classes) + } +} + +pub(crate) struct IsaRegs { + pub banks: PrimaryMap<RegBankIndex, RegBank>, + pub classes: PrimaryMap<RegClassIndex, RegClass>, +} + +impl IsaRegs { + fn new( + banks: PrimaryMap<RegBankIndex, RegBank>, + classes: PrimaryMap<RegClassIndex, RegClass>, + ) -> Self { + Self { banks, classes } + } + + pub fn class_by_name(&self, name: &str) -> RegClassIndex { + self.classes + .values() + .find(|&class| class.name == name) + .unwrap_or_else(|| panic!("register class {} not found", name)) + .index + } + + pub fn regunit_by_name(&self, class_index: RegClassIndex, name: &'static str) -> u8 { + let bank_index = self.classes.get(class_index).unwrap().bank; + self.banks.get(bank_index).unwrap().unit_by_name(name) + } +} diff --git a/third_party/rust/cranelift-codegen-meta/src/cdsl/settings.rs b/third_party/rust/cranelift-codegen-meta/src/cdsl/settings.rs new file mode 100644 index 0000000000..217bad9955 --- /dev/null +++ b/third_party/rust/cranelift-codegen-meta/src/cdsl/settings.rs @@ -0,0 +1,407 @@ +use std::iter; + +#[derive(Clone, Copy, Hash, PartialEq, Eq)] +pub(crate) struct BoolSettingIndex(usize); + +#[derive(Hash, PartialEq, Eq)] +pub(crate) struct BoolSetting { + pub default: bool, + pub bit_offset: u8, + pub predicate_number: u8, +} + +#[derive(Hash, PartialEq, Eq)] +pub(crate) enum SpecificSetting { + Bool(BoolSetting), + Enum(Vec<&'static str>), + Num(u8), +} + +#[derive(Hash, PartialEq, Eq)] +pub(crate) struct Setting { + pub name: &'static str, + pub comment: &'static str, + pub specific: SpecificSetting, + pub byte_offset: u8, +} + +impl Setting { + pub fn default_byte(&self) -> u8 { + match self.specific { + SpecificSetting::Bool(BoolSetting { + default, + bit_offset, + .. + }) => { + if default { + 1 << bit_offset + } else { + 0 + } + } + SpecificSetting::Enum(_) => 0, + SpecificSetting::Num(default) => default, + } + } + + fn byte_for_value(&self, v: bool) -> u8 { + match self.specific { + SpecificSetting::Bool(BoolSetting { bit_offset, .. }) => { + if v { + 1 << bit_offset + } else { + 0 + } + } + _ => panic!("byte_for_value shouldn't be used for non-boolean settings."), + } + } + + fn byte_mask(&self) -> u8 { + match self.specific { + SpecificSetting::Bool(BoolSetting { bit_offset, .. }) => 1 << bit_offset, + _ => panic!("byte_for_value shouldn't be used for non-boolean settings."), + } + } +} + +#[derive(Hash, PartialEq, Eq)] +pub(crate) struct PresetIndex(usize); + +#[derive(Hash, PartialEq, Eq)] +pub(crate) enum PresetType { + BoolSetting(BoolSettingIndex), + OtherPreset(PresetIndex), +} + +impl Into<PresetType> for BoolSettingIndex { + fn into(self) -> PresetType { + PresetType::BoolSetting(self) + } +} +impl Into<PresetType> for PresetIndex { + fn into(self) -> PresetType { + PresetType::OtherPreset(self) + } +} + +#[derive(Hash, PartialEq, Eq)] +pub(crate) struct Preset { + pub name: &'static str, + values: Vec<BoolSettingIndex>, +} + +impl Preset { + pub fn layout(&self, group: &SettingGroup) -> Vec<(u8, u8)> { + let mut layout: Vec<(u8, u8)> = iter::repeat((0, 0)) + .take(group.settings_size as usize) + .collect(); + for bool_index in &self.values { + let setting = &group.settings[bool_index.0]; + let mask = setting.byte_mask(); + let val = setting.byte_for_value(true); + assert!((val & !mask) == 0); + let (ref mut l_mask, ref mut l_val) = + *layout.get_mut(setting.byte_offset as usize).unwrap(); + *l_mask |= mask; + *l_val = (*l_val & !mask) | val; + } + layout + } +} + +pub(crate) struct SettingGroup { + pub name: &'static str, + pub settings: Vec<Setting>, + pub bool_start_byte_offset: u8, + pub settings_size: u8, + pub presets: Vec<Preset>, + pub predicates: Vec<Predicate>, +} + +impl SettingGroup { + fn num_bool_settings(&self) -> u8 { + self.settings + .iter() + .filter(|s| { + if let SpecificSetting::Bool(_) = s.specific { + true + } else { + false + } + }) + .count() as u8 + } + + pub fn byte_size(&self) -> u8 { + let num_predicates = self.num_bool_settings() + (self.predicates.len() as u8); + self.bool_start_byte_offset + (num_predicates + 7) / 8 + } + + pub fn get_bool(&self, name: &'static str) -> (BoolSettingIndex, &Self) { + for (i, s) in self.settings.iter().enumerate() { + if let SpecificSetting::Bool(_) = s.specific { + if s.name == name { + return (BoolSettingIndex(i), self); + } + } + } + panic!("Should have found bool setting by name."); + } + + pub fn predicate_by_name(&self, name: &'static str) -> SettingPredicateNumber { + self.predicates + .iter() + .find(|pred| pred.name == name) + .unwrap_or_else(|| panic!("unknown predicate {}", name)) + .number + } +} + +/// This is the basic information needed to track the specific parts of a setting when building +/// them. +pub(crate) enum ProtoSpecificSetting { + Bool(bool), + Enum(Vec<&'static str>), + Num(u8), +} + +/// This is the information provided during building for a setting. +struct ProtoSetting { + name: &'static str, + comment: &'static str, + specific: ProtoSpecificSetting, +} + +#[derive(Hash, PartialEq, Eq)] +pub(crate) enum PredicateNode { + OwnedBool(BoolSettingIndex), + SharedBool(&'static str, &'static str), + Not(Box<PredicateNode>), + And(Box<PredicateNode>, Box<PredicateNode>), +} + +impl Into<PredicateNode> for BoolSettingIndex { + fn into(self) -> PredicateNode { + PredicateNode::OwnedBool(self) + } +} +impl<'a> Into<PredicateNode> for (BoolSettingIndex, &'a SettingGroup) { + fn into(self) -> PredicateNode { + let (index, group) = (self.0, self.1); + let setting = &group.settings[index.0]; + PredicateNode::SharedBool(group.name, setting.name) + } +} + +impl PredicateNode { + fn render(&self, group: &SettingGroup) -> String { + match *self { + PredicateNode::OwnedBool(bool_setting_index) => format!( + "{}.{}()", + group.name, group.settings[bool_setting_index.0].name + ), + PredicateNode::SharedBool(ref group_name, ref bool_name) => { + format!("{}.{}()", group_name, bool_name) + } + PredicateNode::And(ref lhs, ref rhs) => { + format!("{} && {}", lhs.render(group), rhs.render(group)) + } + PredicateNode::Not(ref node) => format!("!({})", node.render(group)), + } + } +} + +struct ProtoPredicate { + pub name: &'static str, + node: PredicateNode, +} + +pub(crate) type SettingPredicateNumber = u8; + +pub(crate) struct Predicate { + pub name: &'static str, + node: PredicateNode, + pub number: SettingPredicateNumber, +} + +impl Predicate { + pub fn render(&self, group: &SettingGroup) -> String { + self.node.render(group) + } +} + +pub(crate) struct SettingGroupBuilder { + name: &'static str, + settings: Vec<ProtoSetting>, + presets: Vec<Preset>, + predicates: Vec<ProtoPredicate>, +} + +impl SettingGroupBuilder { + pub fn new(name: &'static str) -> Self { + Self { + name, + settings: Vec::new(), + presets: Vec::new(), + predicates: Vec::new(), + } + } + + fn add_setting( + &mut self, + name: &'static str, + comment: &'static str, + specific: ProtoSpecificSetting, + ) { + self.settings.push(ProtoSetting { + name, + comment, + specific, + }) + } + + pub fn add_bool( + &mut self, + name: &'static str, + comment: &'static str, + default: bool, + ) -> BoolSettingIndex { + assert!( + self.predicates.is_empty(), + "predicates must be added after the boolean settings" + ); + self.add_setting(name, comment, ProtoSpecificSetting::Bool(default)); + BoolSettingIndex(self.settings.len() - 1) + } + + pub fn add_enum( + &mut self, + name: &'static str, + comment: &'static str, + values: Vec<&'static str>, + ) { + self.add_setting(name, comment, ProtoSpecificSetting::Enum(values)); + } + + pub fn add_num(&mut self, name: &'static str, comment: &'static str, default: u8) { + self.add_setting(name, comment, ProtoSpecificSetting::Num(default)); + } + + pub fn add_predicate(&mut self, name: &'static str, node: PredicateNode) { + self.predicates.push(ProtoPredicate { name, node }); + } + + pub fn add_preset(&mut self, name: &'static str, args: Vec<PresetType>) -> PresetIndex { + let mut values = Vec::new(); + for arg in args { + match arg { + PresetType::OtherPreset(index) => { + values.extend(self.presets[index.0].values.iter()); + } + PresetType::BoolSetting(index) => values.push(index), + } + } + self.presets.push(Preset { name, values }); + PresetIndex(self.presets.len() - 1) + } + + /// Compute the layout of the byte vector used to represent this settings + /// group. + /// + /// The byte vector contains the following entries in order: + /// + /// 1. Byte-sized settings like `NumSetting` and `EnumSetting`. + /// 2. `BoolSetting` settings. + /// 3. Precomputed named predicates. + /// 4. Other numbered predicates, including parent predicates that need to be accessible by + /// number. + /// + /// Set `self.settings_size` to the length of the byte vector prefix that + /// contains the settings. All bytes after that are computed, not + /// configured. + /// + /// Set `self.boolean_offset` to the beginning of the numbered predicates, + /// 2. in the list above. + /// + /// Assign `byte_offset` and `bit_offset` fields in all settings. + pub fn build(self) -> SettingGroup { + let mut group = SettingGroup { + name: self.name, + settings: Vec::new(), + bool_start_byte_offset: 0, + settings_size: 0, + presets: Vec::new(), + predicates: Vec::new(), + }; + + let mut byte_offset = 0; + + // Assign the non-boolean settings first. + for s in &self.settings { + let specific = match s.specific { + ProtoSpecificSetting::Bool(..) => continue, + ProtoSpecificSetting::Enum(ref values) => SpecificSetting::Enum(values.clone()), + ProtoSpecificSetting::Num(default) => SpecificSetting::Num(default), + }; + + group.settings.push(Setting { + name: s.name, + comment: s.comment, + byte_offset, + specific, + }); + + byte_offset += 1; + } + + group.bool_start_byte_offset = byte_offset; + + let mut predicate_number = 0; + + // Then the boolean settings. + for s in &self.settings { + let default = match s.specific { + ProtoSpecificSetting::Bool(default) => default, + ProtoSpecificSetting::Enum(_) | ProtoSpecificSetting::Num(_) => continue, + }; + group.settings.push(Setting { + name: s.name, + comment: s.comment, + byte_offset: byte_offset + predicate_number / 8, + specific: SpecificSetting::Bool(BoolSetting { + default, + bit_offset: predicate_number % 8, + predicate_number, + }), + }); + predicate_number += 1; + } + + assert!( + group.predicates.is_empty(), + "settings_size is the byte size before adding predicates" + ); + group.settings_size = group.byte_size(); + + // Sort predicates by name to ensure the same order as the Python code. + let mut predicates = self.predicates; + predicates.sort_by_key(|predicate| predicate.name); + + group + .predicates + .extend(predicates.into_iter().map(|predicate| { + let number = predicate_number; + predicate_number += 1; + Predicate { + name: predicate.name, + node: predicate.node, + number, + } + })); + + group.presets.extend(self.presets); + + group + } +} diff --git a/third_party/rust/cranelift-codegen-meta/src/cdsl/type_inference.rs b/third_party/rust/cranelift-codegen-meta/src/cdsl/type_inference.rs new file mode 100644 index 0000000000..25a07a9b84 --- /dev/null +++ b/third_party/rust/cranelift-codegen-meta/src/cdsl/type_inference.rs @@ -0,0 +1,660 @@ +use crate::cdsl::ast::{Def, DefIndex, DefPool, Var, VarIndex, VarPool}; +use crate::cdsl::typevar::{DerivedFunc, TypeSet, TypeVar}; + +use std::collections::{HashMap, HashSet}; +use std::iter::FromIterator; + +#[derive(Debug, Hash, PartialEq, Eq)] +pub(crate) enum Constraint { + /// Constraint specifying that a type var tv1 must be wider than or equal to type var tv2 at + /// runtime. This requires that: + /// 1) They have the same number of lanes + /// 2) In a lane tv1 has at least as many bits as tv2. + WiderOrEq(TypeVar, TypeVar), + + /// Constraint specifying that two derived type vars must have the same runtime type. + Eq(TypeVar, TypeVar), + + /// Constraint specifying that a type var must belong to some typeset. + InTypeset(TypeVar, TypeSet), +} + +impl Constraint { + fn translate_with<F: Fn(&TypeVar) -> TypeVar>(&self, func: F) -> Constraint { + match self { + Constraint::WiderOrEq(lhs, rhs) => { + let lhs = func(&lhs); + let rhs = func(&rhs); + Constraint::WiderOrEq(lhs, rhs) + } + Constraint::Eq(lhs, rhs) => { + let lhs = func(&lhs); + let rhs = func(&rhs); + Constraint::Eq(lhs, rhs) + } + Constraint::InTypeset(tv, ts) => { + let tv = func(&tv); + Constraint::InTypeset(tv, ts.clone()) + } + } + } + + /// Creates a new constraint by replacing type vars by their hashmap equivalent. + fn translate_with_map( + &self, + original_to_own_typevar: &HashMap<&TypeVar, TypeVar>, + ) -> Constraint { + self.translate_with(|tv| substitute(original_to_own_typevar, tv)) + } + + /// Creates a new constraint by replacing type vars by their canonical equivalent. + fn translate_with_env(&self, type_env: &TypeEnvironment) -> Constraint { + self.translate_with(|tv| type_env.get_equivalent(tv)) + } + + fn is_trivial(&self) -> bool { + match self { + Constraint::WiderOrEq(lhs, rhs) => { + // Trivially true. + if lhs == rhs { + return true; + } + + let ts1 = lhs.get_typeset(); + let ts2 = rhs.get_typeset(); + + // Trivially true. + if ts1.is_wider_or_equal(&ts2) { + return true; + } + + // Trivially false. + if ts1.is_narrower(&ts2) { + return true; + } + + // Trivially false. + if (&ts1.lanes & &ts2.lanes).is_empty() { + return true; + } + + self.is_concrete() + } + Constraint::Eq(lhs, rhs) => lhs == rhs || self.is_concrete(), + Constraint::InTypeset(_, _) => { + // The way InTypeset are made, they would always be trivial if we were applying the + // same logic as the Python code did, so ignore this. + self.is_concrete() + } + } + } + + /// Returns true iff all the referenced type vars are singletons. + fn is_concrete(&self) -> bool { + match self { + Constraint::WiderOrEq(lhs, rhs) => { + lhs.singleton_type().is_some() && rhs.singleton_type().is_some() + } + Constraint::Eq(lhs, rhs) => { + lhs.singleton_type().is_some() && rhs.singleton_type().is_some() + } + Constraint::InTypeset(tv, _) => tv.singleton_type().is_some(), + } + } + + fn typevar_args(&self) -> Vec<&TypeVar> { + match self { + Constraint::WiderOrEq(lhs, rhs) => vec![lhs, rhs], + Constraint::Eq(lhs, rhs) => vec![lhs, rhs], + Constraint::InTypeset(tv, _) => vec![tv], + } + } +} + +#[derive(Clone, Copy)] +enum TypeEnvRank { + Singleton = 5, + Input = 4, + Intermediate = 3, + Output = 2, + Temp = 1, + Internal = 0, +} + +/// Class encapsulating the necessary bookkeeping for type inference. +pub(crate) struct TypeEnvironment { + vars: HashSet<VarIndex>, + ranks: HashMap<TypeVar, TypeEnvRank>, + equivalency_map: HashMap<TypeVar, TypeVar>, + pub constraints: Vec<Constraint>, +} + +impl TypeEnvironment { + fn new() -> Self { + TypeEnvironment { + vars: HashSet::new(), + ranks: HashMap::new(), + equivalency_map: HashMap::new(), + constraints: Vec::new(), + } + } + + fn register(&mut self, var_index: VarIndex, var: &mut Var) { + self.vars.insert(var_index); + let rank = if var.is_input() { + TypeEnvRank::Input + } else if var.is_intermediate() { + TypeEnvRank::Intermediate + } else if var.is_output() { + TypeEnvRank::Output + } else { + assert!(var.is_temp()); + TypeEnvRank::Temp + }; + self.ranks.insert(var.get_or_create_typevar(), rank); + } + + fn add_constraint(&mut self, constraint: Constraint) { + if self.constraints.iter().any(|item| *item == constraint) { + return; + } + + // Check extra conditions for InTypeset constraints. + if let Constraint::InTypeset(tv, _) = &constraint { + assert!( + tv.base.is_none(), + "type variable is {:?}, while expecting none", + tv + ); + assert!( + tv.name.starts_with("typeof_"), + "Name \"{}\" should start with \"typeof_\"", + tv.name + ); + } + + self.constraints.push(constraint); + } + + /// Returns the canonical representative of the equivalency class of the given argument, or + /// duplicates it if it's not there yet. + pub fn get_equivalent(&self, tv: &TypeVar) -> TypeVar { + let mut tv = tv; + while let Some(found) = self.equivalency_map.get(tv) { + tv = found; + } + match &tv.base { + Some(parent) => self + .get_equivalent(&parent.type_var) + .derived(parent.derived_func), + None => tv.clone(), + } + } + + /// Get the rank of tv in the partial order: + /// - TVs directly associated with a Var get their rank from the Var (see register()). + /// - Internally generated non-derived TVs implicitly get the lowest rank (0). + /// - Derived variables get their rank from their free typevar. + /// - Singletons have the highest rank. + /// - TVs associated with vars in a source pattern have a higher rank than TVs associated with + /// temporary vars. + fn rank(&self, tv: &TypeVar) -> u8 { + let actual_tv = match tv.base { + Some(_) => tv.free_typevar(), + None => Some(tv.clone()), + }; + + let rank = match actual_tv { + Some(actual_tv) => match self.ranks.get(&actual_tv) { + Some(rank) => Some(*rank), + None => { + assert!( + !actual_tv.name.starts_with("typeof_"), + format!("variable {} should be explicitly ranked", actual_tv.name) + ); + None + } + }, + None => None, + }; + + let rank = match rank { + Some(rank) => rank, + None => { + if tv.singleton_type().is_some() { + TypeEnvRank::Singleton + } else { + TypeEnvRank::Internal + } + } + }; + + rank as u8 + } + + /// Record the fact that the free tv1 is part of the same equivalence class as tv2. The + /// canonical representative of the merged class is tv2's canonical representative. + fn record_equivalent(&mut self, tv1: TypeVar, tv2: TypeVar) { + assert!(tv1.base.is_none()); + assert!(self.get_equivalent(&tv1) == tv1); + if let Some(tv2_base) = &tv2.base { + // Ensure there are no cycles. + assert!(self.get_equivalent(&tv2_base.type_var) != tv1); + } + self.equivalency_map.insert(tv1, tv2); + } + + /// Get the free typevars in the current type environment. + pub fn free_typevars(&self, var_pool: &mut VarPool) -> Vec<TypeVar> { + let mut typevars = Vec::new(); + typevars.extend(self.equivalency_map.keys().cloned()); + typevars.extend( + self.vars + .iter() + .map(|&var_index| var_pool.get_mut(var_index).get_or_create_typevar()), + ); + + let set: HashSet<TypeVar> = HashSet::from_iter( + typevars + .iter() + .map(|tv| self.get_equivalent(tv).free_typevar()) + .filter(|opt_tv| { + // Filter out singleton types. + opt_tv.is_some() + }) + .map(|tv| tv.unwrap()), + ); + Vec::from_iter(set) + } + + /// Normalize by collapsing any roots that don't correspond to a concrete type var AND have a + /// single type var derived from them or equivalent to them. + /// + /// e.g. if we have a root of the tree that looks like: + /// + /// typeof_a typeof_b + /// \\ / + /// typeof_x + /// | + /// half_width(1) + /// | + /// 1 + /// + /// we want to collapse the linear path between 1 and typeof_x. The resulting graph is: + /// + /// typeof_a typeof_b + /// \\ / + /// typeof_x + fn normalize(&mut self, var_pool: &mut VarPool) { + let source_tvs: HashSet<TypeVar> = HashSet::from_iter( + self.vars + .iter() + .map(|&var_index| var_pool.get_mut(var_index).get_or_create_typevar()), + ); + + let mut children: HashMap<TypeVar, HashSet<TypeVar>> = HashMap::new(); + + // Insert all the parents found by the derivation relationship. + for type_var in self.equivalency_map.values() { + if type_var.base.is_none() { + continue; + } + + let parent_tv = type_var.free_typevar(); + if parent_tv.is_none() { + // Ignore this type variable, it's a singleton. + continue; + } + let parent_tv = parent_tv.unwrap(); + + children + .entry(parent_tv) + .or_insert_with(HashSet::new) + .insert(type_var.clone()); + } + + // Insert all the explicit equivalency links. + for (equivalent_tv, canon_tv) in self.equivalency_map.iter() { + children + .entry(canon_tv.clone()) + .or_insert_with(HashSet::new) + .insert(equivalent_tv.clone()); + } + + // Remove links that are straight paths up to typevar of variables. + for free_root in self.free_typevars(var_pool) { + let mut root = &free_root; + while !source_tvs.contains(&root) + && children.contains_key(&root) + && children.get(&root).unwrap().len() == 1 + { + let child = children.get(&root).unwrap().iter().next().unwrap(); + assert_eq!(self.equivalency_map[child], root.clone()); + self.equivalency_map.remove(child); + root = child; + } + } + } + + /// Extract a clean type environment from self, that only mentions type vars associated with + /// real variables. + fn extract(self, var_pool: &mut VarPool) -> TypeEnvironment { + let vars_tv: HashSet<TypeVar> = HashSet::from_iter( + self.vars + .iter() + .map(|&var_index| var_pool.get_mut(var_index).get_or_create_typevar()), + ); + + let mut new_equivalency_map: HashMap<TypeVar, TypeVar> = HashMap::new(); + for tv in &vars_tv { + let canon_tv = self.get_equivalent(tv); + if *tv != canon_tv { + new_equivalency_map.insert(tv.clone(), canon_tv.clone()); + } + + // Sanity check: the translated type map should only refer to real variables. + assert!(vars_tv.contains(tv)); + let canon_free_tv = canon_tv.free_typevar(); + assert!(canon_free_tv.is_none() || vars_tv.contains(&canon_free_tv.unwrap())); + } + + let mut new_constraints: HashSet<Constraint> = HashSet::new(); + for constraint in &self.constraints { + let constraint = constraint.translate_with_env(&self); + if constraint.is_trivial() || new_constraints.contains(&constraint) { + continue; + } + + // Sanity check: translated constraints should refer only to real variables. + for arg in constraint.typevar_args() { + let arg_free_tv = arg.free_typevar(); + assert!(arg_free_tv.is_none() || vars_tv.contains(&arg_free_tv.unwrap())); + } + + new_constraints.insert(constraint); + } + + TypeEnvironment { + vars: self.vars, + ranks: self.ranks, + equivalency_map: new_equivalency_map, + constraints: Vec::from_iter(new_constraints), + } + } +} + +/// Replaces an external type variable according to the following rules: +/// - if a local copy is present in the map, return it. +/// - or if it's derived, create a local derived one that recursively substitutes the parent. +/// - or return itself. +fn substitute(map: &HashMap<&TypeVar, TypeVar>, external_type_var: &TypeVar) -> TypeVar { + match map.get(&external_type_var) { + Some(own_type_var) => own_type_var.clone(), + None => match &external_type_var.base { + Some(parent) => { + let parent_substitute = substitute(map, &parent.type_var); + TypeVar::derived(&parent_substitute, parent.derived_func) + } + None => external_type_var.clone(), + }, + } +} + +/// Normalize a (potentially derived) typevar using the following rules: +/// +/// - vector and width derived functions commute +/// {HALF,DOUBLE}VECTOR({HALF,DOUBLE}WIDTH(base)) -> +/// {HALF,DOUBLE}WIDTH({HALF,DOUBLE}VECTOR(base)) +/// +/// - half/double pairs collapse +/// {HALF,DOUBLE}WIDTH({DOUBLE,HALF}WIDTH(base)) -> base +/// {HALF,DOUBLE}VECTOR({DOUBLE,HALF}VECTOR(base)) -> base +fn canonicalize_derivations(tv: TypeVar) -> TypeVar { + let base = match &tv.base { + Some(base) => base, + None => return tv, + }; + + let derived_func = base.derived_func; + + if let Some(base_base) = &base.type_var.base { + let base_base_tv = &base_base.type_var; + match (derived_func, base_base.derived_func) { + (DerivedFunc::HalfWidth, DerivedFunc::DoubleWidth) + | (DerivedFunc::DoubleWidth, DerivedFunc::HalfWidth) + | (DerivedFunc::HalfVector, DerivedFunc::DoubleVector) + | (DerivedFunc::DoubleVector, DerivedFunc::HalfVector) => { + // Cancelling bijective transformations. This doesn't hide any overflow issues + // since derived type sets are checked upon derivaion, and base typesets are only + // allowed to shrink. + return canonicalize_derivations(base_base_tv.clone()); + } + (DerivedFunc::HalfWidth, DerivedFunc::HalfVector) + | (DerivedFunc::HalfWidth, DerivedFunc::DoubleVector) + | (DerivedFunc::DoubleWidth, DerivedFunc::DoubleVector) + | (DerivedFunc::DoubleWidth, DerivedFunc::HalfVector) => { + // Arbitrarily put WIDTH derivations before VECTOR derivations, since they commute. + return canonicalize_derivations( + base_base_tv + .derived(derived_func) + .derived(base_base.derived_func), + ); + } + _ => {} + }; + } + + canonicalize_derivations(base.type_var.clone()).derived(derived_func) +} + +/// Given typevars tv1 and tv2 (which could be derived from one another), constrain their typesets +/// to be the same. When one is derived from the other, repeat the constrain process until +/// a fixed point is reached. +fn constrain_fixpoint(tv1: &TypeVar, tv2: &TypeVar) { + loop { + let old_tv1_ts = tv1.get_typeset().clone(); + tv2.constrain_types(tv1.clone()); + if tv1.get_typeset() == old_tv1_ts { + break; + } + } + + let old_tv2_ts = tv2.get_typeset(); + tv1.constrain_types(tv2.clone()); + // The above loop should ensure that all reference cycles have been handled. + assert!(old_tv2_ts == tv2.get_typeset()); +} + +/// Unify tv1 and tv2 in the given type environment. tv1 must have a rank greater or equal to tv2's +/// one, modulo commutations. +fn unify(tv1: &TypeVar, tv2: &TypeVar, type_env: &mut TypeEnvironment) -> Result<(), String> { + let tv1 = canonicalize_derivations(type_env.get_equivalent(tv1)); + let tv2 = canonicalize_derivations(type_env.get_equivalent(tv2)); + + if tv1 == tv2 { + // Already unified. + return Ok(()); + } + + if type_env.rank(&tv2) < type_env.rank(&tv1) { + // Make sure tv1 always has the smallest rank, since real variables have the higher rank + // and we want them to be the canonical representatives of their equivalency classes. + return unify(&tv2, &tv1, type_env); + } + + constrain_fixpoint(&tv1, &tv2); + + if tv1.get_typeset().size() == 0 || tv2.get_typeset().size() == 0 { + return Err(format!( + "Error: empty type created when unifying {} and {}", + tv1.name, tv2.name + )); + } + + let base = match &tv1.base { + Some(base) => base, + None => { + type_env.record_equivalent(tv1, tv2); + return Ok(()); + } + }; + + if let Some(inverse) = base.derived_func.inverse() { + return unify(&base.type_var, &tv2.derived(inverse), type_env); + } + + type_env.add_constraint(Constraint::Eq(tv1, tv2)); + Ok(()) +} + +/// Perform type inference on one Def in the current type environment and return an updated type +/// environment or error. +/// +/// At a high level this works by creating fresh copies of each formal type var in the Def's +/// instruction's signature, and unifying the formal typevar with the corresponding actual typevar. +fn infer_definition( + def: &Def, + var_pool: &mut VarPool, + type_env: TypeEnvironment, + last_type_index: &mut usize, +) -> Result<TypeEnvironment, String> { + let apply = &def.apply; + let inst = &apply.inst; + + let mut type_env = type_env; + let free_formal_tvs = inst.all_typevars(); + + let mut original_to_own_typevar: HashMap<&TypeVar, TypeVar> = HashMap::new(); + for &tv in &free_formal_tvs { + assert!(original_to_own_typevar + .insert( + tv, + TypeVar::copy_from(tv, format!("own_{}", last_type_index)) + ) + .is_none()); + *last_type_index += 1; + } + + // Update the mapping with any explicity bound type vars: + for (i, value_type) in apply.value_types.iter().enumerate() { + let singleton = TypeVar::new_singleton(value_type.clone()); + assert!(original_to_own_typevar + .insert(free_formal_tvs[i], singleton) + .is_some()); + } + + // Get fresh copies for each typevar in the signature (both free and derived). + let mut formal_tvs = Vec::new(); + formal_tvs.extend(inst.value_results.iter().map(|&i| { + substitute( + &original_to_own_typevar, + inst.operands_out[i].type_var().unwrap(), + ) + })); + formal_tvs.extend(inst.value_opnums.iter().map(|&i| { + substitute( + &original_to_own_typevar, + inst.operands_in[i].type_var().unwrap(), + ) + })); + + // Get the list of actual vars. + let mut actual_vars = Vec::new(); + actual_vars.extend(inst.value_results.iter().map(|&i| def.defined_vars[i])); + actual_vars.extend( + inst.value_opnums + .iter() + .map(|&i| apply.args[i].unwrap_var()), + ); + + // Get the list of the actual TypeVars. + let mut actual_tvs = Vec::new(); + for var_index in actual_vars { + let var = var_pool.get_mut(var_index); + type_env.register(var_index, var); + actual_tvs.push(var.get_or_create_typevar()); + } + + // Make sure we start unifying with the control type variable first, by putting it at the + // front of both vectors. + if let Some(poly) = &inst.polymorphic_info { + let own_ctrl_tv = &original_to_own_typevar[&poly.ctrl_typevar]; + let ctrl_index = formal_tvs.iter().position(|tv| tv == own_ctrl_tv).unwrap(); + if ctrl_index != 0 { + formal_tvs.swap(0, ctrl_index); + actual_tvs.swap(0, ctrl_index); + } + } + + // Unify each actual type variable with the corresponding formal type variable. + for (actual_tv, formal_tv) in actual_tvs.iter().zip(&formal_tvs) { + if let Err(msg) = unify(actual_tv, formal_tv, &mut type_env) { + return Err(format!( + "fail ti on {} <: {}: {}", + actual_tv.name, formal_tv.name, msg + )); + } + } + + // Add any instruction specific constraints. + for constraint in &inst.constraints { + type_env.add_constraint(constraint.translate_with_map(&original_to_own_typevar)); + } + + Ok(type_env) +} + +/// Perform type inference on an transformation. Return an updated type environment or error. +pub(crate) fn infer_transform( + src: DefIndex, + dst: &[DefIndex], + def_pool: &DefPool, + var_pool: &mut VarPool, +) -> Result<TypeEnvironment, String> { + let mut type_env = TypeEnvironment::new(); + let mut last_type_index = 0; + + // Execute type inference on the source pattern. + type_env = infer_definition(def_pool.get(src), var_pool, type_env, &mut last_type_index) + .map_err(|err| format!("In src pattern: {}", err))?; + + // Collect the type sets once after applying the source patterm; we'll compare the typesets + // after we've also considered the destination pattern, and will emit supplementary InTypeset + // checks if they don't match. + let src_typesets = type_env + .vars + .iter() + .map(|&var_index| { + let var = var_pool.get_mut(var_index); + let tv = type_env.get_equivalent(&var.get_or_create_typevar()); + (var_index, tv.get_typeset()) + }) + .collect::<Vec<_>>(); + + // Execute type inference on the destination pattern. + for (i, &def_index) in dst.iter().enumerate() { + let def = def_pool.get(def_index); + type_env = infer_definition(def, var_pool, type_env, &mut last_type_index) + .map_err(|err| format!("line {}: {}", i, err))?; + } + + for (var_index, src_typeset) in src_typesets { + let var = var_pool.get(var_index); + if !var.has_free_typevar() { + continue; + } + let tv = type_env.get_equivalent(&var.get_typevar().unwrap()); + let new_typeset = tv.get_typeset(); + assert!( + new_typeset.is_subset(&src_typeset), + "type sets can only get narrower" + ); + if new_typeset != src_typeset { + type_env.add_constraint(Constraint::InTypeset(tv.clone(), new_typeset.clone())); + } + } + + type_env.normalize(var_pool); + + Ok(type_env.extract(var_pool)) +} diff --git a/third_party/rust/cranelift-codegen-meta/src/cdsl/types.rs b/third_party/rust/cranelift-codegen-meta/src/cdsl/types.rs new file mode 100644 index 0000000000..7e03c873db --- /dev/null +++ b/third_party/rust/cranelift-codegen-meta/src/cdsl/types.rs @@ -0,0 +1,587 @@ +//! Cranelift ValueType hierarchy + +use std::fmt; + +use crate::shared::types as shared_types; +use cranelift_codegen_shared::constants; + +// Rust name prefix used for the `rust_name` method. +static _RUST_NAME_PREFIX: &str = "ir::types::"; + +// ValueType variants (i8, i32, ...) are provided in `shared::types.rs`. + +/// A concrete SSA value type. +/// +/// All SSA values have a type that is described by an instance of `ValueType` +/// or one of its subclasses. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub(crate) enum ValueType { + Lane(LaneType), + Reference(ReferenceType), + Special(SpecialType), + Vector(VectorType), +} + +impl ValueType { + /// Iterate through all of the lane types. + pub fn all_lane_types() -> LaneTypeIterator { + LaneTypeIterator::new() + } + + /// Iterate through all of the special types (neither lanes nor vectors). + pub fn all_special_types() -> SpecialTypeIterator { + SpecialTypeIterator::new() + } + + pub fn all_reference_types() -> ReferenceTypeIterator { + ReferenceTypeIterator::new() + } + + /// Return a string containing the documentation comment for this type. + pub fn doc(&self) -> String { + match *self { + ValueType::Lane(l) => l.doc(), + ValueType::Reference(r) => r.doc(), + ValueType::Special(s) => s.doc(), + ValueType::Vector(ref v) => v.doc(), + } + } + + /// Return the number of bits in a lane. + pub fn lane_bits(&self) -> u64 { + match *self { + ValueType::Lane(l) => l.lane_bits(), + ValueType::Reference(r) => r.lane_bits(), + ValueType::Special(s) => s.lane_bits(), + ValueType::Vector(ref v) => v.lane_bits(), + } + } + + /// Return the number of lanes. + pub fn lane_count(&self) -> u64 { + match *self { + ValueType::Vector(ref v) => v.lane_count(), + _ => 1, + } + } + + /// Find the number of bytes that this type occupies in memory. + pub fn membytes(&self) -> u64 { + self.width() / 8 + } + + /// Find the unique number associated with this type. + pub fn number(&self) -> Option<u8> { + match *self { + ValueType::Lane(l) => Some(l.number()), + ValueType::Reference(r) => Some(r.number()), + ValueType::Special(s) => Some(s.number()), + ValueType::Vector(ref v) => Some(v.number()), + } + } + + /// Return the name of this type for generated Rust source files. + pub fn rust_name(&self) -> String { + format!("{}{}", _RUST_NAME_PREFIX, self.to_string().to_uppercase()) + } + + /// Return true iff: + /// 1. self and other have equal number of lanes + /// 2. each lane in self has at least as many bits as a lane in other + pub fn _wider_or_equal(&self, rhs: &ValueType) -> bool { + (self.lane_count() == rhs.lane_count()) && (self.lane_bits() >= rhs.lane_bits()) + } + + /// Return the total number of bits of an instance of this type. + pub fn width(&self) -> u64 { + self.lane_count() * self.lane_bits() + } +} + +impl fmt::Display for ValueType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + ValueType::Lane(l) => l.fmt(f), + ValueType::Reference(r) => r.fmt(f), + ValueType::Special(s) => s.fmt(f), + ValueType::Vector(ref v) => v.fmt(f), + } + } +} + +/// Create a ValueType from a given lane type. +impl From<LaneType> for ValueType { + fn from(lane: LaneType) -> Self { + ValueType::Lane(lane) + } +} + +/// Create a ValueType from a given reference type. +impl From<ReferenceType> for ValueType { + fn from(reference: ReferenceType) -> Self { + ValueType::Reference(reference) + } +} + +/// Create a ValueType from a given special type. +impl From<SpecialType> for ValueType { + fn from(spec: SpecialType) -> Self { + ValueType::Special(spec) + } +} + +/// Create a ValueType from a given vector type. +impl From<VectorType> for ValueType { + fn from(vector: VectorType) -> Self { + ValueType::Vector(vector) + } +} + +/// A concrete scalar type that can appear as a vector lane too. +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub(crate) enum LaneType { + Bool(shared_types::Bool), + Float(shared_types::Float), + Int(shared_types::Int), +} + +impl LaneType { + /// Return a string containing the documentation comment for this lane type. + pub fn doc(self) -> String { + match self { + LaneType::Bool(_) => format!("A boolean type with {} bits.", self.lane_bits()), + LaneType::Float(shared_types::Float::F32) => String::from( + "A 32-bit floating point type represented in the IEEE 754-2008 + *binary32* interchange format. This corresponds to the :c:type:`float` + type in most C implementations.", + ), + LaneType::Float(shared_types::Float::F64) => String::from( + "A 64-bit floating point type represented in the IEEE 754-2008 + *binary64* interchange format. This corresponds to the :c:type:`double` + type in most C implementations.", + ), + LaneType::Int(_) if self.lane_bits() < 32 => format!( + "An integer type with {} bits. + WARNING: arithmetic on {}bit integers is incomplete", + self.lane_bits(), + self.lane_bits() + ), + LaneType::Int(_) => format!("An integer type with {} bits.", self.lane_bits()), + } + } + + /// Return the number of bits in a lane. + pub fn lane_bits(self) -> u64 { + match self { + LaneType::Bool(ref b) => *b as u64, + LaneType::Float(ref f) => *f as u64, + LaneType::Int(ref i) => *i as u64, + } + } + + /// Find the unique number associated with this lane type. + pub fn number(self) -> u8 { + constants::LANE_BASE + + match self { + LaneType::Bool(shared_types::Bool::B1) => 0, + LaneType::Bool(shared_types::Bool::B8) => 1, + LaneType::Bool(shared_types::Bool::B16) => 2, + LaneType::Bool(shared_types::Bool::B32) => 3, + LaneType::Bool(shared_types::Bool::B64) => 4, + LaneType::Bool(shared_types::Bool::B128) => 5, + LaneType::Int(shared_types::Int::I8) => 6, + LaneType::Int(shared_types::Int::I16) => 7, + LaneType::Int(shared_types::Int::I32) => 8, + LaneType::Int(shared_types::Int::I64) => 9, + LaneType::Int(shared_types::Int::I128) => 10, + LaneType::Float(shared_types::Float::F32) => 11, + LaneType::Float(shared_types::Float::F64) => 12, + } + } + + pub fn bool_from_bits(num_bits: u16) -> LaneType { + LaneType::Bool(match num_bits { + 1 => shared_types::Bool::B1, + 8 => shared_types::Bool::B8, + 16 => shared_types::Bool::B16, + 32 => shared_types::Bool::B32, + 64 => shared_types::Bool::B64, + 128 => shared_types::Bool::B128, + _ => unreachable!("unxpected num bits for bool"), + }) + } + + pub fn int_from_bits(num_bits: u16) -> LaneType { + LaneType::Int(match num_bits { + 8 => shared_types::Int::I8, + 16 => shared_types::Int::I16, + 32 => shared_types::Int::I32, + 64 => shared_types::Int::I64, + 128 => shared_types::Int::I128, + _ => unreachable!("unxpected num bits for int"), + }) + } + + pub fn float_from_bits(num_bits: u16) -> LaneType { + LaneType::Float(match num_bits { + 32 => shared_types::Float::F32, + 64 => shared_types::Float::F64, + _ => unreachable!("unxpected num bits for float"), + }) + } + + pub fn by(self, lanes: u16) -> ValueType { + if lanes == 1 { + self.into() + } else { + ValueType::Vector(VectorType::new(self, lanes.into())) + } + } + + pub fn is_float(self) -> bool { + match self { + LaneType::Float(_) => true, + _ => false, + } + } + + pub fn is_int(self) -> bool { + match self { + LaneType::Int(_) => true, + _ => false, + } + } +} + +impl fmt::Display for LaneType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + LaneType::Bool(_) => write!(f, "b{}", self.lane_bits()), + LaneType::Float(_) => write!(f, "f{}", self.lane_bits()), + LaneType::Int(_) => write!(f, "i{}", self.lane_bits()), + } + } +} + +impl fmt::Debug for LaneType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let inner_msg = format!("bits={}", self.lane_bits()); + write!( + f, + "{}", + match *self { + LaneType::Bool(_) => format!("BoolType({})", inner_msg), + LaneType::Float(_) => format!("FloatType({})", inner_msg), + LaneType::Int(_) => format!("IntType({})", inner_msg), + } + ) + } +} + +/// Create a LaneType from a given bool variant. +impl From<shared_types::Bool> for LaneType { + fn from(b: shared_types::Bool) -> Self { + LaneType::Bool(b) + } +} + +/// Create a LaneType from a given float variant. +impl From<shared_types::Float> for LaneType { + fn from(f: shared_types::Float) -> Self { + LaneType::Float(f) + } +} + +/// Create a LaneType from a given int variant. +impl From<shared_types::Int> for LaneType { + fn from(i: shared_types::Int) -> Self { + LaneType::Int(i) + } +} + +/// An iterator for different lane types. +pub(crate) struct LaneTypeIterator { + bool_iter: shared_types::BoolIterator, + int_iter: shared_types::IntIterator, + float_iter: shared_types::FloatIterator, +} + +impl LaneTypeIterator { + /// Create a new lane type iterator. + fn new() -> Self { + Self { + bool_iter: shared_types::BoolIterator::new(), + int_iter: shared_types::IntIterator::new(), + float_iter: shared_types::FloatIterator::new(), + } + } +} + +impl Iterator for LaneTypeIterator { + type Item = LaneType; + fn next(&mut self) -> Option<Self::Item> { + if let Some(b) = self.bool_iter.next() { + Some(LaneType::from(b)) + } else if let Some(i) = self.int_iter.next() { + Some(LaneType::from(i)) + } else if let Some(f) = self.float_iter.next() { + Some(LaneType::from(f)) + } else { + None + } + } +} + +/// A concrete SIMD vector type. +/// +/// A vector type has a lane type which is an instance of `LaneType`, +/// and a positive number of lanes. +#[derive(Clone, PartialEq, Eq, Hash)] +pub(crate) struct VectorType { + base: LaneType, + lanes: u64, +} + +impl VectorType { + /// Initialize a new integer type with `n` bits. + pub fn new(base: LaneType, lanes: u64) -> Self { + Self { base, lanes } + } + + /// Return a string containing the documentation comment for this vector type. + pub fn doc(&self) -> String { + format!( + "A SIMD vector with {} lanes containing a `{}` each.", + self.lane_count(), + self.base + ) + } + + /// Return the number of bits in a lane. + pub fn lane_bits(&self) -> u64 { + self.base.lane_bits() + } + + /// Return the number of lanes. + pub fn lane_count(&self) -> u64 { + self.lanes + } + + /// Return the lane type. + pub fn lane_type(&self) -> LaneType { + self.base + } + + /// Find the unique number associated with this vector type. + /// + /// Vector types are encoded with the lane type in the low 4 bits and + /// log2(lanes) in the high 4 bits, giving a range of 2-256 lanes. + pub fn number(&self) -> u8 { + let lanes_log_2: u32 = 63 - self.lane_count().leading_zeros(); + let base_num = u32::from(self.base.number()); + let num = (lanes_log_2 << 4) + base_num; + num as u8 + } +} + +impl fmt::Display for VectorType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}x{}", self.base, self.lane_count()) + } +} + +impl fmt::Debug for VectorType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "VectorType(base={}, lanes={})", + self.base, + self.lane_count() + ) + } +} + +/// A concrete scalar type that is neither a vector nor a lane type. +/// +/// Special types cannot be used to form vectors. +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub(crate) enum SpecialType { + Flag(shared_types::Flag), + // FIXME remove once the old style backends are removed. + StructArgument, +} + +impl SpecialType { + /// Return a string containing the documentation comment for this special type. + pub fn doc(self) -> String { + match self { + SpecialType::Flag(shared_types::Flag::IFlags) => String::from( + "CPU flags representing the result of an integer comparison. These flags + can be tested with an :type:`intcc` condition code.", + ), + SpecialType::Flag(shared_types::Flag::FFlags) => String::from( + "CPU flags representing the result of a floating point comparison. These + flags can be tested with a :type:`floatcc` condition code.", + ), + SpecialType::StructArgument => { + String::from("After legalization sarg_t arguments will get this type.") + } + } + } + + /// Return the number of bits in a lane. + pub fn lane_bits(self) -> u64 { + match self { + SpecialType::Flag(_) => 0, + SpecialType::StructArgument => 0, + } + } + + /// Find the unique number associated with this special type. + pub fn number(self) -> u8 { + match self { + SpecialType::Flag(shared_types::Flag::IFlags) => 1, + SpecialType::Flag(shared_types::Flag::FFlags) => 2, + SpecialType::StructArgument => 3, + } + } +} + +impl fmt::Display for SpecialType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + SpecialType::Flag(shared_types::Flag::IFlags) => write!(f, "iflags"), + SpecialType::Flag(shared_types::Flag::FFlags) => write!(f, "fflags"), + SpecialType::StructArgument => write!(f, "sarg_t"), + } + } +} + +impl fmt::Debug for SpecialType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}", + match *self { + SpecialType::Flag(_) => format!("FlagsType({})", self), + SpecialType::StructArgument => format!("StructArgument"), + } + ) + } +} + +impl From<shared_types::Flag> for SpecialType { + fn from(f: shared_types::Flag) -> Self { + SpecialType::Flag(f) + } +} + +pub(crate) struct SpecialTypeIterator { + flag_iter: shared_types::FlagIterator, + done: bool, +} + +impl SpecialTypeIterator { + fn new() -> Self { + Self { + flag_iter: shared_types::FlagIterator::new(), + done: false, + } + } +} + +impl Iterator for SpecialTypeIterator { + type Item = SpecialType; + fn next(&mut self) -> Option<Self::Item> { + if let Some(f) = self.flag_iter.next() { + Some(SpecialType::from(f)) + } else { + if !self.done { + self.done = true; + Some(SpecialType::StructArgument) + } else { + None + } + } + } +} + +/// Reference type is scalar type, but not lane type. +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub(crate) struct ReferenceType(pub shared_types::Reference); + +impl ReferenceType { + /// Return a string containing the documentation comment for this reference type. + pub fn doc(self) -> String { + format!("An opaque reference type with {} bits.", self.lane_bits()) + } + + /// Return the number of bits in a lane. + pub fn lane_bits(self) -> u64 { + match self.0 { + shared_types::Reference::R32 => 32, + shared_types::Reference::R64 => 64, + } + } + + /// Find the unique number associated with this reference type. + pub fn number(self) -> u8 { + constants::REFERENCE_BASE + + match self { + ReferenceType(shared_types::Reference::R32) => 0, + ReferenceType(shared_types::Reference::R64) => 1, + } + } + + pub fn ref_from_bits(num_bits: u16) -> ReferenceType { + ReferenceType(match num_bits { + 32 => shared_types::Reference::R32, + 64 => shared_types::Reference::R64, + _ => unreachable!("unexpected number of bits for a reference type"), + }) + } +} + +impl fmt::Display for ReferenceType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "r{}", self.lane_bits()) + } +} + +impl fmt::Debug for ReferenceType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "ReferenceType(bits={})", self.lane_bits()) + } +} + +/// Create a ReferenceType from a given reference variant. +impl From<shared_types::Reference> for ReferenceType { + fn from(r: shared_types::Reference) -> Self { + ReferenceType(r) + } +} + +/// An iterator for different reference types. +pub(crate) struct ReferenceTypeIterator { + reference_iter: shared_types::ReferenceIterator, +} + +impl ReferenceTypeIterator { + /// Create a new reference type iterator. + fn new() -> Self { + Self { + reference_iter: shared_types::ReferenceIterator::new(), + } + } +} + +impl Iterator for ReferenceTypeIterator { + type Item = ReferenceType; + fn next(&mut self) -> Option<Self::Item> { + if let Some(r) = self.reference_iter.next() { + Some(ReferenceType::from(r)) + } else { + None + } + } +} diff --git a/third_party/rust/cranelift-codegen-meta/src/cdsl/typevar.rs b/third_party/rust/cranelift-codegen-meta/src/cdsl/typevar.rs new file mode 100644 index 0000000000..c1027bf847 --- /dev/null +++ b/third_party/rust/cranelift-codegen-meta/src/cdsl/typevar.rs @@ -0,0 +1,1274 @@ +use std::cell::RefCell; +use std::collections::{BTreeSet, HashSet}; +use std::fmt; +use std::hash; +use std::iter::FromIterator; +use std::ops; +use std::rc::Rc; + +use crate::cdsl::types::{LaneType, ReferenceType, SpecialType, ValueType}; + +const MAX_LANES: u16 = 256; +const MAX_BITS: u16 = 128; +const MAX_FLOAT_BITS: u16 = 64; + +/// Type variables can be used in place of concrete types when defining +/// instructions. This makes the instructions *polymorphic*. +/// +/// A type variable is restricted to vary over a subset of the value types. +/// This subset is specified by a set of flags that control the permitted base +/// types and whether the type variable can assume scalar or vector types, or +/// both. +#[derive(Debug)] +pub(crate) struct TypeVarContent { + /// Short name of type variable used in instruction descriptions. + pub name: String, + + /// Documentation string. + pub doc: String, + + /// Type set associated to the type variable. + /// This field must remain private; use `get_typeset()` or `get_raw_typeset()` to get the + /// information you want. + type_set: TypeSet, + + pub base: Option<TypeVarParent>, +} + +#[derive(Clone, Debug)] +pub(crate) struct TypeVar { + content: Rc<RefCell<TypeVarContent>>, +} + +impl TypeVar { + pub fn new(name: impl Into<String>, doc: impl Into<String>, type_set: TypeSet) -> Self { + Self { + content: Rc::new(RefCell::new(TypeVarContent { + name: name.into(), + doc: doc.into(), + type_set, + base: None, + })), + } + } + + pub fn new_singleton(value_type: ValueType) -> Self { + let (name, doc) = (value_type.to_string(), value_type.doc()); + let mut builder = TypeSetBuilder::new(); + + let (scalar_type, num_lanes) = match value_type { + ValueType::Special(special_type) => { + return TypeVar::new(name, doc, builder.specials(vec![special_type]).build()); + } + ValueType::Reference(ReferenceType(reference_type)) => { + let bits = reference_type as RangeBound; + return TypeVar::new(name, doc, builder.refs(bits..bits).build()); + } + ValueType::Lane(lane_type) => (lane_type, 1), + ValueType::Vector(vec_type) => { + (vec_type.lane_type(), vec_type.lane_count() as RangeBound) + } + }; + + builder = builder.simd_lanes(num_lanes..num_lanes); + + let builder = match scalar_type { + LaneType::Int(int_type) => { + let bits = int_type as RangeBound; + builder.ints(bits..bits) + } + LaneType::Float(float_type) => { + let bits = float_type as RangeBound; + builder.floats(bits..bits) + } + LaneType::Bool(bool_type) => { + let bits = bool_type as RangeBound; + builder.bools(bits..bits) + } + }; + TypeVar::new(name, doc, builder.build()) + } + + /// Get a fresh copy of self, named after `name`. Can only be called on non-derived typevars. + pub fn copy_from(other: &TypeVar, name: String) -> TypeVar { + assert!( + other.base.is_none(), + "copy_from() can only be called on non-derived type variables" + ); + TypeVar { + content: Rc::new(RefCell::new(TypeVarContent { + name, + doc: "".into(), + type_set: other.type_set.clone(), + base: None, + })), + } + } + + /// Returns the typeset for this TV. If the TV is derived, computes it recursively from the + /// derived function and the base's typeset. + /// Note this can't be done non-lazily in the constructor, because the TypeSet of the base may + /// change over time. + pub fn get_typeset(&self) -> TypeSet { + match &self.base { + Some(base) => base.type_var.get_typeset().image(base.derived_func), + None => self.type_set.clone(), + } + } + + /// Returns this typevar's type set, assuming this type var has no parent. + pub fn get_raw_typeset(&self) -> &TypeSet { + assert_eq!(self.type_set, self.get_typeset()); + &self.type_set + } + + /// If the associated typeset has a single type return it. Otherwise return None. + pub fn singleton_type(&self) -> Option<ValueType> { + let type_set = self.get_typeset(); + if type_set.size() == 1 { + Some(type_set.get_singleton()) + } else { + None + } + } + + /// Get the free type variable controlling this one. + pub fn free_typevar(&self) -> Option<TypeVar> { + match &self.base { + Some(base) => base.type_var.free_typevar(), + None => { + match self.singleton_type() { + // A singleton type isn't a proper free variable. + Some(_) => None, + None => Some(self.clone()), + } + } + } + } + + /// Create a type variable that is a function of another. + pub fn derived(&self, derived_func: DerivedFunc) -> TypeVar { + let ts = self.get_typeset(); + + // Safety checks to avoid over/underflows. + assert!(ts.specials.is_empty(), "can't derive from special types"); + match derived_func { + DerivedFunc::HalfWidth => { + assert!( + ts.ints.is_empty() || *ts.ints.iter().min().unwrap() > 8, + "can't halve all integer types" + ); + assert!( + ts.floats.is_empty() || *ts.floats.iter().min().unwrap() > 32, + "can't halve all float types" + ); + assert!( + ts.bools.is_empty() || *ts.bools.iter().min().unwrap() > 8, + "can't halve all boolean types" + ); + } + DerivedFunc::DoubleWidth => { + assert!( + ts.ints.is_empty() || *ts.ints.iter().max().unwrap() < MAX_BITS, + "can't double all integer types" + ); + assert!( + ts.floats.is_empty() || *ts.floats.iter().max().unwrap() < MAX_FLOAT_BITS, + "can't double all float types" + ); + assert!( + ts.bools.is_empty() || *ts.bools.iter().max().unwrap() < MAX_BITS, + "can't double all boolean types" + ); + } + DerivedFunc::HalfVector => { + assert!( + *ts.lanes.iter().min().unwrap() > 1, + "can't halve a scalar type" + ); + } + DerivedFunc::DoubleVector => { + assert!( + *ts.lanes.iter().max().unwrap() < MAX_LANES, + "can't double 256 lanes" + ); + } + DerivedFunc::SplitLanes => { + assert!( + ts.ints.is_empty() || *ts.ints.iter().min().unwrap() > 8, + "can't halve all integer types" + ); + assert!( + ts.floats.is_empty() || *ts.floats.iter().min().unwrap() > 32, + "can't halve all float types" + ); + assert!( + ts.bools.is_empty() || *ts.bools.iter().min().unwrap() > 8, + "can't halve all boolean types" + ); + assert!( + *ts.lanes.iter().max().unwrap() < MAX_LANES, + "can't double 256 lanes" + ); + } + DerivedFunc::MergeLanes => { + assert!( + ts.ints.is_empty() || *ts.ints.iter().max().unwrap() < MAX_BITS, + "can't double all integer types" + ); + assert!( + ts.floats.is_empty() || *ts.floats.iter().max().unwrap() < MAX_FLOAT_BITS, + "can't double all float types" + ); + assert!( + ts.bools.is_empty() || *ts.bools.iter().max().unwrap() < MAX_BITS, + "can't double all boolean types" + ); + assert!( + *ts.lanes.iter().min().unwrap() > 1, + "can't halve a scalar type" + ); + } + DerivedFunc::LaneOf | DerivedFunc::AsBool => { /* no particular assertions */ } + } + + TypeVar { + content: Rc::new(RefCell::new(TypeVarContent { + name: format!("{}({})", derived_func.name(), self.name), + doc: "".into(), + type_set: ts, + base: Some(TypeVarParent { + type_var: self.clone(), + derived_func, + }), + })), + } + } + + pub fn lane_of(&self) -> TypeVar { + self.derived(DerivedFunc::LaneOf) + } + pub fn as_bool(&self) -> TypeVar { + self.derived(DerivedFunc::AsBool) + } + pub fn half_width(&self) -> TypeVar { + self.derived(DerivedFunc::HalfWidth) + } + pub fn double_width(&self) -> TypeVar { + self.derived(DerivedFunc::DoubleWidth) + } + pub fn half_vector(&self) -> TypeVar { + self.derived(DerivedFunc::HalfVector) + } + pub fn double_vector(&self) -> TypeVar { + self.derived(DerivedFunc::DoubleVector) + } + pub fn split_lanes(&self) -> TypeVar { + self.derived(DerivedFunc::SplitLanes) + } + pub fn merge_lanes(&self) -> TypeVar { + self.derived(DerivedFunc::MergeLanes) + } + + /// Constrain the range of types this variable can assume to a subset of those in the typeset + /// ts. + /// May mutate itself if it's not derived, or its parent if it is. + pub fn constrain_types_by_ts(&self, type_set: TypeSet) { + match &self.base { + Some(base) => { + base.type_var + .constrain_types_by_ts(type_set.preimage(base.derived_func)); + } + None => { + self.content + .borrow_mut() + .type_set + .inplace_intersect_with(&type_set); + } + } + } + + /// Constrain the range of types this variable can assume to a subset of those `other` can + /// assume. + /// May mutate itself if it's not derived, or its parent if it is. + pub fn constrain_types(&self, other: TypeVar) { + if self == &other { + return; + } + self.constrain_types_by_ts(other.get_typeset()); + } + + /// Get a Rust expression that computes the type of this type variable. + pub fn to_rust_code(&self) -> String { + match &self.base { + Some(base) => format!( + "{}.{}().unwrap()", + base.type_var.to_rust_code(), + base.derived_func.name() + ), + None => { + if let Some(singleton) = self.singleton_type() { + singleton.rust_name() + } else { + self.name.clone() + } + } + } + } +} + +impl Into<TypeVar> for &TypeVar { + fn into(self) -> TypeVar { + self.clone() + } +} +impl Into<TypeVar> for ValueType { + fn into(self) -> TypeVar { + TypeVar::new_singleton(self) + } +} + +// Hash TypeVars by pointers. +// There might be a better way to do this, but since TypeVar's content (namely TypeSet) can be +// mutated, it makes sense to use pointer equality/hashing here. +impl hash::Hash for TypeVar { + fn hash<H: hash::Hasher>(&self, h: &mut H) { + match &self.base { + Some(base) => { + base.type_var.hash(h); + base.derived_func.hash(h); + } + None => { + (&**self as *const TypeVarContent).hash(h); + } + } + } +} + +impl PartialEq for TypeVar { + fn eq(&self, other: &TypeVar) -> bool { + match (&self.base, &other.base) { + (Some(base1), Some(base2)) => { + base1.type_var.eq(&base2.type_var) && base1.derived_func == base2.derived_func + } + (None, None) => Rc::ptr_eq(&self.content, &other.content), + _ => false, + } + } +} + +// Allow TypeVar as map keys, based on pointer equality (see also above PartialEq impl). +impl Eq for TypeVar {} + +impl ops::Deref for TypeVar { + type Target = TypeVarContent; + fn deref(&self) -> &Self::Target { + unsafe { self.content.as_ptr().as_ref().unwrap() } + } +} + +#[derive(Clone, Copy, Debug, Hash, PartialEq)] +pub(crate) enum DerivedFunc { + LaneOf, + AsBool, + HalfWidth, + DoubleWidth, + HalfVector, + DoubleVector, + SplitLanes, + MergeLanes, +} + +impl DerivedFunc { + pub fn name(self) -> &'static str { + match self { + DerivedFunc::LaneOf => "lane_of", + DerivedFunc::AsBool => "as_bool", + DerivedFunc::HalfWidth => "half_width", + DerivedFunc::DoubleWidth => "double_width", + DerivedFunc::HalfVector => "half_vector", + DerivedFunc::DoubleVector => "double_vector", + DerivedFunc::SplitLanes => "split_lanes", + DerivedFunc::MergeLanes => "merge_lanes", + } + } + + /// Returns the inverse function of this one, if it is a bijection. + pub fn inverse(self) -> Option<DerivedFunc> { + match self { + DerivedFunc::HalfWidth => Some(DerivedFunc::DoubleWidth), + DerivedFunc::DoubleWidth => Some(DerivedFunc::HalfWidth), + DerivedFunc::HalfVector => Some(DerivedFunc::DoubleVector), + DerivedFunc::DoubleVector => Some(DerivedFunc::HalfVector), + DerivedFunc::MergeLanes => Some(DerivedFunc::SplitLanes), + DerivedFunc::SplitLanes => Some(DerivedFunc::MergeLanes), + _ => None, + } + } +} + +#[derive(Debug, Hash)] +pub(crate) struct TypeVarParent { + pub type_var: TypeVar, + pub derived_func: DerivedFunc, +} + +/// A set of types. +/// +/// We don't allow arbitrary subsets of types, but use a parametrized approach +/// instead. +/// +/// Objects of this class can be used as dictionary keys. +/// +/// Parametrized type sets are specified in terms of ranges: +/// - The permitted range of vector lanes, where 1 indicates a scalar type. +/// - The permitted range of integer types. +/// - The permitted range of floating point types, and +/// - The permitted range of boolean types. +/// +/// The ranges are inclusive from smallest bit-width to largest bit-width. +/// +/// Finally, a type set can contain special types (derived from `SpecialType`) +/// which can't appear as lane types. + +type RangeBound = u16; +type Range = ops::Range<RangeBound>; +type NumSet = BTreeSet<RangeBound>; + +macro_rules! num_set { + ($($expr:expr),*) => { + NumSet::from_iter(vec![$($expr),*]) + }; +} + +#[derive(Clone, PartialEq, Eq, Hash)] +pub(crate) struct TypeSet { + pub lanes: NumSet, + pub ints: NumSet, + pub floats: NumSet, + pub bools: NumSet, + pub refs: NumSet, + pub specials: Vec<SpecialType>, +} + +impl TypeSet { + fn new( + lanes: NumSet, + ints: NumSet, + floats: NumSet, + bools: NumSet, + refs: NumSet, + specials: Vec<SpecialType>, + ) -> Self { + Self { + lanes, + ints, + floats, + bools, + refs, + specials, + } + } + + /// Return the number of concrete types represented by this typeset. + pub fn size(&self) -> usize { + self.lanes.len() + * (self.ints.len() + self.floats.len() + self.bools.len() + self.refs.len()) + + self.specials.len() + } + + /// Return the image of self across the derived function func. + fn image(&self, derived_func: DerivedFunc) -> TypeSet { + match derived_func { + DerivedFunc::LaneOf => self.lane_of(), + DerivedFunc::AsBool => self.as_bool(), + DerivedFunc::HalfWidth => self.half_width(), + DerivedFunc::DoubleWidth => self.double_width(), + DerivedFunc::HalfVector => self.half_vector(), + DerivedFunc::DoubleVector => self.double_vector(), + DerivedFunc::SplitLanes => self.half_width().double_vector(), + DerivedFunc::MergeLanes => self.double_width().half_vector(), + } + } + + /// Return a TypeSet describing the image of self across lane_of. + fn lane_of(&self) -> TypeSet { + let mut copy = self.clone(); + copy.lanes = num_set![1]; + copy + } + + /// Return a TypeSet describing the image of self across as_bool. + fn as_bool(&self) -> TypeSet { + let mut copy = self.clone(); + copy.ints = NumSet::new(); + copy.floats = NumSet::new(); + copy.refs = NumSet::new(); + if !(&self.lanes - &num_set![1]).is_empty() { + copy.bools = &self.ints | &self.floats; + copy.bools = ©.bools | &self.bools; + } + if self.lanes.contains(&1) { + copy.bools.insert(1); + } + copy + } + + /// Return a TypeSet describing the image of self across halfwidth. + fn half_width(&self) -> TypeSet { + let mut copy = self.clone(); + copy.ints = NumSet::from_iter(self.ints.iter().filter(|&&x| x > 8).map(|&x| x / 2)); + copy.floats = NumSet::from_iter(self.floats.iter().filter(|&&x| x > 32).map(|&x| x / 2)); + copy.bools = NumSet::from_iter(self.bools.iter().filter(|&&x| x > 8).map(|&x| x / 2)); + copy.specials = Vec::new(); + copy + } + + /// Return a TypeSet describing the image of self across doublewidth. + fn double_width(&self) -> TypeSet { + let mut copy = self.clone(); + copy.ints = NumSet::from_iter(self.ints.iter().filter(|&&x| x < MAX_BITS).map(|&x| x * 2)); + copy.floats = NumSet::from_iter( + self.floats + .iter() + .filter(|&&x| x < MAX_FLOAT_BITS) + .map(|&x| x * 2), + ); + copy.bools = NumSet::from_iter( + self.bools + .iter() + .filter(|&&x| x < MAX_BITS) + .map(|&x| x * 2) + .filter(|x| legal_bool(*x)), + ); + copy.specials = Vec::new(); + copy + } + + /// Return a TypeSet describing the image of self across halfvector. + fn half_vector(&self) -> TypeSet { + let mut copy = self.clone(); + copy.lanes = NumSet::from_iter(self.lanes.iter().filter(|&&x| x > 1).map(|&x| x / 2)); + copy.specials = Vec::new(); + copy + } + + /// Return a TypeSet describing the image of self across doublevector. + fn double_vector(&self) -> TypeSet { + let mut copy = self.clone(); + copy.lanes = NumSet::from_iter( + self.lanes + .iter() + .filter(|&&x| x < MAX_LANES) + .map(|&x| x * 2), + ); + copy.specials = Vec::new(); + copy + } + + fn concrete_types(&self) -> Vec<ValueType> { + let mut ret = Vec::new(); + for &num_lanes in &self.lanes { + for &bits in &self.ints { + ret.push(LaneType::int_from_bits(bits).by(num_lanes)); + } + for &bits in &self.floats { + ret.push(LaneType::float_from_bits(bits).by(num_lanes)); + } + for &bits in &self.bools { + ret.push(LaneType::bool_from_bits(bits).by(num_lanes)); + } + for &bits in &self.refs { + ret.push(ReferenceType::ref_from_bits(bits).into()); + } + } + for &special in &self.specials { + ret.push(special.into()); + } + ret + } + + /// Return the singleton type represented by self. Can only call on typesets containing 1 type. + fn get_singleton(&self) -> ValueType { + let mut types = self.concrete_types(); + assert_eq!(types.len(), 1); + types.remove(0) + } + + /// Return the inverse image of self across the derived function func. + fn preimage(&self, func: DerivedFunc) -> TypeSet { + if self.size() == 0 { + // The inverse of the empty set is itself. + return self.clone(); + } + + match func { + DerivedFunc::LaneOf => { + let mut copy = self.clone(); + copy.lanes = + NumSet::from_iter((0..=MAX_LANES.trailing_zeros()).map(|i| u16::pow(2, i))); + copy + } + DerivedFunc::AsBool => { + let mut copy = self.clone(); + if self.bools.contains(&1) { + copy.ints = NumSet::from_iter(vec![8, 16, 32, 64, 128]); + copy.floats = NumSet::from_iter(vec![32, 64]); + } else { + copy.ints = &self.bools - &NumSet::from_iter(vec![1]); + copy.floats = &self.bools & &NumSet::from_iter(vec![32, 64]); + // If b1 is not in our typeset, than lanes=1 cannot be in the pre-image, as + // as_bool() of scalars is always b1. + copy.lanes = &self.lanes - &NumSet::from_iter(vec![1]); + } + copy + } + DerivedFunc::HalfWidth => self.double_width(), + DerivedFunc::DoubleWidth => self.half_width(), + DerivedFunc::HalfVector => self.double_vector(), + DerivedFunc::DoubleVector => self.half_vector(), + DerivedFunc::SplitLanes => self.double_width().half_vector(), + DerivedFunc::MergeLanes => self.half_width().double_vector(), + } + } + + pub fn inplace_intersect_with(&mut self, other: &TypeSet) { + self.lanes = &self.lanes & &other.lanes; + self.ints = &self.ints & &other.ints; + self.floats = &self.floats & &other.floats; + self.bools = &self.bools & &other.bools; + self.refs = &self.refs & &other.refs; + + let mut new_specials = Vec::new(); + for spec in &self.specials { + if let Some(spec) = other.specials.iter().find(|&other_spec| other_spec == spec) { + new_specials.push(*spec); + } + } + self.specials = new_specials; + } + + pub fn is_subset(&self, other: &TypeSet) -> bool { + self.lanes.is_subset(&other.lanes) + && self.ints.is_subset(&other.ints) + && self.floats.is_subset(&other.floats) + && self.bools.is_subset(&other.bools) + && self.refs.is_subset(&other.refs) + && { + let specials: HashSet<SpecialType> = HashSet::from_iter(self.specials.clone()); + let other_specials = HashSet::from_iter(other.specials.clone()); + specials.is_subset(&other_specials) + } + } + + pub fn is_wider_or_equal(&self, other: &TypeSet) -> bool { + set_wider_or_equal(&self.ints, &other.ints) + && set_wider_or_equal(&self.floats, &other.floats) + && set_wider_or_equal(&self.bools, &other.bools) + && set_wider_or_equal(&self.refs, &other.refs) + } + + pub fn is_narrower(&self, other: &TypeSet) -> bool { + set_narrower(&self.ints, &other.ints) + && set_narrower(&self.floats, &other.floats) + && set_narrower(&self.bools, &other.bools) + && set_narrower(&self.refs, &other.refs) + } +} + +fn set_wider_or_equal(s1: &NumSet, s2: &NumSet) -> bool { + !s1.is_empty() && !s2.is_empty() && s1.iter().min() >= s2.iter().max() +} + +fn set_narrower(s1: &NumSet, s2: &NumSet) -> bool { + !s1.is_empty() && !s2.is_empty() && s1.iter().min() < s2.iter().max() +} + +impl fmt::Debug for TypeSet { + fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { + write!(fmt, "TypeSet(")?; + + let mut subsets = Vec::new(); + if !self.lanes.is_empty() { + subsets.push(format!( + "lanes={{{}}}", + Vec::from_iter(self.lanes.iter().map(|x| x.to_string())).join(", ") + )); + } + if !self.ints.is_empty() { + subsets.push(format!( + "ints={{{}}}", + Vec::from_iter(self.ints.iter().map(|x| x.to_string())).join(", ") + )); + } + if !self.floats.is_empty() { + subsets.push(format!( + "floats={{{}}}", + Vec::from_iter(self.floats.iter().map(|x| x.to_string())).join(", ") + )); + } + if !self.bools.is_empty() { + subsets.push(format!( + "bools={{{}}}", + Vec::from_iter(self.bools.iter().map(|x| x.to_string())).join(", ") + )); + } + if !self.refs.is_empty() { + subsets.push(format!( + "refs={{{}}}", + Vec::from_iter(self.refs.iter().map(|x| x.to_string())).join(", ") + )); + } + if !self.specials.is_empty() { + subsets.push(format!( + "specials={{{}}}", + Vec::from_iter(self.specials.iter().map(|x| x.to_string())).join(", ") + )); + } + + write!(fmt, "{})", subsets.join(", "))?; + Ok(()) + } +} + +pub(crate) struct TypeSetBuilder { + ints: Interval, + floats: Interval, + bools: Interval, + refs: Interval, + includes_scalars: bool, + simd_lanes: Interval, + specials: Vec<SpecialType>, +} + +impl TypeSetBuilder { + pub fn new() -> Self { + Self { + ints: Interval::None, + floats: Interval::None, + bools: Interval::None, + refs: Interval::None, + includes_scalars: true, + simd_lanes: Interval::None, + specials: Vec::new(), + } + } + + pub fn ints(mut self, interval: impl Into<Interval>) -> Self { + assert!(self.ints == Interval::None); + self.ints = interval.into(); + self + } + pub fn floats(mut self, interval: impl Into<Interval>) -> Self { + assert!(self.floats == Interval::None); + self.floats = interval.into(); + self + } + pub fn bools(mut self, interval: impl Into<Interval>) -> Self { + assert!(self.bools == Interval::None); + self.bools = interval.into(); + self + } + pub fn refs(mut self, interval: impl Into<Interval>) -> Self { + assert!(self.refs == Interval::None); + self.refs = interval.into(); + self + } + pub fn includes_scalars(mut self, includes_scalars: bool) -> Self { + self.includes_scalars = includes_scalars; + self + } + pub fn simd_lanes(mut self, interval: impl Into<Interval>) -> Self { + assert!(self.simd_lanes == Interval::None); + self.simd_lanes = interval.into(); + self + } + pub fn specials(mut self, specials: Vec<SpecialType>) -> Self { + assert!(self.specials.is_empty()); + self.specials = specials; + self + } + + pub fn build(self) -> TypeSet { + let min_lanes = if self.includes_scalars { 1 } else { 2 }; + + let bools = range_to_set(self.bools.to_range(1..MAX_BITS, None)) + .into_iter() + .filter(|x| legal_bool(*x)) + .collect(); + + TypeSet::new( + range_to_set(self.simd_lanes.to_range(min_lanes..MAX_LANES, Some(1))), + range_to_set(self.ints.to_range(8..MAX_BITS, None)), + range_to_set(self.floats.to_range(32..64, None)), + bools, + range_to_set(self.refs.to_range(32..64, None)), + self.specials, + ) + } + + pub fn all() -> TypeSet { + TypeSetBuilder::new() + .ints(Interval::All) + .floats(Interval::All) + .bools(Interval::All) + .refs(Interval::All) + .simd_lanes(Interval::All) + .specials(ValueType::all_special_types().collect()) + .includes_scalars(true) + .build() + } +} + +#[derive(PartialEq)] +pub(crate) enum Interval { + None, + All, + Range(Range), +} + +impl Interval { + fn to_range(&self, full_range: Range, default: Option<RangeBound>) -> Option<Range> { + match self { + Interval::None => { + if let Some(default_val) = default { + Some(default_val..default_val) + } else { + None + } + } + + Interval::All => Some(full_range), + + Interval::Range(range) => { + let (low, high) = (range.start, range.end); + assert!(low.is_power_of_two()); + assert!(high.is_power_of_two()); + assert!(low <= high); + assert!(low >= full_range.start); + assert!(high <= full_range.end); + Some(low..high) + } + } + } +} + +impl Into<Interval> for Range { + fn into(self) -> Interval { + Interval::Range(self) + } +} + +fn legal_bool(bits: RangeBound) -> bool { + // Only allow legal bit widths for bool types. + bits == 1 || (bits >= 8 && bits <= MAX_BITS && bits.is_power_of_two()) +} + +/// Generates a set with all the powers of two included in the range. +fn range_to_set(range: Option<Range>) -> NumSet { + let mut set = NumSet::new(); + + let (low, high) = match range { + Some(range) => (range.start, range.end), + None => return set, + }; + + assert!(low.is_power_of_two()); + assert!(high.is_power_of_two()); + assert!(low <= high); + + for i in low.trailing_zeros()..=high.trailing_zeros() { + assert!(1 << i <= RangeBound::max_value()); + set.insert(1 << i); + } + set +} + +#[test] +fn test_typevar_builder() { + let type_set = TypeSetBuilder::new().ints(Interval::All).build(); + assert_eq!(type_set.lanes, num_set![1]); + assert!(type_set.floats.is_empty()); + assert_eq!(type_set.ints, num_set![8, 16, 32, 64, 128]); + assert!(type_set.bools.is_empty()); + assert!(type_set.specials.is_empty()); + + let type_set = TypeSetBuilder::new().bools(Interval::All).build(); + assert_eq!(type_set.lanes, num_set![1]); + assert!(type_set.floats.is_empty()); + assert!(type_set.ints.is_empty()); + assert_eq!(type_set.bools, num_set![1, 8, 16, 32, 64, 128]); + assert!(type_set.specials.is_empty()); + + let type_set = TypeSetBuilder::new().floats(Interval::All).build(); + assert_eq!(type_set.lanes, num_set![1]); + assert_eq!(type_set.floats, num_set![32, 64]); + assert!(type_set.ints.is_empty()); + assert!(type_set.bools.is_empty()); + assert!(type_set.specials.is_empty()); + + let type_set = TypeSetBuilder::new() + .floats(Interval::All) + .simd_lanes(Interval::All) + .includes_scalars(false) + .build(); + assert_eq!(type_set.lanes, num_set![2, 4, 8, 16, 32, 64, 128, 256]); + assert_eq!(type_set.floats, num_set![32, 64]); + assert!(type_set.ints.is_empty()); + assert!(type_set.bools.is_empty()); + assert!(type_set.specials.is_empty()); + + let type_set = TypeSetBuilder::new() + .floats(Interval::All) + .simd_lanes(Interval::All) + .includes_scalars(true) + .build(); + assert_eq!(type_set.lanes, num_set![1, 2, 4, 8, 16, 32, 64, 128, 256]); + assert_eq!(type_set.floats, num_set![32, 64]); + assert!(type_set.ints.is_empty()); + assert!(type_set.bools.is_empty()); + assert!(type_set.specials.is_empty()); + + let type_set = TypeSetBuilder::new().ints(16..64).build(); + assert_eq!(type_set.lanes, num_set![1]); + assert_eq!(type_set.ints, num_set![16, 32, 64]); + assert!(type_set.floats.is_empty()); + assert!(type_set.bools.is_empty()); + assert!(type_set.specials.is_empty()); +} + +#[test] +#[should_panic] +fn test_typevar_builder_too_high_bound_panic() { + TypeSetBuilder::new().ints(16..2 * MAX_BITS).build(); +} + +#[test] +#[should_panic] +fn test_typevar_builder_inverted_bounds_panic() { + TypeSetBuilder::new().ints(32..16).build(); +} + +#[test] +fn test_as_bool() { + let a = TypeSetBuilder::new() + .simd_lanes(2..8) + .ints(8..8) + .floats(32..32) + .build(); + assert_eq!( + a.lane_of(), + TypeSetBuilder::new().ints(8..8).floats(32..32).build() + ); + + // Test as_bool with disjoint intervals. + let mut a_as_bool = TypeSetBuilder::new().simd_lanes(2..8).build(); + a_as_bool.bools = num_set![8, 32]; + assert_eq!(a.as_bool(), a_as_bool); + + let b = TypeSetBuilder::new() + .simd_lanes(1..8) + .ints(8..8) + .floats(32..32) + .build(); + let mut b_as_bool = TypeSetBuilder::new().simd_lanes(1..8).build(); + b_as_bool.bools = num_set![1, 8, 32]; + assert_eq!(b.as_bool(), b_as_bool); +} + +#[test] +fn test_forward_images() { + let empty_set = TypeSetBuilder::new().build(); + + // Half vector. + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(1..32) + .build() + .half_vector(), + TypeSetBuilder::new().simd_lanes(1..16).build() + ); + + // Double vector. + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(1..32) + .build() + .double_vector(), + TypeSetBuilder::new().simd_lanes(2..64).build() + ); + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(128..256) + .build() + .double_vector(), + TypeSetBuilder::new().simd_lanes(256..256).build() + ); + + // Half width. + assert_eq!( + TypeSetBuilder::new().ints(8..32).build().half_width(), + TypeSetBuilder::new().ints(8..16).build() + ); + assert_eq!( + TypeSetBuilder::new().floats(32..32).build().half_width(), + empty_set + ); + assert_eq!( + TypeSetBuilder::new().floats(32..64).build().half_width(), + TypeSetBuilder::new().floats(32..32).build() + ); + assert_eq!( + TypeSetBuilder::new().bools(1..8).build().half_width(), + empty_set + ); + assert_eq!( + TypeSetBuilder::new().bools(1..32).build().half_width(), + TypeSetBuilder::new().bools(8..16).build() + ); + + // Double width. + assert_eq!( + TypeSetBuilder::new().ints(8..32).build().double_width(), + TypeSetBuilder::new().ints(16..64).build() + ); + assert_eq!( + TypeSetBuilder::new().ints(32..64).build().double_width(), + TypeSetBuilder::new().ints(64..128).build() + ); + assert_eq!( + TypeSetBuilder::new().floats(32..32).build().double_width(), + TypeSetBuilder::new().floats(64..64).build() + ); + assert_eq!( + TypeSetBuilder::new().floats(32..64).build().double_width(), + TypeSetBuilder::new().floats(64..64).build() + ); + assert_eq!( + TypeSetBuilder::new().bools(1..16).build().double_width(), + TypeSetBuilder::new().bools(16..32).build() + ); + assert_eq!( + TypeSetBuilder::new().bools(32..64).build().double_width(), + TypeSetBuilder::new().bools(64..128).build() + ); +} + +#[test] +fn test_backward_images() { + let empty_set = TypeSetBuilder::new().build(); + + // LaneOf. + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(1..1) + .ints(8..8) + .floats(32..32) + .build() + .preimage(DerivedFunc::LaneOf), + TypeSetBuilder::new() + .simd_lanes(Interval::All) + .ints(8..8) + .floats(32..32) + .build() + ); + assert_eq!(empty_set.preimage(DerivedFunc::LaneOf), empty_set); + + // AsBool. + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(1..4) + .bools(1..128) + .build() + .preimage(DerivedFunc::AsBool), + TypeSetBuilder::new() + .simd_lanes(1..4) + .ints(Interval::All) + .bools(Interval::All) + .floats(Interval::All) + .build() + ); + + // Double vector. + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(1..1) + .ints(8..8) + .build() + .preimage(DerivedFunc::DoubleVector) + .size(), + 0 + ); + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(1..16) + .ints(8..16) + .floats(32..32) + .build() + .preimage(DerivedFunc::DoubleVector), + TypeSetBuilder::new() + .simd_lanes(1..8) + .ints(8..16) + .floats(32..32) + .build(), + ); + + // Half vector. + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(256..256) + .ints(8..8) + .build() + .preimage(DerivedFunc::HalfVector) + .size(), + 0 + ); + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(64..128) + .bools(1..32) + .build() + .preimage(DerivedFunc::HalfVector), + TypeSetBuilder::new() + .simd_lanes(128..256) + .bools(1..32) + .build(), + ); + + // Half width. + assert_eq!( + TypeSetBuilder::new() + .ints(128..128) + .floats(64..64) + .bools(128..128) + .build() + .preimage(DerivedFunc::HalfWidth) + .size(), + 0 + ); + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(64..256) + .bools(1..64) + .build() + .preimage(DerivedFunc::HalfWidth), + TypeSetBuilder::new() + .simd_lanes(64..256) + .bools(16..128) + .build(), + ); + + // Double width. + assert_eq!( + TypeSetBuilder::new() + .ints(8..8) + .floats(32..32) + .bools(1..8) + .build() + .preimage(DerivedFunc::DoubleWidth) + .size(), + 0 + ); + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(1..16) + .ints(8..16) + .floats(32..64) + .build() + .preimage(DerivedFunc::DoubleWidth), + TypeSetBuilder::new() + .simd_lanes(1..16) + .ints(8..8) + .floats(32..32) + .build() + ); +} + +#[test] +#[should_panic] +fn test_typeset_singleton_panic_nonsingleton_types() { + TypeSetBuilder::new() + .ints(8..8) + .floats(32..32) + .build() + .get_singleton(); +} + +#[test] +#[should_panic] +fn test_typeset_singleton_panic_nonsingleton_lanes() { + TypeSetBuilder::new() + .simd_lanes(1..2) + .floats(32..32) + .build() + .get_singleton(); +} + +#[test] +fn test_typeset_singleton() { + use crate::shared::types as shared_types; + assert_eq!( + TypeSetBuilder::new().ints(16..16).build().get_singleton(), + ValueType::Lane(shared_types::Int::I16.into()) + ); + assert_eq!( + TypeSetBuilder::new().floats(64..64).build().get_singleton(), + ValueType::Lane(shared_types::Float::F64.into()) + ); + assert_eq!( + TypeSetBuilder::new().bools(1..1).build().get_singleton(), + ValueType::Lane(shared_types::Bool::B1.into()) + ); + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(4..4) + .ints(32..32) + .build() + .get_singleton(), + LaneType::from(shared_types::Int::I32).by(4) + ); +} + +#[test] +fn test_typevar_functions() { + let x = TypeVar::new( + "x", + "i16 and up", + TypeSetBuilder::new().ints(16..64).build(), + ); + assert_eq!(x.half_width().name, "half_width(x)"); + assert_eq!( + x.half_width().double_width().name, + "double_width(half_width(x))" + ); + + let x = TypeVar::new("x", "up to i32", TypeSetBuilder::new().ints(8..32).build()); + assert_eq!(x.double_width().name, "double_width(x)"); +} + +#[test] +fn test_typevar_singleton() { + use crate::cdsl::types::VectorType; + use crate::shared::types as shared_types; + + // Test i32. + let typevar = TypeVar::new_singleton(ValueType::Lane(LaneType::Int(shared_types::Int::I32))); + assert_eq!(typevar.name, "i32"); + assert_eq!(typevar.type_set.ints, num_set![32]); + assert!(typevar.type_set.floats.is_empty()); + assert!(typevar.type_set.bools.is_empty()); + assert!(typevar.type_set.specials.is_empty()); + assert_eq!(typevar.type_set.lanes, num_set![1]); + + // Test f32x4. + let typevar = TypeVar::new_singleton(ValueType::Vector(VectorType::new( + LaneType::Float(shared_types::Float::F32), + 4, + ))); + assert_eq!(typevar.name, "f32x4"); + assert!(typevar.type_set.ints.is_empty()); + assert_eq!(typevar.type_set.floats, num_set![32]); + assert_eq!(typevar.type_set.lanes, num_set![4]); + assert!(typevar.type_set.bools.is_empty()); + assert!(typevar.type_set.specials.is_empty()); +} diff --git a/third_party/rust/cranelift-codegen-meta/src/cdsl/xform.rs b/third_party/rust/cranelift-codegen-meta/src/cdsl/xform.rs new file mode 100644 index 0000000000..d21e93128d --- /dev/null +++ b/third_party/rust/cranelift-codegen-meta/src/cdsl/xform.rs @@ -0,0 +1,484 @@ +use crate::cdsl::ast::{ + Apply, BlockPool, ConstPool, DefIndex, DefPool, DummyDef, DummyExpr, Expr, PatternPosition, + VarIndex, VarPool, +}; +use crate::cdsl::instructions::Instruction; +use crate::cdsl::type_inference::{infer_transform, TypeEnvironment}; +use crate::cdsl::typevar::TypeVar; + +use cranelift_entity::{entity_impl, PrimaryMap}; + +use std::collections::{HashMap, HashSet}; +use std::iter::FromIterator; + +/// An instruction transformation consists of a source and destination pattern. +/// +/// Patterns are expressed in *register transfer language* as tuples of Def or Expr nodes. A +/// pattern may optionally have a sequence of TypeConstraints, that additionally limit the set of +/// cases when it applies. +/// +/// The source pattern can contain only a single instruction. +pub(crate) struct Transform { + pub src: DefIndex, + pub dst: Vec<DefIndex>, + pub var_pool: VarPool, + pub def_pool: DefPool, + pub block_pool: BlockPool, + pub const_pool: ConstPool, + pub type_env: TypeEnvironment, +} + +type SymbolTable = HashMap<String, VarIndex>; + +impl Transform { + fn new(src: DummyDef, dst: Vec<DummyDef>) -> Self { + let mut var_pool = VarPool::new(); + let mut def_pool = DefPool::new(); + let mut block_pool = BlockPool::new(); + let mut const_pool = ConstPool::new(); + + let mut input_vars: Vec<VarIndex> = Vec::new(); + let mut defined_vars: Vec<VarIndex> = Vec::new(); + + // Maps variable names to our own Var copies. + let mut symbol_table: SymbolTable = SymbolTable::new(); + + // Rewrite variables in src and dst using our own copies. + let src = rewrite_def_list( + PatternPosition::Source, + vec![src], + &mut symbol_table, + &mut input_vars, + &mut defined_vars, + &mut var_pool, + &mut def_pool, + &mut block_pool, + &mut const_pool, + )[0]; + + let num_src_inputs = input_vars.len(); + + let dst = rewrite_def_list( + PatternPosition::Destination, + dst, + &mut symbol_table, + &mut input_vars, + &mut defined_vars, + &mut var_pool, + &mut def_pool, + &mut block_pool, + &mut const_pool, + ); + + // Sanity checks. + for &var_index in &input_vars { + assert!( + var_pool.get(var_index).is_input(), + format!("'{:?}' used as both input and def", var_pool.get(var_index)) + ); + } + assert!( + input_vars.len() == num_src_inputs, + format!( + "extra input vars in dst pattern: {:?}", + input_vars + .iter() + .map(|&i| var_pool.get(i)) + .skip(num_src_inputs) + .collect::<Vec<_>>() + ) + ); + + // Perform type inference and cleanup. + let type_env = infer_transform(src, &dst, &def_pool, &mut var_pool).unwrap(); + + // Sanity check: the set of inferred free type variables should be a subset of the type + // variables corresponding to Vars appearing in the source pattern. + { + let free_typevars: HashSet<TypeVar> = + HashSet::from_iter(type_env.free_typevars(&mut var_pool)); + let src_tvs = HashSet::from_iter( + input_vars + .clone() + .iter() + .chain( + defined_vars + .iter() + .filter(|&&var_index| !var_pool.get(var_index).is_temp()), + ) + .map(|&var_index| var_pool.get(var_index).get_typevar()) + .filter(|maybe_var| maybe_var.is_some()) + .map(|var| var.unwrap()), + ); + if !free_typevars.is_subset(&src_tvs) { + let missing_tvs = (&free_typevars - &src_tvs) + .iter() + .map(|tv| tv.name.clone()) + .collect::<Vec<_>>() + .join(", "); + panic!("Some free vars don't appear in src: {}", missing_tvs); + } + } + + for &var_index in input_vars.iter().chain(defined_vars.iter()) { + let var = var_pool.get_mut(var_index); + let canon_tv = type_env.get_equivalent(&var.get_or_create_typevar()); + var.set_typevar(canon_tv); + } + + Self { + src, + dst, + var_pool, + def_pool, + block_pool, + const_pool, + type_env, + } + } + + fn verify_legalize(&self) { + let def = self.def_pool.get(self.src); + for &var_index in def.defined_vars.iter() { + let defined_var = self.var_pool.get(var_index); + assert!( + defined_var.is_output(), + format!("{:?} not defined in the destination pattern", defined_var) + ); + } + } +} + +/// Inserts, if not present, a name in the `symbol_table`. Then returns its index in the variable +/// pool `var_pool`. If the variable was not present in the symbol table, then add it to the list of +/// `defined_vars`. +fn var_index( + name: &str, + symbol_table: &mut SymbolTable, + defined_vars: &mut Vec<VarIndex>, + var_pool: &mut VarPool, +) -> VarIndex { + let name = name.to_string(); + match symbol_table.get(&name) { + Some(&existing_var) => existing_var, + None => { + // Materialize the variable. + let new_var = var_pool.create(name.clone()); + symbol_table.insert(name, new_var); + defined_vars.push(new_var); + new_var + } + } +} + +/// Given a list of symbols defined in a Def, rewrite them to local symbols. Yield the new locals. +fn rewrite_defined_vars( + position: PatternPosition, + dummy_def: &DummyDef, + def_index: DefIndex, + symbol_table: &mut SymbolTable, + defined_vars: &mut Vec<VarIndex>, + var_pool: &mut VarPool, +) -> Vec<VarIndex> { + let mut new_defined_vars = Vec::new(); + for var in &dummy_def.defined_vars { + let own_var = var_index(&var.name, symbol_table, defined_vars, var_pool); + var_pool.get_mut(own_var).set_def(position, def_index); + new_defined_vars.push(own_var); + } + new_defined_vars +} + +/// Find all uses of variables in `expr` and replace them with our own local symbols. +fn rewrite_expr( + position: PatternPosition, + dummy_expr: DummyExpr, + symbol_table: &mut SymbolTable, + input_vars: &mut Vec<VarIndex>, + var_pool: &mut VarPool, + const_pool: &mut ConstPool, +) -> Apply { + let (apply_target, dummy_args) = if let DummyExpr::Apply(apply_target, dummy_args) = dummy_expr + { + (apply_target, dummy_args) + } else { + panic!("we only rewrite apply expressions"); + }; + + assert_eq!( + apply_target.inst().operands_in.len(), + dummy_args.len(), + "number of arguments in instruction {} is incorrect\nexpected: {:?}", + apply_target.inst().name, + apply_target + .inst() + .operands_in + .iter() + .map(|operand| format!("{}: {}", operand.name, operand.kind.rust_type)) + .collect::<Vec<_>>(), + ); + + let mut args = Vec::new(); + for (i, arg) in dummy_args.into_iter().enumerate() { + match arg { + DummyExpr::Var(var) => { + let own_var = var_index(&var.name, symbol_table, input_vars, var_pool); + let var = var_pool.get(own_var); + assert!( + var.is_input() || var.get_def(position).is_some(), + format!("{:?} used as both input and def", var) + ); + args.push(Expr::Var(own_var)); + } + DummyExpr::Literal(literal) => { + assert!(!apply_target.inst().operands_in[i].is_value()); + args.push(Expr::Literal(literal)); + } + DummyExpr::Constant(constant) => { + let const_name = const_pool.insert(constant.0); + // Here we abuse var_index by passing an empty, immediately-dropped vector to + // `defined_vars`; the reason for this is that unlike the `Var` case above, + // constants will create a variable that is not an input variable (it is tracked + // instead by ConstPool). + let const_var = var_index(&const_name, symbol_table, &mut vec![], var_pool); + args.push(Expr::Var(const_var)); + } + DummyExpr::Apply(..) => { + panic!("Recursive apply is not allowed."); + } + DummyExpr::Block(_block) => { + panic!("Blocks are not valid arguments."); + } + } + } + + Apply::new(apply_target, args) +} + +#[allow(clippy::too_many_arguments)] +fn rewrite_def_list( + position: PatternPosition, + dummy_defs: Vec<DummyDef>, + symbol_table: &mut SymbolTable, + input_vars: &mut Vec<VarIndex>, + defined_vars: &mut Vec<VarIndex>, + var_pool: &mut VarPool, + def_pool: &mut DefPool, + block_pool: &mut BlockPool, + const_pool: &mut ConstPool, +) -> Vec<DefIndex> { + let mut new_defs = Vec::new(); + // Register variable names of new blocks first as a block name can be used to jump forward. Thus + // the name has to be registered first to avoid misinterpreting it as an input-var. + for dummy_def in dummy_defs.iter() { + if let DummyExpr::Block(ref var) = dummy_def.expr { + var_index(&var.name, symbol_table, defined_vars, var_pool); + } + } + + // Iterate over the definitions and blocks, to map variables names to inputs or outputs. + for dummy_def in dummy_defs { + let def_index = def_pool.next_index(); + + let new_defined_vars = rewrite_defined_vars( + position, + &dummy_def, + def_index, + symbol_table, + defined_vars, + var_pool, + ); + if let DummyExpr::Block(var) = dummy_def.expr { + let var_index = *symbol_table + .get(&var.name) + .or_else(|| { + panic!( + "Block {} was not registered during the first visit", + var.name + ) + }) + .unwrap(); + var_pool.get_mut(var_index).set_def(position, def_index); + block_pool.create_block(var_index, def_index); + } else { + let new_apply = rewrite_expr( + position, + dummy_def.expr, + symbol_table, + input_vars, + var_pool, + const_pool, + ); + + assert!( + def_pool.next_index() == def_index, + "shouldn't have created new defs in the meanwhile" + ); + assert_eq!( + new_apply.inst.value_results.len(), + new_defined_vars.len(), + "number of Var results in instruction is incorrect" + ); + + new_defs.push(def_pool.create_inst(new_apply, new_defined_vars)); + } + } + new_defs +} + +/// A group of related transformations. +pub(crate) struct TransformGroup { + pub name: &'static str, + pub doc: &'static str, + pub chain_with: Option<TransformGroupIndex>, + pub isa_name: Option<&'static str>, + pub id: TransformGroupIndex, + + /// Maps Instruction camel_case names to custom legalization functions names. + pub custom_legalizes: HashMap<String, &'static str>, + pub transforms: Vec<Transform>, +} + +impl TransformGroup { + pub fn rust_name(&self) -> String { + match self.isa_name { + Some(_) => { + // This is a function in the same module as the LEGALIZE_ACTIONS table referring to + // it. + self.name.to_string() + } + None => format!("crate::legalizer::{}", self.name), + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub(crate) struct TransformGroupIndex(u32); +entity_impl!(TransformGroupIndex); + +pub(crate) struct TransformGroupBuilder { + name: &'static str, + doc: &'static str, + chain_with: Option<TransformGroupIndex>, + isa_name: Option<&'static str>, + pub custom_legalizes: HashMap<String, &'static str>, + pub transforms: Vec<Transform>, +} + +impl TransformGroupBuilder { + pub fn new(name: &'static str, doc: &'static str) -> Self { + Self { + name, + doc, + chain_with: None, + isa_name: None, + custom_legalizes: HashMap::new(), + transforms: Vec::new(), + } + } + + pub fn chain_with(mut self, next_id: TransformGroupIndex) -> Self { + assert!(self.chain_with.is_none()); + self.chain_with = Some(next_id); + self + } + + pub fn isa(mut self, isa_name: &'static str) -> Self { + assert!(self.isa_name.is_none()); + self.isa_name = Some(isa_name); + self + } + + /// Add a custom legalization action for `inst`. + /// + /// The `func_name` parameter is the fully qualified name of a Rust function which takes the + /// same arguments as the `isa::Legalize` actions. + /// + /// The custom function will be called to legalize `inst` and any return value is ignored. + pub fn custom_legalize(&mut self, inst: &Instruction, func_name: &'static str) { + assert!( + self.custom_legalizes + .insert(inst.camel_name.clone(), func_name) + .is_none(), + format!( + "custom legalization action for {} inserted twice", + inst.name + ) + ); + } + + /// Add a legalization pattern to this group. + pub fn legalize(&mut self, src: DummyDef, dst: Vec<DummyDef>) { + let transform = Transform::new(src, dst); + transform.verify_legalize(); + self.transforms.push(transform); + } + + pub fn build_and_add_to(self, owner: &mut TransformGroups) -> TransformGroupIndex { + let next_id = owner.next_key(); + owner.add(TransformGroup { + name: self.name, + doc: self.doc, + isa_name: self.isa_name, + id: next_id, + chain_with: self.chain_with, + custom_legalizes: self.custom_legalizes, + transforms: self.transforms, + }) + } +} + +pub(crate) struct TransformGroups { + groups: PrimaryMap<TransformGroupIndex, TransformGroup>, +} + +impl TransformGroups { + pub fn new() -> Self { + Self { + groups: PrimaryMap::new(), + } + } + pub fn add(&mut self, new_group: TransformGroup) -> TransformGroupIndex { + for group in self.groups.values() { + assert!( + group.name != new_group.name, + format!("trying to insert {} for the second time", new_group.name) + ); + } + self.groups.push(new_group) + } + pub fn get(&self, id: TransformGroupIndex) -> &TransformGroup { + &self.groups[id] + } + fn next_key(&self) -> TransformGroupIndex { + self.groups.next_key() + } + pub fn by_name(&self, name: &'static str) -> &TransformGroup { + for group in self.groups.values() { + if group.name == name { + return group; + } + } + panic!(format!("transform group with name {} not found", name)); + } +} + +#[test] +#[should_panic] +fn test_double_custom_legalization() { + use crate::cdsl::formats::InstructionFormatBuilder; + use crate::cdsl::instructions::{AllInstructions, InstructionBuilder, InstructionGroupBuilder}; + + let nullary = InstructionFormatBuilder::new("nullary").build(); + + let mut dummy_all = AllInstructions::new(); + let mut inst_group = InstructionGroupBuilder::new(&mut dummy_all); + inst_group.push(InstructionBuilder::new("dummy", "doc", &nullary)); + + let inst_group = inst_group.build(); + let dummy_inst = inst_group.by_name("dummy"); + + let mut transform_group = TransformGroupBuilder::new("test", "doc"); + transform_group.custom_legalize(&dummy_inst, "custom 1"); + transform_group.custom_legalize(&dummy_inst, "custom 2"); +} |