summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_mir_transform/src/inline.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src/inline.rs')
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs371
1 files changed, 259 insertions, 112 deletions
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index 76b1522f3..d00a384cb 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -8,8 +8,11 @@ use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}
use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
use rustc_middle::ty::subst::Subst;
-use rustc_middle::ty::{self, ConstKind, Instance, InstanceDef, ParamEnv, Ty, TyCtxt};
+use rustc_middle::ty::{self, Instance, InstanceDef, ParamEnv, Ty, TyCtxt};
+use rustc_session::config::OptLevel;
+use rustc_span::def_id::DefId;
use rustc_span::{hygiene::ExpnKind, ExpnData, LocalExpnId, Span};
+use rustc_target::abi::VariantIdx;
use rustc_target::spec::abi::Abi;
use super::simplify::{remove_dead_blocks, CfgSimplifier};
@@ -43,8 +46,15 @@ impl<'tcx> MirPass<'tcx> for Inline {
return enabled;
}
- // rust-lang/rust#101004: reverted to old inlining decision logic
- sess.mir_opt_level() >= 3
+ match sess.mir_opt_level() {
+ 0 | 1 => false,
+ 2 => {
+ (sess.opts.optimize == OptLevel::Default
+ || sess.opts.optimize == OptLevel::Aggressive)
+ && sess.opts.incremental == None
+ }
+ _ => true,
+ }
}
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
@@ -85,7 +95,7 @@ fn inline<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool {
history: Vec::new(),
changed: false,
};
- let blocks = BasicBlock::new(0)..body.basic_blocks().next_index();
+ let blocks = BasicBlock::new(0)..body.basic_blocks.next_index();
this.process_blocks(body, blocks);
this.changed
}
@@ -95,8 +105,12 @@ struct Inliner<'tcx> {
param_env: ParamEnv<'tcx>,
/// Caller codegen attributes.
codegen_fn_attrs: &'tcx CodegenFnAttrs,
- /// Stack of inlined Instances.
- history: Vec<ty::Instance<'tcx>>,
+ /// Stack of inlined instances.
+ /// We only check the `DefId` and not the substs because we want to
+ /// avoid inlining cases of polymorphic recursion.
+ /// The number of `DefId`s is finite, so checking history is enough
+ /// to ensure that we do not loop endlessly while inlining.
+ history: Vec<DefId>,
/// Indicates that the caller body has been modified.
changed: bool,
}
@@ -124,7 +138,7 @@ impl<'tcx> Inliner<'tcx> {
Ok(new_blocks) => {
debug!("inlined {}", callsite.callee);
self.changed = true;
- self.history.push(callsite.callee);
+ self.history.push(callsite.callee.def_id());
self.process_blocks(caller_body, new_blocks);
self.history.pop();
}
@@ -203,9 +217,9 @@ impl<'tcx> Inliner<'tcx> {
}
}
- let old_blocks = caller_body.basic_blocks().next_index();
+ let old_blocks = caller_body.basic_blocks.next_index();
self.inline_call(caller_body, &callsite, callee_body);
- let new_blocks = old_blocks..caller_body.basic_blocks().next_index();
+ let new_blocks = old_blocks..caller_body.basic_blocks.next_index();
Ok(new_blocks)
}
@@ -300,7 +314,7 @@ impl<'tcx> Inliner<'tcx> {
return None;
}
- if self.history.contains(&callee) {
+ if self.history.contains(&callee.def_id()) {
return None;
}
@@ -395,124 +409,66 @@ impl<'tcx> Inliner<'tcx> {
// Give a bonus functions with a small number of blocks,
// We normally have two or three blocks for even
// very small functions.
- if callee_body.basic_blocks().len() <= 3 {
+ if callee_body.basic_blocks.len() <= 3 {
threshold += threshold / 4;
}
debug!(" final inline threshold = {}", threshold);
// FIXME: Give a bonus to functions with only a single caller
- let mut first_block = true;
- let mut cost = 0;
+ let diverges = matches!(
+ callee_body.basic_blocks[START_BLOCK].terminator().kind,
+ TerminatorKind::Unreachable | TerminatorKind::Call { target: None, .. }
+ );
+ if diverges && !matches!(callee_attrs.inline, InlineAttr::Always) {
+ return Err("callee diverges unconditionally");
+ }
- // Traverse the MIR manually so we can account for the effects of
- // inlining on the CFG.
+ let mut checker = CostChecker {
+ tcx: self.tcx,
+ param_env: self.param_env,
+ instance: callsite.callee,
+ callee_body,
+ cost: 0,
+ validation: Ok(()),
+ };
+
+ // Traverse the MIR manually so we can account for the effects of inlining on the CFG.
let mut work_list = vec![START_BLOCK];
- let mut visited = BitSet::new_empty(callee_body.basic_blocks().len());
+ let mut visited = BitSet::new_empty(callee_body.basic_blocks.len());
while let Some(bb) = work_list.pop() {
if !visited.insert(bb.index()) {
continue;
}
- let blk = &callee_body.basic_blocks()[bb];
- for stmt in &blk.statements {
- // Don't count StorageLive/StorageDead in the inlining cost.
- match stmt.kind {
- StatementKind::StorageLive(_)
- | StatementKind::StorageDead(_)
- | StatementKind::Deinit(_)
- | StatementKind::Nop => {}
- _ => cost += INSTR_COST,
- }
- }
- let term = blk.terminator();
- let mut is_drop = false;
- match term.kind {
- TerminatorKind::Drop { ref place, target, unwind }
- | TerminatorKind::DropAndReplace { ref place, target, unwind, .. } => {
- is_drop = true;
- work_list.push(target);
- // If the place doesn't actually need dropping, treat it like
- // a regular goto.
- let ty = callsite.callee.subst_mir(self.tcx, &place.ty(callee_body, tcx).ty);
- if ty.needs_drop(tcx, self.param_env) {
- cost += CALL_PENALTY;
- if let Some(unwind) = unwind {
- cost += LANDINGPAD_PENALTY;
- work_list.push(unwind);
- }
- } else {
- cost += INSTR_COST;
- }
- }
+ let blk = &callee_body.basic_blocks[bb];
+ checker.visit_basic_block_data(bb, blk);
- TerminatorKind::Unreachable | TerminatorKind::Call { target: None, .. }
- if first_block =>
- {
- // If the function always diverges, don't inline
- // unless the cost is zero
- threshold = 0;
- }
-
- TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => {
- if let ty::FnDef(def_id, _) =
- *callsite.callee.subst_mir(self.tcx, &f.literal.ty()).kind()
- {
- // Don't give intrinsics the extra penalty for calls
- if tcx.is_intrinsic(def_id) {
- cost += INSTR_COST;
- } else {
- cost += CALL_PENALTY;
- }
- } else {
- cost += CALL_PENALTY;
- }
- if cleanup.is_some() {
- cost += LANDINGPAD_PENALTY;
- }
- }
- TerminatorKind::Assert { cleanup, .. } => {
- cost += CALL_PENALTY;
-
- if cleanup.is_some() {
- cost += LANDINGPAD_PENALTY;
- }
- }
- TerminatorKind::Resume => cost += RESUME_PENALTY,
- TerminatorKind::InlineAsm { cleanup, .. } => {
- cost += INSTR_COST;
+ let term = blk.terminator();
+ if let TerminatorKind::Drop { ref place, target, unwind }
+ | TerminatorKind::DropAndReplace { ref place, target, unwind, .. } = term.kind
+ {
+ work_list.push(target);
- if cleanup.is_some() {
- cost += LANDINGPAD_PENALTY;
+ // If the place doesn't actually need dropping, treat it like a regular goto.
+ let ty = callsite.callee.subst_mir(self.tcx, &place.ty(callee_body, tcx).ty);
+ if ty.needs_drop(tcx, self.param_env) && let Some(unwind) = unwind {
+ work_list.push(unwind);
}
- }
- _ => cost += INSTR_COST,
- }
-
- if !is_drop {
- for succ in term.successors() {
- work_list.push(succ);
- }
+ } else {
+ work_list.extend(term.successors())
}
-
- first_block = false;
}
// Count up the cost of local variables and temps, if we know the size
// use that, otherwise we use a moderately-large dummy cost.
-
- let ptr_size = tcx.data_layout.pointer_size.bytes();
-
for v in callee_body.vars_and_temps_iter() {
- let ty = callsite.callee.subst_mir(self.tcx, &callee_body.local_decls[v].ty);
- // Cost of the var is the size in machine-words, if we know
- // it.
- if let Some(size) = type_size_of(tcx, self.param_env, ty) {
- cost += ((size + ptr_size - 1) / ptr_size) as usize;
- } else {
- cost += UNKNOWN_SIZE_COST;
- }
+ checker.visit_local_decl(v, &callee_body.local_decls[v]);
}
+ // Abort if type validation found anything fishy.
+ checker.validation?;
+
+ let cost = checker.cost;
if let InlineAttr::Always = callee_attrs.inline {
debug!("INLINING {:?} because inline(always) [cost={}]", callsite, cost);
Ok(())
@@ -585,7 +541,7 @@ impl<'tcx> Inliner<'tcx> {
args: &args,
new_locals: Local::new(caller_body.local_decls.len())..,
new_scopes: SourceScope::new(caller_body.source_scopes.len())..,
- new_blocks: BasicBlock::new(caller_body.basic_blocks().len())..,
+ new_blocks: BasicBlock::new(caller_body.basic_blocks.len())..,
destination: dest,
callsite_scope: caller_body.source_scopes[callsite.source_info.scope].clone(),
callsite,
@@ -603,7 +559,9 @@ impl<'tcx> Inliner<'tcx> {
// If there are any locals without storage markers, give them storage only for the
// duration of the call.
for local in callee_body.vars_and_temps_iter() {
- if integrator.always_live_locals.contains(local) {
+ if !callee_body.local_decls[local].internal
+ && integrator.always_live_locals.contains(local)
+ {
let new_local = integrator.map_local(local);
caller_body[callsite.block].statements.push(Statement {
source_info: callsite.source_info,
@@ -616,7 +574,9 @@ impl<'tcx> Inliner<'tcx> {
// the slice once.
let mut n = 0;
for local in callee_body.vars_and_temps_iter().rev() {
- if integrator.always_live_locals.contains(local) {
+ if !callee_body.local_decls[local].internal
+ && integrator.always_live_locals.contains(local)
+ {
let new_local = integrator.map_local(local);
caller_body[block].statements.push(Statement {
source_info: callsite.source_info,
@@ -644,11 +604,11 @@ impl<'tcx> Inliner<'tcx> {
// `required_consts`, here we may not only have `ConstKind::Unevaluated`
// because we are calling `subst_and_normalize_erasing_regions`.
caller_body.required_consts.extend(
- callee_body.required_consts.iter().copied().filter(|&ct| {
- match ct.literal.const_for_ty() {
- Some(ct) => matches!(ct.kind(), ConstKind::Unevaluated(_)),
- None => true,
+ callee_body.required_consts.iter().copied().filter(|&ct| match ct.literal {
+ ConstantKind::Ty(_) => {
+ bug!("should never encounter ty::Unevaluated in `required_consts`")
}
+ ConstantKind::Val(..) | ConstantKind::Unevaluated(..) => true,
}),
);
}
@@ -782,6 +742,193 @@ fn type_size_of<'tcx>(
tcx.layout_of(param_env.and(ty)).ok().map(|layout| layout.size.bytes())
}
+/// Verify that the callee body is compatible with the caller.
+///
+/// This visitor mostly computes the inlining cost,
+/// but also needs to verify that types match because of normalization failure.
+struct CostChecker<'b, 'tcx> {
+ tcx: TyCtxt<'tcx>,
+ param_env: ParamEnv<'tcx>,
+ cost: usize,
+ callee_body: &'b Body<'tcx>,
+ instance: ty::Instance<'tcx>,
+ validation: Result<(), &'static str>,
+}
+
+impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
+ fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
+ // Don't count StorageLive/StorageDead in the inlining cost.
+ match statement.kind {
+ StatementKind::StorageLive(_)
+ | StatementKind::StorageDead(_)
+ | StatementKind::Deinit(_)
+ | StatementKind::Nop => {}
+ _ => self.cost += INSTR_COST,
+ }
+
+ self.super_statement(statement, location);
+ }
+
+ fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
+ let tcx = self.tcx;
+ match terminator.kind {
+ TerminatorKind::Drop { ref place, unwind, .. }
+ | TerminatorKind::DropAndReplace { ref place, unwind, .. } => {
+ // If the place doesn't actually need dropping, treat it like a regular goto.
+ let ty = self.instance.subst_mir(tcx, &place.ty(self.callee_body, tcx).ty);
+ if ty.needs_drop(tcx, self.param_env) {
+ self.cost += CALL_PENALTY;
+ if unwind.is_some() {
+ self.cost += LANDINGPAD_PENALTY;
+ }
+ } else {
+ self.cost += INSTR_COST;
+ }
+ }
+ TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => {
+ let fn_ty = self.instance.subst_mir(tcx, &f.literal.ty());
+ self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.is_intrinsic(def_id) {
+ // Don't give intrinsics the extra penalty for calls
+ INSTR_COST
+ } else {
+ CALL_PENALTY
+ };
+ if cleanup.is_some() {
+ self.cost += LANDINGPAD_PENALTY;
+ }
+ }
+ TerminatorKind::Assert { cleanup, .. } => {
+ self.cost += CALL_PENALTY;
+ if cleanup.is_some() {
+ self.cost += LANDINGPAD_PENALTY;
+ }
+ }
+ TerminatorKind::Resume => self.cost += RESUME_PENALTY,
+ TerminatorKind::InlineAsm { cleanup, .. } => {
+ self.cost += INSTR_COST;
+ if cleanup.is_some() {
+ self.cost += LANDINGPAD_PENALTY;
+ }
+ }
+ _ => self.cost += INSTR_COST,
+ }
+
+ self.super_terminator(terminator, location);
+ }
+
+ /// Count up the cost of local variables and temps, if we know the size
+ /// use that, otherwise we use a moderately-large dummy cost.
+ fn visit_local_decl(&mut self, local: Local, local_decl: &LocalDecl<'tcx>) {
+ let tcx = self.tcx;
+ let ptr_size = tcx.data_layout.pointer_size.bytes();
+
+ let ty = self.instance.subst_mir(tcx, &local_decl.ty);
+ // Cost of the var is the size in machine-words, if we know
+ // it.
+ if let Some(size) = type_size_of(tcx, self.param_env, ty) {
+ self.cost += ((size + ptr_size - 1) / ptr_size) as usize;
+ } else {
+ self.cost += UNKNOWN_SIZE_COST;
+ }
+
+ self.super_local_decl(local, local_decl)
+ }
+
+ /// This method duplicates code from MIR validation in an attempt to detect type mismatches due
+ /// to normalization failure.
+ fn visit_projection_elem(
+ &mut self,
+ local: Local,
+ proj_base: &[PlaceElem<'tcx>],
+ elem: PlaceElem<'tcx>,
+ context: PlaceContext,
+ location: Location,
+ ) {
+ if let ProjectionElem::Field(f, ty) = elem {
+ let parent = Place { local, projection: self.tcx.intern_place_elems(proj_base) };
+ let parent_ty = parent.ty(&self.callee_body.local_decls, self.tcx);
+ let check_equal = |this: &mut Self, f_ty| {
+ if !equal_up_to_regions(this.tcx, this.param_env, ty, f_ty) {
+ trace!(?ty, ?f_ty);
+ this.validation = Err("failed to normalize projection type");
+ return;
+ }
+ };
+
+ let kind = match parent_ty.ty.kind() {
+ &ty::Opaque(def_id, substs) => {
+ self.tcx.bound_type_of(def_id).subst(self.tcx, substs).kind()
+ }
+ kind => kind,
+ };
+
+ match kind {
+ ty::Tuple(fields) => {
+ let Some(f_ty) = fields.get(f.as_usize()) else {
+ self.validation = Err("malformed MIR");
+ return;
+ };
+ check_equal(self, *f_ty);
+ }
+ ty::Adt(adt_def, substs) => {
+ let var = parent_ty.variant_index.unwrap_or(VariantIdx::from_u32(0));
+ let Some(field) = adt_def.variant(var).fields.get(f.as_usize()) else {
+ self.validation = Err("malformed MIR");
+ return;
+ };
+ check_equal(self, field.ty(self.tcx, substs));
+ }
+ ty::Closure(_, substs) => {
+ let substs = substs.as_closure();
+ let Some(f_ty) = substs.upvar_tys().nth(f.as_usize()) else {
+ self.validation = Err("malformed MIR");
+ return;
+ };
+ check_equal(self, f_ty);
+ }
+ &ty::Generator(def_id, substs, _) => {
+ let f_ty = if let Some(var) = parent_ty.variant_index {
+ let gen_body = if def_id == self.callee_body.source.def_id() {
+ self.callee_body
+ } else {
+ self.tcx.optimized_mir(def_id)
+ };
+
+ let Some(layout) = gen_body.generator_layout() else {
+ self.validation = Err("malformed MIR");
+ return;
+ };
+
+ let Some(&local) = layout.variant_fields[var].get(f) else {
+ self.validation = Err("malformed MIR");
+ return;
+ };
+
+ let Some(&f_ty) = layout.field_tys.get(local) else {
+ self.validation = Err("malformed MIR");
+ return;
+ };
+
+ f_ty
+ } else {
+ let Some(f_ty) = substs.as_generator().prefix_tys().nth(f.index()) else {
+ self.validation = Err("malformed MIR");
+ return;
+ };
+
+ f_ty
+ };
+
+ check_equal(self, f_ty);
+ }
+ _ => self.validation = Err("malformed MIR"),
+ }
+ }
+
+ self.super_projection_elem(local, proj_base, elem, context, location);
+ }
+}
+
/**
* Integrator.
*