summaryrefslogtreecommitdiffstats
path: root/third_party/rust/cranelift-codegen-meta/src/cdsl/xform.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/cranelift-codegen-meta/src/cdsl/xform.rs')
-rw-r--r--third_party/rust/cranelift-codegen-meta/src/cdsl/xform.rs484
1 files changed, 484 insertions, 0 deletions
diff --git a/third_party/rust/cranelift-codegen-meta/src/cdsl/xform.rs b/third_party/rust/cranelift-codegen-meta/src/cdsl/xform.rs
new file mode 100644
index 0000000000..d21e93128d
--- /dev/null
+++ b/third_party/rust/cranelift-codegen-meta/src/cdsl/xform.rs
@@ -0,0 +1,484 @@
+use crate::cdsl::ast::{
+ Apply, BlockPool, ConstPool, DefIndex, DefPool, DummyDef, DummyExpr, Expr, PatternPosition,
+ VarIndex, VarPool,
+};
+use crate::cdsl::instructions::Instruction;
+use crate::cdsl::type_inference::{infer_transform, TypeEnvironment};
+use crate::cdsl::typevar::TypeVar;
+
+use cranelift_entity::{entity_impl, PrimaryMap};
+
+use std::collections::{HashMap, HashSet};
+use std::iter::FromIterator;
+
+/// An instruction transformation consists of a source and destination pattern.
+///
+/// Patterns are expressed in *register transfer language* as tuples of Def or Expr nodes. A
+/// pattern may optionally have a sequence of TypeConstraints, that additionally limit the set of
+/// cases when it applies.
+///
+/// The source pattern can contain only a single instruction.
+pub(crate) struct Transform {
+ pub src: DefIndex,
+ pub dst: Vec<DefIndex>,
+ pub var_pool: VarPool,
+ pub def_pool: DefPool,
+ pub block_pool: BlockPool,
+ pub const_pool: ConstPool,
+ pub type_env: TypeEnvironment,
+}
+
+type SymbolTable = HashMap<String, VarIndex>;
+
+impl Transform {
+ fn new(src: DummyDef, dst: Vec<DummyDef>) -> Self {
+ let mut var_pool = VarPool::new();
+ let mut def_pool = DefPool::new();
+ let mut block_pool = BlockPool::new();
+ let mut const_pool = ConstPool::new();
+
+ let mut input_vars: Vec<VarIndex> = Vec::new();
+ let mut defined_vars: Vec<VarIndex> = Vec::new();
+
+ // Maps variable names to our own Var copies.
+ let mut symbol_table: SymbolTable = SymbolTable::new();
+
+ // Rewrite variables in src and dst using our own copies.
+ let src = rewrite_def_list(
+ PatternPosition::Source,
+ vec![src],
+ &mut symbol_table,
+ &mut input_vars,
+ &mut defined_vars,
+ &mut var_pool,
+ &mut def_pool,
+ &mut block_pool,
+ &mut const_pool,
+ )[0];
+
+ let num_src_inputs = input_vars.len();
+
+ let dst = rewrite_def_list(
+ PatternPosition::Destination,
+ dst,
+ &mut symbol_table,
+ &mut input_vars,
+ &mut defined_vars,
+ &mut var_pool,
+ &mut def_pool,
+ &mut block_pool,
+ &mut const_pool,
+ );
+
+ // Sanity checks.
+ for &var_index in &input_vars {
+ assert!(
+ var_pool.get(var_index).is_input(),
+ format!("'{:?}' used as both input and def", var_pool.get(var_index))
+ );
+ }
+ assert!(
+ input_vars.len() == num_src_inputs,
+ format!(
+ "extra input vars in dst pattern: {:?}",
+ input_vars
+ .iter()
+ .map(|&i| var_pool.get(i))
+ .skip(num_src_inputs)
+ .collect::<Vec<_>>()
+ )
+ );
+
+ // Perform type inference and cleanup.
+ let type_env = infer_transform(src, &dst, &def_pool, &mut var_pool).unwrap();
+
+ // Sanity check: the set of inferred free type variables should be a subset of the type
+ // variables corresponding to Vars appearing in the source pattern.
+ {
+ let free_typevars: HashSet<TypeVar> =
+ HashSet::from_iter(type_env.free_typevars(&mut var_pool));
+ let src_tvs = HashSet::from_iter(
+ input_vars
+ .clone()
+ .iter()
+ .chain(
+ defined_vars
+ .iter()
+ .filter(|&&var_index| !var_pool.get(var_index).is_temp()),
+ )
+ .map(|&var_index| var_pool.get(var_index).get_typevar())
+ .filter(|maybe_var| maybe_var.is_some())
+ .map(|var| var.unwrap()),
+ );
+ if !free_typevars.is_subset(&src_tvs) {
+ let missing_tvs = (&free_typevars - &src_tvs)
+ .iter()
+ .map(|tv| tv.name.clone())
+ .collect::<Vec<_>>()
+ .join(", ");
+ panic!("Some free vars don't appear in src: {}", missing_tvs);
+ }
+ }
+
+ for &var_index in input_vars.iter().chain(defined_vars.iter()) {
+ let var = var_pool.get_mut(var_index);
+ let canon_tv = type_env.get_equivalent(&var.get_or_create_typevar());
+ var.set_typevar(canon_tv);
+ }
+
+ Self {
+ src,
+ dst,
+ var_pool,
+ def_pool,
+ block_pool,
+ const_pool,
+ type_env,
+ }
+ }
+
+ fn verify_legalize(&self) {
+ let def = self.def_pool.get(self.src);
+ for &var_index in def.defined_vars.iter() {
+ let defined_var = self.var_pool.get(var_index);
+ assert!(
+ defined_var.is_output(),
+ format!("{:?} not defined in the destination pattern", defined_var)
+ );
+ }
+ }
+}
+
+/// Inserts, if not present, a name in the `symbol_table`. Then returns its index in the variable
+/// pool `var_pool`. If the variable was not present in the symbol table, then add it to the list of
+/// `defined_vars`.
+fn var_index(
+ name: &str,
+ symbol_table: &mut SymbolTable,
+ defined_vars: &mut Vec<VarIndex>,
+ var_pool: &mut VarPool,
+) -> VarIndex {
+ let name = name.to_string();
+ match symbol_table.get(&name) {
+ Some(&existing_var) => existing_var,
+ None => {
+ // Materialize the variable.
+ let new_var = var_pool.create(name.clone());
+ symbol_table.insert(name, new_var);
+ defined_vars.push(new_var);
+ new_var
+ }
+ }
+}
+
+/// Given a list of symbols defined in a Def, rewrite them to local symbols. Yield the new locals.
+fn rewrite_defined_vars(
+ position: PatternPosition,
+ dummy_def: &DummyDef,
+ def_index: DefIndex,
+ symbol_table: &mut SymbolTable,
+ defined_vars: &mut Vec<VarIndex>,
+ var_pool: &mut VarPool,
+) -> Vec<VarIndex> {
+ let mut new_defined_vars = Vec::new();
+ for var in &dummy_def.defined_vars {
+ let own_var = var_index(&var.name, symbol_table, defined_vars, var_pool);
+ var_pool.get_mut(own_var).set_def(position, def_index);
+ new_defined_vars.push(own_var);
+ }
+ new_defined_vars
+}
+
+/// Find all uses of variables in `expr` and replace them with our own local symbols.
+fn rewrite_expr(
+ position: PatternPosition,
+ dummy_expr: DummyExpr,
+ symbol_table: &mut SymbolTable,
+ input_vars: &mut Vec<VarIndex>,
+ var_pool: &mut VarPool,
+ const_pool: &mut ConstPool,
+) -> Apply {
+ let (apply_target, dummy_args) = if let DummyExpr::Apply(apply_target, dummy_args) = dummy_expr
+ {
+ (apply_target, dummy_args)
+ } else {
+ panic!("we only rewrite apply expressions");
+ };
+
+ assert_eq!(
+ apply_target.inst().operands_in.len(),
+ dummy_args.len(),
+ "number of arguments in instruction {} is incorrect\nexpected: {:?}",
+ apply_target.inst().name,
+ apply_target
+ .inst()
+ .operands_in
+ .iter()
+ .map(|operand| format!("{}: {}", operand.name, operand.kind.rust_type))
+ .collect::<Vec<_>>(),
+ );
+
+ let mut args = Vec::new();
+ for (i, arg) in dummy_args.into_iter().enumerate() {
+ match arg {
+ DummyExpr::Var(var) => {
+ let own_var = var_index(&var.name, symbol_table, input_vars, var_pool);
+ let var = var_pool.get(own_var);
+ assert!(
+ var.is_input() || var.get_def(position).is_some(),
+ format!("{:?} used as both input and def", var)
+ );
+ args.push(Expr::Var(own_var));
+ }
+ DummyExpr::Literal(literal) => {
+ assert!(!apply_target.inst().operands_in[i].is_value());
+ args.push(Expr::Literal(literal));
+ }
+ DummyExpr::Constant(constant) => {
+ let const_name = const_pool.insert(constant.0);
+ // Here we abuse var_index by passing an empty, immediately-dropped vector to
+ // `defined_vars`; the reason for this is that unlike the `Var` case above,
+ // constants will create a variable that is not an input variable (it is tracked
+ // instead by ConstPool).
+ let const_var = var_index(&const_name, symbol_table, &mut vec![], var_pool);
+ args.push(Expr::Var(const_var));
+ }
+ DummyExpr::Apply(..) => {
+ panic!("Recursive apply is not allowed.");
+ }
+ DummyExpr::Block(_block) => {
+ panic!("Blocks are not valid arguments.");
+ }
+ }
+ }
+
+ Apply::new(apply_target, args)
+}
+
+#[allow(clippy::too_many_arguments)]
+fn rewrite_def_list(
+ position: PatternPosition,
+ dummy_defs: Vec<DummyDef>,
+ symbol_table: &mut SymbolTable,
+ input_vars: &mut Vec<VarIndex>,
+ defined_vars: &mut Vec<VarIndex>,
+ var_pool: &mut VarPool,
+ def_pool: &mut DefPool,
+ block_pool: &mut BlockPool,
+ const_pool: &mut ConstPool,
+) -> Vec<DefIndex> {
+ let mut new_defs = Vec::new();
+ // Register variable names of new blocks first as a block name can be used to jump forward. Thus
+ // the name has to be registered first to avoid misinterpreting it as an input-var.
+ for dummy_def in dummy_defs.iter() {
+ if let DummyExpr::Block(ref var) = dummy_def.expr {
+ var_index(&var.name, symbol_table, defined_vars, var_pool);
+ }
+ }
+
+ // Iterate over the definitions and blocks, to map variables names to inputs or outputs.
+ for dummy_def in dummy_defs {
+ let def_index = def_pool.next_index();
+
+ let new_defined_vars = rewrite_defined_vars(
+ position,
+ &dummy_def,
+ def_index,
+ symbol_table,
+ defined_vars,
+ var_pool,
+ );
+ if let DummyExpr::Block(var) = dummy_def.expr {
+ let var_index = *symbol_table
+ .get(&var.name)
+ .or_else(|| {
+ panic!(
+ "Block {} was not registered during the first visit",
+ var.name
+ )
+ })
+ .unwrap();
+ var_pool.get_mut(var_index).set_def(position, def_index);
+ block_pool.create_block(var_index, def_index);
+ } else {
+ let new_apply = rewrite_expr(
+ position,
+ dummy_def.expr,
+ symbol_table,
+ input_vars,
+ var_pool,
+ const_pool,
+ );
+
+ assert!(
+ def_pool.next_index() == def_index,
+ "shouldn't have created new defs in the meanwhile"
+ );
+ assert_eq!(
+ new_apply.inst.value_results.len(),
+ new_defined_vars.len(),
+ "number of Var results in instruction is incorrect"
+ );
+
+ new_defs.push(def_pool.create_inst(new_apply, new_defined_vars));
+ }
+ }
+ new_defs
+}
+
+/// A group of related transformations.
+pub(crate) struct TransformGroup {
+ pub name: &'static str,
+ pub doc: &'static str,
+ pub chain_with: Option<TransformGroupIndex>,
+ pub isa_name: Option<&'static str>,
+ pub id: TransformGroupIndex,
+
+ /// Maps Instruction camel_case names to custom legalization functions names.
+ pub custom_legalizes: HashMap<String, &'static str>,
+ pub transforms: Vec<Transform>,
+}
+
+impl TransformGroup {
+ pub fn rust_name(&self) -> String {
+ match self.isa_name {
+ Some(_) => {
+ // This is a function in the same module as the LEGALIZE_ACTIONS table referring to
+ // it.
+ self.name.to_string()
+ }
+ None => format!("crate::legalizer::{}", self.name),
+ }
+ }
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
+pub(crate) struct TransformGroupIndex(u32);
+entity_impl!(TransformGroupIndex);
+
+pub(crate) struct TransformGroupBuilder {
+ name: &'static str,
+ doc: &'static str,
+ chain_with: Option<TransformGroupIndex>,
+ isa_name: Option<&'static str>,
+ pub custom_legalizes: HashMap<String, &'static str>,
+ pub transforms: Vec<Transform>,
+}
+
+impl TransformGroupBuilder {
+ pub fn new(name: &'static str, doc: &'static str) -> Self {
+ Self {
+ name,
+ doc,
+ chain_with: None,
+ isa_name: None,
+ custom_legalizes: HashMap::new(),
+ transforms: Vec::new(),
+ }
+ }
+
+ pub fn chain_with(mut self, next_id: TransformGroupIndex) -> Self {
+ assert!(self.chain_with.is_none());
+ self.chain_with = Some(next_id);
+ self
+ }
+
+ pub fn isa(mut self, isa_name: &'static str) -> Self {
+ assert!(self.isa_name.is_none());
+ self.isa_name = Some(isa_name);
+ self
+ }
+
+ /// Add a custom legalization action for `inst`.
+ ///
+ /// The `func_name` parameter is the fully qualified name of a Rust function which takes the
+ /// same arguments as the `isa::Legalize` actions.
+ ///
+ /// The custom function will be called to legalize `inst` and any return value is ignored.
+ pub fn custom_legalize(&mut self, inst: &Instruction, func_name: &'static str) {
+ assert!(
+ self.custom_legalizes
+ .insert(inst.camel_name.clone(), func_name)
+ .is_none(),
+ format!(
+ "custom legalization action for {} inserted twice",
+ inst.name
+ )
+ );
+ }
+
+ /// Add a legalization pattern to this group.
+ pub fn legalize(&mut self, src: DummyDef, dst: Vec<DummyDef>) {
+ let transform = Transform::new(src, dst);
+ transform.verify_legalize();
+ self.transforms.push(transform);
+ }
+
+ pub fn build_and_add_to(self, owner: &mut TransformGroups) -> TransformGroupIndex {
+ let next_id = owner.next_key();
+ owner.add(TransformGroup {
+ name: self.name,
+ doc: self.doc,
+ isa_name: self.isa_name,
+ id: next_id,
+ chain_with: self.chain_with,
+ custom_legalizes: self.custom_legalizes,
+ transforms: self.transforms,
+ })
+ }
+}
+
+pub(crate) struct TransformGroups {
+ groups: PrimaryMap<TransformGroupIndex, TransformGroup>,
+}
+
+impl TransformGroups {
+ pub fn new() -> Self {
+ Self {
+ groups: PrimaryMap::new(),
+ }
+ }
+ pub fn add(&mut self, new_group: TransformGroup) -> TransformGroupIndex {
+ for group in self.groups.values() {
+ assert!(
+ group.name != new_group.name,
+ format!("trying to insert {} for the second time", new_group.name)
+ );
+ }
+ self.groups.push(new_group)
+ }
+ pub fn get(&self, id: TransformGroupIndex) -> &TransformGroup {
+ &self.groups[id]
+ }
+ fn next_key(&self) -> TransformGroupIndex {
+ self.groups.next_key()
+ }
+ pub fn by_name(&self, name: &'static str) -> &TransformGroup {
+ for group in self.groups.values() {
+ if group.name == name {
+ return group;
+ }
+ }
+ panic!(format!("transform group with name {} not found", name));
+ }
+}
+
+#[test]
+#[should_panic]
+fn test_double_custom_legalization() {
+ use crate::cdsl::formats::InstructionFormatBuilder;
+ use crate::cdsl::instructions::{AllInstructions, InstructionBuilder, InstructionGroupBuilder};
+
+ let nullary = InstructionFormatBuilder::new("nullary").build();
+
+ let mut dummy_all = AllInstructions::new();
+ let mut inst_group = InstructionGroupBuilder::new(&mut dummy_all);
+ inst_group.push(InstructionBuilder::new("dummy", "doc", &nullary));
+
+ let inst_group = inst_group.build();
+ let dummy_inst = inst_group.by_name("dummy");
+
+ let mut transform_group = TransformGroupBuilder::new("test", "doc");
+ transform_group.custom_legalize(&dummy_inst, "custom 1");
+ transform_group.custom_legalize(&dummy_inst, "custom 2");
+}