summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_mir_build/src/build/custom
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_build/src/build/custom')
-rw-r--r--compiler/rustc_mir_build/src/build/custom/mod.rs155
-rw-r--r--compiler/rustc_mir_build/src/build/custom/parse.rs256
-rw-r--r--compiler/rustc_mir_build/src/build/custom/parse/instruction.rs105
3 files changed, 516 insertions, 0 deletions
diff --git a/compiler/rustc_mir_build/src/build/custom/mod.rs b/compiler/rustc_mir_build/src/build/custom/mod.rs
new file mode 100644
index 000000000..eb021f477
--- /dev/null
+++ b/compiler/rustc_mir_build/src/build/custom/mod.rs
@@ -0,0 +1,155 @@
+//! Provides the implementation of the `custom_mir` attribute.
+//!
+//! Up until MIR building, this attribute has absolutely no effect. The `mir!` macro is a normal
+//! decl macro that expands like any other, and the code goes through parsing, name resolution and
+//! type checking like all other code. In MIR building we finally detect whether this attribute is
+//! present, and if so we branch off into this module, which implements the attribute by
+//! implementing a custom lowering from THIR to MIR.
+//!
+//! The result of this lowering is returned "normally" from the `mir_built` query, with the only
+//! notable difference being that the `injected` field in the body is set. Various components of the
+//! MIR pipeline, like borrowck and the pass manager will then consult this field (via
+//! `body.should_skip()`) to skip the parts of the MIR pipeline that precede the MIR phase the user
+//! specified.
+//!
+//! This file defines the general framework for the custom parsing. The parsing for all the
+//! "top-level" constructs can be found in the `parse` submodule, while the parsing for statements,
+//! terminators, and everything below can be found in the `parse::instruction` submodule.
+//!
+
+use rustc_ast::Attribute;
+use rustc_data_structures::fx::FxHashMap;
+use rustc_hir::def_id::DefId;
+use rustc_index::vec::IndexVec;
+use rustc_middle::{
+ mir::*,
+ thir::*,
+ ty::{Ty, TyCtxt},
+};
+use rustc_span::Span;
+
+mod parse;
+
+pub(super) fn build_custom_mir<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ did: DefId,
+ thir: &Thir<'tcx>,
+ expr: ExprId,
+ params: &IndexVec<ParamId, Param<'tcx>>,
+ return_ty: Ty<'tcx>,
+ return_ty_span: Span,
+ span: Span,
+ attr: &Attribute,
+) -> Body<'tcx> {
+ let mut body = Body {
+ basic_blocks: BasicBlocks::new(IndexVec::new()),
+ source: MirSource::item(did),
+ phase: MirPhase::Built,
+ source_scopes: IndexVec::new(),
+ generator: None,
+ local_decls: LocalDecls::new(),
+ user_type_annotations: IndexVec::new(),
+ arg_count: params.len(),
+ spread_arg: None,
+ var_debug_info: Vec::new(),
+ span,
+ required_consts: Vec::new(),
+ is_polymorphic: false,
+ tainted_by_errors: None,
+ injection_phase: None,
+ pass_count: 0,
+ };
+
+ body.local_decls.push(LocalDecl::new(return_ty, return_ty_span));
+ body.basic_blocks_mut().push(BasicBlockData::new(None));
+ body.source_scopes.push(SourceScopeData {
+ span,
+ parent_scope: None,
+ inlined: None,
+ inlined_parent_scope: None,
+ local_data: ClearCrossCrate::Clear,
+ });
+ body.injection_phase = Some(parse_attribute(attr));
+
+ let mut pctxt = ParseCtxt {
+ tcx,
+ thir,
+ source_scope: OUTERMOST_SOURCE_SCOPE,
+ body: &mut body,
+ local_map: FxHashMap::default(),
+ block_map: FxHashMap::default(),
+ };
+
+ let res = (|| {
+ pctxt.parse_args(&params)?;
+ pctxt.parse_body(expr)
+ })();
+ if let Err(err) = res {
+ tcx.sess.diagnostic().span_fatal(
+ err.span,
+ format!("Could not parse {}, found: {:?}", err.expected, err.item_description),
+ )
+ }
+
+ body
+}
+
+fn parse_attribute(attr: &Attribute) -> MirPhase {
+ let meta_items = attr.meta_item_list().unwrap();
+ let mut dialect: Option<String> = None;
+ let mut phase: Option<String> = None;
+
+ for nested in meta_items {
+ let name = nested.name_or_empty();
+ let value = nested.value_str().unwrap().as_str().to_string();
+ match name.as_str() {
+ "dialect" => {
+ assert!(dialect.is_none());
+ dialect = Some(value);
+ }
+ "phase" => {
+ assert!(phase.is_none());
+ phase = Some(value);
+ }
+ other => {
+ panic!("Unexpected key {}", other);
+ }
+ }
+ }
+
+ let Some(dialect) = dialect else {
+ assert!(phase.is_none());
+ return MirPhase::Built;
+ };
+
+ MirPhase::parse(dialect, phase)
+}
+
+struct ParseCtxt<'tcx, 'body> {
+ tcx: TyCtxt<'tcx>,
+ thir: &'body Thir<'tcx>,
+ source_scope: SourceScope,
+
+ body: &'body mut Body<'tcx>,
+ local_map: FxHashMap<LocalVarId, Local>,
+ block_map: FxHashMap<LocalVarId, BasicBlock>,
+}
+
+struct ParseError {
+ span: Span,
+ item_description: String,
+ expected: String,
+}
+
+impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
+ fn expr_error(&self, expr: ExprId, expected: &'static str) -> ParseError {
+ let expr = &self.thir[expr];
+ ParseError {
+ span: expr.span,
+ item_description: format!("{:?}", expr.kind),
+ expected: expected.to_string(),
+ }
+ }
+}
+
+type PResult<T> = Result<T, ParseError>;
diff --git a/compiler/rustc_mir_build/src/build/custom/parse.rs b/compiler/rustc_mir_build/src/build/custom/parse.rs
new file mode 100644
index 000000000..d72770e70
--- /dev/null
+++ b/compiler/rustc_mir_build/src/build/custom/parse.rs
@@ -0,0 +1,256 @@
+use rustc_index::vec::IndexVec;
+use rustc_middle::{mir::*, thir::*, ty::Ty};
+use rustc_span::Span;
+
+use super::{PResult, ParseCtxt, ParseError};
+
+mod instruction;
+
+/// Helper macro for parsing custom MIR.
+///
+/// Example usage looks something like:
+/// ```rust,ignore (incomplete example)
+/// parse_by_kind!(
+/// self, // : &ParseCtxt
+/// expr_id, // what you're matching against
+/// "assignment", // the thing you're trying to parse
+/// @call("mir_assign", args) => { args[0] }, // match invocations of the `mir_assign` special function
+/// ExprKind::Assign { lhs, .. } => { lhs }, // match thir assignment expressions
+/// // no need for fallthrough case - reasonable error is automatically generated
+/// )
+/// ```
+macro_rules! parse_by_kind {
+ (
+ $self:ident,
+ $expr_id:expr,
+ $expr_name:pat,
+ $expected:literal,
+ $(
+ @call($name:literal, $args:ident) => $call_expr:expr,
+ )*
+ $(
+ $pat:pat => $expr:expr,
+ )*
+ ) => {{
+ let expr_id = $self.preparse($expr_id);
+ let expr = &$self.thir[expr_id];
+ debug!("Trying to parse {:?} as {}", expr.kind, $expected);
+ let $expr_name = expr;
+ match &expr.kind {
+ $(
+ ExprKind::Call { ty, fun: _, args: $args, .. } if {
+ match ty.kind() {
+ ty::FnDef(did, _) => {
+ $self.tcx.is_diagnostic_item(rustc_span::Symbol::intern($name), *did)
+ }
+ _ => false,
+ }
+ } => $call_expr,
+ )*
+ $(
+ $pat => $expr,
+ )*
+ #[allow(unreachable_patterns)]
+ _ => return Err($self.expr_error(expr_id, $expected))
+ }
+ }};
+}
+pub(crate) use parse_by_kind;
+
+impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
+ /// Expressions should only ever be matched on after preparsing them. This removes extra scopes
+ /// we don't care about.
+ fn preparse(&self, expr_id: ExprId) -> ExprId {
+ let expr = &self.thir[expr_id];
+ match expr.kind {
+ ExprKind::Scope { value, .. } => self.preparse(value),
+ _ => expr_id,
+ }
+ }
+
+ fn statement_as_expr(&self, stmt_id: StmtId) -> PResult<ExprId> {
+ match &self.thir[stmt_id].kind {
+ StmtKind::Expr { expr, .. } => Ok(*expr),
+ kind @ StmtKind::Let { pattern, .. } => {
+ return Err(ParseError {
+ span: pattern.span,
+ item_description: format!("{:?}", kind),
+ expected: "expression".to_string(),
+ });
+ }
+ }
+ }
+
+ pub fn parse_args(&mut self, params: &IndexVec<ParamId, Param<'tcx>>) -> PResult<()> {
+ for param in params.iter() {
+ let (var, span) = {
+ let pat = param.pat.as_ref().unwrap();
+ match &pat.kind {
+ PatKind::Binding { var, .. } => (*var, pat.span),
+ _ => {
+ return Err(ParseError {
+ span: pat.span,
+ item_description: format!("{:?}", pat.kind),
+ expected: "local".to_string(),
+ });
+ }
+ }
+ };
+ let decl = LocalDecl::new(param.ty, span);
+ let local = self.body.local_decls.push(decl);
+ self.local_map.insert(var, local);
+ }
+
+ Ok(())
+ }
+
+ /// Bodies are of the form:
+ ///
+ /// ```text
+ /// {
+ /// let bb1: BasicBlock;
+ /// let bb2: BasicBlock;
+ /// {
+ /// let RET: _;
+ /// let local1;
+ /// let local2;
+ ///
+ /// {
+ /// { // entry block
+ /// statement1;
+ /// terminator1
+ /// };
+ ///
+ /// bb1 = {
+ /// statement2;
+ /// terminator2
+ /// };
+ ///
+ /// bb2 = {
+ /// statement3;
+ /// terminator3
+ /// }
+ ///
+ /// RET
+ /// }
+ /// }
+ /// }
+ /// ```
+ ///
+ /// This allows us to easily parse the basic blocks declarations, local declarations, and
+ /// basic block definitions in order.
+ pub fn parse_body(&mut self, expr_id: ExprId) -> PResult<()> {
+ let body = parse_by_kind!(self, expr_id, _, "whole body",
+ ExprKind::Block { block } => self.thir[*block].expr.unwrap(),
+ );
+ let (block_decls, rest) = parse_by_kind!(self, body, _, "body with block decls",
+ ExprKind::Block { block } => {
+ let block = &self.thir[*block];
+ (&block.stmts, block.expr.unwrap())
+ },
+ );
+ self.parse_block_decls(block_decls.iter().copied())?;
+
+ let (local_decls, rest) = parse_by_kind!(self, rest, _, "body with local decls",
+ ExprKind::Block { block } => {
+ let block = &self.thir[*block];
+ (&block.stmts, block.expr.unwrap())
+ },
+ );
+ self.parse_local_decls(local_decls.iter().copied())?;
+
+ let block_defs = parse_by_kind!(self, rest, _, "body with block defs",
+ ExprKind::Block { block } => &self.thir[*block].stmts,
+ );
+ for (i, block_def) in block_defs.iter().enumerate() {
+ let block = self.parse_block_def(self.statement_as_expr(*block_def)?)?;
+ self.body.basic_blocks_mut()[BasicBlock::from_usize(i)] = block;
+ }
+
+ Ok(())
+ }
+
+ fn parse_block_decls(&mut self, stmts: impl Iterator<Item = StmtId>) -> PResult<()> {
+ for stmt in stmts {
+ let (var, _, _) = self.parse_let_statement(stmt)?;
+ let data = BasicBlockData::new(None);
+ let block = self.body.basic_blocks_mut().push(data);
+ self.block_map.insert(var, block);
+ }
+
+ Ok(())
+ }
+
+ fn parse_local_decls(&mut self, mut stmts: impl Iterator<Item = StmtId>) -> PResult<()> {
+ let (ret_var, ..) = self.parse_let_statement(stmts.next().unwrap())?;
+ self.local_map.insert(ret_var, Local::from_u32(0));
+
+ for stmt in stmts {
+ let (var, ty, span) = self.parse_let_statement(stmt)?;
+ let decl = LocalDecl::new(ty, span);
+ let local = self.body.local_decls.push(decl);
+ self.local_map.insert(var, local);
+ }
+
+ Ok(())
+ }
+
+ fn parse_let_statement(&mut self, stmt_id: StmtId) -> PResult<(LocalVarId, Ty<'tcx>, Span)> {
+ let pattern = match &self.thir[stmt_id].kind {
+ StmtKind::Let { pattern, .. } => pattern,
+ StmtKind::Expr { expr, .. } => {
+ return Err(self.expr_error(*expr, "let statement"));
+ }
+ };
+
+ self.parse_var(pattern)
+ }
+
+ fn parse_var(&mut self, mut pat: &Pat<'tcx>) -> PResult<(LocalVarId, Ty<'tcx>, Span)> {
+ // Make sure we throw out any `AscribeUserType` we find
+ loop {
+ match &pat.kind {
+ PatKind::Binding { var, ty, .. } => break Ok((*var, *ty, pat.span)),
+ PatKind::AscribeUserType { subpattern, .. } => {
+ pat = subpattern;
+ }
+ _ => {
+ break Err(ParseError {
+ span: pat.span,
+ item_description: format!("{:?}", pat.kind),
+ expected: "local".to_string(),
+ });
+ }
+ }
+ }
+ }
+
+ fn parse_block_def(&self, expr_id: ExprId) -> PResult<BasicBlockData<'tcx>> {
+ let block = parse_by_kind!(self, expr_id, _, "basic block",
+ ExprKind::Block { block } => &self.thir[*block],
+ );
+
+ let mut data = BasicBlockData::new(None);
+ for stmt_id in &*block.stmts {
+ let stmt = self.statement_as_expr(*stmt_id)?;
+ let span = self.thir[stmt].span;
+ let statement = self.parse_statement(stmt)?;
+ data.statements.push(Statement {
+ source_info: SourceInfo { span, scope: self.source_scope },
+ kind: statement,
+ });
+ }
+
+ let Some(trailing) = block.expr else {
+ return Err(self.expr_error(expr_id, "terminator"))
+ };
+ let span = self.thir[trailing].span;
+ let terminator = self.parse_terminator(trailing)?;
+ data.terminator = Some(Terminator {
+ source_info: SourceInfo { span, scope: self.source_scope },
+ kind: terminator,
+ });
+
+ Ok(data)
+ }
+}
diff --git a/compiler/rustc_mir_build/src/build/custom/parse/instruction.rs b/compiler/rustc_mir_build/src/build/custom/parse/instruction.rs
new file mode 100644
index 000000000..03206af33
--- /dev/null
+++ b/compiler/rustc_mir_build/src/build/custom/parse/instruction.rs
@@ -0,0 +1,105 @@
+use rustc_middle::mir::interpret::{ConstValue, Scalar};
+use rustc_middle::{mir::*, thir::*, ty};
+
+use super::{parse_by_kind, PResult, ParseCtxt};
+
+impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
+ pub fn parse_statement(&self, expr_id: ExprId) -> PResult<StatementKind<'tcx>> {
+ parse_by_kind!(self, expr_id, _, "statement",
+ @call("mir_retag", args) => {
+ Ok(StatementKind::Retag(RetagKind::Default, Box::new(self.parse_place(args[0])?)))
+ },
+ @call("mir_retag_raw", args) => {
+ Ok(StatementKind::Retag(RetagKind::Raw, Box::new(self.parse_place(args[0])?)))
+ },
+ ExprKind::Assign { lhs, rhs } => {
+ let lhs = self.parse_place(*lhs)?;
+ let rhs = self.parse_rvalue(*rhs)?;
+ Ok(StatementKind::Assign(Box::new((lhs, rhs))))
+ },
+ )
+ }
+
+ pub fn parse_terminator(&self, expr_id: ExprId) -> PResult<TerminatorKind<'tcx>> {
+ parse_by_kind!(self, expr_id, _, "terminator",
+ @call("mir_return", _args) => {
+ Ok(TerminatorKind::Return)
+ },
+ @call("mir_goto", args) => {
+ Ok(TerminatorKind::Goto { target: self.parse_block(args[0])? } )
+ },
+ )
+ }
+
+ fn parse_rvalue(&self, expr_id: ExprId) -> PResult<Rvalue<'tcx>> {
+ parse_by_kind!(self, expr_id, _, "rvalue",
+ ExprKind::Borrow { borrow_kind, arg } => Ok(
+ Rvalue::Ref(self.tcx.lifetimes.re_erased, *borrow_kind, self.parse_place(*arg)?)
+ ),
+ ExprKind::AddressOf { mutability, arg } => Ok(
+ Rvalue::AddressOf(*mutability, self.parse_place(*arg)?)
+ ),
+ _ => self.parse_operand(expr_id).map(Rvalue::Use),
+ )
+ }
+
+ fn parse_operand(&self, expr_id: ExprId) -> PResult<Operand<'tcx>> {
+ parse_by_kind!(self, expr_id, expr, "operand",
+ @call("mir_move", args) => self.parse_place(args[0]).map(Operand::Move),
+ @call("mir_static", args) => self.parse_static(args[0]),
+ @call("mir_static_mut", args) => self.parse_static(args[0]),
+ ExprKind::Literal { .. }
+ | ExprKind::NamedConst { .. }
+ | ExprKind::NonHirLiteral { .. }
+ | ExprKind::ZstLiteral { .. }
+ | ExprKind::ConstParam { .. }
+ | ExprKind::ConstBlock { .. } => {
+ Ok(Operand::Constant(Box::new(
+ crate::build::expr::as_constant::as_constant_inner(expr, |_| None, self.tcx)
+ )))
+ },
+ _ => self.parse_place(expr_id).map(Operand::Copy),
+ )
+ }
+
+ fn parse_place(&self, expr_id: ExprId) -> PResult<Place<'tcx>> {
+ parse_by_kind!(self, expr_id, _, "place",
+ ExprKind::Deref { arg } => Ok(
+ self.parse_place(*arg)?.project_deeper(&[PlaceElem::Deref], self.tcx)
+ ),
+ _ => self.parse_local(expr_id).map(Place::from),
+ )
+ }
+
+ fn parse_local(&self, expr_id: ExprId) -> PResult<Local> {
+ parse_by_kind!(self, expr_id, _, "local",
+ ExprKind::VarRef { id } => Ok(self.local_map[id]),
+ )
+ }
+
+ fn parse_block(&self, expr_id: ExprId) -> PResult<BasicBlock> {
+ parse_by_kind!(self, expr_id, _, "basic block",
+ ExprKind::VarRef { id } => Ok(self.block_map[id]),
+ )
+ }
+
+ fn parse_static(&self, expr_id: ExprId) -> PResult<Operand<'tcx>> {
+ let expr_id = parse_by_kind!(self, expr_id, _, "static",
+ ExprKind::Deref { arg } => *arg,
+ );
+
+ parse_by_kind!(self, expr_id, expr, "static",
+ ExprKind::StaticRef { alloc_id, ty, .. } => {
+ let const_val =
+ ConstValue::Scalar(Scalar::from_pointer((*alloc_id).into(), &self.tcx));
+ let literal = ConstantKind::Val(const_val, *ty);
+
+ Ok(Operand::Constant(Box::new(Constant {
+ span: expr.span,
+ user_ty: None,
+ literal
+ })))
+ },
+ )
+ }
+}