summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src/dataflow_const_prop.rs')
-rw-r--r--compiler/rustc_mir_transform/src/dataflow_const_prop.rs150
1 files changed, 122 insertions, 28 deletions
diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
index c75fe2327..49ded10ba 100644
--- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
+++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
@@ -5,6 +5,7 @@
use rustc_const_eval::const_eval::CheckAlignment;
use rustc_const_eval::interpret::{ConstValue, ImmTy, Immediate, InterpCx, Scalar};
use rustc_data_structures::fx::FxHashMap;
+use rustc_hir::def::DefKind;
use rustc_middle::mir::visit::{MutVisitor, Visitor};
use rustc_middle::mir::*;
use rustc_middle::ty::{self, Ty, TyCtxt};
@@ -12,6 +13,7 @@ use rustc_mir_dataflow::value_analysis::{Map, State, TrackElem, ValueAnalysis, V
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, ResultsVisitor, SwitchIntEdgeEffects};
use rustc_span::DUMMY_SP;
use rustc_target::abi::Align;
+use rustc_target::abi::VariantIdx;
use crate::MirPass;
@@ -29,14 +31,12 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
#[instrument(skip_all level = "debug")]
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ debug!(def_id = ?body.source.def_id());
if tcx.sess.mir_opt_level() < 4 && body.basic_blocks.len() > BLOCK_LIMIT {
debug!("aborted dataflow const prop due too many basic blocks");
return;
}
- // Decide which places to track during the analysis.
- let map = Map::from_filter(tcx, body, Ty::is_scalar);
-
// We want to have a somewhat linear runtime w.r.t. the number of statements/terminators.
// Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function
// applications, where `h` is the height of the lattice. Because the height of our lattice
@@ -45,10 +45,10 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
// `O(num_nodes * tracked_places * n)` in terms of time complexity. Since the number of
// map nodes is strongly correlated to the number of tracked places, this becomes more or
// less `O(n)` if we place a constant limit on the number of tracked places.
- if tcx.sess.mir_opt_level() < 4 && map.tracked_places() > PLACE_LIMIT {
- debug!("aborted dataflow const prop due to too many tracked places");
- return;
- }
+ let place_limit = if tcx.sess.mir_opt_level() < 4 { Some(PLACE_LIMIT) } else { None };
+
+ // Decide which places to track during the analysis.
+ let map = Map::from_filter(tcx, body, Ty::is_scalar, place_limit);
// Perform the actual dataflow analysis.
let analysis = ConstAnalysis::new(tcx, body, map);
@@ -62,14 +62,31 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
}
}
-struct ConstAnalysis<'tcx> {
+struct ConstAnalysis<'a, 'tcx> {
map: Map,
tcx: TyCtxt<'tcx>,
+ local_decls: &'a LocalDecls<'tcx>,
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
param_env: ty::ParamEnv<'tcx>,
}
-impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
+impl<'tcx> ConstAnalysis<'_, 'tcx> {
+ fn eval_discriminant(
+ &self,
+ enum_ty: Ty<'tcx>,
+ variant_index: VariantIdx,
+ ) -> Option<ScalarTy<'tcx>> {
+ if !enum_ty.is_enum() {
+ return None;
+ }
+ let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
+ let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
+ let discr_value = Scalar::try_from_uint(discr.val, discr_layout.size)?;
+ Some(ScalarTy(discr_value, discr.ty))
+ }
+}
+
+impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> {
type Value = FlatSet<ScalarTy<'tcx>>;
const NAME: &'static str = "ConstAnalysis";
@@ -78,6 +95,25 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
&self.map
}
+ fn handle_statement(&self, statement: &Statement<'tcx>, state: &mut State<Self::Value>) {
+ match statement.kind {
+ StatementKind::SetDiscriminant { box ref place, variant_index } => {
+ state.flood_discr(place.as_ref(), &self.map);
+ if self.map.find_discr(place.as_ref()).is_some() {
+ let enum_ty = place.ty(self.local_decls, self.tcx).ty;
+ if let Some(discr) = self.eval_discriminant(enum_ty, variant_index) {
+ state.assign_discr(
+ place.as_ref(),
+ ValueOrPlace::Value(FlatSet::Elem(discr)),
+ &self.map,
+ );
+ }
+ }
+ }
+ _ => self.super_statement(statement, state),
+ }
+ }
+
fn handle_assign(
&self,
target: Place<'tcx>,
@@ -85,13 +121,59 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
state: &mut State<Self::Value>,
) {
match rvalue {
+ Rvalue::Aggregate(kind, operands) => {
+ // If we assign `target = Enum::Variant#0(operand)`,
+ // we must make sure that all `target as Variant#i` are `Top`.
+ state.flood(target.as_ref(), self.map());
+
+ if let Some(target_idx) = self.map().find(target.as_ref()) {
+ let (variant_target, variant_index) = match **kind {
+ AggregateKind::Tuple | AggregateKind::Closure(..) => {
+ (Some(target_idx), None)
+ }
+ AggregateKind::Adt(def_id, variant_index, ..) => {
+ match self.tcx.def_kind(def_id) {
+ DefKind::Struct => (Some(target_idx), None),
+ DefKind::Enum => (
+ self.map.apply(target_idx, TrackElem::Variant(variant_index)),
+ Some(variant_index),
+ ),
+ _ => (None, None),
+ }
+ }
+ _ => (None, None),
+ };
+ if let Some(variant_target_idx) = variant_target {
+ for (field_index, operand) in operands.iter().enumerate() {
+ if let Some(field) = self.map().apply(
+ variant_target_idx,
+ TrackElem::Field(Field::from_usize(field_index)),
+ ) {
+ let result = self.handle_operand(operand, state);
+ state.insert_idx(field, result, self.map());
+ }
+ }
+ }
+ if let Some(variant_index) = variant_index
+ && let Some(discr_idx) = self.map().apply(target_idx, TrackElem::Discriminant)
+ {
+ // We are assigning the discriminant as part of an aggregate.
+ // This discriminant can only alias a variant field's value if the operand
+ // had an invalid value for that type.
+ // Using invalid values is UB, so we are allowed to perform the assignment
+ // without extra flooding.
+ let enum_ty = target.ty(self.local_decls, self.tcx).ty;
+ if let Some(discr_val) = self.eval_discriminant(enum_ty, variant_index) {
+ state.insert_value_idx(discr_idx, FlatSet::Elem(discr_val), &self.map);
+ }
+ }
+ }
+ }
Rvalue::CheckedBinaryOp(op, box (left, right)) => {
+ // Flood everything now, so we can use `insert_value_idx` directly later.
+ state.flood(target.as_ref(), self.map());
+
let target = self.map().find(target.as_ref());
- if let Some(target) = target {
- // We should not track any projections other than
- // what is overwritten below, but just in case...
- state.flood_idx(target, self.map());
- }
let value_target = target
.and_then(|target| self.map().apply(target, TrackElem::Field(0_u32.into())));
@@ -102,26 +184,19 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
let (val, overflow) = self.binary_op(state, *op, left, right);
if let Some(value_target) = value_target {
- state.assign_idx(value_target, ValueOrPlace::Value(val), self.map());
+ // We have flooded `target` earlier.
+ state.insert_value_idx(value_target, val, self.map());
}
if let Some(overflow_target) = overflow_target {
let overflow = match overflow {
FlatSet::Top => FlatSet::Top,
FlatSet::Elem(overflow) => {
- if overflow {
- // Overflow cannot be reliably propagated. See: https://github.com/rust-lang/rust/pull/101168#issuecomment-1288091446
- FlatSet::Top
- } else {
- self.wrap_scalar(Scalar::from_bool(false), self.tcx.types.bool)
- }
+ self.wrap_scalar(Scalar::from_bool(overflow), self.tcx.types.bool)
}
FlatSet::Bottom => FlatSet::Bottom,
};
- state.assign_idx(
- overflow_target,
- ValueOrPlace::Value(overflow),
- self.map(),
- );
+ // We have flooded `target` earlier.
+ state.insert_value_idx(overflow_target, overflow, self.map());
}
}
}
@@ -170,6 +245,9 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
FlatSet::Bottom => ValueOrPlace::Value(FlatSet::Bottom),
FlatSet::Top => ValueOrPlace::Value(FlatSet::Top),
},
+ Rvalue::Discriminant(place) => {
+ ValueOrPlace::Value(state.get_discr(place.as_ref(), self.map()))
+ }
_ => self.super_rvalue(rvalue, state),
}
}
@@ -243,12 +321,13 @@ impl<'tcx> std::fmt::Debug for ScalarTy<'tcx> {
}
}
-impl<'tcx> ConstAnalysis<'tcx> {
- pub fn new(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, map: Map) -> Self {
+impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> {
+ pub fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, map: Map) -> Self {
let param_env = tcx.param_env(body.source.def_id());
Self {
map,
tcx,
+ local_decls: &body.local_decls,
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
param_env: param_env,
}
@@ -441,6 +520,21 @@ impl<'tcx, 'map, 'a> Visitor<'tcx> for OperandCollector<'tcx, 'map, 'a> {
_ => (),
}
}
+
+ fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
+ match rvalue {
+ Rvalue::Discriminant(place) => {
+ match self.state.get_discr(place.as_ref(), self.visitor.map) {
+ FlatSet::Top => (),
+ FlatSet::Elem(value) => {
+ self.visitor.before_effect.insert((location, *place), value);
+ }
+ FlatSet::Bottom => (),
+ }
+ }
+ _ => self.super_rvalue(rvalue, location),
+ }
+ }
}
struct DummyMachine;