From 698f8c2f01ea549d77d7dc3338a12e04c11057b9 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 17 Apr 2024 14:02:58 +0200 Subject: Adding upstream version 1.64.0+dfsg1. Signed-off-by: Daniel Baumann --- .../rust-analyzer/crates/ide-assists/Cargo.toml | 31 + .../crates/ide-assists/src/assist_config.rs | 16 + .../crates/ide-assists/src/assist_context.rs | 347 ++ .../ide-assists/src/handlers/add_explicit_type.rs | 325 ++ .../ide-assists/src/handlers/add_label_to_loop.rs | 164 + .../src/handlers/add_lifetime_to_type.rs | 229 + .../src/handlers/add_missing_impl_members.rs | 1340 +++++ .../src/handlers/add_missing_match_arms.rs | 1709 +++++++ .../ide-assists/src/handlers/add_return_type.rs | 447 ++ .../ide-assists/src/handlers/add_turbo_fish.rs | 400 ++ .../ide-assists/src/handlers/apply_demorgan.rs | 234 + .../crates/ide-assists/src/handlers/auto_import.rs | 1292 +++++ .../ide-assists/src/handlers/change_visibility.rs | 216 + .../ide-assists/src/handlers/convert_bool_then.rs | 575 +++ .../src/handlers/convert_comment_block.rs | 395 ++ .../src/handlers/convert_integer_literal.rs | 268 + .../src/handlers/convert_into_to_from.rs | 351 ++ .../src/handlers/convert_iter_for_each_to_for.rs | 556 ++ .../src/handlers/convert_let_else_to_match.rs | 497 ++ .../src/handlers/convert_to_guarded_return.rs | 574 +++ .../convert_tuple_struct_to_named_struct.rs | 840 +++ .../src/handlers/convert_while_to_loop.rs | 188 + .../src/handlers/destructure_tuple_binding.rs | 2147 ++++++++ .../ide-assists/src/handlers/expand_glob_import.rs | 900 ++++ .../ide-assists/src/handlers/extract_function.rs | 5333 ++++++++++++++++++++ .../ide-assists/src/handlers/extract_module.rs | 1770 +++++++ .../handlers/extract_struct_from_enum_variant.rs | 1076 ++++ .../ide-assists/src/handlers/extract_type_alias.rs | 360 ++ .../ide-assists/src/handlers/extract_variable.rs | 1279 +++++ .../ide-assists/src/handlers/fix_visibility.rs | 606 +++ .../ide-assists/src/handlers/flip_binexpr.rs | 139 + .../crates/ide-assists/src/handlers/flip_comma.rs | 92 + .../ide-assists/src/handlers/flip_trait_bound.rs | 121 + .../ide-assists/src/handlers/generate_constant.rs | 255 + .../handlers/generate_default_from_enum_variant.rs | 179 + .../src/handlers/generate_default_from_new.rs | 657 +++ .../src/handlers/generate_delegate_methods.rs | 334 ++ .../ide-assists/src/handlers/generate_deref.rs | 343 ++ .../ide-assists/src/handlers/generate_derive.rs | 132 + .../handlers/generate_documentation_template.rs | 1328 +++++ .../src/handlers/generate_enum_is_method.rs | 316 ++ .../handlers/generate_enum_projection_method.rs | 342 ++ .../src/handlers/generate_enum_variant.rs | 227 + .../src/handlers/generate_from_impl_for_enum.rs | 310 ++ .../ide-assists/src/handlers/generate_function.rs | 1787 +++++++ .../ide-assists/src/handlers/generate_getter.rs | 492 ++ .../ide-assists/src/handlers/generate_impl.rs | 177 + .../src/handlers/generate_is_empty_from_len.rs | 295 ++ .../ide-assists/src/handlers/generate_new.rs | 495 ++ .../ide-assists/src/handlers/generate_setter.rs | 184 + .../crates/ide-assists/src/handlers/inline_call.rs | 1194 +++++ .../src/handlers/inline_local_variable.rs | 954 ++++ .../ide-assists/src/handlers/inline_type_alias.rs | 838 +++ .../src/handlers/introduce_named_generic.rs | 144 + .../src/handlers/introduce_named_lifetime.rs | 338 ++ .../crates/ide-assists/src/handlers/invert_if.rs | 144 + .../ide-assists/src/handlers/merge_imports.rs | 570 +++ .../ide-assists/src/handlers/merge_match_arms.rs | 822 +++ .../crates/ide-assists/src/handlers/move_bounds.rs | 122 + .../ide-assists/src/handlers/move_from_mod_rs.rs | 130 + .../crates/ide-assists/src/handlers/move_guard.rs | 997 ++++ .../src/handlers/move_module_to_file.rs | 337 ++ .../ide-assists/src/handlers/move_to_mod_rs.rs | 151 + .../src/handlers/number_representation.rs | 183 + .../src/handlers/promote_local_to_const.rs | 221 + .../ide-assists/src/handlers/pull_assignment_up.rs | 507 ++ .../src/handlers/qualify_method_call.rs | 548 ++ .../ide-assists/src/handlers/qualify_path.rs | 1297 +++++ .../crates/ide-assists/src/handlers/raw_string.rs | 509 ++ .../crates/ide-assists/src/handlers/remove_dbg.rs | 241 + .../crates/ide-assists/src/handlers/remove_mut.rs | 37 + .../src/handlers/remove_unused_param.rs | 409 ++ .../ide-assists/src/handlers/reorder_fields.rs | 212 + .../ide-assists/src/handlers/reorder_impl_items.rs | 284 ++ .../handlers/replace_derive_with_manual_impl.rs | 1250 +++++ .../src/handlers/replace_if_let_with_match.rs | 999 ++++ .../src/handlers/replace_let_with_if_let.rs | 100 + .../handlers/replace_qualified_name_with_use.rs | 438 ++ .../src/handlers/replace_string_with_char.rs | 307 ++ .../src/handlers/replace_try_expr_with_match.rs | 150 + .../replace_turbofish_with_explicit_type.rs | 243 + .../crates/ide-assists/src/handlers/sort_items.rs | 588 +++ .../ide-assists/src/handlers/split_import.rs | 82 + .../ide-assists/src/handlers/toggle_ignore.rs | 98 + .../crates/ide-assists/src/handlers/unmerge_use.rs | 237 + .../ide-assists/src/handlers/unnecessary_async.rs | 257 + .../ide-assists/src/handlers/unwrap_block.rs | 719 +++ .../src/handlers/unwrap_result_return_type.rs | 1020 ++++ .../src/handlers/wrap_return_type_in_result.rs | 980 ++++ .../rust-analyzer/crates/ide-assists/src/lib.rs | 309 ++ .../rust-analyzer/crates/ide-assists/src/tests.rs | 558 ++ .../crates/ide-assists/src/tests/generated.rs | 2259 +++++++++ .../crates/ide-assists/src/tests/sourcegen.rs | 195 + .../rust-analyzer/crates/ide-assists/src/utils.rs | 703 +++ .../ide-assists/src/utils/gen_trait_fn_body.rs | 661 +++ .../crates/ide-assists/src/utils/suggest_name.rs | 775 +++ 96 files changed, 56787 insertions(+) create mode 100644 src/tools/rust-analyzer/crates/ide-assists/Cargo.toml create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/assist_config.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/assist_context.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_explicit_type.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_label_to_loop.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_lifetime_to_type.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_impl_members.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_match_arms.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_return_type.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/apply_demorgan.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/auto_import.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/change_visibility.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_bool_then.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_comment_block.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_integer_literal.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_into_to_from.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_iter_for_each_to_for.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_let_else_to_match.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_tuple_struct_to_named_struct.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_while_to_loop.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/destructure_tuple_binding.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/expand_glob_import.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_function.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_module.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_struct_from_enum_variant.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_type_alias.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_variable.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/fix_visibility.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_binexpr.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_comma.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_trait_bound.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_constant.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_default_from_enum_variant.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_default_from_new.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_delegate_methods.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_deref.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_derive.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_documentation_template.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_is_method.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_projection_method.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_variant.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_function.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_getter.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_impl.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_is_empty_from_len.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_new.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_setter.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_call.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_local_variable.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_type_alias.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/introduce_named_generic.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/introduce_named_lifetime.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/invert_if.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/merge_imports.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/merge_match_arms.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_bounds.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_from_mod_rs.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_guard.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_module_to_file.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_to_mod_rs.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/number_representation.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/promote_local_to_const.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/pull_assignment_up.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/qualify_method_call.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/qualify_path.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/raw_string.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_dbg.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_mut.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_unused_param.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_fields.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_impl_items.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_if_let_with_match.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_let_with_if_let.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_qualified_name_with_use.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_string_with_char.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_try_expr_with_match.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_turbofish_with_explicit_type.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/sort_items.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/split_import.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/toggle_ignore.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/unmerge_use.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/unnecessary_async.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_block.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_result_return_type.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/lib.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/tests.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/tests/sourcegen.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/utils.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/utils/gen_trait_fn_body.rs create mode 100644 src/tools/rust-analyzer/crates/ide-assists/src/utils/suggest_name.rs (limited to 'src/tools/rust-analyzer/crates/ide-assists') diff --git a/src/tools/rust-analyzer/crates/ide-assists/Cargo.toml b/src/tools/rust-analyzer/crates/ide-assists/Cargo.toml new file mode 100644 index 000000000..fca09d384 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "ide-assists" +version = "0.0.0" +description = "TBD" +license = "MIT OR Apache-2.0" +edition = "2021" +rust-version = "1.57" + +[lib] +doctest = false + +[dependencies] +cov-mark = "2.0.0-pre.1" + +itertools = "0.10.3" +either = "1.7.0" + +stdx = { path = "../stdx", version = "0.0.0" } +syntax = { path = "../syntax", version = "0.0.0" } +text-edit = { path = "../text-edit", version = "0.0.0" } +profile = { path = "../profile", version = "0.0.0" } +ide-db = { path = "../ide-db", version = "0.0.0" } +hir = { path = "../hir", version = "0.0.0" } + +[dev-dependencies] +test-utils = { path = "../test-utils" } +sourcegen = { path = "../sourcegen" } +expect-test = "1.4.0" + +[features] +in-rust-tree = [] diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/assist_config.rs b/src/tools/rust-analyzer/crates/ide-assists/src/assist_config.rs new file mode 100644 index 000000000..d4d148c77 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/assist_config.rs @@ -0,0 +1,16 @@ +//! Settings for tweaking assists. +//! +//! The fun thing here is `SnippetCap` -- this type can only be created in this +//! module, and we use to statically check that we only produce snippet +//! assists if we are allowed to. + +use ide_db::{imports::insert_use::InsertUseConfig, SnippetCap}; + +use crate::AssistKind; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct AssistConfig { + pub snippet_cap: Option, + pub allowed: Option>, + pub insert_use: InsertUseConfig, +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/assist_context.rs b/src/tools/rust-analyzer/crates/ide-assists/src/assist_context.rs new file mode 100644 index 000000000..f9b426614 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/assist_context.rs @@ -0,0 +1,347 @@ +//! See [`AssistContext`]. + +use std::mem; + +use hir::Semantics; +use ide_db::{ + base_db::{AnchoredPathBuf, FileId, FileRange}, + SnippetCap, +}; +use ide_db::{ + label::Label, + source_change::{FileSystemEdit, SourceChange}, + RootDatabase, +}; +use syntax::{ + algo::{self, find_node_at_offset, find_node_at_range}, + AstNode, AstToken, Direction, SourceFile, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxNodePtr, + SyntaxToken, TextRange, TextSize, TokenAtOffset, +}; +use text_edit::{TextEdit, TextEditBuilder}; + +use crate::{ + assist_config::AssistConfig, Assist, AssistId, AssistKind, AssistResolveStrategy, GroupLabel, +}; + +/// `AssistContext` allows to apply an assist or check if it could be applied. +/// +/// Assists use a somewhat over-engineered approach, given the current needs. +/// The assists workflow consists of two phases. In the first phase, a user asks +/// for the list of available assists. In the second phase, the user picks a +/// particular assist and it gets applied. +/// +/// There are two peculiarities here: +/// +/// * first, we ideally avoid computing more things then necessary to answer "is +/// assist applicable" in the first phase. +/// * second, when we are applying assist, we don't have a guarantee that there +/// weren't any changes between the point when user asked for assists and when +/// they applied a particular assist. So, when applying assist, we need to do +/// all the checks from scratch. +/// +/// To avoid repeating the same code twice for both "check" and "apply" +/// functions, we use an approach reminiscent of that of Django's function based +/// views dealing with forms. Each assist receives a runtime parameter, +/// `resolve`. It first check if an edit is applicable (potentially computing +/// info required to compute the actual edit). If it is applicable, and +/// `resolve` is `true`, it then computes the actual edit. +/// +/// So, to implement the original assists workflow, we can first apply each edit +/// with `resolve = false`, and then applying the selected edit again, with +/// `resolve = true` this time. +/// +/// Note, however, that we don't actually use such two-phase logic at the +/// moment, because the LSP API is pretty awkward in this place, and it's much +/// easier to just compute the edit eagerly :-) +pub(crate) struct AssistContext<'a> { + pub(crate) config: &'a AssistConfig, + pub(crate) sema: Semantics<'a, RootDatabase>, + frange: FileRange, + trimmed_range: TextRange, + source_file: SourceFile, +} + +impl<'a> AssistContext<'a> { + pub(crate) fn new( + sema: Semantics<'a, RootDatabase>, + config: &'a AssistConfig, + frange: FileRange, + ) -> AssistContext<'a> { + let source_file = sema.parse(frange.file_id); + + let start = frange.range.start(); + let end = frange.range.end(); + let left = source_file.syntax().token_at_offset(start); + let right = source_file.syntax().token_at_offset(end); + let left = + left.right_biased().and_then(|t| algo::skip_whitespace_token(t, Direction::Next)); + let right = + right.left_biased().and_then(|t| algo::skip_whitespace_token(t, Direction::Prev)); + let left = left.map(|t| t.text_range().start().clamp(start, end)); + let right = right.map(|t| t.text_range().end().clamp(start, end)); + + let trimmed_range = match (left, right) { + (Some(left), Some(right)) if left <= right => TextRange::new(left, right), + // Selection solely consists of whitespace so just fall back to the original + _ => frange.range, + }; + + AssistContext { config, sema, frange, source_file, trimmed_range } + } + + pub(crate) fn db(&self) -> &RootDatabase { + self.sema.db + } + + // NB, this ignores active selection. + pub(crate) fn offset(&self) -> TextSize { + self.frange.range.start() + } + + pub(crate) fn file_id(&self) -> FileId { + self.frange.file_id + } + + pub(crate) fn has_empty_selection(&self) -> bool { + self.trimmed_range.is_empty() + } + + /// Returns the selected range trimmed for whitespace tokens, that is the range will be snapped + /// to the nearest enclosed token. + pub(crate) fn selection_trimmed(&self) -> TextRange { + self.trimmed_range + } + + pub(crate) fn token_at_offset(&self) -> TokenAtOffset { + self.source_file.syntax().token_at_offset(self.offset()) + } + pub(crate) fn find_token_syntax_at_offset(&self, kind: SyntaxKind) -> Option { + self.token_at_offset().find(|it| it.kind() == kind) + } + pub(crate) fn find_token_at_offset(&self) -> Option { + self.token_at_offset().find_map(T::cast) + } + pub(crate) fn find_node_at_offset(&self) -> Option { + find_node_at_offset(self.source_file.syntax(), self.offset()) + } + pub(crate) fn find_node_at_range(&self) -> Option { + find_node_at_range(self.source_file.syntax(), self.trimmed_range) + } + pub(crate) fn find_node_at_offset_with_descend(&self) -> Option { + self.sema.find_node_at_offset_with_descend(self.source_file.syntax(), self.offset()) + } + /// Returns the element covered by the selection range, this excludes trailing whitespace in the selection. + pub(crate) fn covering_element(&self) -> SyntaxElement { + self.source_file.syntax().covering_element(self.selection_trimmed()) + } +} + +pub(crate) struct Assists { + file: FileId, + resolve: AssistResolveStrategy, + buf: Vec, + allowed: Option>, +} + +impl Assists { + pub(crate) fn new(ctx: &AssistContext<'_>, resolve: AssistResolveStrategy) -> Assists { + Assists { + resolve, + file: ctx.frange.file_id, + buf: Vec::new(), + allowed: ctx.config.allowed.clone(), + } + } + + pub(crate) fn finish(mut self) -> Vec { + self.buf.sort_by_key(|assist| assist.target.len()); + self.buf + } + + pub(crate) fn add( + &mut self, + id: AssistId, + label: impl Into, + target: TextRange, + f: impl FnOnce(&mut AssistBuilder), + ) -> Option<()> { + let mut f = Some(f); + self.add_impl(None, id, label.into(), target, &mut |it| f.take().unwrap()(it)) + } + + pub(crate) fn add_group( + &mut self, + group: &GroupLabel, + id: AssistId, + label: impl Into, + target: TextRange, + f: impl FnOnce(&mut AssistBuilder), + ) -> Option<()> { + let mut f = Some(f); + self.add_impl(Some(group), id, label.into(), target, &mut |it| f.take().unwrap()(it)) + } + + fn add_impl( + &mut self, + group: Option<&GroupLabel>, + id: AssistId, + label: String, + target: TextRange, + f: &mut dyn FnMut(&mut AssistBuilder), + ) -> Option<()> { + if !self.is_allowed(&id) { + return None; + } + + let mut trigger_signature_help = false; + let source_change = if self.resolve.should_resolve(&id) { + let mut builder = AssistBuilder::new(self.file); + f(&mut builder); + trigger_signature_help = builder.trigger_signature_help; + Some(builder.finish()) + } else { + None + }; + + let label = Label::new(label); + let group = group.cloned(); + self.buf.push(Assist { id, label, group, target, source_change, trigger_signature_help }); + Some(()) + } + + fn is_allowed(&self, id: &AssistId) -> bool { + match &self.allowed { + Some(allowed) => allowed.iter().any(|kind| kind.contains(id.1)), + None => true, + } + } +} + +pub(crate) struct AssistBuilder { + edit: TextEditBuilder, + file_id: FileId, + source_change: SourceChange, + trigger_signature_help: bool, + + /// Maps the original, immutable `SyntaxNode` to a `clone_for_update` twin. + mutated_tree: Option, +} + +pub(crate) struct TreeMutator { + immutable: SyntaxNode, + mutable_clone: SyntaxNode, +} + +impl TreeMutator { + pub(crate) fn new(immutable: &SyntaxNode) -> TreeMutator { + let immutable = immutable.ancestors().last().unwrap(); + let mutable_clone = immutable.clone_for_update(); + TreeMutator { immutable, mutable_clone } + } + + pub(crate) fn make_mut(&self, node: &N) -> N { + N::cast(self.make_syntax_mut(node.syntax())).unwrap() + } + + pub(crate) fn make_syntax_mut(&self, node: &SyntaxNode) -> SyntaxNode { + let ptr = SyntaxNodePtr::new(node); + ptr.to_node(&self.mutable_clone) + } +} + +impl AssistBuilder { + pub(crate) fn new(file_id: FileId) -> AssistBuilder { + AssistBuilder { + edit: TextEdit::builder(), + file_id, + source_change: SourceChange::default(), + trigger_signature_help: false, + mutated_tree: None, + } + } + + pub(crate) fn edit_file(&mut self, file_id: FileId) { + self.commit(); + self.file_id = file_id; + } + + fn commit(&mut self) { + if let Some(tm) = self.mutated_tree.take() { + algo::diff(&tm.immutable, &tm.mutable_clone).into_text_edit(&mut self.edit) + } + + let edit = mem::take(&mut self.edit).finish(); + if !edit.is_empty() { + self.source_change.insert_source_edit(self.file_id, edit); + } + } + + pub(crate) fn make_mut(&mut self, node: N) -> N { + self.mutated_tree.get_or_insert_with(|| TreeMutator::new(node.syntax())).make_mut(&node) + } + /// Returns a copy of the `node`, suitable for mutation. + /// + /// Syntax trees in rust-analyzer are typically immutable, and mutating + /// operations panic at runtime. However, it is possible to make a copy of + /// the tree and mutate the copy freely. Mutation is based on interior + /// mutability, and different nodes in the same tree see the same mutations. + /// + /// The typical pattern for an assist is to find specific nodes in the read + /// phase, and then get their mutable couterparts using `make_mut` in the + /// mutable state. + pub(crate) fn make_syntax_mut(&mut self, node: SyntaxNode) -> SyntaxNode { + self.mutated_tree.get_or_insert_with(|| TreeMutator::new(&node)).make_syntax_mut(&node) + } + + /// Remove specified `range` of text. + pub(crate) fn delete(&mut self, range: TextRange) { + self.edit.delete(range) + } + /// Append specified `text` at the given `offset` + pub(crate) fn insert(&mut self, offset: TextSize, text: impl Into) { + self.edit.insert(offset, text.into()) + } + /// Append specified `snippet` at the given `offset` + pub(crate) fn insert_snippet( + &mut self, + _cap: SnippetCap, + offset: TextSize, + snippet: impl Into, + ) { + self.source_change.is_snippet = true; + self.insert(offset, snippet); + } + /// Replaces specified `range` of text with a given string. + pub(crate) fn replace(&mut self, range: TextRange, replace_with: impl Into) { + self.edit.replace(range, replace_with.into()) + } + /// Replaces specified `range` of text with a given `snippet`. + pub(crate) fn replace_snippet( + &mut self, + _cap: SnippetCap, + range: TextRange, + snippet: impl Into, + ) { + self.source_change.is_snippet = true; + self.replace(range, snippet); + } + pub(crate) fn replace_ast(&mut self, old: N, new: N) { + algo::diff(old.syntax(), new.syntax()).into_text_edit(&mut self.edit) + } + pub(crate) fn create_file(&mut self, dst: AnchoredPathBuf, content: impl Into) { + let file_system_edit = FileSystemEdit::CreateFile { dst, initial_contents: content.into() }; + self.source_change.push_file_system_edit(file_system_edit); + } + pub(crate) fn move_file(&mut self, src: FileId, dst: AnchoredPathBuf) { + let file_system_edit = FileSystemEdit::MoveFile { src, dst }; + self.source_change.push_file_system_edit(file_system_edit); + } + pub(crate) fn trigger_signature_help(&mut self) { + self.trigger_signature_help = true; + } + + fn finish(mut self) -> SourceChange { + self.commit(); + mem::take(&mut self.source_change) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_explicit_type.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_explicit_type.rs new file mode 100644 index 000000000..bfa9759ec --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_explicit_type.rs @@ -0,0 +1,325 @@ +use hir::HirDisplay; +use ide_db::syntax_helpers::node_ext::walk_ty; +use syntax::ast::{self, AstNode, LetStmt, Param}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: add_explicit_type +// +// Specify type for a let binding. +// +// ``` +// fn main() { +// let x$0 = 92; +// } +// ``` +// -> +// ``` +// fn main() { +// let x: i32 = 92; +// } +// ``` +pub(crate) fn add_explicit_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let (ascribed_ty, expr, pat) = if let Some(let_stmt) = ctx.find_node_at_offset::() { + let cursor_in_range = { + let eq_range = let_stmt.eq_token()?.text_range(); + ctx.offset() < eq_range.start() + }; + if !cursor_in_range { + cov_mark::hit!(add_explicit_type_not_applicable_if_cursor_after_equals); + return None; + } + + (let_stmt.ty(), let_stmt.initializer(), let_stmt.pat()?) + } else if let Some(param) = ctx.find_node_at_offset::() { + if param.syntax().ancestors().nth(2).and_then(ast::ClosureExpr::cast).is_none() { + cov_mark::hit!(add_explicit_type_not_applicable_in_fn_param); + return None; + } + (param.ty(), None, param.pat()?) + } else { + return None; + }; + + let module = ctx.sema.scope(pat.syntax())?.module(); + let pat_range = pat.syntax().text_range(); + + // Don't enable the assist if there is a type ascription without any placeholders + if let Some(ty) = &ascribed_ty { + let mut contains_infer_ty = false; + walk_ty(ty, &mut |ty| contains_infer_ty |= matches!(ty, ast::Type::InferType(_))); + if !contains_infer_ty { + cov_mark::hit!(add_explicit_type_not_applicable_if_ty_already_specified); + return None; + } + } + + let ty = match (pat, expr) { + (ast::Pat::IdentPat(_), Some(expr)) => ctx.sema.type_of_expr(&expr)?, + (pat, _) => ctx.sema.type_of_pat(&pat)?, + } + .adjusted(); + + // Fully unresolved or unnameable types can't be annotated + if (ty.contains_unknown() && ty.type_arguments().count() == 0) || ty.is_closure() { + cov_mark::hit!(add_explicit_type_not_applicable_if_ty_not_inferred); + return None; + } + + let inferred_type = ty.display_source_code(ctx.db(), module.into()).ok()?; + acc.add( + AssistId("add_explicit_type", AssistKind::RefactorRewrite), + format!("Insert explicit type `{}`", inferred_type), + pat_range, + |builder| match ascribed_ty { + Some(ascribed_ty) => { + builder.replace(ascribed_ty.syntax().text_range(), inferred_type); + } + None => { + builder.insert(pat_range.end(), format!(": {}", inferred_type)); + } + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + #[test] + fn add_explicit_type_target() { + check_assist_target(add_explicit_type, r#"fn f() { let a$0 = 1; }"#, "a"); + } + + #[test] + fn add_explicit_type_simple() { + check_assist( + add_explicit_type, + r#"fn f() { let a$0 = 1; }"#, + r#"fn f() { let a: i32 = 1; }"#, + ); + } + + #[test] + fn add_explicit_type_simple_on_infer_ty() { + check_assist( + add_explicit_type, + r#"fn f() { let a$0: _ = 1; }"#, + r#"fn f() { let a: i32 = 1; }"#, + ); + } + + #[test] + fn add_explicit_type_simple_nested_infer_ty() { + check_assist( + add_explicit_type, + r#" +//- minicore: option +fn f() { + let a$0: Option<_> = Option::Some(1); +} +"#, + r#" +fn f() { + let a: Option = Option::Some(1); +} +"#, + ); + } + + #[test] + fn add_explicit_type_macro_call_expr() { + check_assist( + add_explicit_type, + r"macro_rules! v { () => {0u64} } fn f() { let a$0 = v!(); }", + r"macro_rules! v { () => {0u64} } fn f() { let a: u64 = v!(); }", + ); + } + + #[test] + fn add_explicit_type_not_applicable_for_fully_unresolved() { + cov_mark::check!(add_explicit_type_not_applicable_if_ty_not_inferred); + check_assist_not_applicable(add_explicit_type, r#"fn f() { let a$0 = None; }"#); + } + + #[test] + fn add_explicit_type_applicable_for_partially_unresolved() { + check_assist( + add_explicit_type, + r#" + struct Vec { t: T, v: V } + impl Vec> { + fn new() -> Self { + panic!() + } + } + fn f() { let a$0 = Vec::new(); }"#, + r#" + struct Vec { t: T, v: V } + impl Vec> { + fn new() -> Self { + panic!() + } + } + fn f() { let a: Vec<_, Vec<_, i32>> = Vec::new(); }"#, + ); + } + + #[test] + fn add_explicit_type_not_applicable_closure_expr() { + check_assist_not_applicable(add_explicit_type, r#"fn f() { let a$0 = || {}; }"#); + } + + #[test] + fn add_explicit_type_not_applicable_ty_already_specified() { + cov_mark::check!(add_explicit_type_not_applicable_if_ty_already_specified); + check_assist_not_applicable(add_explicit_type, r#"fn f() { let a$0: i32 = 1; }"#); + } + + #[test] + fn add_explicit_type_not_applicable_cursor_after_equals_of_let() { + cov_mark::check!(add_explicit_type_not_applicable_if_cursor_after_equals); + check_assist_not_applicable( + add_explicit_type, + r#"fn f() {let a =$0 match 1 {2 => 3, 3 => 5};}"#, + ) + } + + /// https://github.com/rust-lang/rust-analyzer/issues/2922 + #[test] + fn regression_issue_2922() { + check_assist( + add_explicit_type, + r#" +fn main() { + let $0v = [0.0; 2]; +} +"#, + r#" +fn main() { + let v: [f64; 2] = [0.0; 2]; +} +"#, + ); + // note: this may break later if we add more consteval. it just needs to be something that our + // consteval engine doesn't understand + check_assist_not_applicable( + add_explicit_type, + r#" +//- minicore: option + +fn main() { + let $0l = [0.0; Some(2).unwrap()]; +} +"#, + ); + } + + #[test] + fn default_generics_should_not_be_added() { + check_assist( + add_explicit_type, + r#" +struct Test { k: K, t: T } + +fn main() { + let test$0 = Test { t: 23u8, k: 33 }; +} +"#, + r#" +struct Test { k: K, t: T } + +fn main() { + let test: Test = Test { t: 23u8, k: 33 }; +} +"#, + ); + } + + #[test] + fn type_should_be_added_after_pattern() { + // LetStmt = Attr* 'let' Pat (':' Type)? '=' initializer:Expr ';' + check_assist( + add_explicit_type, + r#" +fn main() { + let $0test @ () = (); +} +"#, + r#" +fn main() { + let test @ (): () = (); +} +"#, + ); + } + + #[test] + fn add_explicit_type_inserts_coercions() { + check_assist( + add_explicit_type, + r#" +//- minicore: coerce_unsized +fn f() { + let $0x: *const [_] = &[3]; +} +"#, + r#" +fn f() { + let x: *const [i32] = &[3]; +} +"#, + ); + } + + #[test] + fn add_explicit_type_not_applicable_fn_param() { + cov_mark::check!(add_explicit_type_not_applicable_in_fn_param); + check_assist_not_applicable(add_explicit_type, r#"fn f(x$0: ()) {}"#); + } + + #[test] + fn add_explicit_type_ascribes_closure_param() { + check_assist( + add_explicit_type, + r#" +fn f() { + |y$0| { + let x: i32 = y; + }; +} +"#, + r#" +fn f() { + |y: i32| { + let x: i32 = y; + }; +} +"#, + ); + } + + #[test] + fn add_explicit_type_ascribes_closure_param_already_ascribed() { + check_assist( + add_explicit_type, + r#" +//- minicore: option +fn f() { + |mut y$0: Option<_>| { + y = Some(3); + }; +} +"#, + r#" +fn f() { + |mut y: Option| { + y = Some(3); + }; +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_label_to_loop.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_label_to_loop.rs new file mode 100644 index 000000000..001f1e8bb --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_label_to_loop.rs @@ -0,0 +1,164 @@ +use ide_db::syntax_helpers::node_ext::for_each_break_and_continue_expr; +use syntax::{ + ast::{self, AstNode, HasLoopBody}, + T, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: add_label_to_loop +// +// Adds a label to a loop. +// +// ``` +// fn main() { +// loop$0 { +// break; +// continue; +// } +// } +// ``` +// -> +// ``` +// fn main() { +// 'l: loop { +// break 'l; +// continue 'l; +// } +// } +// ``` +pub(crate) fn add_label_to_loop(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let loop_kw = ctx.find_token_syntax_at_offset(T![loop])?; + let loop_expr = loop_kw.parent().and_then(ast::LoopExpr::cast)?; + if loop_expr.label().is_some() { + return None; + } + + acc.add( + AssistId("add_label_to_loop", AssistKind::Generate), + "Add Label", + loop_expr.syntax().text_range(), + |builder| { + builder.insert(loop_kw.text_range().start(), "'l: "); + + let loop_body = loop_expr.loop_body().and_then(|it| it.stmt_list()); + for_each_break_and_continue_expr( + loop_expr.label(), + loop_body, + &mut |expr| match expr { + ast::Expr::BreakExpr(break_expr) => { + if let Some(break_token) = break_expr.break_token() { + builder.insert(break_token.text_range().end(), " 'l") + } + } + ast::Expr::ContinueExpr(continue_expr) => { + if let Some(continue_token) = continue_expr.continue_token() { + builder.insert(continue_token.text_range().end(), " 'l") + } + } + _ => {} + }, + ); + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn add_label() { + check_assist( + add_label_to_loop, + r#" +fn main() { + loop$0 { + break; + continue; + } +}"#, + r#" +fn main() { + 'l: loop { + break 'l; + continue 'l; + } +}"#, + ); + } + + #[test] + fn add_label_to_outer_loop() { + check_assist( + add_label_to_loop, + r#" +fn main() { + loop$0 { + break; + continue; + loop { + break; + continue; + } + } +}"#, + r#" +fn main() { + 'l: loop { + break 'l; + continue 'l; + loop { + break; + continue; + } + } +}"#, + ); + } + + #[test] + fn add_label_to_inner_loop() { + check_assist( + add_label_to_loop, + r#" +fn main() { + loop { + break; + continue; + loop$0 { + break; + continue; + } + } +}"#, + r#" +fn main() { + loop { + break; + continue; + 'l: loop { + break 'l; + continue 'l; + } + } +}"#, + ); + } + + #[test] + fn do_not_add_label_if_exists() { + check_assist_not_applicable( + add_label_to_loop, + r#" +fn main() { + 'l: loop$0 { + break 'l; + continue 'l; + } +}"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_lifetime_to_type.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_lifetime_to_type.rs new file mode 100644 index 000000000..12213c845 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_lifetime_to_type.rs @@ -0,0 +1,229 @@ +use syntax::ast::{self, AstNode, HasGenericParams, HasName}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: add_lifetime_to_type +// +// Adds a new lifetime to a struct, enum or union. +// +// ``` +// struct Point { +// x: &$0u32, +// y: u32, +// } +// ``` +// -> +// ``` +// struct Point<'a> { +// x: &'a u32, +// y: u32, +// } +// ``` +pub(crate) fn add_lifetime_to_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let ref_type_focused = ctx.find_node_at_offset::()?; + if ref_type_focused.lifetime().is_some() { + return None; + } + + let node = ctx.find_node_at_offset::()?; + let has_lifetime = node + .generic_param_list() + .map_or(false, |gen_list| gen_list.lifetime_params().next().is_some()); + + if has_lifetime { + return None; + } + + let ref_types = fetch_borrowed_types(&node)?; + let target = node.syntax().text_range(); + + acc.add( + AssistId("add_lifetime_to_type", AssistKind::Generate), + "Add lifetime", + target, + |builder| { + match node.generic_param_list() { + Some(gen_param) => { + if let Some(left_angle) = gen_param.l_angle_token() { + builder.insert(left_angle.text_range().end(), "'a, "); + } + } + None => { + if let Some(name) = node.name() { + builder.insert(name.syntax().text_range().end(), "<'a>"); + } + } + } + + for ref_type in ref_types { + if let Some(amp_token) = ref_type.amp_token() { + builder.insert(amp_token.text_range().end(), "'a "); + } + } + }, + ) +} + +fn fetch_borrowed_types(node: &ast::Adt) -> Option> { + let ref_types: Vec = match node { + ast::Adt::Enum(enum_) => { + let variant_list = enum_.variant_list()?; + variant_list + .variants() + .filter_map(|variant| { + let field_list = variant.field_list()?; + + find_ref_types_from_field_list(&field_list) + }) + .flatten() + .collect() + } + ast::Adt::Struct(strukt) => { + let field_list = strukt.field_list()?; + find_ref_types_from_field_list(&field_list)? + } + ast::Adt::Union(un) => { + let record_field_list = un.record_field_list()?; + record_field_list + .fields() + .filter_map(|r_field| { + if let ast::Type::RefType(ref_type) = r_field.ty()? { + if ref_type.lifetime().is_none() { + return Some(ref_type); + } + } + + None + }) + .collect() + } + }; + + if ref_types.is_empty() { + None + } else { + Some(ref_types) + } +} + +fn find_ref_types_from_field_list(field_list: &ast::FieldList) -> Option> { + let ref_types: Vec = match field_list { + ast::FieldList::RecordFieldList(record_list) => record_list + .fields() + .filter_map(|f| { + if let ast::Type::RefType(ref_type) = f.ty()? { + if ref_type.lifetime().is_none() { + return Some(ref_type); + } + } + + None + }) + .collect(), + ast::FieldList::TupleFieldList(tuple_field_list) => tuple_field_list + .fields() + .filter_map(|f| { + if let ast::Type::RefType(ref_type) = f.ty()? { + if ref_type.lifetime().is_none() { + return Some(ref_type); + } + } + + None + }) + .collect(), + }; + + if ref_types.is_empty() { + None + } else { + Some(ref_types) + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn add_lifetime_to_struct() { + check_assist( + add_lifetime_to_type, + r#"struct Foo { a: &$0i32 }"#, + r#"struct Foo<'a> { a: &'a i32 }"#, + ); + + check_assist( + add_lifetime_to_type, + r#"struct Foo { a: &$0i32, b: &usize }"#, + r#"struct Foo<'a> { a: &'a i32, b: &'a usize }"#, + ); + + check_assist( + add_lifetime_to_type, + r#"struct Foo { a: &$0i32, b: usize }"#, + r#"struct Foo<'a> { a: &'a i32, b: usize }"#, + ); + + check_assist( + add_lifetime_to_type, + r#"struct Foo { a: &$0T, b: usize }"#, + r#"struct Foo<'a, T> { a: &'a T, b: usize }"#, + ); + + check_assist_not_applicable(add_lifetime_to_type, r#"struct Foo<'a> { a: &$0'a i32 }"#); + check_assist_not_applicable(add_lifetime_to_type, r#"struct Foo { a: &'a$0 i32 }"#); + } + + #[test] + fn add_lifetime_to_enum() { + check_assist( + add_lifetime_to_type, + r#"enum Foo { Bar { a: i32 }, Other, Tuple(u32, &$0u32)}"#, + r#"enum Foo<'a> { Bar { a: i32 }, Other, Tuple(u32, &'a u32)}"#, + ); + + check_assist( + add_lifetime_to_type, + r#"enum Foo { Bar { a: &$0i32 }}"#, + r#"enum Foo<'a> { Bar { a: &'a i32 }}"#, + ); + + check_assist( + add_lifetime_to_type, + r#"enum Foo { Bar { a: &$0i32, b: &T }}"#, + r#"enum Foo<'a, T> { Bar { a: &'a i32, b: &'a T }}"#, + ); + + check_assist_not_applicable( + add_lifetime_to_type, + r#"enum Foo<'a> { Bar { a: &$0'a i32 }}"#, + ); + check_assist_not_applicable(add_lifetime_to_type, r#"enum Foo { Bar, $0Misc }"#); + } + + #[test] + fn add_lifetime_to_union() { + check_assist( + add_lifetime_to_type, + r#"union Foo { a: &$0i32 }"#, + r#"union Foo<'a> { a: &'a i32 }"#, + ); + + check_assist( + add_lifetime_to_type, + r#"union Foo { a: &$0i32, b: &usize }"#, + r#"union Foo<'a> { a: &'a i32, b: &'a usize }"#, + ); + + check_assist( + add_lifetime_to_type, + r#"union Foo { a: &$0T, b: usize }"#, + r#"union Foo<'a, T> { a: &'a T, b: usize }"#, + ); + + check_assist_not_applicable(add_lifetime_to_type, r#"struct Foo<'a> { a: &'a $0i32 }"#); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_impl_members.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_impl_members.rs new file mode 100644 index 000000000..c808c010c --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_impl_members.rs @@ -0,0 +1,1340 @@ +use hir::HasSource; +use ide_db::{ + syntax_helpers::insert_whitespace_into_node::insert_ws_into, traits::resolve_target_trait, +}; +use syntax::ast::{self, make, AstNode}; + +use crate::{ + assist_context::{AssistContext, Assists}, + utils::{ + add_trait_assoc_items_to_impl, filter_assoc_items, gen_trait_fn_body, render_snippet, + Cursor, DefaultMethods, + }, + AssistId, AssistKind, +}; + +// Assist: add_impl_missing_members +// +// Adds scaffold for required impl members. +// +// ``` +// trait Trait { +// type X; +// fn foo(&self) -> T; +// fn bar(&self) {} +// } +// +// impl Trait for () {$0 +// +// } +// ``` +// -> +// ``` +// trait Trait { +// type X; +// fn foo(&self) -> T; +// fn bar(&self) {} +// } +// +// impl Trait for () { +// $0type X; +// +// fn foo(&self) -> u32 { +// todo!() +// } +// } +// ``` +pub(crate) fn add_missing_impl_members(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + add_missing_impl_members_inner( + acc, + ctx, + DefaultMethods::No, + "add_impl_missing_members", + "Implement missing members", + ) +} + +// Assist: add_impl_default_members +// +// Adds scaffold for overriding default impl members. +// +// ``` +// trait Trait { +// type X; +// fn foo(&self); +// fn bar(&self) {} +// } +// +// impl Trait for () { +// type X = (); +// fn foo(&self) {}$0 +// } +// ``` +// -> +// ``` +// trait Trait { +// type X; +// fn foo(&self); +// fn bar(&self) {} +// } +// +// impl Trait for () { +// type X = (); +// fn foo(&self) {} +// +// $0fn bar(&self) {} +// } +// ``` +pub(crate) fn add_missing_default_members( + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { + add_missing_impl_members_inner( + acc, + ctx, + DefaultMethods::Only, + "add_impl_default_members", + "Implement default members", + ) +} + +fn add_missing_impl_members_inner( + acc: &mut Assists, + ctx: &AssistContext<'_>, + mode: DefaultMethods, + assist_id: &'static str, + label: &'static str, +) -> Option<()> { + let _p = profile::span("add_missing_impl_members_inner"); + let impl_def = ctx.find_node_at_offset::()?; + let target_scope = ctx.sema.scope(impl_def.syntax())?; + let trait_ = resolve_target_trait(&ctx.sema, &impl_def)?; + + let missing_items = filter_assoc_items( + &ctx.sema, + &ide_db::traits::get_missing_assoc_items(&ctx.sema, &impl_def), + mode, + ); + + if missing_items.is_empty() { + return None; + } + + let target = impl_def.syntax().text_range(); + acc.add(AssistId(assist_id, AssistKind::QuickFix), label, target, |builder| { + let missing_items = missing_items + .into_iter() + .map(|it| { + if ctx.sema.hir_file_for(it.syntax()).is_macro() { + if let Some(it) = ast::AssocItem::cast(insert_ws_into(it.syntax().clone())) { + return it; + } + } + it.clone_for_update() + }) + .collect(); + let (new_impl_def, first_new_item) = add_trait_assoc_items_to_impl( + &ctx.sema, + missing_items, + trait_, + impl_def.clone(), + target_scope, + ); + match ctx.config.snippet_cap { + None => builder.replace(target, new_impl_def.to_string()), + Some(cap) => { + let mut cursor = Cursor::Before(first_new_item.syntax()); + let placeholder; + if let DefaultMethods::No = mode { + if let ast::AssocItem::Fn(func) = &first_new_item { + if try_gen_trait_body(ctx, func, &trait_, &impl_def).is_none() { + if let Some(m) = + func.syntax().descendants().find_map(ast::MacroCall::cast) + { + if m.syntax().text() == "todo!()" { + placeholder = m; + cursor = Cursor::Replace(placeholder.syntax()); + } + } + } + } + } + builder.replace_snippet( + cap, + target, + render_snippet(cap, new_impl_def.syntax(), cursor), + ) + } + }; + }) +} + +fn try_gen_trait_body( + ctx: &AssistContext<'_>, + func: &ast::Fn, + trait_: &hir::Trait, + impl_def: &ast::Impl, +) -> Option<()> { + let trait_path = make::ext::ident_path(&trait_.name(ctx.db()).to_string()); + let hir_ty = ctx.sema.resolve_type(&impl_def.self_ty()?)?; + let adt = hir_ty.as_adt()?.source(ctx.db())?; + gen_trait_fn_body(func, &trait_path, &adt.value) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_add_missing_impl_members() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { + type Output; + + const CONST: usize = 42; + + fn foo(&self); + fn bar(&self); + fn baz(&self); +} + +struct S; + +impl Foo for S { + fn bar(&self) {} +$0 +}"#, + r#" +trait Foo { + type Output; + + const CONST: usize = 42; + + fn foo(&self); + fn bar(&self); + fn baz(&self); +} + +struct S; + +impl Foo for S { + fn bar(&self) {} + + $0type Output; + + const CONST: usize = 42; + + fn foo(&self) { + todo!() + } + + fn baz(&self) { + todo!() + } + +}"#, + ); + } + + #[test] + fn test_copied_overriden_members() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { + fn foo(&self); + fn bar(&self) -> bool { true } + fn baz(&self) -> u32 { 42 } +} + +struct S; + +impl Foo for S { + fn bar(&self) {} +$0 +}"#, + r#" +trait Foo { + fn foo(&self); + fn bar(&self) -> bool { true } + fn baz(&self) -> u32 { 42 } +} + +struct S; + +impl Foo for S { + fn bar(&self) {} + + fn foo(&self) { + ${0:todo!()} + } + +}"#, + ); + } + + #[test] + fn test_empty_impl_def() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { fn foo(&self); } +struct S; +impl Foo for S { $0 }"#, + r#" +trait Foo { fn foo(&self); } +struct S; +impl Foo for S { + fn foo(&self) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_impl_def_without_braces() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { fn foo(&self); } +struct S; +impl Foo for S$0"#, + r#" +trait Foo { fn foo(&self); } +struct S; +impl Foo for S { + fn foo(&self) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn fill_in_type_params_1() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { fn foo(&self, t: T) -> &T; } +struct S; +impl Foo for S { $0 }"#, + r#" +trait Foo { fn foo(&self, t: T) -> &T; } +struct S; +impl Foo for S { + fn foo(&self, t: u32) -> &u32 { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn fill_in_type_params_2() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { fn foo(&self, t: T) -> &T; } +struct S; +impl Foo for S { $0 }"#, + r#" +trait Foo { fn foo(&self, t: T) -> &T; } +struct S; +impl Foo for S { + fn foo(&self, t: U) -> &U { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_cursor_after_empty_impl_def() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { fn foo(&self); } +struct S; +impl Foo for S {}$0"#, + r#" +trait Foo { fn foo(&self); } +struct S; +impl Foo for S { + fn foo(&self) { + ${0:todo!()} + } +}"#, + ) + } + + #[test] + fn test_qualify_path_1() { + check_assist( + add_missing_impl_members, + r#" +mod foo { + pub struct Bar; + trait Foo { fn foo(&self, bar: Bar); } +} +struct S; +impl foo::Foo for S { $0 }"#, + r#" +mod foo { + pub struct Bar; + trait Foo { fn foo(&self, bar: Bar); } +} +struct S; +impl foo::Foo for S { + fn foo(&self, bar: foo::Bar) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_qualify_path_2() { + check_assist( + add_missing_impl_members, + r#" +mod foo { + pub mod bar { + pub struct Bar; + pub trait Foo { fn foo(&self, bar: Bar); } + } +} + +use foo::bar; + +struct S; +impl bar::Foo for S { $0 }"#, + r#" +mod foo { + pub mod bar { + pub struct Bar; + pub trait Foo { fn foo(&self, bar: Bar); } + } +} + +use foo::bar; + +struct S; +impl bar::Foo for S { + fn foo(&self, bar: bar::Bar) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_qualify_path_generic() { + check_assist( + add_missing_impl_members, + r#" +mod foo { + pub struct Bar; + trait Foo { fn foo(&self, bar: Bar); } +} +struct S; +impl foo::Foo for S { $0 }"#, + r#" +mod foo { + pub struct Bar; + trait Foo { fn foo(&self, bar: Bar); } +} +struct S; +impl foo::Foo for S { + fn foo(&self, bar: foo::Bar) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_qualify_path_and_substitute_param() { + check_assist( + add_missing_impl_members, + r#" +mod foo { + pub struct Bar; + trait Foo { fn foo(&self, bar: Bar); } +} +struct S; +impl foo::Foo for S { $0 }"#, + r#" +mod foo { + pub struct Bar; + trait Foo { fn foo(&self, bar: Bar); } +} +struct S; +impl foo::Foo for S { + fn foo(&self, bar: foo::Bar) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_substitute_param_no_qualify() { + // when substituting params, the substituted param should not be qualified! + check_assist( + add_missing_impl_members, + r#" +mod foo { + trait Foo { fn foo(&self, bar: T); } + pub struct Param; +} +struct Param; +struct S; +impl foo::Foo for S { $0 }"#, + r#" +mod foo { + trait Foo { fn foo(&self, bar: T); } + pub struct Param; +} +struct Param; +struct S; +impl foo::Foo for S { + fn foo(&self, bar: Param) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_qualify_path_associated_item() { + check_assist( + add_missing_impl_members, + r#" +mod foo { + pub struct Bar; + impl Bar { type Assoc = u32; } + trait Foo { fn foo(&self, bar: Bar::Assoc); } +} +struct S; +impl foo::Foo for S { $0 }"#, + r#" +mod foo { + pub struct Bar; + impl Bar { type Assoc = u32; } + trait Foo { fn foo(&self, bar: Bar::Assoc); } +} +struct S; +impl foo::Foo for S { + fn foo(&self, bar: foo::Bar::Assoc) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_qualify_path_nested() { + check_assist( + add_missing_impl_members, + r#" +mod foo { + pub struct Bar; + pub struct Baz; + trait Foo { fn foo(&self, bar: Bar); } +} +struct S; +impl foo::Foo for S { $0 }"#, + r#" +mod foo { + pub struct Bar; + pub struct Baz; + trait Foo { fn foo(&self, bar: Bar); } +} +struct S; +impl foo::Foo for S { + fn foo(&self, bar: foo::Bar) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_qualify_path_fn_trait_notation() { + check_assist( + add_missing_impl_members, + r#" +mod foo { + pub trait Fn { type Output; } + trait Foo { fn foo(&self, bar: dyn Fn(u32) -> i32); } +} +struct S; +impl foo::Foo for S { $0 }"#, + r#" +mod foo { + pub trait Fn { type Output; } + trait Foo { fn foo(&self, bar: dyn Fn(u32) -> i32); } +} +struct S; +impl foo::Foo for S { + fn foo(&self, bar: dyn Fn(u32) -> i32) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_empty_trait() { + check_assist_not_applicable( + add_missing_impl_members, + r#" +trait Foo; +struct S; +impl Foo for S { $0 }"#, + ) + } + + #[test] + fn test_ignore_unnamed_trait_members_and_default_methods() { + check_assist_not_applicable( + add_missing_impl_members, + r#" +trait Foo { + fn (arg: u32); + fn valid(some: u32) -> bool { false } +} +struct S; +impl Foo for S { $0 }"#, + ) + } + + #[test] + fn test_with_docstring_and_attrs() { + check_assist( + add_missing_impl_members, + r#" +#[doc(alias = "test alias")] +trait Foo { + /// doc string + type Output; + + #[must_use] + fn foo(&self); +} +struct S; +impl Foo for S {}$0"#, + r#" +#[doc(alias = "test alias")] +trait Foo { + /// doc string + type Output; + + #[must_use] + fn foo(&self); +} +struct S; +impl Foo for S { + $0type Output; + + fn foo(&self) { + todo!() + } +}"#, + ) + } + + #[test] + fn test_default_methods() { + check_assist( + add_missing_default_members, + r#" +trait Foo { + type Output; + + const CONST: usize = 42; + + fn valid(some: u32) -> bool { false } + fn foo(some: u32) -> bool; +} +struct S; +impl Foo for S { $0 }"#, + r#" +trait Foo { + type Output; + + const CONST: usize = 42; + + fn valid(some: u32) -> bool { false } + fn foo(some: u32) -> bool; +} +struct S; +impl Foo for S { + $0fn valid(some: u32) -> bool { false } +}"#, + ) + } + + #[test] + fn test_generic_single_default_parameter() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { + fn bar(&self, other: &T); +} + +struct S; +impl Foo for S { $0 }"#, + r#" +trait Foo { + fn bar(&self, other: &T); +} + +struct S; +impl Foo for S { + fn bar(&self, other: &Self) { + ${0:todo!()} + } +}"#, + ) + } + + #[test] + fn test_generic_default_parameter_is_second() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { + fn bar(&self, this: &T1, that: &T2); +} + +struct S; +impl Foo for S { $0 }"#, + r#" +trait Foo { + fn bar(&self, this: &T1, that: &T2); +} + +struct S; +impl Foo for S { + fn bar(&self, this: &T, that: &Self) { + ${0:todo!()} + } +}"#, + ) + } + + #[test] + fn test_assoc_type_bounds_are_removed() { + check_assist( + add_missing_impl_members, + r#" +trait Tr { + type Ty: Copy + 'static; +} + +impl Tr for ()$0 { +}"#, + r#" +trait Tr { + type Ty: Copy + 'static; +} + +impl Tr for () { + $0type Ty; +}"#, + ) + } + + #[test] + fn test_whitespace_fixup_preserves_bad_tokens() { + check_assist( + add_missing_impl_members, + r#" +trait Tr { + fn foo(); +} + +impl Tr for ()$0 { + +++ +}"#, + r#" +trait Tr { + fn foo(); +} + +impl Tr for () { + fn foo() { + ${0:todo!()} + } + +++ +}"#, + ) + } + + #[test] + fn test_whitespace_fixup_preserves_comments() { + check_assist( + add_missing_impl_members, + r#" +trait Tr { + fn foo(); +} + +impl Tr for ()$0 { + // very important +}"#, + r#" +trait Tr { + fn foo(); +} + +impl Tr for () { + fn foo() { + ${0:todo!()} + } + // very important +}"#, + ) + } + + #[test] + fn weird_path() { + check_assist( + add_missing_impl_members, + r#" +trait Test { + fn foo(&self, x: crate) +} +impl Test for () { + $0 +} +"#, + r#" +trait Test { + fn foo(&self, x: crate) +} +impl Test for () { + fn foo(&self, x: crate) { + ${0:todo!()} + } +} +"#, + ) + } + + #[test] + fn missing_generic_type() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { + fn foo(&self, bar: BAR); +} +impl Foo for () { + $0 +} +"#, + r#" +trait Foo { + fn foo(&self, bar: BAR); +} +impl Foo for () { + fn foo(&self, bar: BAR) { + ${0:todo!()} + } +} +"#, + ) + } + + #[test] + fn does_not_requalify_self_as_crate() { + check_assist( + add_missing_default_members, + r" +struct Wrapper(T); + +trait T { + fn f(self) -> Wrapper { + Wrapper(self) + } +} + +impl T for () { + $0 +} +", + r" +struct Wrapper(T); + +trait T { + fn f(self) -> Wrapper { + Wrapper(self) + } +} + +impl T for () { + $0fn f(self) -> Wrapper { + Wrapper(self) + } +} +", + ); + } + + #[test] + fn test_default_body_generation() { + check_assist( + add_missing_impl_members, + r#" +//- minicore: default +struct Foo(usize); + +impl Default for Foo { + $0 +} +"#, + r#" +struct Foo(usize); + +impl Default for Foo { + $0fn default() -> Self { + Self(Default::default()) + } +} +"#, + ) + } + + #[test] + fn test_from_macro() { + check_assist( + add_missing_default_members, + r#" +macro_rules! foo { + () => { + trait FooB { + fn foo<'lt>(&'lt self) {} + } + } +} +foo!(); +struct Foo(usize); + +impl FooB for Foo { + $0 +} +"#, + r#" +macro_rules! foo { + () => { + trait FooB { + fn foo<'lt>(&'lt self) {} + } + } +} +foo!(); +struct Foo(usize); + +impl FooB for Foo { + $0fn foo< 'lt>(& 'lt self){} +} +"#, + ) + } + + #[test] + fn test_assoc_type_when_trait_with_same_name_in_scope() { + check_assist( + add_missing_impl_members, + r#" +pub trait Foo {} + +pub trait Types { + type Foo; +} + +pub trait Behavior { + fn reproduce(&self, foo: T::Foo); +} + +pub struct Impl; + +impl Behavior for Impl { $0 }"#, + r#" +pub trait Foo {} + +pub trait Types { + type Foo; +} + +pub trait Behavior { + fn reproduce(&self, foo: T::Foo); +} + +pub struct Impl; + +impl Behavior for Impl { + fn reproduce(&self, foo: ::Foo) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_assoc_type_on_concrete_type() { + check_assist( + add_missing_impl_members, + r#" +pub trait Types { + type Foo; +} + +impl Types for u32 { + type Foo = bool; +} + +pub trait Behavior { + fn reproduce(&self, foo: T::Foo); +} + +pub struct Impl; + +impl Behavior for Impl { $0 }"#, + r#" +pub trait Types { + type Foo; +} + +impl Types for u32 { + type Foo = bool; +} + +pub trait Behavior { + fn reproduce(&self, foo: T::Foo); +} + +pub struct Impl; + +impl Behavior for Impl { + fn reproduce(&self, foo: ::Foo) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_assoc_type_on_concrete_type_qualified() { + check_assist( + add_missing_impl_members, + r#" +pub trait Types { + type Foo; +} + +impl Types for std::string::String { + type Foo = bool; +} + +pub trait Behavior { + fn reproduce(&self, foo: T::Foo); +} + +pub struct Impl; + +impl Behavior for Impl { $0 }"#, + r#" +pub trait Types { + type Foo; +} + +impl Types for std::string::String { + type Foo = bool; +} + +pub trait Behavior { + fn reproduce(&self, foo: T::Foo); +} + +pub struct Impl; + +impl Behavior for Impl { + fn reproduce(&self, foo: ::Foo) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_assoc_type_on_concrete_type_multi_option_ambiguous() { + check_assist( + add_missing_impl_members, + r#" +pub trait Types { + type Foo; +} + +pub trait Types2 { + type Foo; +} + +impl Types for u32 { + type Foo = bool; +} + +impl Types2 for u32 { + type Foo = String; +} + +pub trait Behavior { + fn reproduce(&self, foo: ::Foo); +} + +pub struct Impl; + +impl Behavior for Impl { $0 }"#, + r#" +pub trait Types { + type Foo; +} + +pub trait Types2 { + type Foo; +} + +impl Types for u32 { + type Foo = bool; +} + +impl Types2 for u32 { + type Foo = String; +} + +pub trait Behavior { + fn reproduce(&self, foo: ::Foo); +} + +pub struct Impl; + +impl Behavior for Impl { + fn reproduce(&self, foo: ::Foo) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_assoc_type_on_concrete_type_multi_option() { + check_assist( + add_missing_impl_members, + r#" +pub trait Types { + type Foo; +} + +pub trait Types2 { + type Bar; +} + +impl Types for u32 { + type Foo = bool; +} + +impl Types2 for u32 { + type Bar = String; +} + +pub trait Behavior { + fn reproduce(&self, foo: T::Bar); +} + +pub struct Impl; + +impl Behavior for Impl { $0 }"#, + r#" +pub trait Types { + type Foo; +} + +pub trait Types2 { + type Bar; +} + +impl Types for u32 { + type Foo = bool; +} + +impl Types2 for u32 { + type Bar = String; +} + +pub trait Behavior { + fn reproduce(&self, foo: T::Bar); +} + +pub struct Impl; + +impl Behavior for Impl { + fn reproduce(&self, foo: ::Bar) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_assoc_type_on_concrete_type_multi_option_foreign() { + check_assist( + add_missing_impl_members, + r#" +mod bar { + pub trait Types2 { + type Bar; + } +} + +pub trait Types { + type Foo; +} + +impl Types for u32 { + type Foo = bool; +} + +impl bar::Types2 for u32 { + type Bar = String; +} + +pub trait Behavior { + fn reproduce(&self, foo: T::Bar); +} + +pub struct Impl; + +impl Behavior for Impl { $0 }"#, + r#" +mod bar { + pub trait Types2 { + type Bar; + } +} + +pub trait Types { + type Foo; +} + +impl Types for u32 { + type Foo = bool; +} + +impl bar::Types2 for u32 { + type Bar = String; +} + +pub trait Behavior { + fn reproduce(&self, foo: T::Bar); +} + +pub struct Impl; + +impl Behavior for Impl { + fn reproduce(&self, foo: ::Bar) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_transform_path_in_path_expr() { + check_assist( + add_missing_default_members, + r#" +pub trait Const { + const FOO: u32; +} + +pub trait Trait { + fn foo() -> bool { + match T::FOO { + 0 => true, + _ => false, + } + } +} + +impl Const for u32 { + const FOO: u32 = 1; +} + +struct Impl; + +impl Trait for Impl { $0 }"#, + r#" +pub trait Const { + const FOO: u32; +} + +pub trait Trait { + fn foo() -> bool { + match T::FOO { + 0 => true, + _ => false, + } + } +} + +impl Const for u32 { + const FOO: u32 = 1; +} + +struct Impl; + +impl Trait for Impl { + $0fn foo() -> bool { + match ::FOO { + 0 => true, + _ => false, + } + } +}"#, + ); + } + + #[test] + fn test_default_partial_eq() { + check_assist( + add_missing_default_members, + r#" +//- minicore: eq +struct SomeStruct { + data: usize, + field: (usize, usize), +} +impl PartialEq for SomeStruct {$0} +"#, + r#" +struct SomeStruct { + data: usize, + field: (usize, usize), +} +impl PartialEq for SomeStruct { + $0fn ne(&self, other: &Self) -> bool { + !self.eq(other) + } +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_match_arms.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_match_arms.rs new file mode 100644 index 000000000..b16f6fe03 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_match_arms.rs @@ -0,0 +1,1709 @@ +use std::iter::{self, Peekable}; + +use either::Either; +use hir::{Adt, Crate, HasAttrs, HasSource, ModuleDef, Semantics}; +use ide_db::RootDatabase; +use ide_db::{famous_defs::FamousDefs, helpers::mod_path_to_ast}; +use itertools::Itertools; +use syntax::ast::{self, make, AstNode, HasName, MatchArmList, MatchExpr, Pat}; + +use crate::{ + utils::{self, render_snippet, Cursor}, + AssistContext, AssistId, AssistKind, Assists, +}; + +// Assist: add_missing_match_arms +// +// Adds missing clauses to a `match` expression. +// +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// $0 +// } +// } +// ``` +// -> +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// $0Action::Move { distance } => todo!(), +// Action::Stop => todo!(), +// } +// } +// ``` +pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let match_expr = ctx.find_node_at_offset_with_descend::()?; + let match_arm_list = match_expr.match_arm_list()?; + let target_range = ctx.sema.original_range(match_expr.syntax()).range; + + if let None = cursor_at_trivial_match_arm_list(ctx, &match_expr, &match_arm_list) { + let arm_list_range = ctx.sema.original_range(match_arm_list.syntax()).range; + let cursor_in_range = arm_list_range.contains_range(ctx.selection_trimmed()); + if cursor_in_range { + cov_mark::hit!(not_applicable_outside_of_range_right); + return None; + } + } + + let expr = match_expr.expr()?; + + let mut has_catch_all_arm = false; + + let top_lvl_pats: Vec<_> = match_arm_list + .arms() + .filter_map(|arm| Some((arm.pat()?, arm.guard().is_some()))) + .flat_map(|(pat, has_guard)| { + match pat { + // Special case OrPat as separate top-level pats + Pat::OrPat(or_pat) => Either::Left(or_pat.pats()), + _ => Either::Right(iter::once(pat)), + } + .map(move |pat| (pat, has_guard)) + }) + .map(|(pat, has_guard)| { + has_catch_all_arm |= !has_guard && matches!(pat, Pat::WildcardPat(_)); + pat + }) + // Exclude top level wildcards so that they are expanded by this assist, retains status quo in #8129. + .filter(|pat| !matches!(pat, Pat::WildcardPat(_))) + .collect(); + + let module = ctx.sema.scope(expr.syntax())?.module(); + let (mut missing_pats, is_non_exhaustive): ( + Peekable>>, + bool, + ) = if let Some(enum_def) = resolve_enum_def(&ctx.sema, &expr) { + let is_non_exhaustive = enum_def.is_non_exhaustive(ctx.db(), module.krate()); + + let variants = enum_def.variants(ctx.db()); + + let missing_pats = variants + .into_iter() + .filter_map(|variant| { + Some(( + build_pat(ctx.db(), module, variant)?, + variant.should_be_hidden(ctx.db(), module.krate()), + )) + }) + .filter(|(variant_pat, _)| is_variant_missing(&top_lvl_pats, variant_pat)); + + let option_enum = FamousDefs(&ctx.sema, module.krate()).core_option_Option().map(lift_enum); + let missing_pats: Box> = if Some(enum_def) == option_enum { + // Match `Some` variant first. + cov_mark::hit!(option_order); + Box::new(missing_pats.rev()) + } else { + Box::new(missing_pats) + }; + (missing_pats.peekable(), is_non_exhaustive) + } else if let Some(enum_defs) = resolve_tuple_of_enum_def(&ctx.sema, &expr) { + let is_non_exhaustive = + enum_defs.iter().any(|enum_def| enum_def.is_non_exhaustive(ctx.db(), module.krate())); + + let mut n_arms = 1; + let variants_of_enums: Vec> = enum_defs + .into_iter() + .map(|enum_def| enum_def.variants(ctx.db())) + .inspect(|variants| n_arms *= variants.len()) + .collect(); + + // When calculating the match arms for a tuple of enums, we want + // to create a match arm for each possible combination of enum + // values. The `multi_cartesian_product` method transforms + // Vec> into Vec<(EnumVariant, .., EnumVariant)> + // where each tuple represents a proposed match arm. + + // A number of arms grows very fast on even a small tuple of large enums. + // We skip the assist beyond an arbitrary threshold. + if n_arms > 256 { + return None; + } + let missing_pats = variants_of_enums + .into_iter() + .multi_cartesian_product() + .inspect(|_| cov_mark::hit!(add_missing_match_arms_lazy_computation)) + .map(|variants| { + let is_hidden = variants + .iter() + .any(|variant| variant.should_be_hidden(ctx.db(), module.krate())); + let patterns = + variants.into_iter().filter_map(|variant| build_pat(ctx.db(), module, variant)); + + (ast::Pat::from(make::tuple_pat(patterns)), is_hidden) + }) + .filter(|(variant_pat, _)| is_variant_missing(&top_lvl_pats, variant_pat)); + ((Box::new(missing_pats) as Box>).peekable(), is_non_exhaustive) + } else { + return None; + }; + + let mut needs_catch_all_arm = is_non_exhaustive && !has_catch_all_arm; + + if !needs_catch_all_arm && missing_pats.peek().is_none() { + return None; + } + + acc.add( + AssistId("add_missing_match_arms", AssistKind::QuickFix), + "Fill match arms", + target_range, + |builder| { + let new_match_arm_list = match_arm_list.clone_for_update(); + let missing_arms = missing_pats + .map(|(pat, hidden)| { + (make::match_arm(iter::once(pat), None, make::ext::expr_todo()), hidden) + }) + .map(|(it, hidden)| (it.clone_for_update(), hidden)); + + let catch_all_arm = new_match_arm_list + .arms() + .find(|arm| matches!(arm.pat(), Some(ast::Pat::WildcardPat(_)))); + if let Some(arm) = catch_all_arm { + let is_empty_expr = arm.expr().map_or(true, |e| match e { + ast::Expr::BlockExpr(b) => { + b.statements().next().is_none() && b.tail_expr().is_none() + } + ast::Expr::TupleExpr(t) => t.fields().next().is_none(), + _ => false, + }); + if is_empty_expr { + arm.remove(); + } else { + cov_mark::hit!(add_missing_match_arms_empty_expr); + } + } + let mut first_new_arm = None; + for (arm, hidden) in missing_arms { + if hidden { + needs_catch_all_arm = !has_catch_all_arm; + } else { + first_new_arm.get_or_insert_with(|| arm.clone()); + new_match_arm_list.add_arm(arm); + } + } + if needs_catch_all_arm && !has_catch_all_arm { + cov_mark::hit!(added_wildcard_pattern); + let arm = make::match_arm( + iter::once(make::wildcard_pat().into()), + None, + make::ext::expr_todo(), + ) + .clone_for_update(); + first_new_arm.get_or_insert_with(|| arm.clone()); + new_match_arm_list.add_arm(arm); + } + + let old_range = ctx.sema.original_range(match_arm_list.syntax()).range; + match (first_new_arm, ctx.config.snippet_cap) { + (Some(first_new_arm), Some(cap)) => { + let extend_lifetime; + let cursor = + match first_new_arm.syntax().descendants().find_map(ast::WildcardPat::cast) + { + Some(it) => { + extend_lifetime = it.syntax().clone(); + Cursor::Replace(&extend_lifetime) + } + None => Cursor::Before(first_new_arm.syntax()), + }; + let snippet = render_snippet(cap, new_match_arm_list.syntax(), cursor); + builder.replace_snippet(cap, old_range, snippet); + } + _ => builder.replace(old_range, new_match_arm_list.to_string()), + } + }, + ) +} + +fn cursor_at_trivial_match_arm_list( + ctx: &AssistContext<'_>, + match_expr: &MatchExpr, + match_arm_list: &MatchArmList, +) -> Option<()> { + // match x { $0 } + if match_arm_list.arms().next() == None { + cov_mark::hit!(add_missing_match_arms_empty_body); + return Some(()); + } + + // match x { + // bar => baz, + // $0 + // } + if let Some(last_arm) = match_arm_list.arms().last() { + let last_arm_range = last_arm.syntax().text_range(); + let match_expr_range = match_expr.syntax().text_range(); + if last_arm_range.end() <= ctx.offset() && ctx.offset() < match_expr_range.end() { + cov_mark::hit!(add_missing_match_arms_end_of_last_arm); + return Some(()); + } + } + + // match { _$0 => {...} } + let wild_pat = ctx.find_node_at_offset_with_descend::()?; + let arm = wild_pat.syntax().parent().and_then(ast::MatchArm::cast)?; + let arm_match_expr = arm.syntax().ancestors().nth(2).and_then(ast::MatchExpr::cast)?; + if arm_match_expr == *match_expr { + cov_mark::hit!(add_missing_match_arms_trivial_arm); + return Some(()); + } + + None +} + +fn is_variant_missing(existing_pats: &[Pat], var: &Pat) -> bool { + !existing_pats.iter().any(|pat| does_pat_match_variant(pat, var)) +} + +// Fixme: this is still somewhat limited, use hir_ty::diagnostics::match_check? +fn does_pat_match_variant(pat: &Pat, var: &Pat) -> bool { + match (pat, var) { + (Pat::WildcardPat(_), _) => true, + (Pat::TuplePat(tpat), Pat::TuplePat(tvar)) => { + tpat.fields().zip(tvar.fields()).all(|(p, v)| does_pat_match_variant(&p, &v)) + } + _ => utils::does_pat_match_variant(pat, var), + } +} + +#[derive(Eq, PartialEq, Clone, Copy)] +enum ExtendedEnum { + Bool, + Enum(hir::Enum), +} + +#[derive(Eq, PartialEq, Clone, Copy)] +enum ExtendedVariant { + True, + False, + Variant(hir::Variant), +} + +impl ExtendedVariant { + fn should_be_hidden(self, db: &RootDatabase, krate: Crate) -> bool { + match self { + ExtendedVariant::Variant(var) => { + var.attrs(db).has_doc_hidden() && var.module(db).krate() != krate + } + _ => false, + } + } +} + +fn lift_enum(e: hir::Enum) -> ExtendedEnum { + ExtendedEnum::Enum(e) +} + +impl ExtendedEnum { + fn is_non_exhaustive(self, db: &RootDatabase, krate: Crate) -> bool { + match self { + ExtendedEnum::Enum(e) => { + e.attrs(db).by_key("non_exhaustive").exists() && e.module(db).krate() != krate + } + _ => false, + } + } + + fn variants(self, db: &RootDatabase) -> Vec { + match self { + ExtendedEnum::Enum(e) => { + e.variants(db).into_iter().map(ExtendedVariant::Variant).collect::>() + } + ExtendedEnum::Bool => { + Vec::::from([ExtendedVariant::True, ExtendedVariant::False]) + } + } + } +} + +fn resolve_enum_def(sema: &Semantics<'_, RootDatabase>, expr: &ast::Expr) -> Option { + sema.type_of_expr(expr)?.adjusted().autoderef(sema.db).find_map(|ty| match ty.as_adt() { + Some(Adt::Enum(e)) => Some(ExtendedEnum::Enum(e)), + _ => ty.is_bool().then(|| ExtendedEnum::Bool), + }) +} + +fn resolve_tuple_of_enum_def( + sema: &Semantics<'_, RootDatabase>, + expr: &ast::Expr, +) -> Option> { + sema.type_of_expr(expr)? + .adjusted() + .tuple_fields(sema.db) + .iter() + .map(|ty| { + ty.autoderef(sema.db).find_map(|ty| match ty.as_adt() { + Some(Adt::Enum(e)) => Some(lift_enum(e)), + // For now we only handle expansion for a tuple of enums. Here + // we map non-enum items to None and rely on `collect` to + // convert Vec> into Option>. + _ => ty.is_bool().then(|| ExtendedEnum::Bool), + }) + }) + .collect() +} + +fn build_pat(db: &RootDatabase, module: hir::Module, var: ExtendedVariant) -> Option { + match var { + ExtendedVariant::Variant(var) => { + let path = mod_path_to_ast(&module.find_use_path(db, ModuleDef::from(var))?); + + // FIXME: use HIR for this; it doesn't currently expose struct vs. tuple vs. unit variants though + let pat: ast::Pat = match var.source(db)?.value.kind() { + ast::StructKind::Tuple(field_list) => { + let pats = + iter::repeat(make::wildcard_pat().into()).take(field_list.fields().count()); + make::tuple_struct_pat(path, pats).into() + } + ast::StructKind::Record(field_list) => { + let pats = field_list + .fields() + .map(|f| make::ext::simple_ident_pat(f.name().unwrap()).into()); + make::record_pat(path, pats).into() + } + ast::StructKind::Unit => make::path_pat(path), + }; + + Some(pat) + } + ExtendedVariant::True => Some(ast::Pat::from(make::literal_pat("true"))), + ExtendedVariant::False => Some(ast::Pat::from(make::literal_pat("false"))), + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{ + check_assist, check_assist_not_applicable, check_assist_target, check_assist_unresolved, + }; + + use super::add_missing_match_arms; + + #[test] + fn all_match_arms_provided() { + check_assist_not_applicable( + add_missing_match_arms, + r#" +enum A { + As, + Bs{x:i32, y:Option}, + Cs(i32, Option), +} +fn main() { + match A::As$0 { + A::As, + A::Bs{x,y:Some(_)} => {} + A::Cs(_, Some(_)) => {} + } +} + "#, + ); + } + + #[test] + fn not_applicable_outside_of_range_left() { + check_assist_not_applicable( + add_missing_match_arms, + r#" +enum A { X, Y } + +fn foo(a: A) { + $0 match a { + A::X => { } + } +} + "#, + ); + } + + #[test] + fn not_applicable_outside_of_range_right() { + cov_mark::check!(not_applicable_outside_of_range_right); + check_assist_not_applicable( + add_missing_match_arms, + r#" +enum A { X, Y } + +fn foo(a: A) { + match a {$0 + A::X => { } + } +} + "#, + ); + } + + #[test] + fn all_boolean_match_arms_provided() { + check_assist_not_applicable( + add_missing_match_arms, + r#" +fn foo(a: bool) { + match a$0 { + true => {} + false => {} + } +} +"#, + ) + } + + #[test] + fn tuple_of_non_enum() { + // for now this case is not handled, although it potentially could be + // in the future + check_assist_not_applicable( + add_missing_match_arms, + r#" +fn main() { + match (0, false)$0 { + } +} +"#, + ); + } + + #[test] + fn add_missing_match_arms_boolean() { + check_assist( + add_missing_match_arms, + r#" +fn foo(a: bool) { + match a$0 { + } +} +"#, + r#" +fn foo(a: bool) { + match a { + $0true => todo!(), + false => todo!(), + } +} +"#, + ) + } + + #[test] + fn partial_fill_boolean() { + check_assist( + add_missing_match_arms, + r#" +fn foo(a: bool) { + match a$0 { + true => {} + } +} +"#, + r#" +fn foo(a: bool) { + match a { + true => {} + $0false => todo!(), + } +} +"#, + ) + } + + #[test] + fn all_boolean_tuple_arms_provided() { + check_assist_not_applicable( + add_missing_match_arms, + r#" +fn foo(a: bool) { + match (a, a)$0 { + (true, true) => {} + (true, false) => {} + (false, true) => {} + (false, false) => {} + } +} +"#, + ) + } + + #[test] + fn fill_boolean_tuple() { + check_assist( + add_missing_match_arms, + r#" +fn foo(a: bool) { + match (a, a)$0 { + } +} +"#, + r#" +fn foo(a: bool) { + match (a, a) { + $0(true, true) => todo!(), + (true, false) => todo!(), + (false, true) => todo!(), + (false, false) => todo!(), + } +} +"#, + ) + } + + #[test] + fn partial_fill_boolean_tuple() { + check_assist( + add_missing_match_arms, + r#" +fn foo(a: bool) { + match (a, a)$0 { + (false, true) => {} + } +} +"#, + r#" +fn foo(a: bool) { + match (a, a) { + (false, true) => {} + $0(true, true) => todo!(), + (true, false) => todo!(), + (false, false) => todo!(), + } +} +"#, + ) + } + + #[test] + fn partial_fill_record_tuple() { + check_assist( + add_missing_match_arms, + r#" +enum A { + As, + Bs { x: i32, y: Option }, + Cs(i32, Option), +} +fn main() { + match A::As$0 { + A::Bs { x, y: Some(_) } => {} + A::Cs(_, Some(_)) => {} + } +} +"#, + r#" +enum A { + As, + Bs { x: i32, y: Option }, + Cs(i32, Option), +} +fn main() { + match A::As { + A::Bs { x, y: Some(_) } => {} + A::Cs(_, Some(_)) => {} + $0A::As => todo!(), + } +} +"#, + ); + } + + #[test] + fn partial_fill_option() { + check_assist( + add_missing_match_arms, + r#" +//- minicore: option +fn main() { + match None$0 { + None => {} + } +} +"#, + r#" +fn main() { + match None { + None => {} + Some(${0:_}) => todo!(), + } +} +"#, + ); + } + + #[test] + fn partial_fill_or_pat() { + check_assist( + add_missing_match_arms, + r#" +enum A { As, Bs, Cs(Option) } +fn main() { + match A::As$0 { + A::Cs(_) | A::Bs => {} + } +} +"#, + r#" +enum A { As, Bs, Cs(Option) } +fn main() { + match A::As { + A::Cs(_) | A::Bs => {} + $0A::As => todo!(), + } +} +"#, + ); + } + + #[test] + fn partial_fill() { + check_assist( + add_missing_match_arms, + r#" +enum A { As, Bs, Cs, Ds(String), Es(B) } +enum B { Xs, Ys } +fn main() { + match A::As$0 { + A::Bs if 0 < 1 => {} + A::Ds(_value) => { let x = 1; } + A::Es(B::Xs) => (), + } +} +"#, + r#" +enum A { As, Bs, Cs, Ds(String), Es(B) } +enum B { Xs, Ys } +fn main() { + match A::As { + A::Bs if 0 < 1 => {} + A::Ds(_value) => { let x = 1; } + A::Es(B::Xs) => (), + $0A::As => todo!(), + A::Cs => todo!(), + } +} +"#, + ); + } + + #[test] + fn partial_fill_bind_pat() { + check_assist( + add_missing_match_arms, + r#" +enum A { As, Bs, Cs(Option) } +fn main() { + match A::As$0 { + A::As(_) => {} + a @ A::Bs(_) => {} + } +} +"#, + r#" +enum A { As, Bs, Cs(Option) } +fn main() { + match A::As { + A::As(_) => {} + a @ A::Bs(_) => {} + A::Cs(${0:_}) => todo!(), + } +} +"#, + ); + } + + #[test] + fn add_missing_match_arms_empty_body() { + cov_mark::check!(add_missing_match_arms_empty_body); + check_assist( + add_missing_match_arms, + r#" +enum A { As, Bs, Cs(String), Ds(String, String), Es { x: usize, y: usize } } + +fn main() { + let a = A::As; + match a {$0} +} +"#, + r#" +enum A { As, Bs, Cs(String), Ds(String, String), Es { x: usize, y: usize } } + +fn main() { + let a = A::As; + match a { + $0A::As => todo!(), + A::Bs => todo!(), + A::Cs(_) => todo!(), + A::Ds(_, _) => todo!(), + A::Es { x, y } => todo!(), + } +} +"#, + ); + } + + #[test] + fn add_missing_match_arms_end_of_last_arm() { + cov_mark::check!(add_missing_match_arms_end_of_last_arm); + check_assist( + add_missing_match_arms, + r#" +enum A { One, Two } +enum B { One, Two } + +fn main() { + let a = A::One; + let b = B::One; + match (a, b) { + (A::Two, B::One) => {},$0 + } +} +"#, + r#" +enum A { One, Two } +enum B { One, Two } + +fn main() { + let a = A::One; + let b = B::One; + match (a, b) { + (A::Two, B::One) => {}, + $0(A::One, B::One) => todo!(), + (A::One, B::Two) => todo!(), + (A::Two, B::Two) => todo!(), + } +} +"#, + ); + } + + #[test] + fn add_missing_match_arms_tuple_of_enum() { + check_assist( + add_missing_match_arms, + r#" +enum A { One, Two } +enum B { One, Two } + +fn main() { + let a = A::One; + let b = B::One; + match (a$0, b) {} +} +"#, + r#" +enum A { One, Two } +enum B { One, Two } + +fn main() { + let a = A::One; + let b = B::One; + match (a, b) { + $0(A::One, B::One) => todo!(), + (A::One, B::Two) => todo!(), + (A::Two, B::One) => todo!(), + (A::Two, B::Two) => todo!(), + } +} +"#, + ); + } + + #[test] + fn add_missing_match_arms_tuple_of_enum_ref() { + check_assist( + add_missing_match_arms, + r#" +enum A { One, Two } +enum B { One, Two } + +fn main() { + let a = A::One; + let b = B::One; + match (&a$0, &b) {} +} +"#, + r#" +enum A { One, Two } +enum B { One, Two } + +fn main() { + let a = A::One; + let b = B::One; + match (&a, &b) { + $0(A::One, B::One) => todo!(), + (A::One, B::Two) => todo!(), + (A::Two, B::One) => todo!(), + (A::Two, B::Two) => todo!(), + } +} +"#, + ); + } + + #[test] + fn add_missing_match_arms_tuple_of_enum_partial() { + check_assist( + add_missing_match_arms, + r#" +enum A { One, Two } +enum B { One, Two } + +fn main() { + let a = A::One; + let b = B::One; + match (a$0, b) { + (A::Two, B::One) => {} + } +} +"#, + r#" +enum A { One, Two } +enum B { One, Two } + +fn main() { + let a = A::One; + let b = B::One; + match (a, b) { + (A::Two, B::One) => {} + $0(A::One, B::One) => todo!(), + (A::One, B::Two) => todo!(), + (A::Two, B::Two) => todo!(), + } +} +"#, + ); + } + + #[test] + fn add_missing_match_arms_tuple_of_enum_partial_with_wildcards() { + check_assist( + add_missing_match_arms, + r#" +//- minicore: option +fn main() { + let a = Some(1); + let b = Some(()); + match (a$0, b) { + (Some(_), _) => {} + (None, Some(_)) => {} + } +} +"#, + r#" +fn main() { + let a = Some(1); + let b = Some(()); + match (a, b) { + (Some(_), _) => {} + (None, Some(_)) => {} + $0(None, None) => todo!(), + } +} +"#, + ); + } + + #[test] + fn add_missing_match_arms_partial_with_deep_pattern() { + // Fixme: cannot handle deep patterns + check_assist_not_applicable( + add_missing_match_arms, + r#" +//- minicore: option +fn main() { + match $0Some(true) { + Some(true) => {} + None => {} + } +} +"#, + ); + } + + #[test] + fn add_missing_match_arms_tuple_of_enum_not_applicable() { + check_assist_not_applicable( + add_missing_match_arms, + r#" +enum A { One, Two } +enum B { One, Two } + +fn main() { + let a = A::One; + let b = B::One; + match (a$0, b) { + (A::Two, B::One) => {} + (A::One, B::One) => {} + (A::One, B::Two) => {} + (A::Two, B::Two) => {} + } +} +"#, + ); + } + + #[test] + fn add_missing_match_arms_single_element_tuple_of_enum() { + check_assist( + add_missing_match_arms, + r#" +enum A { One, Two } + +fn main() { + let a = A::One; + match (a$0, ) { + } +} +"#, + r#" +enum A { One, Two } + +fn main() { + let a = A::One; + match (a, ) { + $0(A::One,) => todo!(), + (A::Two,) => todo!(), + } +} +"#, + ); + } + + #[test] + fn test_fill_match_arm_refs() { + check_assist( + add_missing_match_arms, + r#" +enum A { As } + +fn foo(a: &A) { + match a$0 { + } +} +"#, + r#" +enum A { As } + +fn foo(a: &A) { + match a { + $0A::As => todo!(), + } +} +"#, + ); + + check_assist( + add_missing_match_arms, + r#" +enum A { + Es { x: usize, y: usize } +} + +fn foo(a: &mut A) { + match a$0 { + } +} +"#, + r#" +enum A { + Es { x: usize, y: usize } +} + +fn foo(a: &mut A) { + match a { + $0A::Es { x, y } => todo!(), + } +} +"#, + ); + } + + #[test] + fn add_missing_match_arms_target_simple() { + check_assist_target( + add_missing_match_arms, + r#" +enum E { X, Y } + +fn main() { + match E::X$0 {} +} +"#, + "match E::X {}", + ); + } + + #[test] + fn add_missing_match_arms_target_complex() { + check_assist_target( + add_missing_match_arms, + r#" +enum E { X, Y } + +fn main() { + match E::X$0 { + E::X => {} + } +} +"#, + "match E::X { + E::X => {} + }", + ); + } + + #[test] + fn add_missing_match_arms_trivial_arm() { + cov_mark::check!(add_missing_match_arms_trivial_arm); + check_assist( + add_missing_match_arms, + r#" +enum E { X, Y } + +fn main() { + match E::X { + $0_ => {} + } +} +"#, + r#" +enum E { X, Y } + +fn main() { + match E::X { + $0E::X => todo!(), + E::Y => todo!(), + } +} +"#, + ); + } + + #[test] + fn wildcard_inside_expression_not_applicable() { + check_assist_not_applicable( + add_missing_match_arms, + r#" +enum E { X, Y } + +fn foo(e : E) { + match e { + _ => { + println!("1");$0 + println!("2"); + } + } +} +"#, + ); + } + + #[test] + fn add_missing_match_arms_qualifies_path() { + check_assist( + add_missing_match_arms, + r#" +mod foo { pub enum E { X, Y } } +use foo::E::X; + +fn main() { + match X { + $0 + } +} +"#, + r#" +mod foo { pub enum E { X, Y } } +use foo::E::X; + +fn main() { + match X { + $0X => todo!(), + foo::E::Y => todo!(), + } +} +"#, + ); + } + + #[test] + fn add_missing_match_arms_preserves_comments() { + check_assist( + add_missing_match_arms, + r#" +enum A { One, Two } +fn foo(a: A) { + match a $0 { + // foo bar baz + A::One => {} + // This is where the rest should be + } +} +"#, + r#" +enum A { One, Two } +fn foo(a: A) { + match a { + // foo bar baz + A::One => {} + $0A::Two => todo!(), + // This is where the rest should be + } +} +"#, + ); + } + + #[test] + fn add_missing_match_arms_preserves_comments_empty() { + check_assist( + add_missing_match_arms, + r#" +enum A { One, Two } +fn foo(a: A) { + match a { + // foo bar baz$0 + } +} +"#, + r#" +enum A { One, Two } +fn foo(a: A) { + match a { + $0A::One => todo!(), + A::Two => todo!(), + // foo bar baz + } +} +"#, + ); + } + + #[test] + fn add_missing_match_arms_placeholder() { + check_assist( + add_missing_match_arms, + r#" +enum A { One, Two, } +fn foo(a: A) { + match a$0 { + _ => (), + } +} +"#, + r#" +enum A { One, Two, } +fn foo(a: A) { + match a { + $0A::One => todo!(), + A::Two => todo!(), + } +} +"#, + ); + } + + #[test] + fn option_order() { + cov_mark::check!(option_order); + check_assist( + add_missing_match_arms, + r#" +//- minicore: option +fn foo(opt: Option) { + match opt$0 { + } +} +"#, + r#" +fn foo(opt: Option) { + match opt { + Some(${0:_}) => todo!(), + None => todo!(), + } +} +"#, + ); + } + + #[test] + fn works_inside_macro_call() { + check_assist( + add_missing_match_arms, + r#" +macro_rules! m { ($expr:expr) => {$expr}} +enum Test { + A, + B, + C, +} + +fn foo(t: Test) { + m!(match t$0 {}); +}"#, + r#" +macro_rules! m { ($expr:expr) => {$expr}} +enum Test { + A, + B, + C, +} + +fn foo(t: Test) { + m!(match t { + $0Test::A => todo!(), + Test::B => todo!(), + Test::C => todo!(), +}); +}"#, + ); + } + + #[test] + fn lazy_computation() { + // Computing a single missing arm is enough to determine applicability of the assist. + cov_mark::check_count!(add_missing_match_arms_lazy_computation, 1); + check_assist_unresolved( + add_missing_match_arms, + r#" +enum A { One, Two, } +fn foo(tuple: (A, A)) { + match $0tuple {}; +} +"#, + ); + } + + #[test] + fn adds_comma_before_new_arms() { + check_assist( + add_missing_match_arms, + r#" +fn foo(t: bool) { + match $0t { + true => 1 + 2 + } +}"#, + r#" +fn foo(t: bool) { + match t { + true => 1 + 2, + $0false => todo!(), + } +}"#, + ); + } + + #[test] + fn does_not_add_extra_comma() { + check_assist( + add_missing_match_arms, + r#" +fn foo(t: bool) { + match $0t { + true => 1 + 2, + } +}"#, + r#" +fn foo(t: bool) { + match t { + true => 1 + 2, + $0false => todo!(), + } +}"#, + ); + } + + #[test] + fn does_not_remove_catch_all_with_non_empty_expr() { + cov_mark::check!(add_missing_match_arms_empty_expr); + check_assist( + add_missing_match_arms, + r#" +fn foo(t: bool) { + match $0t { + _ => 1 + 2, + } +}"#, + r#" +fn foo(t: bool) { + match t { + _ => 1 + 2, + $0true => todo!(), + false => todo!(), + } +}"#, + ); + } + + #[test] + fn does_not_fill_hidden_variants() { + cov_mark::check!(added_wildcard_pattern); + check_assist( + add_missing_match_arms, + r#" +//- /main.rs crate:main deps:e +fn foo(t: ::e::E) { + match $0t { + } +} +//- /e.rs crate:e +pub enum E { A, #[doc(hidden)] B, } +"#, + r#" +fn foo(t: ::e::E) { + match t { + $0e::E::A => todo!(), + _ => todo!(), + } +} +"#, + ); + } + + #[test] + fn does_not_fill_hidden_variants_tuple() { + cov_mark::check!(added_wildcard_pattern); + check_assist( + add_missing_match_arms, + r#" +//- /main.rs crate:main deps:e +fn foo(t: (bool, ::e::E)) { + match $0t { + } +} +//- /e.rs crate:e +pub enum E { A, #[doc(hidden)] B, } +"#, + r#" +fn foo(t: (bool, ::e::E)) { + match t { + $0(true, e::E::A) => todo!(), + (false, e::E::A) => todo!(), + _ => todo!(), + } +} +"#, + ); + } + + #[test] + fn fills_wildcard_with_only_hidden_variants() { + cov_mark::check!(added_wildcard_pattern); + check_assist( + add_missing_match_arms, + r#" +//- /main.rs crate:main deps:e +fn foo(t: ::e::E) { + match $0t { + } +} +//- /e.rs crate:e +pub enum E { #[doc(hidden)] A, } +"#, + r#" +fn foo(t: ::e::E) { + match t { + ${0:_} => todo!(), + } +} +"#, + ); + } + + #[test] + fn does_not_fill_wildcard_when_hidden_variants_are_explicit() { + check_assist_not_applicable( + add_missing_match_arms, + r#" +//- /main.rs crate:main deps:e +fn foo(t: ::e::E) { + match $0t { + e::E::A => todo!(), + } +} +//- /e.rs crate:e +pub enum E { #[doc(hidden)] A, } +"#, + ); + } + + // FIXME: I don't think the assist should be applicable in this case + #[test] + fn does_not_fill_wildcard_with_wildcard() { + check_assist( + add_missing_match_arms, + r#" +//- /main.rs crate:main deps:e +fn foo(t: ::e::E) { + match $0t { + _ => todo!(), + } +} +//- /e.rs crate:e +pub enum E { #[doc(hidden)] A, } +"#, + r#" +fn foo(t: ::e::E) { + match t { + _ => todo!(), + } +} +"#, + ); + } + + #[test] + fn fills_wildcard_on_non_exhaustive_with_explicit_matches() { + cov_mark::check!(added_wildcard_pattern); + check_assist( + add_missing_match_arms, + r#" +//- /main.rs crate:main deps:e +fn foo(t: ::e::E) { + match $0t { + e::E::A => todo!(), + } +} +//- /e.rs crate:e +#[non_exhaustive] +pub enum E { A, } +"#, + r#" +fn foo(t: ::e::E) { + match t { + e::E::A => todo!(), + ${0:_} => todo!(), + } +} +"#, + ); + } + + #[test] + fn fills_wildcard_on_non_exhaustive_without_matches() { + cov_mark::check!(added_wildcard_pattern); + check_assist( + add_missing_match_arms, + r#" +//- /main.rs crate:main deps:e +fn foo(t: ::e::E) { + match $0t { + } +} +//- /e.rs crate:e +#[non_exhaustive] +pub enum E { A, } +"#, + r#" +fn foo(t: ::e::E) { + match t { + $0e::E::A => todo!(), + _ => todo!(), + } +} +"#, + ); + } + + #[test] + fn fills_wildcard_on_non_exhaustive_with_doc_hidden() { + cov_mark::check!(added_wildcard_pattern); + check_assist( + add_missing_match_arms, + r#" +//- /main.rs crate:main deps:e +fn foo(t: ::e::E) { + match $0t { + } +} +//- /e.rs crate:e +#[non_exhaustive] +pub enum E { A, #[doc(hidden)] B }"#, + r#" +fn foo(t: ::e::E) { + match t { + $0e::E::A => todo!(), + _ => todo!(), + } +} +"#, + ); + } + + #[test] + fn fills_wildcard_on_non_exhaustive_with_doc_hidden_with_explicit_arms() { + cov_mark::check!(added_wildcard_pattern); + check_assist( + add_missing_match_arms, + r#" +//- /main.rs crate:main deps:e +fn foo(t: ::e::E) { + match $0t { + e::E::A => todo!(), + } +} +//- /e.rs crate:e +#[non_exhaustive] +pub enum E { A, #[doc(hidden)] B }"#, + r#" +fn foo(t: ::e::E) { + match t { + e::E::A => todo!(), + ${0:_} => todo!(), + } +} +"#, + ); + } + + #[test] + fn fill_wildcard_with_partial_wildcard() { + cov_mark::check!(added_wildcard_pattern); + check_assist( + add_missing_match_arms, + r#" +//- /main.rs crate:main deps:e +fn foo(t: ::e::E, b: bool) { + match $0t { + _ if b => todo!(), + } +} +//- /e.rs crate:e +pub enum E { #[doc(hidden)] A, }"#, + r#" +fn foo(t: ::e::E, b: bool) { + match t { + _ if b => todo!(), + ${0:_} => todo!(), + } +} +"#, + ); + } + + #[test] + fn does_not_fill_wildcard_with_partial_wildcard_and_wildcard() { + check_assist( + add_missing_match_arms, + r#" +//- /main.rs crate:main deps:e +fn foo(t: ::e::E, b: bool) { + match $0t { + _ if b => todo!(), + _ => todo!(), + } +} +//- /e.rs crate:e +pub enum E { #[doc(hidden)] A, }"#, + r#" +fn foo(t: ::e::E, b: bool) { + match t { + _ if b => todo!(), + _ => todo!(), + } +} +"#, + ); + } + + #[test] + fn non_exhaustive_doc_hidden_tuple_fills_wildcard() { + cov_mark::check!(added_wildcard_pattern); + check_assist( + add_missing_match_arms, + r#" +//- /main.rs crate:main deps:e +fn foo(t: ::e::E) { + match $0t { + } +} +//- /e.rs crate:e +#[non_exhaustive] +pub enum E { A, #[doc(hidden)] B, }"#, + r#" +fn foo(t: ::e::E) { + match t { + $0e::E::A => todo!(), + _ => todo!(), + } +} +"#, + ); + } + + #[test] + fn ignores_doc_hidden_for_crate_local_enums() { + check_assist( + add_missing_match_arms, + r#" +enum E { A, #[doc(hidden)] B, } + +fn foo(t: E) { + match $0t { + } +}"#, + r#" +enum E { A, #[doc(hidden)] B, } + +fn foo(t: E) { + match t { + $0E::A => todo!(), + E::B => todo!(), + } +}"#, + ); + } + + #[test] + fn ignores_non_exhaustive_for_crate_local_enums() { + check_assist( + add_missing_match_arms, + r#" +#[non_exhaustive] +enum E { A, B, } + +fn foo(t: E) { + match $0t { + } +}"#, + r#" +#[non_exhaustive] +enum E { A, B, } + +fn foo(t: E) { + match t { + $0E::A => todo!(), + E::B => todo!(), + } +}"#, + ); + } + + #[test] + fn ignores_doc_hidden_and_non_exhaustive_for_crate_local_enums() { + check_assist( + add_missing_match_arms, + r#" +#[non_exhaustive] +enum E { A, #[doc(hidden)] B, } + +fn foo(t: E) { + match $0t { + } +}"#, + r#" +#[non_exhaustive] +enum E { A, #[doc(hidden)] B, } + +fn foo(t: E) { + match t { + $0E::A => todo!(), + E::B => todo!(), + } +}"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_return_type.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_return_type.rs new file mode 100644 index 000000000..f858d7a15 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_return_type.rs @@ -0,0 +1,447 @@ +use hir::HirDisplay; +use syntax::{ast, match_ast, AstNode, SyntaxKind, SyntaxToken, TextRange, TextSize}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: add_return_type +// +// Adds the return type to a function or closure inferred from its tail expression if it doesn't have a return +// type specified. This assists is useable in a functions or closures tail expression or return type position. +// +// ``` +// fn foo() { 4$02i32 } +// ``` +// -> +// ``` +// fn foo() -> i32 { 42i32 } +// ``` +pub(crate) fn add_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let (fn_type, tail_expr, builder_edit_pos) = extract_tail(ctx)?; + let module = ctx.sema.scope(tail_expr.syntax())?.module(); + let ty = ctx.sema.type_of_expr(&peel_blocks(tail_expr.clone()))?.original(); + if ty.is_unit() { + return None; + } + let ty = ty.display_source_code(ctx.db(), module.into()).ok()?; + + acc.add( + AssistId("add_return_type", AssistKind::RefactorRewrite), + match fn_type { + FnType::Function => "Add this function's return type", + FnType::Closure { .. } => "Add this closure's return type", + }, + tail_expr.syntax().text_range(), + |builder| { + match builder_edit_pos { + InsertOrReplace::Insert(insert_pos, needs_whitespace) => { + let preceeding_whitespace = if needs_whitespace { " " } else { "" }; + builder.insert(insert_pos, &format!("{}-> {} ", preceeding_whitespace, ty)) + } + InsertOrReplace::Replace(text_range) => { + builder.replace(text_range, &format!("-> {}", ty)) + } + } + if let FnType::Closure { wrap_expr: true } = fn_type { + cov_mark::hit!(wrap_closure_non_block_expr); + // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block + builder.replace(tail_expr.syntax().text_range(), &format!("{{{}}}", tail_expr)); + } + }, + ) +} + +enum InsertOrReplace { + Insert(TextSize, bool), + Replace(TextRange), +} + +/// Check the potentially already specified return type and reject it or turn it into a builder command +/// if allowed. +fn ret_ty_to_action( + ret_ty: Option, + insert_after: SyntaxToken, +) -> Option { + match ret_ty { + Some(ret_ty) => match ret_ty.ty() { + Some(ast::Type::InferType(_)) | None => { + cov_mark::hit!(existing_infer_ret_type); + cov_mark::hit!(existing_infer_ret_type_closure); + Some(InsertOrReplace::Replace(ret_ty.syntax().text_range())) + } + _ => { + cov_mark::hit!(existing_ret_type); + cov_mark::hit!(existing_ret_type_closure); + None + } + }, + None => { + let insert_after_pos = insert_after.text_range().end(); + let (insert_pos, needs_whitespace) = match insert_after.next_token() { + Some(it) if it.kind() == SyntaxKind::WHITESPACE => { + (insert_after_pos + TextSize::from(1), false) + } + _ => (insert_after_pos, true), + }; + + Some(InsertOrReplace::Insert(insert_pos, needs_whitespace)) + } + } +} + +enum FnType { + Function, + Closure { wrap_expr: bool }, +} + +/// If we're looking at a block that is supposed to return `()`, type inference +/// will just tell us it has type `()`. We have to look at the tail expression +/// to see the mismatched actual type. This 'unpeels' the various blocks to +/// hopefully let us see the type the user intends. (This still doesn't handle +/// all situations fully correctly; the 'ideal' way to handle this would be to +/// run type inference on the function again, but with a variable as the return +/// type.) +fn peel_blocks(mut expr: ast::Expr) -> ast::Expr { + loop { + match_ast! { + match (expr.syntax()) { + ast::BlockExpr(it) => { + if let Some(tail) = it.tail_expr() { + expr = tail.clone(); + } else { + break; + } + }, + ast::IfExpr(it) => { + if let Some(then_branch) = it.then_branch() { + expr = ast::Expr::BlockExpr(then_branch.clone()); + } else { + break; + } + }, + ast::MatchExpr(it) => { + if let Some(arm_expr) = it.match_arm_list().and_then(|l| l.arms().next()).and_then(|a| a.expr()) { + expr = arm_expr; + } else { + break; + } + }, + _ => break, + } + } + } + expr +} + +fn extract_tail(ctx: &AssistContext<'_>) -> Option<(FnType, ast::Expr, InsertOrReplace)> { + let (fn_type, tail_expr, return_type_range, action) = + if let Some(closure) = ctx.find_node_at_offset::() { + let rpipe = closure.param_list()?.syntax().last_token()?; + let rpipe_pos = rpipe.text_range().end(); + + let action = ret_ty_to_action(closure.ret_type(), rpipe)?; + + let body = closure.body()?; + let body_start = body.syntax().first_token()?.text_range().start(); + let (tail_expr, wrap_expr) = match body { + ast::Expr::BlockExpr(block) => (block.tail_expr()?, false), + body => (body, true), + }; + + let ret_range = TextRange::new(rpipe_pos, body_start); + (FnType::Closure { wrap_expr }, tail_expr, ret_range, action) + } else { + let func = ctx.find_node_at_offset::()?; + + let rparen = func.param_list()?.r_paren_token()?; + let rparen_pos = rparen.text_range().end(); + let action = ret_ty_to_action(func.ret_type(), rparen)?; + + let body = func.body()?; + let stmt_list = body.stmt_list()?; + let tail_expr = stmt_list.tail_expr()?; + + let ret_range_end = stmt_list.l_curly_token()?.text_range().start(); + let ret_range = TextRange::new(rparen_pos, ret_range_end); + (FnType::Function, tail_expr, ret_range, action) + }; + let range = ctx.selection_trimmed(); + if return_type_range.contains_range(range) { + cov_mark::hit!(cursor_in_ret_position); + cov_mark::hit!(cursor_in_ret_position_closure); + } else if tail_expr.syntax().text_range().contains_range(range) { + cov_mark::hit!(cursor_on_tail); + cov_mark::hit!(cursor_on_tail_closure); + } else { + return None; + } + Some((fn_type, tail_expr, action)) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn infer_return_type_specified_inferred() { + cov_mark::check!(existing_infer_ret_type); + check_assist( + add_return_type, + r#"fn foo() -> $0_ { + 45 +}"#, + r#"fn foo() -> i32 { + 45 +}"#, + ); + } + + #[test] + fn infer_return_type_specified_inferred_closure() { + cov_mark::check!(existing_infer_ret_type_closure); + check_assist( + add_return_type, + r#"fn foo() { + || -> _ {$045}; +}"#, + r#"fn foo() { + || -> i32 {45}; +}"#, + ); + } + + #[test] + fn infer_return_type_cursor_at_return_type_pos() { + cov_mark::check!(cursor_in_ret_position); + check_assist( + add_return_type, + r#"fn foo() $0{ + 45 +}"#, + r#"fn foo() -> i32 { + 45 +}"#, + ); + } + + #[test] + fn infer_return_type_cursor_at_return_type_pos_closure() { + cov_mark::check!(cursor_in_ret_position_closure); + check_assist( + add_return_type, + r#"fn foo() { + || $045 +}"#, + r#"fn foo() { + || -> i32 {45} +}"#, + ); + } + + #[test] + fn infer_return_type() { + cov_mark::check!(cursor_on_tail); + check_assist( + add_return_type, + r#"fn foo() { + 45$0 +}"#, + r#"fn foo() -> i32 { + 45 +}"#, + ); + } + + #[test] + fn infer_return_type_no_whitespace() { + check_assist( + add_return_type, + r#"fn foo(){ + 45$0 +}"#, + r#"fn foo() -> i32 { + 45 +}"#, + ); + } + + #[test] + fn infer_return_type_nested() { + check_assist( + add_return_type, + r#"fn foo() { + if true { + 3$0 + } else { + 5 + } +}"#, + r#"fn foo() -> i32 { + if true { + 3 + } else { + 5 + } +}"#, + ); + } + + #[test] + fn infer_return_type_nested_match() { + check_assist( + add_return_type, + r#"fn foo() { + match true { + true => { 3$0 }, + false => { 5 }, + } +}"#, + r#"fn foo() -> i32 { + match true { + true => { 3 }, + false => { 5 }, + } +}"#, + ); + } + + #[test] + fn not_applicable_ret_type_specified() { + cov_mark::check!(existing_ret_type); + check_assist_not_applicable( + add_return_type, + r#"fn foo() -> i32 { + ( 45$0 + 32 ) * 123 +}"#, + ); + } + + #[test] + fn not_applicable_non_tail_expr() { + check_assist_not_applicable( + add_return_type, + r#"fn foo() { + let x = $03; + ( 45 + 32 ) * 123 +}"#, + ); + } + + #[test] + fn not_applicable_unit_return_type() { + check_assist_not_applicable( + add_return_type, + r#"fn foo() { + ($0) +}"#, + ); + } + + #[test] + fn infer_return_type_closure_block() { + cov_mark::check!(cursor_on_tail_closure); + check_assist( + add_return_type, + r#"fn foo() { + |x: i32| { + x$0 + }; +}"#, + r#"fn foo() { + |x: i32| -> i32 { + x + }; +}"#, + ); + } + + #[test] + fn infer_return_type_closure() { + check_assist( + add_return_type, + r#"fn foo() { + |x: i32| { x$0 }; +}"#, + r#"fn foo() { + |x: i32| -> i32 { x }; +}"#, + ); + } + + #[test] + fn infer_return_type_closure_no_whitespace() { + check_assist( + add_return_type, + r#"fn foo() { + |x: i32|{ x$0 }; +}"#, + r#"fn foo() { + |x: i32| -> i32 { x }; +}"#, + ); + } + + #[test] + fn infer_return_type_closure_wrap() { + cov_mark::check!(wrap_closure_non_block_expr); + check_assist( + add_return_type, + r#"fn foo() { + |x: i32| x$0; +}"#, + r#"fn foo() { + |x: i32| -> i32 {x}; +}"#, + ); + } + + #[test] + fn infer_return_type_nested_closure() { + check_assist( + add_return_type, + r#"fn foo() { + || { + if true { + 3$0 + } else { + 5 + } + } +}"#, + r#"fn foo() { + || -> i32 { + if true { + 3 + } else { + 5 + } + } +}"#, + ); + } + + #[test] + fn not_applicable_ret_type_specified_closure() { + cov_mark::check!(existing_ret_type_closure); + check_assist_not_applicable( + add_return_type, + r#"fn foo() { + || -> i32 { 3$0 } +}"#, + ); + } + + #[test] + fn not_applicable_non_tail_expr_closure() { + check_assist_not_applicable( + add_return_type, + r#"fn foo() { + || -> i32 { + let x = 3$0; + 6 + } +}"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs new file mode 100644 index 000000000..c0bf238db --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs @@ -0,0 +1,400 @@ +use ide_db::defs::{Definition, NameRefClass}; +use itertools::Itertools; +use syntax::{ast, AstNode, SyntaxKind, T}; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, AssistKind, +}; + +// Assist: add_turbo_fish +// +// Adds `::<_>` to a call of a generic method or function. +// +// ``` +// fn make() -> T { todo!() } +// fn main() { +// let x = make$0(); +// } +// ``` +// -> +// ``` +// fn make() -> T { todo!() } +// fn main() { +// let x = make::<${0:_}>(); +// } +// ``` +pub(crate) fn add_turbo_fish(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let ident = ctx.find_token_syntax_at_offset(SyntaxKind::IDENT).or_else(|| { + let arg_list = ctx.find_node_at_offset::()?; + if arg_list.args().next().is_some() { + return None; + } + cov_mark::hit!(add_turbo_fish_after_call); + cov_mark::hit!(add_type_ascription_after_call); + arg_list.l_paren_token()?.prev_token().filter(|it| it.kind() == SyntaxKind::IDENT) + })?; + let next_token = ident.next_token()?; + if next_token.kind() == T![::] { + cov_mark::hit!(add_turbo_fish_one_fish_is_enough); + return None; + } + let name_ref = ast::NameRef::cast(ident.parent()?)?; + let def = match NameRefClass::classify(&ctx.sema, &name_ref)? { + NameRefClass::Definition(def) => def, + NameRefClass::FieldShorthand { .. } => return None, + }; + let fun = match def { + Definition::Function(it) => it, + _ => return None, + }; + let generics = hir::GenericDef::Function(fun).params(ctx.sema.db); + if generics.is_empty() { + cov_mark::hit!(add_turbo_fish_non_generic); + return None; + } + + if let Some(let_stmt) = ctx.find_node_at_offset::() { + if let_stmt.colon_token().is_none() { + let type_pos = let_stmt.pat()?.syntax().last_token()?.text_range().end(); + let semi_pos = let_stmt.syntax().last_token()?.text_range().end(); + + acc.add( + AssistId("add_type_ascription", AssistKind::RefactorRewrite), + "Add `: _` before assignment operator", + ident.text_range(), + |builder| { + if let_stmt.semicolon_token().is_none() { + builder.insert(semi_pos, ";"); + } + match ctx.config.snippet_cap { + Some(cap) => builder.insert_snippet(cap, type_pos, ": ${0:_}"), + None => builder.insert(type_pos, ": _"), + } + }, + )? + } else { + cov_mark::hit!(add_type_ascription_already_typed); + } + } + + let number_of_arguments = generics + .iter() + .filter(|param| { + matches!(param, hir::GenericParam::TypeParam(_) | hir::GenericParam::ConstParam(_)) + }) + .count(); + + acc.add( + AssistId("add_turbo_fish", AssistKind::RefactorRewrite), + "Add `::<>`", + ident.text_range(), + |builder| { + builder.trigger_signature_help(); + match ctx.config.snippet_cap { + Some(cap) => { + let snip = format!("::<{}>", get_snippet_fish_head(number_of_arguments)); + builder.insert_snippet(cap, ident.text_range().end(), snip) + } + None => { + let fish_head = std::iter::repeat("_").take(number_of_arguments).format(", "); + let snip = format!("::<{}>", fish_head); + builder.insert(ident.text_range().end(), snip); + } + } + }, + ) +} + +/// This will create a snippet string with tabstops marked +fn get_snippet_fish_head(number_of_arguments: usize) -> String { + let mut fish_head = (1..number_of_arguments) + .format_with("", |i, f| f(&format_args!("${{{}:_}}, ", i))) + .to_string(); + + // tabstop 0 is a special case and always the last one + fish_head.push_str("${0:_}"); + fish_head +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_by_label, check_assist_not_applicable}; + + use super::*; + + #[test] + fn add_turbo_fish_function() { + check_assist( + add_turbo_fish, + r#" +fn make() -> T {} +fn main() { + make$0(); +} +"#, + r#" +fn make() -> T {} +fn main() { + make::<${0:_}>(); +} +"#, + ); + } + + #[test] + fn add_turbo_fish_function_multiple_generic_types() { + check_assist( + add_turbo_fish, + r#" +fn make() -> T {} +fn main() { + make$0(); +} +"#, + r#" +fn make() -> T {} +fn main() { + make::<${1:_}, ${0:_}>(); +} +"#, + ); + } + + #[test] + fn add_turbo_fish_function_many_generic_types() { + check_assist( + add_turbo_fish, + r#" +fn make() -> T {} +fn main() { + make$0(); +} +"#, + r#" +fn make() -> T {} +fn main() { + make::<${1:_}, ${2:_}, ${3:_}, ${4:_}, ${5:_}, ${6:_}, ${0:_}>(); +} +"#, + ); + } + + #[test] + fn add_turbo_fish_after_call() { + cov_mark::check!(add_turbo_fish_after_call); + check_assist( + add_turbo_fish, + r#" +fn make() -> T {} +fn main() { + make()$0; +} +"#, + r#" +fn make() -> T {} +fn main() { + make::<${0:_}>(); +} +"#, + ); + } + + #[test] + fn add_turbo_fish_method() { + check_assist( + add_turbo_fish, + r#" +struct S; +impl S { + fn make(&self) -> T {} +} +fn main() { + S.make$0(); +} +"#, + r#" +struct S; +impl S { + fn make(&self) -> T {} +} +fn main() { + S.make::<${0:_}>(); +} +"#, + ); + } + + #[test] + fn add_turbo_fish_one_fish_is_enough() { + cov_mark::check!(add_turbo_fish_one_fish_is_enough); + check_assist_not_applicable( + add_turbo_fish, + r#" +fn make() -> T {} +fn main() { + make$0::<()>(); +} +"#, + ); + } + + #[test] + fn add_turbo_fish_non_generic() { + cov_mark::check!(add_turbo_fish_non_generic); + check_assist_not_applicable( + add_turbo_fish, + r#" +fn make() -> () {} +fn main() { + make$0(); +} +"#, + ); + } + + #[test] + fn add_type_ascription_function() { + check_assist_by_label( + add_turbo_fish, + r#" +fn make() -> T {} +fn main() { + let x = make$0(); +} +"#, + r#" +fn make() -> T {} +fn main() { + let x: ${0:_} = make(); +} +"#, + "Add `: _` before assignment operator", + ); + } + + #[test] + fn add_type_ascription_after_call() { + cov_mark::check!(add_type_ascription_after_call); + check_assist_by_label( + add_turbo_fish, + r#" +fn make() -> T {} +fn main() { + let x = make()$0; +} +"#, + r#" +fn make() -> T {} +fn main() { + let x: ${0:_} = make(); +} +"#, + "Add `: _` before assignment operator", + ); + } + + #[test] + fn add_type_ascription_method() { + check_assist_by_label( + add_turbo_fish, + r#" +struct S; +impl S { + fn make(&self) -> T {} +} +fn main() { + let x = S.make$0(); +} +"#, + r#" +struct S; +impl S { + fn make(&self) -> T {} +} +fn main() { + let x: ${0:_} = S.make(); +} +"#, + "Add `: _` before assignment operator", + ); + } + + #[test] + fn add_type_ascription_already_typed() { + cov_mark::check!(add_type_ascription_already_typed); + check_assist( + add_turbo_fish, + r#" +fn make() -> T {} +fn main() { + let x: () = make$0(); +} +"#, + r#" +fn make() -> T {} +fn main() { + let x: () = make::<${0:_}>(); +} +"#, + ); + } + + #[test] + fn add_type_ascription_append_semicolon() { + check_assist_by_label( + add_turbo_fish, + r#" +fn make() -> T {} +fn main() { + let x = make$0() +} +"#, + r#" +fn make() -> T {} +fn main() { + let x: ${0:_} = make(); +} +"#, + "Add `: _` before assignment operator", + ); + } + + #[test] + fn add_turbo_fish_function_lifetime_parameter() { + check_assist( + add_turbo_fish, + r#" +fn make<'a, T, A>(t: T, a: A) {} +fn main() { + make$0(5, 2); +} +"#, + r#" +fn make<'a, T, A>(t: T, a: A) {} +fn main() { + make::<${1:_}, ${0:_}>(5, 2); +} +"#, + ); + } + + #[test] + fn add_turbo_fish_function_const_parameter() { + check_assist( + add_turbo_fish, + r#" +fn make(t: T) {} +fn main() { + make$0(3); +} +"#, + r#" +fn make(t: T) {} +fn main() { + make::<${1:_}, ${0:_}>(3); +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/apply_demorgan.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/apply_demorgan.rs new file mode 100644 index 000000000..2853d1d1b --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/apply_demorgan.rs @@ -0,0 +1,234 @@ +use std::collections::VecDeque; + +use syntax::ast::{self, AstNode}; + +use crate::{utils::invert_boolean_expression, AssistContext, AssistId, AssistKind, Assists}; + +// Assist: apply_demorgan +// +// Apply https://en.wikipedia.org/wiki/De_Morgan%27s_laws[De Morgan's law]. +// This transforms expressions of the form `!l || !r` into `!(l && r)`. +// This also works with `&&`. This assist can only be applied with the cursor +// on either `||` or `&&`. +// +// ``` +// fn main() { +// if x != 4 ||$0 y < 3.14 {} +// } +// ``` +// -> +// ``` +// fn main() { +// if !(x == 4 && y >= 3.14) {} +// } +// ``` +pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let expr = ctx.find_node_at_offset::()?; + let op = expr.op_kind()?; + let op_range = expr.op_token()?.text_range(); + + let opposite_op = match op { + ast::BinaryOp::LogicOp(ast::LogicOp::And) => "||", + ast::BinaryOp::LogicOp(ast::LogicOp::Or) => "&&", + _ => return None, + }; + + let cursor_in_range = op_range.contains_range(ctx.selection_trimmed()); + if !cursor_in_range { + return None; + } + + let mut expr = expr; + + // Walk up the tree while we have the same binary operator + while let Some(parent_expr) = expr.syntax().parent().and_then(ast::BinExpr::cast) { + match expr.op_kind() { + Some(parent_op) if parent_op == op => { + expr = parent_expr; + } + _ => break, + } + } + + let mut expr_stack = vec![expr.clone()]; + let mut terms = Vec::new(); + let mut op_ranges = Vec::new(); + + // Find all the children with the same binary operator + while let Some(expr) = expr_stack.pop() { + let mut traverse_bin_expr_arm = |expr| { + if let ast::Expr::BinExpr(bin_expr) = expr { + if let Some(expr_op) = bin_expr.op_kind() { + if expr_op == op { + expr_stack.push(bin_expr); + } else { + terms.push(ast::Expr::BinExpr(bin_expr)); + } + } else { + terms.push(ast::Expr::BinExpr(bin_expr)); + } + } else { + terms.push(expr); + } + }; + + op_ranges.extend(expr.op_token().map(|t| t.text_range())); + traverse_bin_expr_arm(expr.lhs()?); + traverse_bin_expr_arm(expr.rhs()?); + } + + acc.add( + AssistId("apply_demorgan", AssistKind::RefactorRewrite), + "Apply De Morgan's law", + op_range, + |edit| { + terms.sort_by_key(|t| t.syntax().text_range().start()); + let mut terms = VecDeque::from(terms); + + let paren_expr = expr.syntax().parent().and_then(ast::ParenExpr::cast); + + let neg_expr = paren_expr + .clone() + .and_then(|paren_expr| paren_expr.syntax().parent()) + .and_then(ast::PrefixExpr::cast) + .and_then(|prefix_expr| { + if prefix_expr.op_kind().unwrap() == ast::UnaryOp::Not { + Some(prefix_expr) + } else { + None + } + }); + + for op_range in op_ranges { + edit.replace(op_range, opposite_op); + } + + if let Some(paren_expr) = paren_expr { + for term in terms { + let range = term.syntax().text_range(); + let not_term = invert_boolean_expression(term); + + edit.replace(range, not_term.syntax().text()); + } + + if let Some(neg_expr) = neg_expr { + cov_mark::hit!(demorgan_double_negation); + edit.replace(neg_expr.op_token().unwrap().text_range(), ""); + } else { + cov_mark::hit!(demorgan_double_parens); + edit.replace(paren_expr.l_paren_token().unwrap().text_range(), "!("); + } + } else { + if let Some(lhs) = terms.pop_front() { + let lhs_range = lhs.syntax().text_range(); + let not_lhs = invert_boolean_expression(lhs); + + edit.replace(lhs_range, format!("!({}", not_lhs.syntax().text())); + } + + if let Some(rhs) = terms.pop_back() { + let rhs_range = rhs.syntax().text_range(); + let not_rhs = invert_boolean_expression(rhs); + + edit.replace(rhs_range, format!("{})", not_rhs.syntax().text())); + } + + for term in terms { + let term_range = term.syntax().text_range(); + let not_term = invert_boolean_expression(term); + edit.replace(term_range, not_term.syntax().text()); + } + } + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn demorgan_handles_leq() { + check_assist( + apply_demorgan, + r#" +struct S; +fn f() { S < S &&$0 S <= S } +"#, + r#" +struct S; +fn f() { !(S >= S || S > S) } +"#, + ); + } + + #[test] + fn demorgan_handles_geq() { + check_assist( + apply_demorgan, + r#" +struct S; +fn f() { S > S &&$0 S >= S } +"#, + r#" +struct S; +fn f() { !(S <= S || S < S) } +"#, + ); + } + + #[test] + fn demorgan_turns_and_into_or() { + check_assist(apply_demorgan, "fn f() { !x &&$0 !x }", "fn f() { !(x || x) }") + } + + #[test] + fn demorgan_turns_or_into_and() { + check_assist(apply_demorgan, "fn f() { !x ||$0 !x }", "fn f() { !(x && x) }") + } + + #[test] + fn demorgan_removes_inequality() { + check_assist(apply_demorgan, "fn f() { x != x ||$0 !x }", "fn f() { !(x == x && x) }") + } + + #[test] + fn demorgan_general_case() { + check_assist(apply_demorgan, "fn f() { x ||$0 x }", "fn f() { !(!x && !x) }") + } + + #[test] + fn demorgan_multiple_terms() { + check_assist(apply_demorgan, "fn f() { x ||$0 y || z }", "fn f() { !(!x && !y && !z) }"); + check_assist(apply_demorgan, "fn f() { x || y ||$0 z }", "fn f() { !(!x && !y && !z) }"); + } + + #[test] + fn demorgan_doesnt_apply_with_cursor_not_on_op() { + check_assist_not_applicable(apply_demorgan, "fn f() { $0 !x || !x }") + } + + #[test] + fn demorgan_doesnt_double_negation() { + cov_mark::check!(demorgan_double_negation); + check_assist(apply_demorgan, "fn f() { !(x ||$0 x) }", "fn f() { (!x && !x) }") + } + + #[test] + fn demorgan_doesnt_double_parens() { + cov_mark::check!(demorgan_double_parens); + check_assist(apply_demorgan, "fn f() { (x ||$0 x) }", "fn f() { !(!x && !x) }") + } + + // https://github.com/rust-lang/rust-analyzer/issues/10963 + #[test] + fn demorgan_doesnt_hang() { + check_assist( + apply_demorgan, + "fn f() { 1 || 3 &&$0 4 || 5 }", + "fn f() { !(!1 || !3 || !4) || 5 }", + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/auto_import.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/auto_import.rs new file mode 100644 index 000000000..949cf3167 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/auto_import.rs @@ -0,0 +1,1292 @@ +use std::cmp::Reverse; + +use hir::{db::HirDatabase, Module}; +use ide_db::{ + helpers::mod_path_to_ast, + imports::{ + import_assets::{ImportAssets, ImportCandidate, LocatedImport}, + insert_use::{insert_use, ImportScope}, + }, +}; +use syntax::{ast, AstNode, NodeOrToken, SyntaxElement}; + +use crate::{AssistContext, AssistId, AssistKind, Assists, GroupLabel}; + +// Feature: Auto Import +// +// Using the `auto-import` assist it is possible to insert missing imports for unresolved items. +// When inserting an import it will do so in a structured manner by keeping imports grouped, +// separated by a newline in the following order: +// +// - `std` and `core` +// - External Crates +// - Current Crate, paths prefixed by `crate` +// - Current Module, paths prefixed by `self` +// - Super Module, paths prefixed by `super` +// +// Example: +// ```rust +// use std::fs::File; +// +// use itertools::Itertools; +// use syntax::ast; +// +// use crate::utils::insert_use; +// +// use self::auto_import; +// +// use super::AssistContext; +// ``` +// +// .Import Granularity +// +// It is possible to configure how use-trees are merged with the `imports.granularity.group` setting. +// It has the following configurations: +// +// - `crate`: Merge imports from the same crate into a single use statement. This kind of +// nesting is only supported in Rust versions later than 1.24. +// - `module`: Merge imports from the same module into a single use statement. +// - `item`: Don't merge imports at all, creating one import per item. +// - `preserve`: Do not change the granularity of any imports. For auto-import this has the same +// effect as `item`. +// +// In `VS Code` the configuration for this is `rust-analyzer.imports.granularity.group`. +// +// .Import Prefix +// +// The style of imports in the same crate is configurable through the `imports.prefix` setting. +// It has the following configurations: +// +// - `crate`: This setting will force paths to be always absolute, starting with the `crate` +// prefix, unless the item is defined outside of the current crate. +// - `self`: This setting will force paths that are relative to the current module to always +// start with `self`. This will result in paths that always start with either `crate`, `self`, +// `super` or an extern crate identifier. +// - `plain`: This setting does not impose any restrictions in imports. +// +// In `VS Code` the configuration for this is `rust-analyzer.imports.prefix`. +// +// image::https://user-images.githubusercontent.com/48062697/113020673-b85be580-917a-11eb-9022-59585f35d4f8.gif[] + +// Assist: auto_import +// +// If the name is unresolved, provides all possible imports for it. +// +// ``` +// fn main() { +// let map = HashMap$0::new(); +// } +// # pub mod std { pub mod collections { pub struct HashMap { } } } +// ``` +// -> +// ``` +// use std::collections::HashMap; +// +// fn main() { +// let map = HashMap::new(); +// } +// # pub mod std { pub mod collections { pub struct HashMap { } } } +// ``` +pub(crate) fn auto_import(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let (import_assets, syntax_under_caret) = find_importable_node(ctx)?; + let mut proposed_imports = + import_assets.search_for_imports(&ctx.sema, ctx.config.insert_use.prefix_kind); + if proposed_imports.is_empty() { + return None; + } + + let range = match &syntax_under_caret { + NodeOrToken::Node(node) => ctx.sema.original_range(node).range, + NodeOrToken::Token(token) => token.text_range(), + }; + let group_label = group_label(import_assets.import_candidate()); + let scope = ImportScope::find_insert_use_container( + &match syntax_under_caret { + NodeOrToken::Node(it) => it, + NodeOrToken::Token(it) => it.parent()?, + }, + &ctx.sema, + )?; + + // we aren't interested in different namespaces + proposed_imports.dedup_by(|a, b| a.import_path == b.import_path); + + let current_node = match ctx.covering_element() { + NodeOrToken::Node(node) => Some(node), + NodeOrToken::Token(token) => token.parent(), + }; + + let current_module = + current_node.as_ref().and_then(|node| ctx.sema.scope(node)).map(|scope| scope.module()); + + // prioritize more relevant imports + proposed_imports + .sort_by_key(|import| Reverse(relevance_score(ctx, import, current_module.as_ref()))); + + for import in proposed_imports { + acc.add_group( + &group_label, + AssistId("auto_import", AssistKind::QuickFix), + format!("Import `{}`", import.import_path), + range, + |builder| { + let scope = match scope.clone() { + ImportScope::File(it) => ImportScope::File(builder.make_mut(it)), + ImportScope::Module(it) => ImportScope::Module(builder.make_mut(it)), + ImportScope::Block(it) => ImportScope::Block(builder.make_mut(it)), + }; + insert_use(&scope, mod_path_to_ast(&import.import_path), &ctx.config.insert_use); + }, + ); + } + Some(()) +} + +pub(super) fn find_importable_node( + ctx: &AssistContext<'_>, +) -> Option<(ImportAssets, SyntaxElement)> { + if let Some(path_under_caret) = ctx.find_node_at_offset_with_descend::() { + ImportAssets::for_exact_path(&path_under_caret, &ctx.sema) + .zip(Some(path_under_caret.syntax().clone().into())) + } else if let Some(method_under_caret) = + ctx.find_node_at_offset_with_descend::() + { + ImportAssets::for_method_call(&method_under_caret, &ctx.sema) + .zip(Some(method_under_caret.syntax().clone().into())) + } else if let Some(pat) = ctx + .find_node_at_offset_with_descend::() + .filter(ast::IdentPat::is_simple_ident) + { + ImportAssets::for_ident_pat(&ctx.sema, &pat).zip(Some(pat.syntax().clone().into())) + } else { + None + } +} + +fn group_label(import_candidate: &ImportCandidate) -> GroupLabel { + let name = match import_candidate { + ImportCandidate::Path(candidate) => format!("Import {}", candidate.name.text()), + ImportCandidate::TraitAssocItem(candidate) => { + format!("Import a trait for item {}", candidate.assoc_item_name.text()) + } + ImportCandidate::TraitMethod(candidate) => { + format!("Import a trait for method {}", candidate.assoc_item_name.text()) + } + }; + GroupLabel(name) +} + +/// Determine how relevant a given import is in the current context. Higher scores are more +/// relevant. +fn relevance_score( + ctx: &AssistContext<'_>, + import: &LocatedImport, + current_module: Option<&Module>, +) -> i32 { + let mut score = 0; + + let db = ctx.db(); + + let item_module = match import.item_to_import { + hir::ItemInNs::Types(item) | hir::ItemInNs::Values(item) => item.module(db), + hir::ItemInNs::Macros(makro) => Some(makro.module(db)), + }; + + match item_module.zip(current_module) { + // get the distance between the imported path and the current module + // (prefer items that are more local) + Some((item_module, current_module)) => { + score -= module_distance_hueristic(db, ¤t_module, &item_module) as i32; + } + + // could not find relevant modules, so just use the length of the path as an estimate + None => return -(2 * import.import_path.len() as i32), + } + + score +} + +/// A heuristic that gives a higher score to modules that are more separated. +fn module_distance_hueristic(db: &dyn HirDatabase, current: &Module, item: &Module) -> usize { + // get the path starting from the item to the respective crate roots + let mut current_path = current.path_to_root(db); + let mut item_path = item.path_to_root(db); + + // we want paths going from the root to the item + current_path.reverse(); + item_path.reverse(); + + // length of the common prefix of the two paths + let prefix_length = current_path.iter().zip(&item_path).take_while(|(a, b)| a == b).count(); + + // how many modules differ between the two paths (all modules, removing any duplicates) + let distinct_length = current_path.len() + item_path.len() - 2 * prefix_length; + + // cost of importing from another crate + let crate_boundary_cost = if current.krate() == item.krate() { + 0 + } else if item.krate().is_builtin(db) { + 2 + } else { + 4 + }; + + distinct_length + crate_boundary_cost +} + +#[cfg(test)] +mod tests { + use super::*; + + use hir::Semantics; + use ide_db::{ + assists::AssistResolveStrategy, + base_db::{fixture::WithFixture, FileRange}, + RootDatabase, + }; + + use crate::tests::{ + check_assist, check_assist_not_applicable, check_assist_target, TEST_CONFIG, + }; + + fn check_auto_import_order(before: &str, order: &[&str]) { + let (db, file_id, range_or_offset) = RootDatabase::with_range_or_offset(before); + let frange = FileRange { file_id, range: range_or_offset.into() }; + + let sema = Semantics::new(&db); + let config = TEST_CONFIG; + let ctx = AssistContext::new(sema, &config, frange); + let mut acc = Assists::new(&ctx, AssistResolveStrategy::All); + auto_import(&mut acc, &ctx); + let assists = acc.finish(); + + let labels = assists.iter().map(|assist| assist.label.to_string()).collect::>(); + + assert_eq!(labels, order); + } + + #[test] + fn prefer_shorter_paths() { + let before = r" +//- /main.rs crate:main deps:foo,bar +HashMap$0::new(); + +//- /lib.rs crate:foo +pub mod collections { pub struct HashMap; } + +//- /lib.rs crate:bar +pub mod collections { pub mod hash_map { pub struct HashMap; } } + "; + + check_auto_import_order( + before, + &["Import `foo::collections::HashMap`", "Import `bar::collections::hash_map::HashMap`"], + ) + } + + #[test] + fn prefer_same_crate() { + let before = r" +//- /main.rs crate:main deps:foo +HashMap$0::new(); + +mod collections { + pub mod hash_map { + pub struct HashMap; + } +} + +//- /lib.rs crate:foo +pub struct HashMap; + "; + + check_auto_import_order( + before, + &["Import `collections::hash_map::HashMap`", "Import `foo::HashMap`"], + ) + } + + #[test] + fn not_applicable_if_scope_inside_macro() { + check_assist_not_applicable( + auto_import, + r" +mod bar { + pub struct Baz; +} +macro_rules! foo { + ($it:ident) => { + mod __ { + fn __(x: $it) {} + } + }; +} +foo! { + Baz$0 +} +", + ); + } + + #[test] + fn applicable_in_attributes() { + check_assist( + auto_import, + r" +//- proc_macros: identity +#[proc_macros::identity] +mod foo { + mod bar { + const _: Baz$0 = (); + } +} +mod baz { + pub struct Baz; +} +", + r" +#[proc_macros::identity] +mod foo { + mod bar { + use crate::baz::Baz; + + const _: Baz = (); + } +} +mod baz { + pub struct Baz; +} +", + ); + } + + #[test] + fn applicable_when_found_an_import_partial() { + check_assist( + auto_import, + r" + mod std { + pub mod fmt { + pub struct Formatter; + } + } + + use std::fmt; + + $0Formatter + ", + r" + mod std { + pub mod fmt { + pub struct Formatter; + } + } + + use std::fmt::{self, Formatter}; + + Formatter + ", + ); + } + + #[test] + fn applicable_when_found_an_import() { + check_assist( + auto_import, + r" + $0PubStruct + + pub mod PubMod { + pub struct PubStruct; + } + ", + r" + use PubMod::PubStruct; + + PubStruct + + pub mod PubMod { + pub struct PubStruct; + } + ", + ); + } + + #[test] + fn applicable_when_found_an_import_in_macros() { + check_assist( + auto_import, + r" + macro_rules! foo { + ($i:ident) => { fn foo(a: $i) {} } + } + foo!(Pub$0Struct); + + pub mod PubMod { + pub struct PubStruct; + } + ", + r" + use PubMod::PubStruct; + + macro_rules! foo { + ($i:ident) => { fn foo(a: $i) {} } + } + foo!(PubStruct); + + pub mod PubMod { + pub struct PubStruct; + } + ", + ); + } + + #[test] + fn applicable_when_found_multiple_imports() { + check_assist( + auto_import, + r" + PubSt$0ruct + + pub mod PubMod1 { + pub struct PubStruct; + } + pub mod PubMod2 { + pub struct PubStruct; + } + pub mod PubMod3 { + pub struct PubStruct; + } + ", + r" + use PubMod3::PubStruct; + + PubStruct + + pub mod PubMod1 { + pub struct PubStruct; + } + pub mod PubMod2 { + pub struct PubStruct; + } + pub mod PubMod3 { + pub struct PubStruct; + } + ", + ); + } + + #[test] + fn not_applicable_for_already_imported_types() { + check_assist_not_applicable( + auto_import, + r" + use PubMod::PubStruct; + + PubStruct$0 + + pub mod PubMod { + pub struct PubStruct; + } + ", + ); + } + + #[test] + fn not_applicable_for_types_with_private_paths() { + check_assist_not_applicable( + auto_import, + r" + PrivateStruct$0 + + pub mod PubMod { + struct PrivateStruct; + } + ", + ); + } + + #[test] + fn not_applicable_when_no_imports_found() { + check_assist_not_applicable( + auto_import, + " + PubStruct$0", + ); + } + + #[test] + fn function_import() { + check_assist( + auto_import, + r" + test_function$0 + + pub mod PubMod { + pub fn test_function() {}; + } + ", + r" + use PubMod::test_function; + + test_function + + pub mod PubMod { + pub fn test_function() {}; + } + ", + ); + } + + #[test] + fn macro_import() { + check_assist( + auto_import, + r" +//- /lib.rs crate:crate_with_macro +#[macro_export] +macro_rules! foo { + () => () +} + +//- /main.rs crate:main deps:crate_with_macro +fn main() { + foo$0 +} +", + r"use crate_with_macro::foo; + +fn main() { + foo +} +", + ); + } + + #[test] + fn auto_import_target() { + check_assist_target( + auto_import, + r" + struct AssistInfo { + group_label: Option<$0GroupLabel>, + } + + mod m { pub struct GroupLabel; } + ", + "GroupLabel", + ) + } + + #[test] + fn not_applicable_when_path_start_is_imported() { + check_assist_not_applicable( + auto_import, + r" + pub mod mod1 { + pub mod mod2 { + pub mod mod3 { + pub struct TestStruct; + } + } + } + + use mod1::mod2; + fn main() { + mod2::mod3::TestStruct$0 + } + ", + ); + } + + #[test] + fn not_applicable_for_imported_function() { + check_assist_not_applicable( + auto_import, + r" + pub mod test_mod { + pub fn test_function() {} + } + + use test_mod::test_function; + fn main() { + test_function$0 + } + ", + ); + } + + #[test] + fn associated_struct_function() { + check_assist( + auto_import, + r" + mod test_mod { + pub struct TestStruct {} + impl TestStruct { + pub fn test_function() {} + } + } + + fn main() { + TestStruct::test_function$0 + } + ", + r" + use test_mod::TestStruct; + + mod test_mod { + pub struct TestStruct {} + impl TestStruct { + pub fn test_function() {} + } + } + + fn main() { + TestStruct::test_function + } + ", + ); + } + + #[test] + fn associated_struct_const() { + check_assist( + auto_import, + r" + mod test_mod { + pub struct TestStruct {} + impl TestStruct { + const TEST_CONST: u8 = 42; + } + } + + fn main() { + TestStruct::TEST_CONST$0 + } + ", + r" + use test_mod::TestStruct; + + mod test_mod { + pub struct TestStruct {} + impl TestStruct { + const TEST_CONST: u8 = 42; + } + } + + fn main() { + TestStruct::TEST_CONST + } + ", + ); + } + + #[test] + fn associated_trait_function() { + check_assist( + auto_import, + r" + mod test_mod { + pub trait TestTrait { + fn test_function(); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_function() {} + } + } + + fn main() { + test_mod::TestStruct::test_function$0 + } + ", + r" + use test_mod::TestTrait; + + mod test_mod { + pub trait TestTrait { + fn test_function(); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_function() {} + } + } + + fn main() { + test_mod::TestStruct::test_function + } + ", + ); + } + + #[test] + fn not_applicable_for_imported_trait_for_function() { + check_assist_not_applicable( + auto_import, + r" + mod test_mod { + pub trait TestTrait { + fn test_function(); + } + pub trait TestTrait2 { + fn test_function(); + } + pub enum TestEnum { + One, + Two, + } + impl TestTrait2 for TestEnum { + fn test_function() {} + } + impl TestTrait for TestEnum { + fn test_function() {} + } + } + + use test_mod::TestTrait2; + fn main() { + test_mod::TestEnum::test_function$0; + } + ", + ) + } + + #[test] + fn associated_trait_const() { + check_assist( + auto_import, + r" + mod test_mod { + pub trait TestTrait { + const TEST_CONST: u8; + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + const TEST_CONST: u8 = 42; + } + } + + fn main() { + test_mod::TestStruct::TEST_CONST$0 + } + ", + r" + use test_mod::TestTrait; + + mod test_mod { + pub trait TestTrait { + const TEST_CONST: u8; + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + const TEST_CONST: u8 = 42; + } + } + + fn main() { + test_mod::TestStruct::TEST_CONST + } + ", + ); + } + + #[test] + fn not_applicable_for_imported_trait_for_const() { + check_assist_not_applicable( + auto_import, + r" + mod test_mod { + pub trait TestTrait { + const TEST_CONST: u8; + } + pub trait TestTrait2 { + const TEST_CONST: f64; + } + pub enum TestEnum { + One, + Two, + } + impl TestTrait2 for TestEnum { + const TEST_CONST: f64 = 42.0; + } + impl TestTrait for TestEnum { + const TEST_CONST: u8 = 42; + } + } + + use test_mod::TestTrait2; + fn main() { + test_mod::TestEnum::TEST_CONST$0; + } + ", + ) + } + + #[test] + fn trait_method() { + check_assist( + auto_import, + r" + mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } + } + + fn main() { + let test_struct = test_mod::TestStruct {}; + test_struct.test_meth$0od() + } + ", + r" + use test_mod::TestTrait; + + mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } + } + + fn main() { + let test_struct = test_mod::TestStruct {}; + test_struct.test_method() + } + ", + ); + } + + #[test] + fn trait_method_cross_crate() { + check_assist( + auto_import, + r" + //- /main.rs crate:main deps:dep + fn main() { + let test_struct = dep::test_mod::TestStruct {}; + test_struct.test_meth$0od() + } + //- /dep.rs crate:dep + pub mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } + } + ", + r" + use dep::test_mod::TestTrait; + + fn main() { + let test_struct = dep::test_mod::TestStruct {}; + test_struct.test_method() + } + ", + ); + } + + #[test] + fn assoc_fn_cross_crate() { + check_assist( + auto_import, + r" + //- /main.rs crate:main deps:dep + fn main() { + dep::test_mod::TestStruct::test_func$0tion + } + //- /dep.rs crate:dep + pub mod test_mod { + pub trait TestTrait { + fn test_function(); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_function() {} + } + } + ", + r" + use dep::test_mod::TestTrait; + + fn main() { + dep::test_mod::TestStruct::test_function + } + ", + ); + } + + #[test] + fn assoc_const_cross_crate() { + check_assist( + auto_import, + r" + //- /main.rs crate:main deps:dep + fn main() { + dep::test_mod::TestStruct::CONST$0 + } + //- /dep.rs crate:dep + pub mod test_mod { + pub trait TestTrait { + const CONST: bool; + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + const CONST: bool = true; + } + } + ", + r" + use dep::test_mod::TestTrait; + + fn main() { + dep::test_mod::TestStruct::CONST + } + ", + ); + } + + #[test] + fn assoc_fn_as_method_cross_crate() { + check_assist_not_applicable( + auto_import, + r" + //- /main.rs crate:main deps:dep + fn main() { + let test_struct = dep::test_mod::TestStruct {}; + test_struct.test_func$0tion() + } + //- /dep.rs crate:dep + pub mod test_mod { + pub trait TestTrait { + fn test_function(); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_function() {} + } + } + ", + ); + } + + #[test] + fn private_trait_cross_crate() { + check_assist_not_applicable( + auto_import, + r" + //- /main.rs crate:main deps:dep + fn main() { + let test_struct = dep::test_mod::TestStruct {}; + test_struct.test_meth$0od() + } + //- /dep.rs crate:dep + pub mod test_mod { + trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } + } + ", + ); + } + + #[test] + fn not_applicable_for_imported_trait_for_method() { + check_assist_not_applicable( + auto_import, + r" + mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub trait TestTrait2 { + fn test_method(&self); + } + pub enum TestEnum { + One, + Two, + } + impl TestTrait2 for TestEnum { + fn test_method(&self) {} + } + impl TestTrait for TestEnum { + fn test_method(&self) {} + } + } + + use test_mod::TestTrait2; + fn main() { + let one = test_mod::TestEnum::One; + one.test$0_method(); + } + ", + ) + } + + #[test] + fn dep_import() { + check_assist( + auto_import, + r" +//- /lib.rs crate:dep +pub struct Struct; + +//- /main.rs crate:main deps:dep +fn main() { + Struct$0 +} +", + r"use dep::Struct; + +fn main() { + Struct +} +", + ); + } + + #[test] + fn whole_segment() { + // Tests that only imports whose last segment matches the identifier get suggested. + check_assist( + auto_import, + r" +//- /lib.rs crate:dep +pub mod fmt { + pub trait Display {} +} + +pub fn panic_fmt() {} + +//- /main.rs crate:main deps:dep +struct S; + +impl f$0mt::Display for S {} +", + r"use dep::fmt; + +struct S; + +impl fmt::Display for S {} +", + ); + } + + #[test] + fn macro_generated() { + // Tests that macro-generated items are suggested from external crates. + check_assist( + auto_import, + r" +//- /lib.rs crate:dep +macro_rules! mac { + () => { + pub struct Cheese; + }; +} + +mac!(); + +//- /main.rs crate:main deps:dep +fn main() { + Cheese$0; +} +", + r"use dep::Cheese; + +fn main() { + Cheese; +} +", + ); + } + + #[test] + fn casing() { + // Tests that differently cased names don't interfere and we only suggest the matching one. + check_assist( + auto_import, + r" +//- /lib.rs crate:dep +pub struct FMT; +pub struct fmt; + +//- /main.rs crate:main deps:dep +fn main() { + FMT$0; +} +", + r"use dep::FMT; + +fn main() { + FMT; +} +", + ); + } + + #[test] + fn inner_items() { + check_assist( + auto_import, + r#" +mod baz { + pub struct Foo {} +} + +mod bar { + fn bar() { + Foo$0; + println!("Hallo"); + } +} +"#, + r#" +mod baz { + pub struct Foo {} +} + +mod bar { + use crate::baz::Foo; + + fn bar() { + Foo; + println!("Hallo"); + } +} +"#, + ); + } + + #[test] + fn uses_abs_path_with_extern_crate_clash() { + cov_mark::check!(ambiguous_crate_start); + check_assist( + auto_import, + r#" +//- /main.rs crate:main deps:foo +mod foo {} + +const _: () = { + Foo$0 +}; +//- /foo.rs crate:foo +pub struct Foo +"#, + r#" +use ::foo::Foo; + +mod foo {} + +const _: () = { + Foo +}; +"#, + ); + } + + #[test] + fn works_on_ident_patterns() { + check_assist( + auto_import, + r#" +mod foo { + pub struct Foo {} +} +fn foo() { + let Foo$0; +} +"#, + r#" +use foo::Foo; + +mod foo { + pub struct Foo {} +} +fn foo() { + let Foo; +} +"#, + ); + } + + #[test] + fn works_in_derives() { + check_assist( + auto_import, + r#" +//- minicore:derive +mod foo { + #[rustc_builtin_macro] + pub macro Copy {} +} +#[derive(Copy$0)] +struct Foo; +"#, + r#" +use foo::Copy; + +mod foo { + #[rustc_builtin_macro] + pub macro Copy {} +} +#[derive(Copy)] +struct Foo; +"#, + ); + } + + #[test] + fn works_in_use_start() { + check_assist( + auto_import, + r#" +mod bar { + pub mod foo { + pub struct Foo; + } +} +use foo$0::Foo; +"#, + r#" +mod bar { + pub mod foo { + pub struct Foo; + } +} +use bar::foo; +use foo::Foo; +"#, + ); + } + + #[test] + fn not_applicable_in_non_start_use() { + check_assist_not_applicable( + auto_import, + r" +mod bar { + pub mod foo { + pub struct Foo; + } +} +use foo::Foo$0; +", + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/change_visibility.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/change_visibility.rs new file mode 100644 index 000000000..2b1d8f6f0 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/change_visibility.rs @@ -0,0 +1,216 @@ +use syntax::{ + ast::{self, HasName, HasVisibility}, + AstNode, + SyntaxKind::{ + CONST, ENUM, FN, MACRO_DEF, MODULE, STATIC, STRUCT, TRAIT, TYPE_ALIAS, USE, VISIBILITY, + }, + T, +}; + +use crate::{utils::vis_offset, AssistContext, AssistId, AssistKind, Assists}; + +// Assist: change_visibility +// +// Adds or changes existing visibility specifier. +// +// ``` +// $0fn frobnicate() {} +// ``` +// -> +// ``` +// pub(crate) fn frobnicate() {} +// ``` +pub(crate) fn change_visibility(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + if let Some(vis) = ctx.find_node_at_offset::() { + return change_vis(acc, vis); + } + add_vis(acc, ctx) +} + +fn add_vis(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let item_keyword = ctx.token_at_offset().find(|leaf| { + matches!( + leaf.kind(), + T![const] + | T![static] + | T![fn] + | T![mod] + | T![struct] + | T![enum] + | T![trait] + | T![type] + | T![use] + | T![macro] + ) + }); + + let (offset, target) = if let Some(keyword) = item_keyword { + let parent = keyword.parent()?; + let def_kws = + vec![CONST, STATIC, TYPE_ALIAS, FN, MODULE, STRUCT, ENUM, TRAIT, USE, MACRO_DEF]; + // Parent is not a definition, can't add visibility + if !def_kws.iter().any(|&def_kw| def_kw == parent.kind()) { + return None; + } + // Already have visibility, do nothing + if parent.children().any(|child| child.kind() == VISIBILITY) { + return None; + } + (vis_offset(&parent), keyword.text_range()) + } else if let Some(field_name) = ctx.find_node_at_offset::() { + let field = field_name.syntax().ancestors().find_map(ast::RecordField::cast)?; + if field.name()? != field_name { + cov_mark::hit!(change_visibility_field_false_positive); + return None; + } + if field.visibility().is_some() { + return None; + } + (vis_offset(field.syntax()), field_name.syntax().text_range()) + } else if let Some(field) = ctx.find_node_at_offset::() { + if field.visibility().is_some() { + return None; + } + (vis_offset(field.syntax()), field.syntax().text_range()) + } else { + return None; + }; + + acc.add( + AssistId("change_visibility", AssistKind::RefactorRewrite), + "Change visibility to pub(crate)", + target, + |edit| { + edit.insert(offset, "pub(crate) "); + }, + ) +} + +fn change_vis(acc: &mut Assists, vis: ast::Visibility) -> Option<()> { + if vis.syntax().text() == "pub" { + let target = vis.syntax().text_range(); + return acc.add( + AssistId("change_visibility", AssistKind::RefactorRewrite), + "Change Visibility to pub(crate)", + target, + |edit| { + edit.replace(vis.syntax().text_range(), "pub(crate)"); + }, + ); + } + if vis.syntax().text() == "pub(crate)" { + let target = vis.syntax().text_range(); + return acc.add( + AssistId("change_visibility", AssistKind::RefactorRewrite), + "Change visibility to pub", + target, + |edit| { + edit.replace(vis.syntax().text_range(), "pub"); + }, + ); + } + None +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + use super::*; + + #[test] + fn change_visibility_adds_pub_crate_to_items() { + check_assist(change_visibility, "$0fn foo() {}", "pub(crate) fn foo() {}"); + check_assist(change_visibility, "f$0n foo() {}", "pub(crate) fn foo() {}"); + check_assist(change_visibility, "$0struct Foo {}", "pub(crate) struct Foo {}"); + check_assist(change_visibility, "$0mod foo {}", "pub(crate) mod foo {}"); + check_assist(change_visibility, "$0trait Foo {}", "pub(crate) trait Foo {}"); + check_assist(change_visibility, "m$0od {}", "pub(crate) mod {}"); + check_assist(change_visibility, "unsafe f$0n foo() {}", "pub(crate) unsafe fn foo() {}"); + check_assist(change_visibility, "$0macro foo() {}", "pub(crate) macro foo() {}"); + check_assist(change_visibility, "$0use foo;", "pub(crate) use foo;"); + } + + #[test] + fn change_visibility_works_with_struct_fields() { + check_assist( + change_visibility, + r"struct S { $0field: u32 }", + r"struct S { pub(crate) field: u32 }", + ); + check_assist(change_visibility, r"struct S ( $0u32 )", r"struct S ( pub(crate) u32 )"); + } + + #[test] + fn change_visibility_field_false_positive() { + cov_mark::check!(change_visibility_field_false_positive); + check_assist_not_applicable( + change_visibility, + r"struct S { field: [(); { let $0x = ();}] }", + ) + } + + #[test] + fn change_visibility_pub_to_pub_crate() { + check_assist(change_visibility, "$0pub fn foo() {}", "pub(crate) fn foo() {}") + } + + #[test] + fn change_visibility_pub_crate_to_pub() { + check_assist(change_visibility, "$0pub(crate) fn foo() {}", "pub fn foo() {}") + } + + #[test] + fn change_visibility_const() { + check_assist(change_visibility, "$0const FOO = 3u8;", "pub(crate) const FOO = 3u8;"); + } + + #[test] + fn change_visibility_static() { + check_assist(change_visibility, "$0static FOO = 3u8;", "pub(crate) static FOO = 3u8;"); + } + + #[test] + fn change_visibility_type_alias() { + check_assist(change_visibility, "$0type T = ();", "pub(crate) type T = ();"); + } + + #[test] + fn change_visibility_handles_comment_attrs() { + check_assist( + change_visibility, + r" + /// docs + + // comments + + #[derive(Debug)] + $0struct Foo; + ", + r" + /// docs + + // comments + + #[derive(Debug)] + pub(crate) struct Foo; + ", + ) + } + + #[test] + fn not_applicable_for_enum_variants() { + check_assist_not_applicable( + change_visibility, + r"mod foo { pub enum Foo {Foo1} } + fn main() { foo::Foo::Foo1$0 } ", + ); + } + + #[test] + fn change_visibility_target() { + check_assist_target(change_visibility, "$0fn foo() {}", "fn"); + check_assist_target(change_visibility, "pub(crate)$0 fn foo() {}", "pub(crate)"); + check_assist_target(change_visibility, "struct S { $0field: u32 }", "field"); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_bool_then.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_bool_then.rs new file mode 100644 index 000000000..db96ad330 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_bool_then.rs @@ -0,0 +1,575 @@ +use hir::{known, AsAssocItem, Semantics}; +use ide_db::{ + famous_defs::FamousDefs, + syntax_helpers::node_ext::{ + block_as_lone_tail, for_each_tail_expr, is_pattern_cond, preorder_expr, + }, + RootDatabase, +}; +use itertools::Itertools; +use syntax::{ + ast::{self, edit::AstNodeEdit, make, HasArgList}, + ted, AstNode, SyntaxNode, +}; + +use crate::{ + utils::{invert_boolean_expression, unwrap_trivial_block}, + AssistContext, AssistId, AssistKind, Assists, +}; + +// Assist: convert_if_to_bool_then +// +// Converts an if expression into a corresponding `bool::then` call. +// +// ``` +// # //- minicore: option +// fn main() { +// if$0 cond { +// Some(val) +// } else { +// None +// } +// } +// ``` +// -> +// ``` +// fn main() { +// cond.then(|| val) +// } +// ``` +pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + // FIXME applies to match as well + let expr = ctx.find_node_at_offset::()?; + if !expr.if_token()?.text_range().contains_inclusive(ctx.offset()) { + return None; + } + + let cond = expr.condition().filter(|cond| !is_pattern_cond(cond.clone()))?; + let then = expr.then_branch()?; + let else_ = match expr.else_branch()? { + ast::ElseBranch::Block(b) => b, + ast::ElseBranch::IfExpr(_) => { + cov_mark::hit!(convert_if_to_bool_then_chain); + return None; + } + }; + + let (none_variant, some_variant) = option_variants(&ctx.sema, expr.syntax())?; + + let (invert_cond, closure_body) = match ( + block_is_none_variant(&ctx.sema, &then, none_variant), + block_is_none_variant(&ctx.sema, &else_, none_variant), + ) { + (invert @ true, false) => (invert, ast::Expr::BlockExpr(else_)), + (invert @ false, true) => (invert, ast::Expr::BlockExpr(then)), + _ => return None, + }; + + if is_invalid_body(&ctx.sema, some_variant, &closure_body) { + cov_mark::hit!(convert_if_to_bool_then_pattern_invalid_body); + return None; + } + + let target = expr.syntax().text_range(); + acc.add( + AssistId("convert_if_to_bool_then", AssistKind::RefactorRewrite), + "Convert `if` expression to `bool::then` call", + target, + |builder| { + let closure_body = closure_body.clone_for_update(); + // Rewrite all `Some(e)` in tail position to `e` + let mut replacements = Vec::new(); + for_each_tail_expr(&closure_body, &mut |e| { + let e = match e { + ast::Expr::BreakExpr(e) => e.expr(), + e @ ast::Expr::CallExpr(_) => Some(e.clone()), + _ => None, + }; + if let Some(ast::Expr::CallExpr(call)) = e { + if let Some(arg_list) = call.arg_list() { + if let Some(arg) = arg_list.args().next() { + replacements.push((call.syntax().clone(), arg.syntax().clone())); + } + } + } + }); + replacements.into_iter().for_each(|(old, new)| ted::replace(old, new)); + let closure_body = match closure_body { + ast::Expr::BlockExpr(block) => unwrap_trivial_block(block), + e => e, + }; + + let parenthesize = matches!( + cond, + ast::Expr::BinExpr(_) + | ast::Expr::BlockExpr(_) + | ast::Expr::BoxExpr(_) + | ast::Expr::BreakExpr(_) + | ast::Expr::CastExpr(_) + | ast::Expr::ClosureExpr(_) + | ast::Expr::ContinueExpr(_) + | ast::Expr::ForExpr(_) + | ast::Expr::IfExpr(_) + | ast::Expr::LoopExpr(_) + | ast::Expr::MacroExpr(_) + | ast::Expr::MatchExpr(_) + | ast::Expr::PrefixExpr(_) + | ast::Expr::RangeExpr(_) + | ast::Expr::RefExpr(_) + | ast::Expr::ReturnExpr(_) + | ast::Expr::WhileExpr(_) + | ast::Expr::YieldExpr(_) + ); + let cond = if invert_cond { invert_boolean_expression(cond) } else { cond }; + let cond = if parenthesize { make::expr_paren(cond) } else { cond }; + let arg_list = make::arg_list(Some(make::expr_closure(None, closure_body))); + let mcall = make::expr_method_call(cond, make::name_ref("then"), arg_list); + builder.replace(target, mcall.to_string()); + }, + ) +} + +// Assist: convert_bool_then_to_if +// +// Converts a `bool::then` method call to an equivalent if expression. +// +// ``` +// # //- minicore: bool_impl +// fn main() { +// (0 == 0).then$0(|| val) +// } +// ``` +// -> +// ``` +// fn main() { +// if 0 == 0 { +// Some(val) +// } else { +// None +// } +// } +// ``` +pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let name_ref = ctx.find_node_at_offset::()?; + let mcall = name_ref.syntax().parent().and_then(ast::MethodCallExpr::cast)?; + let receiver = mcall.receiver()?; + let closure_body = mcall.arg_list()?.args().exactly_one().ok()?; + let closure_body = match closure_body { + ast::Expr::ClosureExpr(expr) => expr.body()?, + _ => return None, + }; + // Verify this is `bool::then` that is being called. + let func = ctx.sema.resolve_method_call(&mcall)?; + if func.name(ctx.sema.db).to_string() != "then" { + return None; + } + let assoc = func.as_assoc_item(ctx.sema.db)?; + match assoc.container(ctx.sema.db) { + hir::AssocItemContainer::Impl(impl_) if impl_.self_ty(ctx.sema.db).is_bool() => {} + _ => return None, + } + + let target = mcall.syntax().text_range(); + acc.add( + AssistId("convert_bool_then_to_if", AssistKind::RefactorRewrite), + "Convert `bool::then` call to `if`", + target, + |builder| { + let closure_body = match closure_body { + ast::Expr::BlockExpr(block) => block, + e => make::block_expr(None, Some(e)), + }; + + let closure_body = closure_body.clone_for_update(); + // Wrap all tails in `Some(...)` + let none_path = make::expr_path(make::ext::ident_path("None")); + let some_path = make::expr_path(make::ext::ident_path("Some")); + let mut replacements = Vec::new(); + for_each_tail_expr(&ast::Expr::BlockExpr(closure_body.clone()), &mut |e| { + let e = match e { + ast::Expr::BreakExpr(e) => e.expr(), + ast::Expr::ReturnExpr(e) => e.expr(), + _ => Some(e.clone()), + }; + if let Some(expr) = e { + replacements.push(( + expr.syntax().clone(), + make::expr_call(some_path.clone(), make::arg_list(Some(expr))) + .syntax() + .clone_for_update(), + )); + } + }); + replacements.into_iter().for_each(|(old, new)| ted::replace(old, new)); + + let cond = match &receiver { + ast::Expr::ParenExpr(expr) => expr.expr().unwrap_or(receiver), + _ => receiver, + }; + let if_expr = make::expr_if( + cond, + closure_body.reset_indent(), + Some(ast::ElseBranch::Block(make::block_expr(None, Some(none_path)))), + ) + .indent(mcall.indent_level()); + + builder.replace(target, if_expr.to_string()); + }, + ) +} + +fn option_variants( + sema: &Semantics<'_, RootDatabase>, + expr: &SyntaxNode, +) -> Option<(hir::Variant, hir::Variant)> { + let fam = FamousDefs(sema, sema.scope(expr)?.krate()); + let option_variants = fam.core_option_Option()?.variants(sema.db); + match &*option_variants { + &[variant0, variant1] => Some(if variant0.name(sema.db) == known::None { + (variant0, variant1) + } else { + (variant1, variant0) + }), + _ => None, + } +} + +/// Traverses the expression checking if it contains `return` or `?` expressions or if any tail is not a `Some(expr)` expression. +/// If any of these conditions are met it is impossible to rewrite this as a `bool::then` call. +fn is_invalid_body( + sema: &Semantics<'_, RootDatabase>, + some_variant: hir::Variant, + expr: &ast::Expr, +) -> bool { + let mut invalid = false; + preorder_expr(expr, &mut |e| { + invalid |= + matches!(e, syntax::WalkEvent::Enter(ast::Expr::TryExpr(_) | ast::Expr::ReturnExpr(_))); + invalid + }); + if !invalid { + for_each_tail_expr(expr, &mut |e| { + if invalid { + return; + } + let e = match e { + ast::Expr::BreakExpr(e) => e.expr(), + e @ ast::Expr::CallExpr(_) => Some(e.clone()), + _ => None, + }; + if let Some(ast::Expr::CallExpr(call)) = e { + if let Some(ast::Expr::PathExpr(p)) = call.expr() { + let res = p.path().and_then(|p| sema.resolve_path(&p)); + if let Some(hir::PathResolution::Def(hir::ModuleDef::Variant(v))) = res { + return invalid |= v != some_variant; + } + } + } + invalid = true + }); + } + invalid +} + +fn block_is_none_variant( + sema: &Semantics<'_, RootDatabase>, + block: &ast::BlockExpr, + none_variant: hir::Variant, +) -> bool { + block_as_lone_tail(block).and_then(|e| match e { + ast::Expr::PathExpr(pat) => match sema.resolve_path(&pat.path()?)? { + hir::PathResolution::Def(hir::ModuleDef::Variant(v)) => Some(v), + _ => None, + }, + _ => None, + }) == Some(none_variant) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn convert_if_to_bool_then_simple() { + check_assist( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + if$0 true { + Some(15) + } else { + None + } +} +", + r" +fn main() { + true.then(|| 15) +} +", + ); + } + + #[test] + fn convert_if_to_bool_then_invert() { + check_assist( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + if$0 true { + None + } else { + Some(15) + } +} +", + r" +fn main() { + false.then(|| 15) +} +", + ); + } + + #[test] + fn convert_if_to_bool_then_none_none() { + check_assist_not_applicable( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + if$0 true { + None + } else { + None + } +} +", + ); + } + + #[test] + fn convert_if_to_bool_then_some_some() { + check_assist_not_applicable( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + if$0 true { + Some(15) + } else { + Some(15) + } +} +", + ); + } + + #[test] + fn convert_if_to_bool_then_mixed() { + check_assist_not_applicable( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + if$0 true { + if true { + Some(15) + } else { + None + } + } else { + None + } +} +", + ); + } + + #[test] + fn convert_if_to_bool_then_chain() { + cov_mark::check!(convert_if_to_bool_then_chain); + check_assist_not_applicable( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + if$0 true { + Some(15) + } else if true { + None + } else { + None + } +} +", + ); + } + + #[test] + fn convert_if_to_bool_then_pattern_cond() { + check_assist_not_applicable( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + if$0 let true = true { + Some(15) + } else { + None + } +} +", + ); + } + + #[test] + fn convert_if_to_bool_then_pattern_invalid_body() { + cov_mark::check_count!(convert_if_to_bool_then_pattern_invalid_body, 2); + check_assist_not_applicable( + convert_if_to_bool_then, + r" +//- minicore:option +fn make_me_an_option() -> Option { None } +fn main() { + if$0 true { + if true { + make_me_an_option() + } else { + Some(15) + } + } else { + None + } +} +", + ); + check_assist_not_applicable( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + if$0 true { + if true { + return; + } + Some(15) + } else { + None + } +} +", + ); + } + + #[test] + fn convert_bool_then_to_if_inapplicable() { + check_assist_not_applicable( + convert_bool_then_to_if, + r" +//- minicore:bool_impl +fn main() { + 0.t$0hen(|| 15); +} +", + ); + check_assist_not_applicable( + convert_bool_then_to_if, + r" +//- minicore:bool_impl +fn main() { + true.t$0hen(15); +} +", + ); + check_assist_not_applicable( + convert_bool_then_to_if, + r" +//- minicore:bool_impl +fn main() { + true.t$0hen(|| 15, 15); +} +", + ); + } + + #[test] + fn convert_bool_then_to_if_simple() { + check_assist( + convert_bool_then_to_if, + r" +//- minicore:bool_impl +fn main() { + true.t$0hen(|| 15) +} +", + r" +fn main() { + if true { + Some(15) + } else { + None + } +} +", + ); + check_assist( + convert_bool_then_to_if, + r" +//- minicore:bool_impl +fn main() { + true.t$0hen(|| { + 15 + }) +} +", + r" +fn main() { + if true { + Some(15) + } else { + None + } +} +", + ); + } + + #[test] + fn convert_bool_then_to_if_tails() { + check_assist( + convert_bool_then_to_if, + r" +//- minicore:bool_impl +fn main() { + true.t$0hen(|| { + loop { + if false { + break 0; + } + break 15; + } + }) +} +", + r" +fn main() { + if true { + loop { + if false { + break Some(0); + } + break Some(15); + } + } else { + None + } +} +", + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_comment_block.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_comment_block.rs new file mode 100644 index 000000000..f171dd81a --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_comment_block.rs @@ -0,0 +1,395 @@ +use itertools::Itertools; +use syntax::{ + ast::{self, edit::IndentLevel, Comment, CommentKind, CommentShape, Whitespace}, + AstToken, Direction, SyntaxElement, TextRange, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: line_to_block +// +// Converts comments between block and single-line form. +// +// ``` +// // Multi-line$0 +// // comment +// ``` +// -> +// ``` +// /* +// Multi-line +// comment +// */ +// ``` +pub(crate) fn convert_comment_block(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let comment = ctx.find_token_at_offset::()?; + // Only allow comments which are alone on their line + if let Some(prev) = comment.syntax().prev_token() { + if Whitespace::cast(prev).filter(|w| w.text().contains('\n')).is_none() { + return None; + } + } + + match comment.kind().shape { + ast::CommentShape::Block => block_to_line(acc, comment), + ast::CommentShape::Line => line_to_block(acc, comment), + } +} + +fn block_to_line(acc: &mut Assists, comment: ast::Comment) -> Option<()> { + let target = comment.syntax().text_range(); + + acc.add( + AssistId("block_to_line", AssistKind::RefactorRewrite), + "Replace block comment with line comments", + target, + |edit| { + let indentation = IndentLevel::from_token(comment.syntax()); + let line_prefix = CommentKind { shape: CommentShape::Line, ..comment.kind() }.prefix(); + + let text = comment.text(); + let text = &text[comment.prefix().len()..(text.len() - "*/".len())].trim(); + + let lines = text.lines().peekable(); + + let indent_spaces = indentation.to_string(); + let output = lines + .map(|l| l.trim_start_matches(&indent_spaces)) + .map(|l| { + // Don't introduce trailing whitespace + if l.is_empty() { + line_prefix.to_string() + } else { + format!("{} {}", line_prefix, l.trim_start_matches(&indent_spaces)) + } + }) + .join(&format!("\n{}", indent_spaces)); + + edit.replace(target, output) + }, + ) +} + +fn line_to_block(acc: &mut Assists, comment: ast::Comment) -> Option<()> { + // Find all the comments we'll be collapsing into a block + let comments = relevant_line_comments(&comment); + + // Establish the target of our edit based on the comments we found + let target = TextRange::new( + comments[0].syntax().text_range().start(), + comments.last().unwrap().syntax().text_range().end(), + ); + + acc.add( + AssistId("line_to_block", AssistKind::RefactorRewrite), + "Replace line comments with a single block comment", + target, + |edit| { + // We pick a single indentation level for the whole block comment based on the + // comment where the assist was invoked. This will be prepended to the + // contents of each line comment when they're put into the block comment. + let indentation = IndentLevel::from_token(comment.syntax()); + + let block_comment_body = + comments.into_iter().map(|c| line_comment_text(indentation, c)).join("\n"); + + let block_prefix = + CommentKind { shape: CommentShape::Block, ..comment.kind() }.prefix(); + + let output = format!("{}\n{}\n{}*/", block_prefix, block_comment_body, indentation); + + edit.replace(target, output) + }, + ) +} + +/// The line -> block assist can be invoked from anywhere within a sequence of line comments. +/// relevant_line_comments crawls backwards and forwards finding the complete sequence of comments that will +/// be joined. +fn relevant_line_comments(comment: &ast::Comment) -> Vec { + // The prefix identifies the kind of comment we're dealing with + let prefix = comment.prefix(); + let same_prefix = |c: &ast::Comment| c.prefix() == prefix; + + // These tokens are allowed to exist between comments + let skippable = |not: &SyntaxElement| { + not.clone() + .into_token() + .and_then(Whitespace::cast) + .map(|w| !w.spans_multiple_lines()) + .unwrap_or(false) + }; + + // Find all preceding comments (in reverse order) that have the same prefix + let prev_comments = comment + .syntax() + .siblings_with_tokens(Direction::Prev) + .filter(|s| !skippable(s)) + .map(|not| not.into_token().and_then(Comment::cast).filter(same_prefix)) + .take_while(|opt_com| opt_com.is_some()) + .flatten() + .skip(1); // skip the first element so we don't duplicate it in next_comments + + let next_comments = comment + .syntax() + .siblings_with_tokens(Direction::Next) + .filter(|s| !skippable(s)) + .map(|not| not.into_token().and_then(Comment::cast).filter(same_prefix)) + .take_while(|opt_com| opt_com.is_some()) + .flatten(); + + let mut comments: Vec<_> = prev_comments.collect(); + comments.reverse(); + comments.extend(next_comments); + comments +} + +// Line comments usually begin with a single space character following the prefix as seen here: +//^ +// But comments can also include indented text: +// > Hello there +// +// We handle this by stripping *AT MOST* one space character from the start of the line +// This has its own problems because it can cause alignment issues: +// +// /* +// a ----> a +//b ----> b +// */ +// +// But since such comments aren't idiomatic we're okay with this. +fn line_comment_text(indentation: IndentLevel, comm: ast::Comment) -> String { + let contents_without_prefix = comm.text().strip_prefix(comm.prefix()).unwrap(); + let contents = contents_without_prefix.strip_prefix(' ').unwrap_or(contents_without_prefix); + + // Don't add the indentation if the line is empty + if contents.is_empty() { + contents.to_owned() + } else { + indentation.to_string() + contents + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn single_line_to_block() { + check_assist( + convert_comment_block, + r#" +// line$0 comment +fn main() { + foo(); +} +"#, + r#" +/* +line comment +*/ +fn main() { + foo(); +} +"#, + ); + } + + #[test] + fn single_line_to_block_indented() { + check_assist( + convert_comment_block, + r#" +fn main() { + // line$0 comment + foo(); +} +"#, + r#" +fn main() { + /* + line comment + */ + foo(); +} +"#, + ); + } + + #[test] + fn multiline_to_block() { + check_assist( + convert_comment_block, + r#" +fn main() { + // above + // line$0 comment + // + // below + foo(); +} +"#, + r#" +fn main() { + /* + above + line comment + + below + */ + foo(); +} +"#, + ); + } + + #[test] + fn end_of_line_to_block() { + check_assist_not_applicable( + convert_comment_block, + r#" +fn main() { + foo(); // end-of-line$0 comment +} +"#, + ); + } + + #[test] + fn single_line_different_kinds() { + check_assist( + convert_comment_block, + r#" +fn main() { + /// different prefix + // line$0 comment + // below + foo(); +} +"#, + r#" +fn main() { + /// different prefix + /* + line comment + below + */ + foo(); +} +"#, + ); + } + + #[test] + fn single_line_separate_chunks() { + check_assist( + convert_comment_block, + r#" +fn main() { + // different chunk + + // line$0 comment + // below + foo(); +} +"#, + r#" +fn main() { + // different chunk + + /* + line comment + below + */ + foo(); +} +"#, + ); + } + + #[test] + fn doc_block_comment_to_lines() { + check_assist( + convert_comment_block, + r#" +/** + hi$0 there +*/ +"#, + r#" +/// hi there +"#, + ); + } + + #[test] + fn block_comment_to_lines() { + check_assist( + convert_comment_block, + r#" +/* + hi$0 there +*/ +"#, + r#" +// hi there +"#, + ); + } + + #[test] + fn inner_doc_block_to_lines() { + check_assist( + convert_comment_block, + r#" +/*! + hi$0 there +*/ +"#, + r#" +//! hi there +"#, + ); + } + + #[test] + fn block_to_lines_indent() { + check_assist( + convert_comment_block, + r#" +fn main() { + /*! + hi$0 there + + ``` + code_sample + ``` + */ +} +"#, + r#" +fn main() { + //! hi there + //! + //! ``` + //! code_sample + //! ``` +} +"#, + ); + } + + #[test] + fn end_of_line_block_to_line() { + check_assist_not_applicable( + convert_comment_block, + r#" +fn main() { + foo(); /* end-of-line$0 comment */ +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_integer_literal.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_integer_literal.rs new file mode 100644 index 000000000..9060696cd --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_integer_literal.rs @@ -0,0 +1,268 @@ +use syntax::{ast, ast::Radix, AstToken}; + +use crate::{AssistContext, AssistId, AssistKind, Assists, GroupLabel}; + +// Assist: convert_integer_literal +// +// Converts the base of integer literals to other bases. +// +// ``` +// const _: i32 = 10$0; +// ``` +// -> +// ``` +// const _: i32 = 0b1010; +// ``` +pub(crate) fn convert_integer_literal(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let literal = ctx.find_node_at_offset::()?; + let literal = match literal.kind() { + ast::LiteralKind::IntNumber(it) => it, + _ => return None, + }; + let radix = literal.radix(); + let value = literal.value()?; + let suffix = literal.suffix(); + + let range = literal.syntax().text_range(); + let group_id = GroupLabel("Convert integer base".into()); + + for &target_radix in Radix::ALL { + if target_radix == radix { + continue; + } + + let mut converted = match target_radix { + Radix::Binary => format!("0b{:b}", value), + Radix::Octal => format!("0o{:o}", value), + Radix::Decimal => value.to_string(), + Radix::Hexadecimal => format!("0x{:X}", value), + }; + + let label = format!("Convert {} to {}{}", literal, converted, suffix.unwrap_or_default()); + + // Appends the type suffix back into the new literal if it exists. + if let Some(suffix) = suffix { + converted.push_str(suffix); + } + + acc.add_group( + &group_id, + AssistId("convert_integer_literal", AssistKind::RefactorInline), + label, + range, + |builder| builder.replace(range, converted), + ); + } + + Some(()) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist_by_label, check_assist_not_applicable, check_assist_target}; + + use super::*; + + #[test] + fn binary_target() { + check_assist_target(convert_integer_literal, "const _: i32 = 0b1010$0;", "0b1010"); + } + + #[test] + fn octal_target() { + check_assist_target(convert_integer_literal, "const _: i32 = 0o12$0;", "0o12"); + } + + #[test] + fn decimal_target() { + check_assist_target(convert_integer_literal, "const _: i32 = 10$0;", "10"); + } + + #[test] + fn hexadecimal_target() { + check_assist_target(convert_integer_literal, "const _: i32 = 0xA$0;", "0xA"); + } + + #[test] + fn binary_target_with_underscores() { + check_assist_target(convert_integer_literal, "const _: i32 = 0b10_10$0;", "0b10_10"); + } + + #[test] + fn octal_target_with_underscores() { + check_assist_target(convert_integer_literal, "const _: i32 = 0o1_2$0;", "0o1_2"); + } + + #[test] + fn decimal_target_with_underscores() { + check_assist_target(convert_integer_literal, "const _: i32 = 1_0$0;", "1_0"); + } + + #[test] + fn hexadecimal_target_with_underscores() { + check_assist_target(convert_integer_literal, "const _: i32 = 0x_A$0;", "0x_A"); + } + + #[test] + fn convert_decimal_integer() { + let before = "const _: i32 = 1000$0;"; + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 0b1111101000;", + "Convert 1000 to 0b1111101000", + ); + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 0o1750;", + "Convert 1000 to 0o1750", + ); + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 0x3E8;", + "Convert 1000 to 0x3E8", + ); + } + + #[test] + fn convert_hexadecimal_integer() { + let before = "const _: i32 = 0xFF$0;"; + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 0b11111111;", + "Convert 0xFF to 0b11111111", + ); + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 0o377;", + "Convert 0xFF to 0o377", + ); + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 255;", + "Convert 0xFF to 255", + ); + } + + #[test] + fn convert_binary_integer() { + let before = "const _: i32 = 0b11111111$0;"; + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 0o377;", + "Convert 0b11111111 to 0o377", + ); + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 255;", + "Convert 0b11111111 to 255", + ); + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 0xFF;", + "Convert 0b11111111 to 0xFF", + ); + } + + #[test] + fn convert_octal_integer() { + let before = "const _: i32 = 0o377$0;"; + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 0b11111111;", + "Convert 0o377 to 0b11111111", + ); + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 255;", + "Convert 0o377 to 255", + ); + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 0xFF;", + "Convert 0o377 to 0xFF", + ); + } + + #[test] + fn convert_integer_with_underscores() { + let before = "const _: i32 = 1_00_0$0;"; + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 0b1111101000;", + "Convert 1_00_0 to 0b1111101000", + ); + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 0o1750;", + "Convert 1_00_0 to 0o1750", + ); + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 0x3E8;", + "Convert 1_00_0 to 0x3E8", + ); + } + + #[test] + fn convert_integer_with_suffix() { + let before = "const _: i32 = 1000i32$0;"; + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 0b1111101000i32;", + "Convert 1000i32 to 0b1111101000i32", + ); + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 0o1750i32;", + "Convert 1000i32 to 0o1750i32", + ); + + check_assist_by_label( + convert_integer_literal, + before, + "const _: i32 = 0x3E8i32;", + "Convert 1000i32 to 0x3E8i32", + ); + } + + #[test] + fn convert_overflowing_literal() { + let before = "const _: i32 = + 111111111111111111111111111111111111111111111111111111111111111111111111$0;"; + check_assist_not_applicable(convert_integer_literal, before); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_into_to_from.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_into_to_from.rs new file mode 100644 index 000000000..30f6dd41a --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_into_to_from.rs @@ -0,0 +1,351 @@ +use ide_db::{famous_defs::FamousDefs, helpers::mod_path_to_ast, traits::resolve_target_trait}; +use syntax::ast::{self, AstNode, HasName}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// FIXME: this should be a diagnostic + +// Assist: convert_into_to_from +// +// Converts an Into impl to an equivalent From impl. +// +// ``` +// # //- minicore: from +// impl $0Into for usize { +// fn into(self) -> Thing { +// Thing { +// b: self.to_string(), +// a: self +// } +// } +// } +// ``` +// -> +// ``` +// impl From for Thing { +// fn from(val: usize) -> Self { +// Thing { +// b: val.to_string(), +// a: val +// } +// } +// } +// ``` +pub(crate) fn convert_into_to_from(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let impl_ = ctx.find_node_at_offset::()?; + let src_type = impl_.self_ty()?; + let ast_trait = impl_.trait_()?; + + let module = ctx.sema.scope(impl_.syntax())?.module(); + + let trait_ = resolve_target_trait(&ctx.sema, &impl_)?; + if trait_ != FamousDefs(&ctx.sema, module.krate()).core_convert_Into()? { + return None; + } + + let src_type_path = { + let src_type_path = src_type.syntax().descendants().find_map(ast::Path::cast)?; + let src_type_def = match ctx.sema.resolve_path(&src_type_path) { + Some(hir::PathResolution::Def(module_def)) => module_def, + _ => return None, + }; + + mod_path_to_ast(&module.find_use_path(ctx.db(), src_type_def)?) + }; + + let dest_type = match &ast_trait { + ast::Type::PathType(path) => { + path.path()?.segment()?.generic_arg_list()?.generic_args().next()? + } + _ => return None, + }; + + let into_fn = impl_.assoc_item_list()?.assoc_items().find_map(|item| { + if let ast::AssocItem::Fn(f) = item { + if f.name()?.text() == "into" { + return Some(f); + } + }; + None + })?; + + let into_fn_name = into_fn.name()?; + let into_fn_params = into_fn.param_list()?; + let into_fn_return = into_fn.ret_type()?; + + let selfs = into_fn + .body()? + .syntax() + .descendants() + .filter_map(ast::NameRef::cast) + .filter(|name| name.text() == "self" || name.text() == "Self"); + + acc.add( + AssistId("convert_into_to_from", AssistKind::RefactorRewrite), + "Convert Into to From", + impl_.syntax().text_range(), + |builder| { + builder.replace(src_type.syntax().text_range(), dest_type.to_string()); + builder.replace(ast_trait.syntax().text_range(), format!("From<{}>", src_type)); + builder.replace(into_fn_return.syntax().text_range(), "-> Self"); + builder.replace(into_fn_params.syntax().text_range(), format!("(val: {})", src_type)); + builder.replace(into_fn_name.syntax().text_range(), "from"); + + for s in selfs { + match s.text().as_ref() { + "self" => builder.replace(s.syntax().text_range(), "val"), + "Self" => builder.replace(s.syntax().text_range(), src_type_path.to_string()), + _ => {} + } + } + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable}; + + #[test] + fn convert_into_to_from_converts_a_struct() { + check_assist( + convert_into_to_from, + r#" +//- minicore: from +struct Thing { + a: String, + b: usize +} + +impl $0core::convert::Into for usize { + fn into(self) -> Thing { + Thing { + b: self.to_string(), + a: self + } + } +} +"#, + r#" +struct Thing { + a: String, + b: usize +} + +impl From for Thing { + fn from(val: usize) -> Self { + Thing { + b: val.to_string(), + a: val + } + } +} +"#, + ) + } + + #[test] + fn convert_into_to_from_converts_enums() { + check_assist( + convert_into_to_from, + r#" +//- minicore: from +enum Thing { + Foo(String), + Bar(String) +} + +impl $0core::convert::Into for Thing { + fn into(self) -> String { + match self { + Self::Foo(s) => s, + Self::Bar(s) => s + } + } +} +"#, + r#" +enum Thing { + Foo(String), + Bar(String) +} + +impl From for String { + fn from(val: Thing) -> Self { + match val { + Thing::Foo(s) => s, + Thing::Bar(s) => s + } + } +} +"#, + ) + } + + #[test] + fn convert_into_to_from_on_enum_with_lifetimes() { + check_assist( + convert_into_to_from, + r#" +//- minicore: from +enum Thing<'a> { + Foo(&'a str), + Bar(&'a str) +} + +impl<'a> $0core::convert::Into<&'a str> for Thing<'a> { + fn into(self) -> &'a str { + match self { + Self::Foo(s) => s, + Self::Bar(s) => s + } + } +} +"#, + r#" +enum Thing<'a> { + Foo(&'a str), + Bar(&'a str) +} + +impl<'a> From> for &'a str { + fn from(val: Thing<'a>) -> Self { + match val { + Thing::Foo(s) => s, + Thing::Bar(s) => s + } + } +} +"#, + ) + } + + #[test] + fn convert_into_to_from_works_on_references() { + check_assist( + convert_into_to_from, + r#" +//- minicore: from +struct Thing(String); + +impl $0core::convert::Into for &Thing { + fn into(self) -> Thing { + self.0.clone() + } +} +"#, + r#" +struct Thing(String); + +impl From<&Thing> for String { + fn from(val: &Thing) -> Self { + val.0.clone() + } +} +"#, + ) + } + + #[test] + fn convert_into_to_from_works_on_qualified_structs() { + check_assist( + convert_into_to_from, + r#" +//- minicore: from +mod things { + pub struct Thing(String); + pub struct BetterThing(String); +} + +impl $0core::convert::Into for &things::Thing { + fn into(self) -> Thing { + things::BetterThing(self.0.clone()) + } +} +"#, + r#" +mod things { + pub struct Thing(String); + pub struct BetterThing(String); +} + +impl From<&things::Thing> for things::BetterThing { + fn from(val: &things::Thing) -> Self { + things::BetterThing(val.0.clone()) + } +} +"#, + ) + } + + #[test] + fn convert_into_to_from_works_on_qualified_enums() { + check_assist( + convert_into_to_from, + r#" +//- minicore: from +mod things { + pub enum Thing { + A(String) + } + pub struct BetterThing { + B(String) + } +} + +impl $0core::convert::Into for &things::Thing { + fn into(self) -> Thing { + match self { + Self::A(s) => things::BetterThing::B(s) + } + } +} +"#, + r#" +mod things { + pub enum Thing { + A(String) + } + pub struct BetterThing { + B(String) + } +} + +impl From<&things::Thing> for things::BetterThing { + fn from(val: &things::Thing) -> Self { + match val { + things::Thing::A(s) => things::BetterThing::B(s) + } + } +} +"#, + ) + } + + #[test] + fn convert_into_to_from_not_applicable_on_any_trait_named_into() { + check_assist_not_applicable( + convert_into_to_from, + r#" +//- minicore: from +pub trait Into { + pub fn into(self) -> T; +} + +struct Thing { + a: String, +} + +impl $0Into for String { + fn into(self) -> Thing { + Thing { + a: self + } + } +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_iter_for_each_to_for.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_iter_for_each_to_for.rs new file mode 100644 index 000000000..2cf370c09 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_iter_for_each_to_for.rs @@ -0,0 +1,556 @@ +use hir::known; +use ide_db::famous_defs::FamousDefs; +use stdx::format_to; +use syntax::{ + ast::{self, edit_in_place::Indent, make, HasArgList, HasLoopBody}, + AstNode, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: convert_iter_for_each_to_for +// +// Converts an Iterator::for_each function into a for loop. +// +// ``` +// # //- minicore: iterators +// # use core::iter; +// fn main() { +// let iter = iter::repeat((9, 2)); +// iter.for_each$0(|(x, y)| { +// println!("x: {}, y: {}", x, y); +// }); +// } +// ``` +// -> +// ``` +// # use core::iter; +// fn main() { +// let iter = iter::repeat((9, 2)); +// for (x, y) in iter { +// println!("x: {}, y: {}", x, y); +// } +// } +// ``` +pub(crate) fn convert_iter_for_each_to_for( + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { + let method = ctx.find_node_at_offset::()?; + + let closure = match method.arg_list()?.args().next()? { + ast::Expr::ClosureExpr(expr) => expr, + _ => return None, + }; + + let (method, receiver) = validate_method_call_expr(ctx, method)?; + + let param_list = closure.param_list()?; + let param = param_list.params().next()?.pat()?; + let body = closure.body()?; + + let stmt = method.syntax().parent().and_then(ast::ExprStmt::cast); + let range = stmt.as_ref().map_or(method.syntax(), AstNode::syntax).text_range(); + + acc.add( + AssistId("convert_iter_for_each_to_for", AssistKind::RefactorRewrite), + "Replace this `Iterator::for_each` with a for loop", + range, + |builder| { + let indent = + stmt.as_ref().map_or_else(|| method.indent_level(), ast::ExprStmt::indent_level); + + let block = match body { + ast::Expr::BlockExpr(block) => block, + _ => make::block_expr(Vec::new(), Some(body)), + } + .clone_for_update(); + block.reindent_to(indent); + + let expr_for_loop = make::expr_for_loop(param, receiver, block); + builder.replace(range, expr_for_loop.to_string()) + }, + ) +} + +// Assist: convert_for_loop_with_for_each +// +// Converts a for loop into a for_each loop on the Iterator. +// +// ``` +// fn main() { +// let x = vec![1, 2, 3]; +// for$0 v in x { +// let y = v * 2; +// } +// } +// ``` +// -> +// ``` +// fn main() { +// let x = vec![1, 2, 3]; +// x.into_iter().for_each(|v| { +// let y = v * 2; +// }); +// } +// ``` +pub(crate) fn convert_for_loop_with_for_each( + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { + let for_loop = ctx.find_node_at_offset::()?; + let iterable = for_loop.iterable()?; + let pat = for_loop.pat()?; + let body = for_loop.loop_body()?; + if body.syntax().text_range().start() < ctx.offset() { + cov_mark::hit!(not_available_in_body); + return None; + } + + acc.add( + AssistId("convert_for_loop_with_for_each", AssistKind::RefactorRewrite), + "Replace this for loop with `Iterator::for_each`", + for_loop.syntax().text_range(), + |builder| { + let mut buf = String::new(); + + if let Some((expr_behind_ref, method)) = + is_ref_and_impls_iter_method(&ctx.sema, &iterable) + { + // We have either "for x in &col" and col implements a method called iter + // or "for x in &mut col" and col implements a method called iter_mut + format_to!(buf, "{}.{}()", expr_behind_ref, method); + } else if let ast::Expr::RangeExpr(..) = iterable { + // range expressions need to be parenthesized for the syntax to be correct + format_to!(buf, "({})", iterable); + } else if impls_core_iter(&ctx.sema, &iterable) { + format_to!(buf, "{}", iterable); + } else if let ast::Expr::RefExpr(_) = iterable { + format_to!(buf, "({}).into_iter()", iterable); + } else { + format_to!(buf, "{}.into_iter()", iterable); + } + + format_to!(buf, ".for_each(|{}| {});", pat, body); + + builder.replace(for_loop.syntax().text_range(), buf) + }, + ) +} + +/// If iterable is a reference where the expression behind the reference implements a method +/// returning an Iterator called iter or iter_mut (depending on the type of reference) then return +/// the expression behind the reference and the method name +fn is_ref_and_impls_iter_method( + sema: &hir::Semantics<'_, ide_db::RootDatabase>, + iterable: &ast::Expr, +) -> Option<(ast::Expr, hir::Name)> { + let ref_expr = match iterable { + ast::Expr::RefExpr(r) => r, + _ => return None, + }; + let wanted_method = if ref_expr.mut_token().is_some() { known::iter_mut } else { known::iter }; + let expr_behind_ref = ref_expr.expr()?; + let ty = sema.type_of_expr(&expr_behind_ref)?.adjusted(); + let scope = sema.scope(iterable.syntax())?; + let krate = scope.krate(); + let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?; + + let has_wanted_method = ty + .iterate_method_candidates( + sema.db, + &scope, + &scope.visible_traits().0, + None, + Some(&wanted_method), + |func| { + if func.ret_type(sema.db).impls_trait(sema.db, iter_trait, &[]) { + return Some(()); + } + None + }, + ) + .is_some(); + if !has_wanted_method { + return None; + } + + Some((expr_behind_ref, wanted_method)) +} + +/// Whether iterable implements core::Iterator +fn impls_core_iter(sema: &hir::Semantics<'_, ide_db::RootDatabase>, iterable: &ast::Expr) -> bool { + (|| { + let it_typ = sema.type_of_expr(iterable)?.adjusted(); + + let module = sema.scope(iterable.syntax())?.module(); + + let krate = module.krate(); + let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?; + cov_mark::hit!(test_already_impls_iterator); + Some(it_typ.impls_trait(sema.db, iter_trait, &[])) + })() + .unwrap_or(false) +} + +fn validate_method_call_expr( + ctx: &AssistContext<'_>, + expr: ast::MethodCallExpr, +) -> Option<(ast::Expr, ast::Expr)> { + let name_ref = expr.name_ref()?; + if !name_ref.syntax().text_range().contains_range(ctx.selection_trimmed()) { + cov_mark::hit!(test_for_each_not_applicable_invalid_cursor_pos); + return None; + } + if name_ref.text() != "for_each" { + return None; + } + + let sema = &ctx.sema; + + let receiver = expr.receiver()?; + let expr = ast::Expr::MethodCallExpr(expr); + + let it_type = sema.type_of_expr(&receiver)?.adjusted(); + let module = sema.scope(receiver.syntax())?.module(); + let krate = module.krate(); + + let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?; + it_type.impls_trait(sema.db, iter_trait, &[]).then(|| (expr, receiver)) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_for_each_in_method_stmt() { + check_assist( + convert_iter_for_each_to_for, + r#" +//- minicore: iterators +fn main() { + let it = core::iter::repeat(92); + it.$0for_each(|(x, y)| { + println!("x: {}, y: {}", x, y); + }); +} +"#, + r#" +fn main() { + let it = core::iter::repeat(92); + for (x, y) in it { + println!("x: {}, y: {}", x, y); + } +} +"#, + ) + } + + #[test] + fn test_for_each_in_method() { + check_assist( + convert_iter_for_each_to_for, + r#" +//- minicore: iterators +fn main() { + let it = core::iter::repeat(92); + it.$0for_each(|(x, y)| { + println!("x: {}, y: {}", x, y); + }) +} +"#, + r#" +fn main() { + let it = core::iter::repeat(92); + for (x, y) in it { + println!("x: {}, y: {}", x, y); + } +} +"#, + ) + } + + #[test] + fn test_for_each_without_braces_stmt() { + check_assist( + convert_iter_for_each_to_for, + r#" +//- minicore: iterators +fn main() { + let it = core::iter::repeat(92); + it.$0for_each(|(x, y)| println!("x: {}, y: {}", x, y)); +} +"#, + r#" +fn main() { + let it = core::iter::repeat(92); + for (x, y) in it { + println!("x: {}, y: {}", x, y) + } +} +"#, + ) + } + + #[test] + fn test_for_each_not_applicable() { + check_assist_not_applicable( + convert_iter_for_each_to_for, + r#" +//- minicore: iterators +fn main() { + ().$0for_each(|x| println!("{}", x)); +}"#, + ) + } + + #[test] + fn test_for_each_not_applicable_invalid_cursor_pos() { + cov_mark::check!(test_for_each_not_applicable_invalid_cursor_pos); + check_assist_not_applicable( + convert_iter_for_each_to_for, + r#" +//- minicore: iterators +fn main() { + core::iter::repeat(92).for_each(|(x, y)| $0println!("x: {}, y: {}", x, y)); +}"#, + ) + } + + #[test] + fn each_to_for_not_for() { + check_assist_not_applicable( + convert_for_loop_with_for_each, + r" +let mut x = vec![1, 2, 3]; +x.iter_mut().$0for_each(|v| *v *= 2); + ", + ) + } + + #[test] + fn each_to_for_simple_for() { + check_assist( + convert_for_loop_with_for_each, + r" +fn main() { + let x = vec![1, 2, 3]; + for $0v in x { + v *= 2; + } +}", + r" +fn main() { + let x = vec![1, 2, 3]; + x.into_iter().for_each(|v| { + v *= 2; + }); +}", + ) + } + + #[test] + fn each_to_for_for_in_range() { + check_assist( + convert_for_loop_with_for_each, + r#" +//- minicore: range, iterators +impl core::iter::Iterator for core::ops::Range { + type Item = T; + + fn next(&mut self) -> Option { + None + } +} + +fn main() { + for $0x in 0..92 { + print!("{}", x); + } +}"#, + r#" +impl core::iter::Iterator for core::ops::Range { + type Item = T; + + fn next(&mut self) -> Option { + None + } +} + +fn main() { + (0..92).for_each(|x| { + print!("{}", x); + }); +}"#, + ) + } + + #[test] + fn each_to_for_not_available_in_body() { + cov_mark::check!(not_available_in_body); + check_assist_not_applicable( + convert_for_loop_with_for_each, + r" +fn main() { + let x = vec![1, 2, 3]; + for v in x { + $0v *= 2; + } +}", + ) + } + + #[test] + fn each_to_for_for_borrowed() { + check_assist( + convert_for_loop_with_for_each, + r#" +//- minicore: iterators +use core::iter::{Repeat, repeat}; + +struct S; +impl S { + fn iter(&self) -> Repeat { repeat(92) } + fn iter_mut(&mut self) -> Repeat { repeat(92) } +} + +fn main() { + let x = S; + for $0v in &x { + let a = v * 2; + } +} +"#, + r#" +use core::iter::{Repeat, repeat}; + +struct S; +impl S { + fn iter(&self) -> Repeat { repeat(92) } + fn iter_mut(&mut self) -> Repeat { repeat(92) } +} + +fn main() { + let x = S; + x.iter().for_each(|v| { + let a = v * 2; + }); +} +"#, + ) + } + + #[test] + fn each_to_for_for_borrowed_no_iter_method() { + check_assist( + convert_for_loop_with_for_each, + r" +struct NoIterMethod; +fn main() { + let x = NoIterMethod; + for $0v in &x { + let a = v * 2; + } +} +", + r" +struct NoIterMethod; +fn main() { + let x = NoIterMethod; + (&x).into_iter().for_each(|v| { + let a = v * 2; + }); +} +", + ) + } + + #[test] + fn each_to_for_for_borrowed_mut() { + check_assist( + convert_for_loop_with_for_each, + r#" +//- minicore: iterators +use core::iter::{Repeat, repeat}; + +struct S; +impl S { + fn iter(&self) -> Repeat { repeat(92) } + fn iter_mut(&mut self) -> Repeat { repeat(92) } +} + +fn main() { + let x = S; + for $0v in &mut x { + let a = v * 2; + } +} +"#, + r#" +use core::iter::{Repeat, repeat}; + +struct S; +impl S { + fn iter(&self) -> Repeat { repeat(92) } + fn iter_mut(&mut self) -> Repeat { repeat(92) } +} + +fn main() { + let x = S; + x.iter_mut().for_each(|v| { + let a = v * 2; + }); +} +"#, + ) + } + + #[test] + fn each_to_for_for_borrowed_mut_behind_var() { + check_assist( + convert_for_loop_with_for_each, + r" +fn main() { + let x = vec![1, 2, 3]; + let y = &mut x; + for $0v in y { + *v *= 2; + } +}", + r" +fn main() { + let x = vec![1, 2, 3]; + let y = &mut x; + y.into_iter().for_each(|v| { + *v *= 2; + }); +}", + ) + } + + #[test] + fn each_to_for_already_impls_iterator() { + cov_mark::check!(test_already_impls_iterator); + check_assist( + convert_for_loop_with_for_each, + r#" +//- minicore: iterators +fn main() { + for$0 a in core::iter::repeat(92).take(1) { + println!("{}", a); + } +} +"#, + r#" +fn main() { + core::iter::repeat(92).take(1).for_each(|a| { + println!("{}", a); + }); +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_let_else_to_match.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_let_else_to_match.rs new file mode 100644 index 000000000..00095de25 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_let_else_to_match.rs @@ -0,0 +1,497 @@ +use hir::Semantics; +use ide_db::RootDatabase; +use syntax::ast::{edit::AstNodeEdit, AstNode, HasName, LetStmt, Name, Pat}; +use syntax::T; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +/// Gets a list of binders in a pattern, and whether they are mut. +fn binders_in_pat( + acc: &mut Vec<(Name, bool)>, + pat: &Pat, + sem: &Semantics<'_, RootDatabase>, +) -> Option<()> { + use Pat::*; + match pat { + IdentPat(p) => { + let ident = p.name()?; + let ismut = p.ref_token().is_none() && p.mut_token().is_some(); + // check for const reference + if sem.resolve_bind_pat_to_const(p).is_none() { + acc.push((ident, ismut)); + } + if let Some(inner) = p.pat() { + binders_in_pat(acc, &inner, sem)?; + } + Some(()) + } + BoxPat(p) => p.pat().and_then(|p| binders_in_pat(acc, &p, sem)), + RestPat(_) | LiteralPat(_) | PathPat(_) | WildcardPat(_) | ConstBlockPat(_) => Some(()), + OrPat(p) => { + for p in p.pats() { + binders_in_pat(acc, &p, sem)?; + } + Some(()) + } + ParenPat(p) => p.pat().and_then(|p| binders_in_pat(acc, &p, sem)), + RangePat(p) => { + if let Some(st) = p.start() { + binders_in_pat(acc, &st, sem)? + } + if let Some(ed) = p.end() { + binders_in_pat(acc, &ed, sem)? + } + Some(()) + } + RecordPat(p) => { + for f in p.record_pat_field_list()?.fields() { + let pat = f.pat()?; + binders_in_pat(acc, &pat, sem)?; + } + Some(()) + } + RefPat(p) => p.pat().and_then(|p| binders_in_pat(acc, &p, sem)), + SlicePat(p) => { + for p in p.pats() { + binders_in_pat(acc, &p, sem)?; + } + Some(()) + } + TuplePat(p) => { + for p in p.fields() { + binders_in_pat(acc, &p, sem)?; + } + Some(()) + } + TupleStructPat(p) => { + for p in p.fields() { + binders_in_pat(acc, &p, sem)?; + } + Some(()) + } + // don't support macro pat yet + MacroPat(_) => None, + } +} + +fn binders_to_str(binders: &[(Name, bool)], addmut: bool) -> String { + let vars = binders + .iter() + .map( + |(ident, ismut)| { + if *ismut && addmut { + format!("mut {}", ident) + } else { + ident.to_string() + } + }, + ) + .collect::>() + .join(", "); + if binders.is_empty() { + String::from("{}") + } else if binders.len() == 1 { + vars + } else { + format!("({})", vars) + } +} + +// Assist: convert_let_else_to_match +// +// Converts let-else statement to let statement and match expression. +// +// ``` +// fn main() { +// let Ok(mut x) = f() else$0 { return }; +// } +// ``` +// -> +// ``` +// fn main() { +// let mut x = match f() { +// Ok(x) => x, +// _ => return, +// }; +// } +// ``` +pub(crate) fn convert_let_else_to_match(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + // should focus on else token to trigger + let else_token = ctx.find_token_syntax_at_offset(T![else])?; + let let_stmt = LetStmt::cast(else_token.parent()?.parent()?)?; + let let_else_block = let_stmt.let_else()?.block_expr()?; + let let_init = let_stmt.initializer()?; + if let_stmt.ty().is_some() { + // don't support let with type annotation + return None; + } + let pat = let_stmt.pat()?; + let mut binders = Vec::new(); + binders_in_pat(&mut binders, &pat, &ctx.sema)?; + + let target = let_stmt.syntax().text_range(); + acc.add( + AssistId("convert_let_else_to_match", AssistKind::RefactorRewrite), + "Convert let-else to let and match", + target, + |edit| { + let indent_level = let_stmt.indent_level().0 as usize; + let indent = " ".repeat(indent_level); + let indent1 = " ".repeat(indent_level + 1); + + let binders_str = binders_to_str(&binders, false); + let binders_str_mut = binders_to_str(&binders, true); + + let init_expr = let_init.syntax().text(); + let mut pat_no_mut = pat.syntax().text().to_string(); + // remove the mut from the pattern + for (b, ismut) in binders.iter() { + if *ismut { + pat_no_mut = pat_no_mut.replace(&format!("mut {b}"), &b.to_string()); + } + } + + let only_expr = let_else_block.statements().next().is_none(); + let branch2 = match &let_else_block.tail_expr() { + Some(tail) if only_expr => format!("{},", tail.syntax().text()), + _ => let_else_block.syntax().text().to_string(), + }; + let replace = if binders.is_empty() { + format!( + "match {init_expr} {{ +{indent1}{pat_no_mut} => {binders_str} +{indent1}_ => {branch2} +{indent}}}" + ) + } else { + format!( + "let {binders_str_mut} = match {init_expr} {{ +{indent1}{pat_no_mut} => {binders_str}, +{indent1}_ => {branch2} +{indent}}};" + ) + }; + edit.replace(target, replace); + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + #[test] + fn convert_let_else_to_match_no_type_let() { + check_assist_not_applicable( + convert_let_else_to_match, + r#" +fn main() { + let 1: u32 = v.iter().sum() else$0 { return }; +}"#, + ); + } + + #[test] + fn convert_let_else_to_match_on_else() { + check_assist_not_applicable( + convert_let_else_to_match, + r#" +fn main() { + let Ok(x) = f() else {$0 return }; +} + "#, + ); + } + + #[test] + fn convert_let_else_to_match_no_macropat() { + check_assist_not_applicable( + convert_let_else_to_match, + r#" +fn main() { + let m!() = g() else$0 { return }; +} + "#, + ); + } + + #[test] + fn convert_let_else_to_match_target() { + check_assist_target( + convert_let_else_to_match, + r" +fn main() { + let Ok(x) = f() else$0 { continue }; +}", + "let Ok(x) = f() else { continue };", + ); + } + + #[test] + fn convert_let_else_to_match_basic() { + check_assist( + convert_let_else_to_match, + r" +fn main() { + let Ok(x) = f() else$0 { continue }; +}", + r" +fn main() { + let x = match f() { + Ok(x) => x, + _ => continue, + }; +}", + ); + } + + #[test] + fn convert_let_else_to_match_const_ref() { + check_assist( + convert_let_else_to_match, + r" +enum Option { + Some(T), + None, +} +use Option::*; +fn main() { + let None = f() el$0se { continue }; +}", + r" +enum Option { + Some(T), + None, +} +use Option::*; +fn main() { + match f() { + None => {} + _ => continue, + } +}", + ); + } + + #[test] + fn convert_let_else_to_match_const_ref_const() { + check_assist( + convert_let_else_to_match, + r" +const NEG1: i32 = -1; +fn main() { + let NEG1 = f() el$0se { continue }; +}", + r" +const NEG1: i32 = -1; +fn main() { + match f() { + NEG1 => {} + _ => continue, + } +}", + ); + } + + #[test] + fn convert_let_else_to_match_mut() { + check_assist( + convert_let_else_to_match, + r" +fn main() { + let Ok(mut x) = f() el$0se { continue }; +}", + r" +fn main() { + let mut x = match f() { + Ok(x) => x, + _ => continue, + }; +}", + ); + } + + #[test] + fn convert_let_else_to_match_multi_binders() { + check_assist( + convert_let_else_to_match, + r#" +fn main() { + let ControlFlow::Break((x, "tag", y, ..)) = f() else$0 { g(); return }; +}"#, + r#" +fn main() { + let (x, y) = match f() { + ControlFlow::Break((x, "tag", y, ..)) => (x, y), + _ => { g(); return } + }; +}"#, + ); + } + + #[test] + fn convert_let_else_to_match_slice() { + check_assist( + convert_let_else_to_match, + r#" +fn main() { + let [one, 1001, other] = f() else$0 { break }; +}"#, + r#" +fn main() { + let (one, other) = match f() { + [one, 1001, other] => (one, other), + _ => break, + }; +}"#, + ); + } + + #[test] + fn convert_let_else_to_match_struct() { + check_assist( + convert_let_else_to_match, + r#" +fn main() { + let [Struct { inner: Some(it) }, 1001, other] = f() else$0 { break }; +}"#, + r#" +fn main() { + let (it, other) = match f() { + [Struct { inner: Some(it) }, 1001, other] => (it, other), + _ => break, + }; +}"#, + ); + } + + #[test] + fn convert_let_else_to_match_struct_ident_pat() { + check_assist( + convert_let_else_to_match, + r#" +fn main() { + let [Struct { inner }, 1001, other] = f() else$0 { break }; +}"#, + r#" +fn main() { + let (inner, other) = match f() { + [Struct { inner }, 1001, other] => (inner, other), + _ => break, + }; +}"#, + ); + } + + #[test] + fn convert_let_else_to_match_no_binder() { + check_assist( + convert_let_else_to_match, + r#" +fn main() { + let (8 | 9) = f() else$0 { panic!() }; +}"#, + r#" +fn main() { + match f() { + (8 | 9) => {} + _ => panic!(), + } +}"#, + ); + } + + #[test] + fn convert_let_else_to_match_range() { + check_assist( + convert_let_else_to_match, + r#" +fn main() { + let 1.. = f() e$0lse { return }; +}"#, + r#" +fn main() { + match f() { + 1.. => {} + _ => return, + } +}"#, + ); + } + + #[test] + fn convert_let_else_to_match_refpat() { + check_assist( + convert_let_else_to_match, + r#" +fn main() { + let Ok(&mut x) = f(&mut 0) else$0 { return }; +}"#, + r#" +fn main() { + let x = match f(&mut 0) { + Ok(&mut x) => x, + _ => return, + }; +}"#, + ); + } + + #[test] + fn convert_let_else_to_match_refmut() { + check_assist( + convert_let_else_to_match, + r#" +fn main() { + let Ok(ref mut x) = f() else$0 { return }; +}"#, + r#" +fn main() { + let x = match f() { + Ok(ref mut x) => x, + _ => return, + }; +}"#, + ); + } + + #[test] + fn convert_let_else_to_match_atpat() { + check_assist( + convert_let_else_to_match, + r#" +fn main() { + let out @ Ok(ins) = f() else$0 { return }; +}"#, + r#" +fn main() { + let (out, ins) = match f() { + out @ Ok(ins) => (out, ins), + _ => return, + }; +}"#, + ); + } + + #[test] + fn convert_let_else_to_match_complex_init() { + check_assist( + convert_let_else_to_match, + r#" +fn main() { + let v = vec![1, 2, 3]; + let &[mut x, y, ..] = &v.iter().collect::>()[..] else$0 { return }; +}"#, + r#" +fn main() { + let v = vec![1, 2, 3]; + let (mut x, y) = match &v.iter().collect::>()[..] { + &[x, y, ..] => (x, y), + _ => return, + }; +}"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs new file mode 100644 index 000000000..cb75619ce --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs @@ -0,0 +1,574 @@ +use std::iter::once; + +use ide_db::syntax_helpers::node_ext::{is_pattern_cond, single_let}; +use syntax::{ + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + make, + }, + ted, AstNode, + SyntaxKind::{FN, LOOP_EXPR, WHILE_EXPR, WHITESPACE}, + T, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + utils::invert_boolean_expression, + AssistId, AssistKind, +}; + +// Assist: convert_to_guarded_return +// +// Replace a large conditional with a guarded return. +// +// ``` +// fn main() { +// $0if cond { +// foo(); +// bar(); +// } +// } +// ``` +// -> +// ``` +// fn main() { +// if !cond { +// return; +// } +// foo(); +// bar(); +// } +// ``` +pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let if_expr: ast::IfExpr = ctx.find_node_at_offset()?; + if if_expr.else_branch().is_some() { + return None; + } + + let cond = if_expr.condition()?; + + // Check if there is an IfLet that we can handle. + let (if_let_pat, cond_expr) = if is_pattern_cond(cond.clone()) { + let let_ = single_let(cond)?; + match let_.pat() { + Some(ast::Pat::TupleStructPat(pat)) if pat.fields().count() == 1 => { + let path = pat.path()?; + if path.qualifier().is_some() { + return None; + } + + let bound_ident = pat.fields().next().unwrap(); + if !ast::IdentPat::can_cast(bound_ident.syntax().kind()) { + return None; + } + + (Some((path, bound_ident)), let_.expr()?) + } + _ => return None, // Unsupported IfLet. + } + } else { + (None, cond) + }; + + let then_block = if_expr.then_branch()?; + let then_block = then_block.stmt_list()?; + + let parent_block = if_expr.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?; + + if parent_block.tail_expr()? != if_expr.clone().into() { + return None; + } + + // FIXME: This relies on untyped syntax tree and casts to much. It should be + // rewritten to use strongly-typed APIs. + + // check for early return and continue + let first_in_then_block = then_block.syntax().first_child()?; + if ast::ReturnExpr::can_cast(first_in_then_block.kind()) + || ast::ContinueExpr::can_cast(first_in_then_block.kind()) + || first_in_then_block + .children() + .any(|x| ast::ReturnExpr::can_cast(x.kind()) || ast::ContinueExpr::can_cast(x.kind())) + { + return None; + } + + let parent_container = parent_block.syntax().parent()?; + + let early_expression: ast::Expr = match parent_container.kind() { + WHILE_EXPR | LOOP_EXPR => make::expr_continue(None), + FN => make::expr_return(None), + _ => return None, + }; + + if then_block.syntax().first_child_or_token().map(|t| t.kind() == T!['{']).is_none() { + return None; + } + + then_block.syntax().last_child_or_token().filter(|t| t.kind() == T!['}'])?; + + let target = if_expr.syntax().text_range(); + acc.add( + AssistId("convert_to_guarded_return", AssistKind::RefactorRewrite), + "Convert to guarded return", + target, + |edit| { + let if_expr = edit.make_mut(if_expr); + let if_indent_level = IndentLevel::from_node(if_expr.syntax()); + let replacement = match if_let_pat { + None => { + // If. + let new_expr = { + let then_branch = + make::block_expr(once(make::expr_stmt(early_expression).into()), None); + let cond = invert_boolean_expression(cond_expr); + make::expr_if(cond, then_branch, None).indent(if_indent_level) + }; + new_expr.syntax().clone_for_update() + } + Some((path, bound_ident)) => { + // If-let. + let match_expr = { + let happy_arm = { + let pat = make::tuple_struct_pat( + path, + once(make::ext::simple_ident_pat(make::name("it")).into()), + ); + let expr = { + let path = make::ext::ident_path("it"); + make::expr_path(path) + }; + make::match_arm(once(pat.into()), None, expr) + }; + + let sad_arm = make::match_arm( + // FIXME: would be cool to use `None` or `Err(_)` if appropriate + once(make::wildcard_pat().into()), + None, + early_expression, + ); + + make::expr_match(cond_expr, make::match_arm_list(vec![happy_arm, sad_arm])) + }; + + let let_stmt = make::let_stmt(bound_ident, None, Some(match_expr)); + let let_stmt = let_stmt.indent(if_indent_level); + let_stmt.syntax().clone_for_update() + } + }; + + let then_block_items = then_block.dedent(IndentLevel(1)).clone_for_update(); + + let end_of_then = then_block_items.syntax().last_child_or_token().unwrap(); + let end_of_then = + if end_of_then.prev_sibling_or_token().map(|n| n.kind()) == Some(WHITESPACE) { + end_of_then.prev_sibling_or_token().unwrap() + } else { + end_of_then + }; + + let then_statements = replacement + .children_with_tokens() + .chain( + then_block_items + .syntax() + .children_with_tokens() + .skip(1) + .take_while(|i| *i != end_of_then), + ) + .collect(); + + ted::replace_with_many(if_expr.syntax(), then_statements) + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn convert_inside_fn() { + check_assist( + convert_to_guarded_return, + r#" +fn main() { + bar(); + if$0 true { + foo(); + + // comment + bar(); + } +} +"#, + r#" +fn main() { + bar(); + if false { + return; + } + foo(); + + // comment + bar(); +} +"#, + ); + } + + #[test] + fn convert_let_inside_fn() { + check_assist( + convert_to_guarded_return, + r#" +fn main(n: Option) { + bar(); + if$0 let Some(n) = n { + foo(n); + + // comment + bar(); + } +} +"#, + r#" +fn main(n: Option) { + bar(); + let n = match n { + Some(it) => it, + _ => return, + }; + foo(n); + + // comment + bar(); +} +"#, + ); + } + + #[test] + fn convert_if_let_result() { + check_assist( + convert_to_guarded_return, + r#" +fn main() { + if$0 let Ok(x) = Err(92) { + foo(x); + } +} +"#, + r#" +fn main() { + let x = match Err(92) { + Ok(it) => it, + _ => return, + }; + foo(x); +} +"#, + ); + } + + #[test] + fn convert_let_ok_inside_fn() { + check_assist( + convert_to_guarded_return, + r#" +fn main(n: Option) { + bar(); + if$0 let Some(n) = n { + foo(n); + + // comment + bar(); + } +} +"#, + r#" +fn main(n: Option) { + bar(); + let n = match n { + Some(it) => it, + _ => return, + }; + foo(n); + + // comment + bar(); +} +"#, + ); + } + + #[test] + fn convert_let_mut_ok_inside_fn() { + check_assist( + convert_to_guarded_return, + r#" +fn main(n: Option) { + bar(); + if$0 let Some(mut n) = n { + foo(n); + + // comment + bar(); + } +} +"#, + r#" +fn main(n: Option) { + bar(); + let mut n = match n { + Some(it) => it, + _ => return, + }; + foo(n); + + // comment + bar(); +} +"#, + ); + } + + #[test] + fn convert_let_ref_ok_inside_fn() { + check_assist( + convert_to_guarded_return, + r#" +fn main(n: Option<&str>) { + bar(); + if$0 let Some(ref n) = n { + foo(n); + + // comment + bar(); + } +} +"#, + r#" +fn main(n: Option<&str>) { + bar(); + let ref n = match n { + Some(it) => it, + _ => return, + }; + foo(n); + + // comment + bar(); +} +"#, + ); + } + + #[test] + fn convert_inside_while() { + check_assist( + convert_to_guarded_return, + r#" +fn main() { + while true { + if$0 true { + foo(); + bar(); + } + } +} +"#, + r#" +fn main() { + while true { + if false { + continue; + } + foo(); + bar(); + } +} +"#, + ); + } + + #[test] + fn convert_let_inside_while() { + check_assist( + convert_to_guarded_return, + r#" +fn main() { + while true { + if$0 let Some(n) = n { + foo(n); + bar(); + } + } +} +"#, + r#" +fn main() { + while true { + let n = match n { + Some(it) => it, + _ => continue, + }; + foo(n); + bar(); + } +} +"#, + ); + } + + #[test] + fn convert_inside_loop() { + check_assist( + convert_to_guarded_return, + r#" +fn main() { + loop { + if$0 true { + foo(); + bar(); + } + } +} +"#, + r#" +fn main() { + loop { + if false { + continue; + } + foo(); + bar(); + } +} +"#, + ); + } + + #[test] + fn convert_let_inside_loop() { + check_assist( + convert_to_guarded_return, + r#" +fn main() { + loop { + if$0 let Some(n) = n { + foo(n); + bar(); + } + } +} +"#, + r#" +fn main() { + loop { + let n = match n { + Some(it) => it, + _ => continue, + }; + foo(n); + bar(); + } +} +"#, + ); + } + + #[test] + fn ignore_already_converted_if() { + check_assist_not_applicable( + convert_to_guarded_return, + r#" +fn main() { + if$0 true { + return; + } +} +"#, + ); + } + + #[test] + fn ignore_already_converted_loop() { + check_assist_not_applicable( + convert_to_guarded_return, + r#" +fn main() { + loop { + if$0 true { + continue; + } + } +} +"#, + ); + } + + #[test] + fn ignore_return() { + check_assist_not_applicable( + convert_to_guarded_return, + r#" +fn main() { + if$0 true { + return + } +} +"#, + ); + } + + #[test] + fn ignore_else_branch() { + check_assist_not_applicable( + convert_to_guarded_return, + r#" +fn main() { + if$0 true { + foo(); + } else { + bar() + } +} +"#, + ); + } + + #[test] + fn ignore_statements_aftert_if() { + check_assist_not_applicable( + convert_to_guarded_return, + r#" +fn main() { + if$0 true { + foo(); + } + bar(); +} +"#, + ); + } + + #[test] + fn ignore_statements_inside_if() { + check_assist_not_applicable( + convert_to_guarded_return, + r#" +fn main() { + if false { + if$0 true { + foo(); + } + } +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_tuple_struct_to_named_struct.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_tuple_struct_to_named_struct.rs new file mode 100644 index 000000000..4ab8e93a2 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_tuple_struct_to_named_struct.rs @@ -0,0 +1,840 @@ +use either::Either; +use ide_db::defs::{Definition, NameRefClass}; +use syntax::{ + ast::{self, AstNode, HasGenericParams, HasVisibility}, + match_ast, SyntaxNode, +}; + +use crate::{assist_context::AssistBuilder, AssistContext, AssistId, AssistKind, Assists}; + +// Assist: convert_tuple_struct_to_named_struct +// +// Converts tuple struct to struct with named fields, and analogously for tuple enum variants. +// +// ``` +// struct Point$0(f32, f32); +// +// impl Point { +// pub fn new(x: f32, y: f32) -> Self { +// Point(x, y) +// } +// +// pub fn x(&self) -> f32 { +// self.0 +// } +// +// pub fn y(&self) -> f32 { +// self.1 +// } +// } +// ``` +// -> +// ``` +// struct Point { field1: f32, field2: f32 } +// +// impl Point { +// pub fn new(x: f32, y: f32) -> Self { +// Point { field1: x, field2: y } +// } +// +// pub fn x(&self) -> f32 { +// self.field1 +// } +// +// pub fn y(&self) -> f32 { +// self.field2 +// } +// } +// ``` +pub(crate) fn convert_tuple_struct_to_named_struct( + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { + let strukt = ctx + .find_node_at_offset::() + .map(Either::Left) + .or_else(|| ctx.find_node_at_offset::().map(Either::Right))?; + let field_list = strukt.as_ref().either(|s| s.field_list(), |v| v.field_list())?; + let tuple_fields = match field_list { + ast::FieldList::TupleFieldList(it) => it, + ast::FieldList::RecordFieldList(_) => return None, + }; + let strukt_def = match &strukt { + Either::Left(s) => Either::Left(ctx.sema.to_def(s)?), + Either::Right(v) => Either::Right(ctx.sema.to_def(v)?), + }; + let target = strukt.as_ref().either(|s| s.syntax(), |v| v.syntax()).text_range(); + + acc.add( + AssistId("convert_tuple_struct_to_named_struct", AssistKind::RefactorRewrite), + "Convert to named struct", + target, + |edit| { + let names = generate_names(tuple_fields.fields()); + edit_field_references(ctx, edit, tuple_fields.fields(), &names); + edit_struct_references(ctx, edit, strukt_def, &names); + edit_struct_def(ctx, edit, &strukt, tuple_fields, names); + }, + ) +} + +fn edit_struct_def( + ctx: &AssistContext<'_>, + edit: &mut AssistBuilder, + strukt: &Either, + tuple_fields: ast::TupleFieldList, + names: Vec, +) { + let record_fields = tuple_fields + .fields() + .zip(names) + .filter_map(|(f, name)| Some(ast::make::record_field(f.visibility(), name, f.ty()?))); + let record_fields = ast::make::record_field_list(record_fields); + let tuple_fields_text_range = tuple_fields.syntax().text_range(); + + edit.edit_file(ctx.file_id()); + + if let Either::Left(strukt) = strukt { + if let Some(w) = strukt.where_clause() { + edit.delete(w.syntax().text_range()); + edit.insert( + tuple_fields_text_range.start(), + ast::make::tokens::single_newline().text(), + ); + edit.insert(tuple_fields_text_range.start(), w.syntax().text()); + edit.insert(tuple_fields_text_range.start(), ","); + edit.insert( + tuple_fields_text_range.start(), + ast::make::tokens::single_newline().text(), + ); + } else { + edit.insert(tuple_fields_text_range.start(), ast::make::tokens::single_space().text()); + } + if let Some(t) = strukt.semicolon_token() { + edit.delete(t.text_range()); + } + } else { + edit.insert(tuple_fields_text_range.start(), ast::make::tokens::single_space().text()); + } + + edit.replace(tuple_fields_text_range, record_fields.to_string()); +} + +fn edit_struct_references( + ctx: &AssistContext<'_>, + edit: &mut AssistBuilder, + strukt: Either, + names: &[ast::Name], +) { + let strukt_def = match strukt { + Either::Left(s) => Definition::Adt(hir::Adt::Struct(s)), + Either::Right(v) => Definition::Variant(v), + }; + let usages = strukt_def.usages(&ctx.sema).include_self_refs().all(); + + let edit_node = |edit: &mut AssistBuilder, node: SyntaxNode| -> Option<()> { + match_ast! { + match node { + ast::TupleStructPat(tuple_struct_pat) => { + edit.replace( + tuple_struct_pat.syntax().text_range(), + ast::make::record_pat_with_fields( + tuple_struct_pat.path()?, + ast::make::record_pat_field_list(tuple_struct_pat.fields().zip(names).map( + |(pat, name)| { + ast::make::record_pat_field( + ast::make::name_ref(&name.to_string()), + pat, + ) + }, + )), + ) + .to_string(), + ); + }, + // for tuple struct creations like Foo(42) + ast::CallExpr(call_expr) => { + let path = call_expr.syntax().descendants().find_map(ast::PathExpr::cast).and_then(|expr| expr.path())?; + + // this also includes method calls like Foo::new(42), we should skip them + if let Some(name_ref) = path.segment().and_then(|s| s.name_ref()) { + match NameRefClass::classify(&ctx.sema, &name_ref) { + Some(NameRefClass::Definition(Definition::SelfType(_))) => {}, + Some(NameRefClass::Definition(def)) if def == strukt_def => {}, + _ => return None, + }; + } + + let arg_list = call_expr.syntax().descendants().find_map(ast::ArgList::cast)?; + + edit.replace( + call_expr.syntax().text_range(), + ast::make::record_expr( + path, + ast::make::record_expr_field_list(arg_list.args().zip(names).map( + |(expr, name)| { + ast::make::record_expr_field( + ast::make::name_ref(&name.to_string()), + Some(expr), + ) + }, + )), + ) + .to_string(), + ); + }, + _ => return None, + } + } + Some(()) + }; + + for (file_id, refs) in usages { + edit.edit_file(file_id); + for r in refs { + for node in r.name.syntax().ancestors() { + if edit_node(edit, node).is_some() { + break; + } + } + } + } +} + +fn edit_field_references( + ctx: &AssistContext<'_>, + edit: &mut AssistBuilder, + fields: impl Iterator, + names: &[ast::Name], +) { + for (field, name) in fields.zip(names) { + let field = match ctx.sema.to_def(&field) { + Some(it) => it, + None => continue, + }; + let def = Definition::Field(field); + let usages = def.usages(&ctx.sema).all(); + for (file_id, refs) in usages { + edit.edit_file(file_id); + for r in refs { + if let Some(name_ref) = r.name.as_name_ref() { + edit.replace(name_ref.syntax().text_range(), name.text()); + } + } + } + } +} + +fn generate_names(fields: impl Iterator) -> Vec { + fields.enumerate().map(|(i, _)| ast::make::name(&format!("field{}", i + 1))).collect() +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn not_applicable_other_than_tuple_struct() { + check_assist_not_applicable( + convert_tuple_struct_to_named_struct, + r#"struct Foo$0 { bar: u32 };"#, + ); + check_assist_not_applicable(convert_tuple_struct_to_named_struct, r#"struct Foo$0;"#); + } + + #[test] + fn convert_simple_struct() { + check_assist( + convert_tuple_struct_to_named_struct, + r#" +struct Inner; +struct A$0(Inner); + +impl A { + fn new(inner: Inner) -> A { + A(inner) + } + + fn new_with_default() -> A { + A::new(Inner) + } + + fn into_inner(self) -> Inner { + self.0 + } +}"#, + r#" +struct Inner; +struct A { field1: Inner } + +impl A { + fn new(inner: Inner) -> A { + A { field1: inner } + } + + fn new_with_default() -> A { + A::new(Inner) + } + + fn into_inner(self) -> Inner { + self.field1 + } +}"#, + ); + } + + #[test] + fn convert_struct_referenced_via_self_kw() { + check_assist( + convert_tuple_struct_to_named_struct, + r#" +struct Inner; +struct A$0(Inner); + +impl A { + fn new(inner: Inner) -> Self { + Self(inner) + } + + fn new_with_default() -> Self { + Self::new(Inner) + } + + fn into_inner(self) -> Inner { + self.0 + } +}"#, + r#" +struct Inner; +struct A { field1: Inner } + +impl A { + fn new(inner: Inner) -> Self { + Self { field1: inner } + } + + fn new_with_default() -> Self { + Self::new(Inner) + } + + fn into_inner(self) -> Inner { + self.field1 + } +}"#, + ); + } + + #[test] + fn convert_destructured_struct() { + check_assist( + convert_tuple_struct_to_named_struct, + r#" +struct Inner; +struct A$0(Inner); + +impl A { + fn into_inner(self) -> Inner { + let A(first) = self; + first + } + + fn into_inner_via_self(self) -> Inner { + let Self(first) = self; + first + } +}"#, + r#" +struct Inner; +struct A { field1: Inner } + +impl A { + fn into_inner(self) -> Inner { + let A { field1: first } = self; + first + } + + fn into_inner_via_self(self) -> Inner { + let Self { field1: first } = self; + first + } +}"#, + ); + } + + #[test] + fn convert_struct_with_visibility() { + check_assist( + convert_tuple_struct_to_named_struct, + r#" +struct A$0(pub u32, pub(crate) u64); + +impl A { + fn new() -> A { + A(42, 42) + } + + fn into_first(self) -> u32 { + self.0 + } + + fn into_second(self) -> u64 { + self.1 + } +}"#, + r#" +struct A { pub field1: u32, pub(crate) field2: u64 } + +impl A { + fn new() -> A { + A { field1: 42, field2: 42 } + } + + fn into_first(self) -> u32 { + self.field1 + } + + fn into_second(self) -> u64 { + self.field2 + } +}"#, + ); + } + + #[test] + fn convert_struct_with_wrapped_references() { + check_assist( + convert_tuple_struct_to_named_struct, + r#" +struct Inner$0(u32); +struct Outer(Inner); + +impl Outer { + fn new() -> Self { + Self(Inner(42)) + } + + fn into_inner(self) -> u32 { + (self.0).0 + } + + fn into_inner_destructed(self) -> u32 { + let Outer(Inner(x)) = self; + x + } +}"#, + r#" +struct Inner { field1: u32 } +struct Outer(Inner); + +impl Outer { + fn new() -> Self { + Self(Inner { field1: 42 }) + } + + fn into_inner(self) -> u32 { + (self.0).field1 + } + + fn into_inner_destructed(self) -> u32 { + let Outer(Inner { field1: x }) = self; + x + } +}"#, + ); + + check_assist( + convert_tuple_struct_to_named_struct, + r#" +struct Inner(u32); +struct Outer$0(Inner); + +impl Outer { + fn new() -> Self { + Self(Inner(42)) + } + + fn into_inner(self) -> u32 { + (self.0).0 + } + + fn into_inner_destructed(self) -> u32 { + let Outer(Inner(x)) = self; + x + } +}"#, + r#" +struct Inner(u32); +struct Outer { field1: Inner } + +impl Outer { + fn new() -> Self { + Self { field1: Inner(42) } + } + + fn into_inner(self) -> u32 { + (self.field1).0 + } + + fn into_inner_destructed(self) -> u32 { + let Outer { field1: Inner(x) } = self; + x + } +}"#, + ); + } + + #[test] + fn convert_struct_with_multi_file_references() { + check_assist( + convert_tuple_struct_to_named_struct, + r#" +//- /main.rs +struct Inner; +struct A$0(Inner); + +mod foo; + +//- /foo.rs +use crate::{A, Inner}; +fn f() { + let a = A(Inner); +} +"#, + r#" +//- /main.rs +struct Inner; +struct A { field1: Inner } + +mod foo; + +//- /foo.rs +use crate::{A, Inner}; +fn f() { + let a = A { field1: Inner }; +} +"#, + ); + } + + #[test] + fn convert_struct_with_where_clause() { + check_assist( + convert_tuple_struct_to_named_struct, + r#" +struct Wrap$0(T) +where + T: Display; +"#, + r#" +struct Wrap +where + T: Display, +{ field1: T } + +"#, + ); + } + #[test] + fn not_applicable_other_than_tuple_variant() { + check_assist_not_applicable( + convert_tuple_struct_to_named_struct, + r#"enum Enum { Variant$0 { value: usize } };"#, + ); + check_assist_not_applicable( + convert_tuple_struct_to_named_struct, + r#"enum Enum { Variant$0 }"#, + ); + } + + #[test] + fn convert_simple_variant() { + check_assist( + convert_tuple_struct_to_named_struct, + r#" +enum A { + $0Variant(usize), +} + +impl A { + fn new(value: usize) -> A { + A::Variant(value) + } + + fn new_with_default() -> A { + A::new(Default::default()) + } + + fn value(self) -> usize { + match self { + A::Variant(value) => value, + } + } +}"#, + r#" +enum A { + Variant { field1: usize }, +} + +impl A { + fn new(value: usize) -> A { + A::Variant { field1: value } + } + + fn new_with_default() -> A { + A::new(Default::default()) + } + + fn value(self) -> usize { + match self { + A::Variant { field1: value } => value, + } + } +}"#, + ); + } + + #[test] + fn convert_variant_referenced_via_self_kw() { + check_assist( + convert_tuple_struct_to_named_struct, + r#" +enum A { + $0Variant(usize), +} + +impl A { + fn new(value: usize) -> A { + Self::Variant(value) + } + + fn new_with_default() -> A { + Self::new(Default::default()) + } + + fn value(self) -> usize { + match self { + Self::Variant(value) => value, + } + } +}"#, + r#" +enum A { + Variant { field1: usize }, +} + +impl A { + fn new(value: usize) -> A { + Self::Variant { field1: value } + } + + fn new_with_default() -> A { + Self::new(Default::default()) + } + + fn value(self) -> usize { + match self { + Self::Variant { field1: value } => value, + } + } +}"#, + ); + } + + #[test] + fn convert_destructured_variant() { + check_assist( + convert_tuple_struct_to_named_struct, + r#" +enum A { + $0Variant(usize), +} + +impl A { + fn into_inner(self) -> usize { + let A::Variant(first) = self; + first + } + + fn into_inner_via_self(self) -> usize { + let Self::Variant(first) = self; + first + } +}"#, + r#" +enum A { + Variant { field1: usize }, +} + +impl A { + fn into_inner(self) -> usize { + let A::Variant { field1: first } = self; + first + } + + fn into_inner_via_self(self) -> usize { + let Self::Variant { field1: first } = self; + first + } +}"#, + ); + } + + #[test] + fn convert_variant_with_wrapped_references() { + check_assist( + convert_tuple_struct_to_named_struct, + r#" +enum Inner { + $0Variant(usize), +} +enum Outer { + Variant(Inner), +} + +impl Outer { + fn new() -> Self { + Self::Variant(Inner::Variant(42)) + } + + fn into_inner_destructed(self) -> u32 { + let Outer::Variant(Inner::Variant(x)) = self; + x + } +}"#, + r#" +enum Inner { + Variant { field1: usize }, +} +enum Outer { + Variant(Inner), +} + +impl Outer { + fn new() -> Self { + Self::Variant(Inner::Variant { field1: 42 }) + } + + fn into_inner_destructed(self) -> u32 { + let Outer::Variant(Inner::Variant { field1: x }) = self; + x + } +}"#, + ); + + check_assist( + convert_tuple_struct_to_named_struct, + r#" +enum Inner { + Variant(usize), +} +enum Outer { + $0Variant(Inner), +} + +impl Outer { + fn new() -> Self { + Self::Variant(Inner::Variant(42)) + } + + fn into_inner_destructed(self) -> u32 { + let Outer::Variant(Inner::Variant(x)) = self; + x + } +}"#, + r#" +enum Inner { + Variant(usize), +} +enum Outer { + Variant { field1: Inner }, +} + +impl Outer { + fn new() -> Self { + Self::Variant { field1: Inner::Variant(42) } + } + + fn into_inner_destructed(self) -> u32 { + let Outer::Variant { field1: Inner::Variant(x) } = self; + x + } +}"#, + ); + } + + #[test] + fn convert_variant_with_multi_file_references() { + check_assist( + convert_tuple_struct_to_named_struct, + r#" +//- /main.rs +struct Inner; +enum A { + $0Variant(Inner), +} + +mod foo; + +//- /foo.rs +use crate::{A, Inner}; +fn f() { + let a = A::Variant(Inner); +} +"#, + r#" +//- /main.rs +struct Inner; +enum A { + Variant { field1: Inner }, +} + +mod foo; + +//- /foo.rs +use crate::{A, Inner}; +fn f() { + let a = A::Variant { field1: Inner }; +} +"#, + ); + } + + #[test] + fn convert_directly_used_variant() { + check_assist( + convert_tuple_struct_to_named_struct, + r#" +//- /main.rs +struct Inner; +enum A { + $0Variant(Inner), +} + +mod foo; + +//- /foo.rs +use crate::{A::Variant, Inner}; +fn f() { + let a = Variant(Inner); +} +"#, + r#" +//- /main.rs +struct Inner; +enum A { + Variant { field1: Inner }, +} + +mod foo; + +//- /foo.rs +use crate::{A::Variant, Inner}; +fn f() { + let a = Variant { field1: Inner }; +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_while_to_loop.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_while_to_loop.rs new file mode 100644 index 000000000..c34b68411 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_while_to_loop.rs @@ -0,0 +1,188 @@ +use std::iter::once; + +use ide_db::syntax_helpers::node_ext::is_pattern_cond; +use syntax::{ + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + make, HasLoopBody, + }, + AstNode, T, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + utils::invert_boolean_expression, + AssistId, AssistKind, +}; + +// Assist: convert_while_to_loop +// +// Replace a while with a loop. +// +// ``` +// fn main() { +// $0while cond { +// foo(); +// } +// } +// ``` +// -> +// ``` +// fn main() { +// loop { +// if !cond { +// break; +// } +// foo(); +// } +// } +// ``` +pub(crate) fn convert_while_to_loop(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let while_kw = ctx.find_token_syntax_at_offset(T![while])?; + let while_expr = while_kw.parent().and_then(ast::WhileExpr::cast)?; + let while_body = while_expr.loop_body()?; + let while_cond = while_expr.condition()?; + + let target = while_expr.syntax().text_range(); + acc.add( + AssistId("convert_while_to_loop", AssistKind::RefactorRewrite), + "Convert while to loop", + target, + |edit| { + let while_indent_level = IndentLevel::from_node(while_expr.syntax()); + + let break_block = + make::block_expr(once(make::expr_stmt(make::expr_break(None, None)).into()), None) + .indent(while_indent_level); + let block_expr = if is_pattern_cond(while_cond.clone()) { + let if_expr = make::expr_if(while_cond, while_body, Some(break_block.into())); + let stmts = once(make::expr_stmt(if_expr).into()); + make::block_expr(stmts, None) + } else { + let if_cond = invert_boolean_expression(while_cond); + let if_expr = make::expr_if(if_cond, break_block, None); + let stmts = once(make::expr_stmt(if_expr).into()).chain(while_body.statements()); + make::block_expr(stmts, while_body.tail_expr()) + }; + + let replacement = make::expr_loop(block_expr.indent(while_indent_level)); + edit.replace(target, replacement.syntax().text()) + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn convert_inside_fn() { + check_assist( + convert_while_to_loop, + r#" +fn main() { + while$0 cond { + foo(); + } +} +"#, + r#" +fn main() { + loop { + if !cond { + break; + } + foo(); + } +} +"#, + ); + } + + #[test] + fn convert_busy_wait() { + check_assist( + convert_while_to_loop, + r#" +fn main() { + while$0 cond() {} +} +"#, + r#" +fn main() { + loop { + if !cond() { + break; + } + } +} +"#, + ); + } + + #[test] + fn convert_trailing_expr() { + check_assist( + convert_while_to_loop, + r#" +fn main() { + while$0 cond() { + bar() + } +} +"#, + r#" +fn main() { + loop { + if !cond() { + break; + } + bar() + } +} +"#, + ); + } + + #[test] + fn convert_while_let() { + check_assist( + convert_while_to_loop, + r#" +fn main() { + while$0 let Some(_) = foo() { + bar(); + } +} +"#, + r#" +fn main() { + loop { + if let Some(_) = foo() { + bar(); + } else { + break; + } + } +} +"#, + ); + } + + #[test] + fn ignore_cursor_in_body() { + check_assist_not_applicable( + convert_while_to_loop, + r#" +fn main() { + while cond {$0 + bar(); + } +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/destructure_tuple_binding.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/destructure_tuple_binding.rs new file mode 100644 index 000000000..c1f57532b --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/destructure_tuple_binding.rs @@ -0,0 +1,2147 @@ +use ide_db::{ + assists::{AssistId, AssistKind}, + defs::Definition, + search::{FileReference, SearchScope, UsageSearchResult}, +}; +use syntax::{ + ast::{self, AstNode, FieldExpr, HasName, IdentPat, MethodCallExpr}, + TextRange, +}; + +use crate::assist_context::{AssistBuilder, AssistContext, Assists}; + +// Assist: destructure_tuple_binding +// +// Destructures a tuple binding in place. +// +// ``` +// fn main() { +// let $0t = (1,2); +// let v = t.0; +// } +// ``` +// -> +// ``` +// fn main() { +// let ($0_0, _1) = (1,2); +// let v = _0; +// } +// ``` +pub(crate) fn destructure_tuple_binding(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + destructure_tuple_binding_impl(acc, ctx, false) +} + +// And when `with_sub_pattern` enabled (currently disabled): +// Assist: destructure_tuple_binding_in_sub_pattern +// +// Destructures tuple items in sub-pattern (after `@`). +// +// ``` +// fn main() { +// let $0t = (1,2); +// let v = t.0; +// } +// ``` +// -> +// ``` +// fn main() { +// let t @ ($0_0, _1) = (1,2); +// let v = _0; +// } +// ``` +pub(crate) fn destructure_tuple_binding_impl( + acc: &mut Assists, + ctx: &AssistContext<'_>, + with_sub_pattern: bool, +) -> Option<()> { + let ident_pat = ctx.find_node_at_offset::()?; + let data = collect_data(ident_pat, ctx)?; + + if with_sub_pattern { + acc.add( + AssistId("destructure_tuple_binding_in_sub_pattern", AssistKind::RefactorRewrite), + "Destructure tuple in sub-pattern", + data.range, + |builder| { + edit_tuple_assignment(ctx, builder, &data, true); + edit_tuple_usages(&data, builder, ctx, true); + }, + ); + } + + acc.add( + AssistId("destructure_tuple_binding", AssistKind::RefactorRewrite), + if with_sub_pattern { "Destructure tuple in place" } else { "Destructure tuple" }, + data.range, + |builder| { + edit_tuple_assignment(ctx, builder, &data, false); + edit_tuple_usages(&data, builder, ctx, false); + }, + ); + + Some(()) +} + +fn collect_data(ident_pat: IdentPat, ctx: &AssistContext<'_>) -> Option { + if ident_pat.at_token().is_some() { + // Cannot destructure pattern with sub-pattern: + // Only IdentPat can have sub-pattern, + // but not TuplePat (`(a,b)`). + cov_mark::hit!(destructure_tuple_subpattern); + return None; + } + + let ty = ctx.sema.type_of_pat(&ident_pat.clone().into())?.adjusted(); + let ref_type = if ty.is_mutable_reference() { + Some(RefType::Mutable) + } else if ty.is_reference() { + Some(RefType::ReadOnly) + } else { + None + }; + // might be reference + let ty = ty.strip_references(); + // must be tuple + let field_types = ty.tuple_fields(ctx.db()); + if field_types.is_empty() { + cov_mark::hit!(destructure_tuple_no_tuple); + return None; + } + + let name = ident_pat.name()?.to_string(); + let range = ident_pat.syntax().text_range(); + + let usages = ctx.sema.to_def(&ident_pat).map(|def| { + Definition::Local(def) + .usages(&ctx.sema) + .in_scope(SearchScope::single_file(ctx.file_id())) + .all() + }); + + let field_names = (0..field_types.len()) + .map(|i| generate_name(ctx, i, &name, &ident_pat, &usages)) + .collect::>(); + + Some(TupleData { ident_pat, range, ref_type, field_names, usages }) +} + +fn generate_name( + _ctx: &AssistContext<'_>, + index: usize, + _tuple_name: &str, + _ident_pat: &IdentPat, + _usages: &Option, +) -> String { + // FIXME: detect if name already used + format!("_{}", index) +} + +enum RefType { + ReadOnly, + Mutable, +} +struct TupleData { + ident_pat: IdentPat, + // name: String, + range: TextRange, + ref_type: Option, + field_names: Vec, + // field_types: Vec, + usages: Option, +} +fn edit_tuple_assignment( + ctx: &AssistContext<'_>, + builder: &mut AssistBuilder, + data: &TupleData, + in_sub_pattern: bool, +) { + let tuple_pat = { + let original = &data.ident_pat; + let is_ref = original.ref_token().is_some(); + let is_mut = original.mut_token().is_some(); + let fields = data.field_names.iter().map(|name| { + ast::Pat::from(ast::make::ident_pat(is_ref, is_mut, ast::make::name(name))) + }); + ast::make::tuple_pat(fields) + }; + + let add_cursor = |text: &str| { + // place cursor on first tuple item + let first_tuple = &data.field_names[0]; + text.replacen(first_tuple, &format!("$0{}", first_tuple), 1) + }; + + // with sub_pattern: keep original tuple and add subpattern: `tup @ (_0, _1)` + if in_sub_pattern { + let text = format!(" @ {}", tuple_pat); + match ctx.config.snippet_cap { + Some(cap) => { + let snip = add_cursor(&text); + builder.insert_snippet(cap, data.range.end(), snip); + } + None => builder.insert(data.range.end(), text), + }; + } else { + let text = tuple_pat.to_string(); + match ctx.config.snippet_cap { + Some(cap) => { + let snip = add_cursor(&text); + builder.replace_snippet(cap, data.range, snip); + } + None => builder.replace(data.range, text), + }; + } +} + +fn edit_tuple_usages( + data: &TupleData, + builder: &mut AssistBuilder, + ctx: &AssistContext<'_>, + in_sub_pattern: bool, +) { + if let Some(usages) = data.usages.as_ref() { + for (file_id, refs) in usages.iter() { + builder.edit_file(*file_id); + + for r in refs { + edit_tuple_usage(ctx, builder, r, data, in_sub_pattern); + } + } + } +} +fn edit_tuple_usage( + ctx: &AssistContext<'_>, + builder: &mut AssistBuilder, + usage: &FileReference, + data: &TupleData, + in_sub_pattern: bool, +) { + match detect_tuple_index(usage, data) { + Some(index) => edit_tuple_field_usage(ctx, builder, data, index), + None => { + if in_sub_pattern { + cov_mark::hit!(destructure_tuple_call_with_subpattern); + return; + } + + // no index access -> make invalid -> requires handling by user + // -> put usage in block comment + // + // Note: For macro invocations this might result in still valid code: + // When a macro accepts the tuple as argument, as well as no arguments at all, + // uncommenting the tuple still leaves the macro call working (see `tests::in_macro_call::empty_macro`). + // But this is an unlikely case. Usually the resulting macro call will become erroneous. + builder.insert(usage.range.start(), "/*"); + builder.insert(usage.range.end(), "*/"); + } + } +} + +fn edit_tuple_field_usage( + ctx: &AssistContext<'_>, + builder: &mut AssistBuilder, + data: &TupleData, + index: TupleIndex, +) { + let field_name = &data.field_names[index.index]; + + if data.ref_type.is_some() { + let ref_data = handle_ref_field_usage(ctx, &index.field_expr); + builder.replace(ref_data.range, ref_data.format(field_name)); + } else { + builder.replace(index.range, field_name); + } +} +struct TupleIndex { + index: usize, + range: TextRange, + field_expr: FieldExpr, +} +fn detect_tuple_index(usage: &FileReference, data: &TupleData) -> Option { + // usage is IDENT + // IDENT + // NAME_REF + // PATH_SEGMENT + // PATH + // PATH_EXPR + // PAREN_EXRP* + // FIELD_EXPR + + let node = usage + .name + .syntax() + .ancestors() + .skip_while(|s| !ast::PathExpr::can_cast(s.kind())) + .skip(1) // PATH_EXPR + .find(|s| !ast::ParenExpr::can_cast(s.kind()))?; // skip parentheses + + if let Some(field_expr) = ast::FieldExpr::cast(node) { + let idx = field_expr.name_ref()?.as_tuple_field()?; + if idx < data.field_names.len() { + // special case: in macro call -> range of `field_expr` in applied macro, NOT range in actual file! + if field_expr.syntax().ancestors().any(|a| ast::MacroStmts::can_cast(a.kind())) { + cov_mark::hit!(destructure_tuple_macro_call); + + // issue: cannot differentiate between tuple index passed into macro or tuple index as result of macro: + // ```rust + // macro_rules! m { + // ($t1:expr, $t2:expr) => { $t1; $t2.0 } + // } + // let t = (1,2); + // m!(t.0, t) + // ``` + // -> 2 tuple index usages detected! + // + // -> only handle `t` + return None; + } + + Some(TupleIndex { index: idx, range: field_expr.syntax().text_range(), field_expr }) + } else { + // tuple index out of range + None + } + } else { + None + } +} + +struct RefData { + range: TextRange, + needs_deref: bool, + needs_parentheses: bool, +} +impl RefData { + fn format(&self, field_name: &str) -> String { + match (self.needs_deref, self.needs_parentheses) { + (true, true) => format!("(*{})", field_name), + (true, false) => format!("*{}", field_name), + (false, true) => format!("({})", field_name), + (false, false) => field_name.to_string(), + } + } +} +fn handle_ref_field_usage(ctx: &AssistContext<'_>, field_expr: &FieldExpr) -> RefData { + let s = field_expr.syntax(); + let mut ref_data = + RefData { range: s.text_range(), needs_deref: true, needs_parentheses: true }; + + let parent = match s.parent().map(ast::Expr::cast) { + Some(Some(parent)) => parent, + Some(None) => { + ref_data.needs_parentheses = false; + return ref_data; + } + None => return ref_data, + }; + + match parent { + ast::Expr::ParenExpr(it) => { + // already parens in place -> don't replace + ref_data.needs_parentheses = false; + // there might be a ref outside: `&(t.0)` -> can be removed + if let Some(it) = it.syntax().parent().and_then(ast::RefExpr::cast) { + ref_data.needs_deref = false; + ref_data.range = it.syntax().text_range(); + } + } + ast::Expr::RefExpr(it) => { + // `&*` -> cancel each other out + ref_data.needs_deref = false; + ref_data.needs_parentheses = false; + // might be surrounded by parens -> can be removed too + match it.syntax().parent().and_then(ast::ParenExpr::cast) { + Some(parent) => ref_data.range = parent.syntax().text_range(), + None => ref_data.range = it.syntax().text_range(), + }; + } + // higher precedence than deref `*` + // https://doc.rust-lang.org/reference/expressions.html#expression-precedence + // -> requires parentheses + ast::Expr::PathExpr(_it) => {} + ast::Expr::MethodCallExpr(it) => { + // `field_expr` is `self_param` (otherwise it would be in `ArgList`) + + // test if there's already auto-ref in place (`value` -> `&value`) + // -> no method accepting `self`, but `&self` -> no need for deref + // + // other combinations (`&value` -> `value`, `&&value` -> `&value`, `&value` -> `&&value`) might or might not be able to auto-ref/deref, + // but there might be trait implementations an added `&` might resolve to + // -> ONLY handle auto-ref from `value` to `&value` + fn is_auto_ref(ctx: &AssistContext<'_>, call_expr: &MethodCallExpr) -> bool { + fn impl_(ctx: &AssistContext<'_>, call_expr: &MethodCallExpr) -> Option { + let rec = call_expr.receiver()?; + let rec_ty = ctx.sema.type_of_expr(&rec)?.original(); + // input must be actual value + if rec_ty.is_reference() { + return Some(false); + } + + // doesn't resolve trait impl + let f = ctx.sema.resolve_method_call(call_expr)?; + let self_param = f.self_param(ctx.db())?; + // self must be ref + match self_param.access(ctx.db()) { + hir::Access::Shared | hir::Access::Exclusive => Some(true), + hir::Access::Owned => Some(false), + } + } + impl_(ctx, call_expr).unwrap_or(false) + } + + if is_auto_ref(ctx, &it) { + ref_data.needs_deref = false; + ref_data.needs_parentheses = false; + } + } + ast::Expr::FieldExpr(_it) => { + // `t.0.my_field` + ref_data.needs_deref = false; + ref_data.needs_parentheses = false; + } + ast::Expr::IndexExpr(_it) => { + // `t.0[1]` + ref_data.needs_deref = false; + ref_data.needs_parentheses = false; + } + ast::Expr::TryExpr(_it) => { + // `t.0?` + // requires deref and parens: `(*_0)` + } + // lower precedence than deref `*` -> no parens + _ => { + ref_data.needs_parentheses = false; + } + }; + + ref_data +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable}; + + // Tests for direct tuple destructure: + // `let $0t = (1,2);` -> `let (_0, _1) = (1,2);` + + fn assist(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + destructure_tuple_binding_impl(acc, ctx, false) + } + + #[test] + fn dont_trigger_on_unit() { + cov_mark::check!(destructure_tuple_no_tuple); + check_assist_not_applicable( + assist, + r#" +fn main() { +let $0v = (); +} + "#, + ) + } + #[test] + fn dont_trigger_on_number() { + cov_mark::check!(destructure_tuple_no_tuple); + check_assist_not_applicable( + assist, + r#" +fn main() { +let $0v = 32; +} + "#, + ) + } + + #[test] + fn destructure_3_tuple() { + check_assist( + assist, + r#" +fn main() { + let $0tup = (1,2,3); +} + "#, + r#" +fn main() { + let ($0_0, _1, _2) = (1,2,3); +} + "#, + ) + } + #[test] + fn destructure_2_tuple() { + check_assist( + assist, + r#" +fn main() { + let $0tup = (1,2); +} + "#, + r#" +fn main() { + let ($0_0, _1) = (1,2); +} + "#, + ) + } + #[test] + fn replace_indices() { + check_assist( + assist, + r#" +fn main() { + let $0tup = (1,2,3); + let v1 = tup.0; + let v2 = tup.1; + let v3 = tup.2; +} + "#, + r#" +fn main() { + let ($0_0, _1, _2) = (1,2,3); + let v1 = _0; + let v2 = _1; + let v3 = _2; +} + "#, + ) + } + + #[test] + fn replace_usage_in_parentheses() { + check_assist( + assist, + r#" +fn main() { + let $0tup = (1,2,3); + let a = (tup).1; + let b = ((tup)).1; +} + "#, + r#" +fn main() { + let ($0_0, _1, _2) = (1,2,3); + let a = _1; + let b = _1; +} + "#, + ) + } + + #[test] + fn handle_function_call() { + check_assist( + assist, + r#" +fn main() { + let $0tup = (1,2); + let v = tup.into(); +} + "#, + r#" +fn main() { + let ($0_0, _1) = (1,2); + let v = /*tup*/.into(); +} + "#, + ) + } + + #[test] + fn handle_invalid_index() { + check_assist( + assist, + r#" +fn main() { + let $0tup = (1,2); + let v = tup.3; +} + "#, + r#" +fn main() { + let ($0_0, _1) = (1,2); + let v = /*tup*/.3; +} + "#, + ) + } + + #[test] + fn dont_replace_variable_with_same_name_as_tuple() { + check_assist( + assist, + r#" +fn main() { + let tup = (1,2); + let v = tup.1; + let $0tup = (1,2,3); + let v = tup.1; + let tup = (1,2,3); + let v = tup.1; +} + "#, + r#" +fn main() { + let tup = (1,2); + let v = tup.1; + let ($0_0, _1, _2) = (1,2,3); + let v = _1; + let tup = (1,2,3); + let v = tup.1; +} + "#, + ) + } + + #[test] + fn keep_function_call_in_tuple_item() { + check_assist( + assist, + r#" +fn main() { + let $0t = ("3.14", 0); + let pi: f32 = t.0.parse().unwrap_or(0.0); +} + "#, + r#" +fn main() { + let ($0_0, _1) = ("3.14", 0); + let pi: f32 = _0.parse().unwrap_or(0.0); +} + "#, + ) + } + + #[test] + fn keep_type() { + check_assist( + assist, + r#" +fn main() { + let $0t: (usize, i32) = (1,2); +} + "#, + r#" +fn main() { + let ($0_0, _1): (usize, i32) = (1,2); +} + "#, + ) + } + + #[test] + fn destructure_reference() { + check_assist( + assist, + r#" +fn main() { + let t = (1,2); + let $0t = &t; + let v = t.0; +} + "#, + r#" +fn main() { + let t = (1,2); + let ($0_0, _1) = &t; + let v = *_0; +} + "#, + ) + } + + #[test] + fn destructure_multiple_reference() { + check_assist( + assist, + r#" +fn main() { + let t = (1,2); + let $0t = &&t; + let v = t.0; +} + "#, + r#" +fn main() { + let t = (1,2); + let ($0_0, _1) = &&t; + let v = *_0; +} + "#, + ) + } + + #[test] + fn keep_reference() { + check_assist( + assist, + r#" +fn foo(t: &(usize, usize)) -> usize { + match t { + &$0t => t.0 + } +} + "#, + r#" +fn foo(t: &(usize, usize)) -> usize { + match t { + &($0_0, _1) => _0 + } +} + "#, + ) + } + + #[test] + fn with_ref() { + check_assist( + assist, + r#" +fn main() { + let ref $0t = (1,2); + let v = t.0; +} + "#, + r#" +fn main() { + let (ref $0_0, ref _1) = (1,2); + let v = *_0; +} + "#, + ) + } + + #[test] + fn with_mut() { + check_assist( + assist, + r#" +fn main() { + let mut $0t = (1,2); + t.0 = 42; + let v = t.0; +} + "#, + r#" +fn main() { + let (mut $0_0, mut _1) = (1,2); + _0 = 42; + let v = _0; +} + "#, + ) + } + + #[test] + fn with_ref_mut() { + check_assist( + assist, + r#" +fn main() { + let ref mut $0t = (1,2); + t.0 = 42; + let v = t.0; +} + "#, + r#" +fn main() { + let (ref mut $0_0, ref mut _1) = (1,2); + *_0 = 42; + let v = *_0; +} + "#, + ) + } + + #[test] + fn dont_trigger_for_non_tuple_reference() { + check_assist_not_applicable( + assist, + r#" +fn main() { + let v = 42; + let $0v = &42; +} + "#, + ) + } + + #[test] + fn dont_trigger_on_static_tuple() { + check_assist_not_applicable( + assist, + r#" +static $0TUP: (usize, usize) = (1,2); + "#, + ) + } + + #[test] + fn dont_trigger_on_wildcard() { + check_assist_not_applicable( + assist, + r#" +fn main() { + let $0_ = (1,2); +} + "#, + ) + } + + #[test] + fn dont_trigger_in_struct() { + check_assist_not_applicable( + assist, + r#" +struct S { + $0tup: (usize, usize), +} + "#, + ) + } + + #[test] + fn dont_trigger_in_struct_creation() { + check_assist_not_applicable( + assist, + r#" +struct S { + tup: (usize, usize), +} +fn main() { + let s = S { + $0tup: (1,2), + }; +} + "#, + ) + } + + #[test] + fn dont_trigger_on_tuple_struct() { + check_assist_not_applicable( + assist, + r#" +struct S(usize, usize); +fn main() { + let $0s = S(1,2); +} + "#, + ) + } + + #[test] + fn dont_trigger_when_subpattern_exists() { + // sub-pattern is only allowed with IdentPat (name), not other patterns (like TuplePat) + cov_mark::check!(destructure_tuple_subpattern); + check_assist_not_applicable( + assist, + r#" +fn sum(t: (usize, usize)) -> usize { + match t { + $0t @ (1..=3,1..=3) => t.0 + t.1, + _ => 0, + } +} + "#, + ) + } + + #[test] + fn in_subpattern() { + check_assist( + assist, + r#" +fn main() { + let t1 @ (_, $0t2) = (1, (2,3)); + let v = t1.0 + t2.0 + t2.1; +} + "#, + r#" +fn main() { + let t1 @ (_, ($0_0, _1)) = (1, (2,3)); + let v = t1.0 + _0 + _1; +} + "#, + ) + } + + #[test] + fn in_nested_tuple() { + check_assist( + assist, + r#" +fn main() { + let ($0tup, v) = ((1,2),3); +} + "#, + r#" +fn main() { + let (($0_0, _1), v) = ((1,2),3); +} + "#, + ) + } + + #[test] + fn in_closure() { + check_assist( + assist, + r#" +fn main() { + let $0tup = (1,2,3); + let f = |v| v + tup.1; +} + "#, + r#" +fn main() { + let ($0_0, _1, _2) = (1,2,3); + let f = |v| v + _1; +} + "#, + ) + } + + #[test] + fn in_closure_args() { + check_assist( + assist, + r#" +fn main() { + let f = |$0t| t.0 + t.1; + let v = f((1,2)); +} + "#, + r#" +fn main() { + let f = |($0_0, _1)| _0 + _1; + let v = f((1,2)); +} + "#, + ) + } + + #[test] + fn in_function_args() { + check_assist( + assist, + r#" +fn f($0t: (usize, usize)) { + let v = t.0; +} + "#, + r#" +fn f(($0_0, _1): (usize, usize)) { + let v = _0; +} + "#, + ) + } + + #[test] + fn in_if_let() { + check_assist( + assist, + r#" +fn f(t: (usize, usize)) { + if let $0t = t { + let v = t.0; + } +} + "#, + r#" +fn f(t: (usize, usize)) { + if let ($0_0, _1) = t { + let v = _0; + } +} + "#, + ) + } + #[test] + fn in_if_let_option() { + check_assist( + assist, + r#" +//- minicore: option +fn f(o: Option<(usize, usize)>) { + if let Some($0t) = o { + let v = t.0; + } +} + "#, + r#" +fn f(o: Option<(usize, usize)>) { + if let Some(($0_0, _1)) = o { + let v = _0; + } +} + "#, + ) + } + + #[test] + fn in_match() { + check_assist( + assist, + r#" +fn main() { + match (1,2) { + $0t => t.1, + }; +} + "#, + r#" +fn main() { + match (1,2) { + ($0_0, _1) => _1, + }; +} + "#, + ) + } + #[test] + fn in_match_option() { + check_assist( + assist, + r#" +//- minicore: option +fn main() { + match Some((1,2)) { + Some($0t) => t.1, + _ => 0, + }; +} + "#, + r#" +fn main() { + match Some((1,2)) { + Some(($0_0, _1)) => _1, + _ => 0, + }; +} + "#, + ) + } + #[test] + fn in_match_reference_option() { + check_assist( + assist, + r#" +//- minicore: option +fn main() { + let t = (1,2); + match Some(&t) { + Some($0t) => t.1, + _ => 0, + }; +} + "#, + r#" +fn main() { + let t = (1,2); + match Some(&t) { + Some(($0_0, _1)) => *_1, + _ => 0, + }; +} + "#, + ) + } + + #[test] + fn in_for() { + check_assist( + assist, + r#" +//- minicore: iterators +fn main() { + for $0t in core::iter::repeat((1,2)) { + let v = t.1; + } +} + "#, + r#" +fn main() { + for ($0_0, _1) in core::iter::repeat((1,2)) { + let v = _1; + } +} + "#, + ) + } + #[test] + fn in_for_nested() { + check_assist( + assist, + r#" +//- minicore: iterators +fn main() { + for (a, $0b) in core::iter::repeat((1,(2,3))) { + let v = b.1; + } +} + "#, + r#" +fn main() { + for (a, ($0_0, _1)) in core::iter::repeat((1,(2,3))) { + let v = _1; + } +} + "#, + ) + } + + #[test] + fn not_applicable_on_tuple_usage() { + //Improvement: might be reasonable to allow & implement + check_assist_not_applicable( + assist, + r#" +fn main() { + let t = (1,2); + let v = $0t.0; +} + "#, + ) + } + + #[test] + fn replace_all() { + check_assist( + assist, + r#" +fn main() { + let $0t = (1,2); + let v = t.1; + let s = (t.0 + t.1) / 2; + let f = |v| v + t.0; + let r = f(t.1); + let e = t == (9,0); + let m = + match t { + (_,2) if t.0 > 2 => 1, + _ => 0, + }; +} + "#, + r#" +fn main() { + let ($0_0, _1) = (1,2); + let v = _1; + let s = (_0 + _1) / 2; + let f = |v| v + _0; + let r = f(_1); + let e = /*t*/ == (9,0); + let m = + match /*t*/ { + (_,2) if _0 > 2 => 1, + _ => 0, + }; +} + "#, + ) + } + + #[test] + fn non_trivial_tuple_assignment() { + check_assist( + assist, + r#" +fn main { + let $0t = + if 1 > 2 { + (1,2) + } else { + (5,6) + }; + let v1 = t.0; + let v2 = + if t.0 > t.1 { + t.0 - t.1 + } else { + t.1 - t.0 + }; +} + "#, + r#" +fn main { + let ($0_0, _1) = + if 1 > 2 { + (1,2) + } else { + (5,6) + }; + let v1 = _0; + let v2 = + if _0 > _1 { + _0 - _1 + } else { + _1 - _0 + }; +} + "#, + ) + } + + mod assist { + use super::*; + use crate::tests::check_assist_by_label; + + fn assist(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + destructure_tuple_binding_impl(acc, ctx, true) + } + fn in_place_assist(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + destructure_tuple_binding_impl(acc, ctx, false) + } + + pub(crate) fn check_in_place_assist(ra_fixture_before: &str, ra_fixture_after: &str) { + check_assist_by_label( + in_place_assist, + ra_fixture_before, + ra_fixture_after, + // "Destructure tuple in place", + "Destructure tuple", + ); + } + + pub(crate) fn check_sub_pattern_assist(ra_fixture_before: &str, ra_fixture_after: &str) { + check_assist_by_label( + assist, + ra_fixture_before, + ra_fixture_after, + "Destructure tuple in sub-pattern", + ); + } + + pub(crate) fn check_both_assists( + ra_fixture_before: &str, + ra_fixture_after_in_place: &str, + ra_fixture_after_in_sub_pattern: &str, + ) { + check_in_place_assist(ra_fixture_before, ra_fixture_after_in_place); + check_sub_pattern_assist(ra_fixture_before, ra_fixture_after_in_sub_pattern); + } + } + + /// Tests for destructure of tuple in sub-pattern: + /// `let $0t = (1,2);` -> `let t @ (_0, _1) = (1,2);` + mod sub_pattern { + use super::assist::*; + use super::*; + use crate::tests::check_assist_by_label; + + #[test] + fn destructure_in_sub_pattern() { + check_sub_pattern_assist( + r#" +#![feature(bindings_after_at)] + +fn main() { + let $0t = (1,2); +} + "#, + r#" +#![feature(bindings_after_at)] + +fn main() { + let t @ ($0_0, _1) = (1,2); +} + "#, + ) + } + + #[test] + fn trigger_both_destructure_tuple_assists() { + fn assist(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + destructure_tuple_binding_impl(acc, ctx, true) + } + let text = r#" +fn main() { + let $0t = (1,2); +} + "#; + check_assist_by_label( + assist, + text, + r#" +fn main() { + let ($0_0, _1) = (1,2); +} + "#, + "Destructure tuple in place", + ); + check_assist_by_label( + assist, + text, + r#" +fn main() { + let t @ ($0_0, _1) = (1,2); +} + "#, + "Destructure tuple in sub-pattern", + ); + } + + #[test] + fn replace_indices() { + check_sub_pattern_assist( + r#" +fn main() { + let $0t = (1,2); + let v1 = t.0; + let v2 = t.1; +} + "#, + r#" +fn main() { + let t @ ($0_0, _1) = (1,2); + let v1 = _0; + let v2 = _1; +} + "#, + ) + } + + #[test] + fn keep_function_call() { + cov_mark::check!(destructure_tuple_call_with_subpattern); + check_sub_pattern_assist( + r#" +fn main() { + let $0t = (1,2); + let v = t.into(); +} + "#, + r#" +fn main() { + let t @ ($0_0, _1) = (1,2); + let v = t.into(); +} + "#, + ) + } + + #[test] + fn keep_type() { + check_sub_pattern_assist( + r#" +fn main() { + let $0t: (usize, i32) = (1,2); + let v = t.1; + let f = t.into(); +} + "#, + r#" +fn main() { + let t @ ($0_0, _1): (usize, i32) = (1,2); + let v = _1; + let f = t.into(); +} + "#, + ) + } + + #[test] + fn in_function_args() { + check_sub_pattern_assist( + r#" +fn f($0t: (usize, usize)) { + let v = t.0; + let f = t.into(); +} + "#, + r#" +fn f(t @ ($0_0, _1): (usize, usize)) { + let v = _0; + let f = t.into(); +} + "#, + ) + } + + #[test] + fn with_ref() { + check_sub_pattern_assist( + r#" +fn main() { + let ref $0t = (1,2); + let v = t.1; + let f = t.into(); +} + "#, + r#" +fn main() { + let ref t @ (ref $0_0, ref _1) = (1,2); + let v = *_1; + let f = t.into(); +} + "#, + ) + } + #[test] + fn with_mut() { + check_sub_pattern_assist( + r#" +fn main() { + let mut $0t = (1,2); + let v = t.1; + let f = t.into(); +} + "#, + r#" +fn main() { + let mut t @ (mut $0_0, mut _1) = (1,2); + let v = _1; + let f = t.into(); +} + "#, + ) + } + #[test] + fn with_ref_mut() { + check_sub_pattern_assist( + r#" +fn main() { + let ref mut $0t = (1,2); + let v = t.1; + let f = t.into(); +} + "#, + r#" +fn main() { + let ref mut t @ (ref mut $0_0, ref mut _1) = (1,2); + let v = *_1; + let f = t.into(); +} + "#, + ) + } + } + + /// Tests for tuple usage in macro call: + /// `println!("{}", t.0)` + mod in_macro_call { + use super::assist::*; + + #[test] + fn detect_macro_call() { + cov_mark::check!(destructure_tuple_macro_call); + check_in_place_assist( + r#" +macro_rules! m { + ($e:expr) => { "foo"; $e }; +} + +fn main() { + let $0t = (1,2); + m!(t.0); +} + "#, + r#" +macro_rules! m { + ($e:expr) => { "foo"; $e }; +} + +fn main() { + let ($0_0, _1) = (1,2); + m!(/*t*/.0); +} + "#, + ) + } + + #[test] + fn tuple_usage() { + check_both_assists( + // leading `"foo"` to ensure `$e` doesn't start at position `0` + r#" +macro_rules! m { + ($e:expr) => { "foo"; $e }; +} + +fn main() { + let $0t = (1,2); + m!(t); +} + "#, + r#" +macro_rules! m { + ($e:expr) => { "foo"; $e }; +} + +fn main() { + let ($0_0, _1) = (1,2); + m!(/*t*/); +} + "#, + r#" +macro_rules! m { + ($e:expr) => { "foo"; $e }; +} + +fn main() { + let t @ ($0_0, _1) = (1,2); + m!(t); +} + "#, + ) + } + + #[test] + fn tuple_function_usage() { + check_both_assists( + r#" +macro_rules! m { + ($e:expr) => { "foo"; $e }; +} + +fn main() { + let $0t = (1,2); + m!(t.into()); +} + "#, + r#" +macro_rules! m { + ($e:expr) => { "foo"; $e }; +} + +fn main() { + let ($0_0, _1) = (1,2); + m!(/*t*/.into()); +} + "#, + r#" +macro_rules! m { + ($e:expr) => { "foo"; $e }; +} + +fn main() { + let t @ ($0_0, _1) = (1,2); + m!(t.into()); +} + "#, + ) + } + + #[test] + fn tuple_index_usage() { + check_both_assists( + r#" +macro_rules! m { + ($e:expr) => { "foo"; $e }; +} + +fn main() { + let $0t = (1,2); + m!(t.0); +} + "#, + // FIXME: replace `t.0` with `_0` (cannot detect range of tuple index in macro call) + r#" +macro_rules! m { + ($e:expr) => { "foo"; $e }; +} + +fn main() { + let ($0_0, _1) = (1,2); + m!(/*t*/.0); +} + "#, + // FIXME: replace `t.0` with `_0` + r#" +macro_rules! m { + ($e:expr) => { "foo"; $e }; +} + +fn main() { + let t @ ($0_0, _1) = (1,2); + m!(t.0); +} + "#, + ) + } + + #[test] + fn tuple_in_parentheses_index_usage() { + check_both_assists( + r#" +macro_rules! m { + ($e:expr) => { "foo"; $e }; +} + +fn main() { + let $0t = (1,2); + m!((t).0); +} + "#, + // FIXME: replace `(t).0` with `_0` + r#" +macro_rules! m { + ($e:expr) => { "foo"; $e }; +} + +fn main() { + let ($0_0, _1) = (1,2); + m!((/*t*/).0); +} + "#, + // FIXME: replace `(t).0` with `_0` + r#" +macro_rules! m { + ($e:expr) => { "foo"; $e }; +} + +fn main() { + let t @ ($0_0, _1) = (1,2); + m!((t).0); +} + "#, + ) + } + + #[test] + fn empty_macro() { + check_in_place_assist( + r#" +macro_rules! m { + () => { "foo" }; + ($e:expr) => { $e; "foo" }; +} + +fn main() { + let $0t = (1,2); + m!(t); +} + "#, + // FIXME: macro allows no arg -> is valid. But assist should result in invalid code + r#" +macro_rules! m { + () => { "foo" }; + ($e:expr) => { $e; "foo" }; +} + +fn main() { + let ($0_0, _1) = (1,2); + m!(/*t*/); +} + "#, + ) + } + + #[test] + fn tuple_index_in_macro() { + check_both_assists( + r#" +macro_rules! m { + ($t:expr, $i:expr) => { $t.0 + $i }; +} + +fn main() { + let $0t = (1,2); + m!(t, t.0); +} + "#, + // FIXME: replace `t.0` in macro call (not IN macro) with `_0` + r#" +macro_rules! m { + ($t:expr, $i:expr) => { $t.0 + $i }; +} + +fn main() { + let ($0_0, _1) = (1,2); + m!(/*t*/, /*t*/.0); +} + "#, + // FIXME: replace `t.0` in macro call with `_0` + r#" +macro_rules! m { + ($t:expr, $i:expr) => { $t.0 + $i }; +} + +fn main() { + let t @ ($0_0, _1) = (1,2); + m!(t, t.0); +} + "#, + ) + } + } + + mod refs { + use super::assist::*; + + #[test] + fn no_ref() { + check_in_place_assist( + r#" +fn main() { + let $0t = &(1,2); + let v: i32 = t.0; +} + "#, + r#" +fn main() { + let ($0_0, _1) = &(1,2); + let v: i32 = *_0; +} + "#, + ) + } + #[test] + fn no_ref_with_parens() { + check_in_place_assist( + r#" +fn main() { + let $0t = &(1,2); + let v: i32 = (t.0); +} + "#, + r#" +fn main() { + let ($0_0, _1) = &(1,2); + let v: i32 = (*_0); +} + "#, + ) + } + #[test] + fn with_ref() { + check_in_place_assist( + r#" +fn main() { + let $0t = &(1,2); + let v: &i32 = &t.0; +} + "#, + r#" +fn main() { + let ($0_0, _1) = &(1,2); + let v: &i32 = _0; +} + "#, + ) + } + #[test] + fn with_ref_in_parens_ref() { + check_in_place_assist( + r#" +fn main() { + let $0t = &(1,2); + let v: &i32 = &(t.0); +} + "#, + r#" +fn main() { + let ($0_0, _1) = &(1,2); + let v: &i32 = _0; +} + "#, + ) + } + #[test] + fn with_ref_in_ref_parens() { + check_in_place_assist( + r#" +fn main() { + let $0t = &(1,2); + let v: &i32 = (&t.0); +} + "#, + r#" +fn main() { + let ($0_0, _1) = &(1,2); + let v: &i32 = _0; +} + "#, + ) + } + + #[test] + fn deref_and_parentheses() { + // Operator/Expressions with higher precedence than deref (`*`): + // https://doc.rust-lang.org/reference/expressions.html#expression-precedence + // * Path + // * Method call + // * Field expression + // * Function calls, array indexing + // * `?` + check_in_place_assist( + r#" +//- minicore: option +fn f1(v: i32) {} +fn f2(v: &i32) {} +trait T { + fn do_stuff(self) {} +} +impl T for i32 { + fn do_stuff(self) {} +} +impl T for &i32 { + fn do_stuff(self) {} +} +struct S4 { + value: i32, +} + +fn foo() -> Option<()> { + let $0t = &(0, (1,"1"), Some(2), [3;3], S4 { value: 4 }, &5); + let v: i32 = t.0; // deref, no parens + let v: &i32 = &t.0; // no deref, no parens, remove `&` + f1(t.0); // deref, no parens + f2(&t.0); // `&*` -> cancel out -> no deref, no parens + // https://github.com/rust-lang/rust-analyzer/issues/1109#issuecomment-658868639 + // let v: i32 = t.1.0; // no deref, no parens + let v: i32 = t.4.value; // no deref, no parens + t.0.do_stuff(); // deref, parens + let v: i32 = t.2?; // deref, parens + let v: i32 = t.3[0]; // no deref, no parens + (t.0).do_stuff(); // deref, no additional parens + let v: i32 = *t.5; // deref (-> 2), no parens + + None +} + "#, + r#" +fn f1(v: i32) {} +fn f2(v: &i32) {} +trait T { + fn do_stuff(self) {} +} +impl T for i32 { + fn do_stuff(self) {} +} +impl T for &i32 { + fn do_stuff(self) {} +} +struct S4 { + value: i32, +} + +fn foo() -> Option<()> { + let ($0_0, _1, _2, _3, _4, _5) = &(0, (1,"1"), Some(2), [3;3], S4 { value: 4 }, &5); + let v: i32 = *_0; // deref, no parens + let v: &i32 = _0; // no deref, no parens, remove `&` + f1(*_0); // deref, no parens + f2(_0); // `&*` -> cancel out -> no deref, no parens + // https://github.com/rust-lang/rust-analyzer/issues/1109#issuecomment-658868639 + // let v: i32 = t.1.0; // no deref, no parens + let v: i32 = _4.value; // no deref, no parens + (*_0).do_stuff(); // deref, parens + let v: i32 = (*_2)?; // deref, parens + let v: i32 = _3[0]; // no deref, no parens + (*_0).do_stuff(); // deref, no additional parens + let v: i32 = **_5; // deref (-> 2), no parens + + None +} + "#, + ) + } + + // --------- + // auto-ref/deref + + #[test] + fn self_auto_ref_doesnt_need_deref() { + check_in_place_assist( + r#" +#[derive(Clone, Copy)] +struct S; +impl S { + fn f(&self) {} +} + +fn main() { + let $0t = &(S,2); + let s = t.0.f(); +} + "#, + r#" +#[derive(Clone, Copy)] +struct S; +impl S { + fn f(&self) {} +} + +fn main() { + let ($0_0, _1) = &(S,2); + let s = _0.f(); +} + "#, + ) + } + + #[test] + fn self_owned_requires_deref() { + check_in_place_assist( + r#" +#[derive(Clone, Copy)] +struct S; +impl S { + fn f(self) {} +} + +fn main() { + let $0t = &(S,2); + let s = t.0.f(); +} + "#, + r#" +#[derive(Clone, Copy)] +struct S; +impl S { + fn f(self) {} +} + +fn main() { + let ($0_0, _1) = &(S,2); + let s = (*_0).f(); +} + "#, + ) + } + + #[test] + fn self_auto_ref_in_trait_call_doesnt_require_deref() { + check_in_place_assist( + r#" +trait T { + fn f(self); +} +#[derive(Clone, Copy)] +struct S; +impl T for &S { + fn f(self) {} +} + +fn main() { + let $0t = &(S,2); + let s = t.0.f(); +} + "#, + // FIXME: doesn't need deref * parens. But `ctx.sema.resolve_method_call` doesn't resolve trait implementations + r#" +trait T { + fn f(self); +} +#[derive(Clone, Copy)] +struct S; +impl T for &S { + fn f(self) {} +} + +fn main() { + let ($0_0, _1) = &(S,2); + let s = (*_0).f(); +} + "#, + ) + } + #[test] + fn no_auto_deref_because_of_owned_and_ref_trait_impl() { + check_in_place_assist( + r#" +trait T { + fn f(self); +} +#[derive(Clone, Copy)] +struct S; +impl T for S { + fn f(self) {} +} +impl T for &S { + fn f(self) {} +} + +fn main() { + let $0t = &(S,2); + let s = t.0.f(); +} + "#, + r#" +trait T { + fn f(self); +} +#[derive(Clone, Copy)] +struct S; +impl T for S { + fn f(self) {} +} +impl T for &S { + fn f(self) {} +} + +fn main() { + let ($0_0, _1) = &(S,2); + let s = (*_0).f(); +} + "#, + ) + } + + #[test] + fn no_outer_parens_when_ref_deref() { + check_in_place_assist( + r#" +#[derive(Clone, Copy)] +struct S; +impl S { + fn do_stuff(&self) -> i32 { 42 } +} +fn main() { + let $0t = &(S,&S); + let v = (&t.0).do_stuff(); +} + "#, + r#" +#[derive(Clone, Copy)] +struct S; +impl S { + fn do_stuff(&self) -> i32 { 42 } +} +fn main() { + let ($0_0, _1) = &(S,&S); + let v = _0.do_stuff(); +} + "#, + ) + } + + #[test] + fn auto_ref_deref() { + check_in_place_assist( + r#" +#[derive(Clone, Copy)] +struct S; +impl S { + fn do_stuff(&self) -> i32 { 42 } +} +fn main() { + let $0t = &(S,&S); + let v = (&t.0).do_stuff(); // no deref, remove parens + // `t.0` gets auto-refed -> no deref needed -> no parens + let v = t.0.do_stuff(); // no deref, no parens + let v = &t.0.do_stuff(); // `&` is for result -> no deref, no parens + // deref: `_1` is `&&S`, but method called is on `&S` -> there might be a method accepting `&&S` + let v = t.1.do_stuff(); // deref, parens +} + "#, + r#" +#[derive(Clone, Copy)] +struct S; +impl S { + fn do_stuff(&self) -> i32 { 42 } +} +fn main() { + let ($0_0, _1) = &(S,&S); + let v = _0.do_stuff(); // no deref, remove parens + // `t.0` gets auto-refed -> no deref needed -> no parens + let v = _0.do_stuff(); // no deref, no parens + let v = &_0.do_stuff(); // `&` is for result -> no deref, no parens + // deref: `_1` is `&&S`, but method called is on `&S` -> there might be a method accepting `&&S` + let v = (*_1).do_stuff(); // deref, parens +} + "#, + ) + } + + #[test] + fn mutable() { + check_in_place_assist( + r#" +fn f_owned(v: i32) {} +fn f(v: &i32) {} +fn f_mut(v: &mut i32) { *v = 42; } + +fn main() { + let $0t = &mut (1,2); + let v = t.0; + t.0 = 42; + f_owned(t.0); + f(&t.0); + f_mut(&mut t.0); +} + "#, + r#" +fn f_owned(v: i32) {} +fn f(v: &i32) {} +fn f_mut(v: &mut i32) { *v = 42; } + +fn main() { + let ($0_0, _1) = &mut (1,2); + let v = *_0; + *_0 = 42; + f_owned(*_0); + f(_0); + f_mut(_0); +} + "#, + ) + } + + #[test] + fn with_ref_keyword() { + check_in_place_assist( + r#" +fn f_owned(v: i32) {} +fn f(v: &i32) {} + +fn main() { + let ref $0t = (1,2); + let v = t.0; + f_owned(t.0); + f(&t.0); +} + "#, + r#" +fn f_owned(v: i32) {} +fn f(v: &i32) {} + +fn main() { + let (ref $0_0, ref _1) = (1,2); + let v = *_0; + f_owned(*_0); + f(_0); +} + "#, + ) + } + #[test] + fn with_ref_mut_keywords() { + check_in_place_assist( + r#" +fn f_owned(v: i32) {} +fn f(v: &i32) {} +fn f_mut(v: &mut i32) { *v = 42; } + +fn main() { + let ref mut $0t = (1,2); + let v = t.0; + t.0 = 42; + f_owned(t.0); + f(&t.0); + f_mut(&mut t.0); +} + "#, + r#" +fn f_owned(v: i32) {} +fn f(v: &i32) {} +fn f_mut(v: &mut i32) { *v = 42; } + +fn main() { + let (ref mut $0_0, ref mut _1) = (1,2); + let v = *_0; + *_0 = 42; + f_owned(*_0); + f(_0); + f_mut(_0); +} + "#, + ) + } + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/expand_glob_import.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/expand_glob_import.rs new file mode 100644 index 000000000..87f5018fb --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/expand_glob_import.rs @@ -0,0 +1,900 @@ +use either::Either; +use hir::{AssocItem, HasVisibility, Module, ModuleDef, Name, PathResolution, ScopeDef}; +use ide_db::{ + defs::{Definition, NameRefClass}, + search::SearchScope, +}; +use stdx::never; +use syntax::{ + ast::{self, make}, + ted, AstNode, Direction, SyntaxNode, SyntaxToken, T, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, AssistKind, +}; + +// Assist: expand_glob_import +// +// Expands glob imports. +// +// ``` +// mod foo { +// pub struct Bar; +// pub struct Baz; +// } +// +// use foo::*$0; +// +// fn qux(bar: Bar, baz: Baz) {} +// ``` +// -> +// ``` +// mod foo { +// pub struct Bar; +// pub struct Baz; +// } +// +// use foo::{Bar, Baz}; +// +// fn qux(bar: Bar, baz: Baz) {} +// ``` +pub(crate) fn expand_glob_import(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let star = ctx.find_token_syntax_at_offset(T![*])?; + let use_tree = star.parent().and_then(ast::UseTree::cast)?; + let (parent, mod_path) = find_parent_and_path(&star)?; + let target_module = match ctx.sema.resolve_path(&mod_path)? { + PathResolution::Def(ModuleDef::Module(it)) => it, + _ => return None, + }; + + let current_scope = ctx.sema.scope(&star.parent()?)?; + let current_module = current_scope.module(); + + let refs_in_target = find_refs_in_mod(ctx, target_module, current_module)?; + let imported_defs = find_imported_defs(ctx, star)?; + + let target = parent.either(|n| n.syntax().clone(), |n| n.syntax().clone()); + acc.add( + AssistId("expand_glob_import", AssistKind::RefactorRewrite), + "Expand glob import", + target.text_range(), + |builder| { + let use_tree = builder.make_mut(use_tree); + + let names_to_import = find_names_to_import(ctx, refs_in_target, imported_defs); + let expanded = make::use_tree_list(names_to_import.iter().map(|n| { + let path = make::ext::ident_path(&n.to_string()); + make::use_tree(path, None, None, false) + })) + .clone_for_update(); + + match use_tree.star_token() { + Some(star) => { + let needs_braces = use_tree.path().is_some() && names_to_import.len() != 1; + if needs_braces { + ted::replace(star, expanded.syntax()) + } else { + let without_braces = expanded + .syntax() + .children_with_tokens() + .filter(|child| !matches!(child.kind(), T!['{'] | T!['}'])) + .collect(); + ted::replace_with_many(star, without_braces) + } + } + None => never!(), + } + }, + ) +} + +fn find_parent_and_path( + star: &SyntaxToken, +) -> Option<(Either, ast::Path)> { + return star.parent_ancestors().find_map(|n| { + find_use_tree_list(n.clone()) + .map(|(u, p)| (Either::Right(u), p)) + .or_else(|| find_use_tree(n).map(|(u, p)| (Either::Left(u), p))) + }); + + fn find_use_tree_list(n: SyntaxNode) -> Option<(ast::UseTreeList, ast::Path)> { + let use_tree_list = ast::UseTreeList::cast(n)?; + let path = use_tree_list.parent_use_tree().path()?; + Some((use_tree_list, path)) + } + + fn find_use_tree(n: SyntaxNode) -> Option<(ast::UseTree, ast::Path)> { + let use_tree = ast::UseTree::cast(n)?; + let path = use_tree.path()?; + Some((use_tree, path)) + } +} + +fn def_is_referenced_in(def: Definition, ctx: &AssistContext<'_>) -> bool { + let search_scope = SearchScope::single_file(ctx.file_id()); + def.usages(&ctx.sema).in_scope(search_scope).at_least_one() +} + +#[derive(Debug, Clone)] +struct Ref { + // could be alias + visible_name: Name, + def: Definition, +} + +impl Ref { + fn from_scope_def(name: Name, scope_def: ScopeDef) -> Option { + match scope_def { + ScopeDef::ModuleDef(def) => { + Some(Ref { visible_name: name, def: Definition::from(def) }) + } + _ => None, + } + } +} + +#[derive(Debug, Clone)] +struct Refs(Vec); + +impl Refs { + fn used_refs(&self, ctx: &AssistContext<'_>) -> Refs { + Refs( + self.0 + .clone() + .into_iter() + .filter(|r| { + if let Definition::Trait(tr) = r.def { + if tr.items(ctx.db()).into_iter().any(|ai| { + if let AssocItem::Function(f) = ai { + def_is_referenced_in(Definition::Function(f), ctx) + } else { + false + } + }) { + return true; + } + } + + def_is_referenced_in(r.def, ctx) + }) + .collect(), + ) + } + + fn filter_out_by_defs(&self, defs: Vec) -> Refs { + Refs(self.0.clone().into_iter().filter(|r| !defs.contains(&r.def)).collect()) + } +} + +fn find_refs_in_mod(ctx: &AssistContext<'_>, module: Module, visible_from: Module) -> Option { + if !is_mod_visible_from(ctx, module, visible_from) { + return None; + } + + let module_scope = module.scope(ctx.db(), Some(visible_from)); + let refs = module_scope.into_iter().filter_map(|(n, d)| Ref::from_scope_def(n, d)).collect(); + Some(Refs(refs)) +} + +fn is_mod_visible_from(ctx: &AssistContext<'_>, module: Module, from: Module) -> bool { + match module.parent(ctx.db()) { + Some(parent) => { + module.visibility(ctx.db()).is_visible_from(ctx.db(), from.into()) + && is_mod_visible_from(ctx, parent, from) + } + None => true, + } +} + +// looks for name refs in parent use block's siblings +// +// mod bar { +// mod qux { +// struct Qux; +// } +// +// pub use qux::Qux; +// } +// +// ↓ --------------- +// use foo::*$0; +// use baz::Baz; +// ↑ --------------- +fn find_imported_defs(ctx: &AssistContext<'_>, star: SyntaxToken) -> Option> { + let parent_use_item_syntax = star.parent_ancestors().find_map(|n| { + if ast::Use::can_cast(n.kind()) { + Some(n) + } else { + None + } + })?; + + Some( + [Direction::Prev, Direction::Next] + .into_iter() + .flat_map(|dir| { + parent_use_item_syntax + .siblings(dir.to_owned()) + .filter(|n| ast::Use::can_cast(n.kind())) + }) + .flat_map(|n| n.descendants().filter_map(ast::NameRef::cast)) + .filter_map(|r| match NameRefClass::classify(&ctx.sema, &r)? { + NameRefClass::Definition( + def @ (Definition::Macro(_) + | Definition::Module(_) + | Definition::Function(_) + | Definition::Adt(_) + | Definition::Variant(_) + | Definition::Const(_) + | Definition::Static(_) + | Definition::Trait(_) + | Definition::TypeAlias(_)), + ) => Some(def), + _ => None, + }) + .collect(), + ) +} + +fn find_names_to_import( + ctx: &AssistContext<'_>, + refs_in_target: Refs, + imported_defs: Vec, +) -> Vec { + let used_refs = refs_in_target.used_refs(ctx).filter_out_by_defs(imported_defs); + used_refs.0.iter().map(|r| r.visible_name.clone()).collect() +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn expanding_glob_import() { + check_assist( + expand_glob_import, + r" +mod foo { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} +} + +use foo::*$0; + +fn qux(bar: Bar, baz: Baz) { + f(); +} +", + r" +mod foo { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} +} + +use foo::{Bar, Baz, f}; + +fn qux(bar: Bar, baz: Baz) { + f(); +} +", + ) + } + + #[test] + fn expanding_glob_import_unused() { + check_assist( + expand_glob_import, + r" +mod foo { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} +} + +use foo::*$0; + +fn qux() {} +", + r" +mod foo { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} +} + +use foo::{}; + +fn qux() {} +", + ) + } + + #[test] + fn expanding_glob_import_with_existing_explicit_names() { + check_assist( + expand_glob_import, + r" +mod foo { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} +} + +use foo::{*$0, f}; + +fn qux(bar: Bar, baz: Baz) { + f(); +} +", + r" +mod foo { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} +} + +use foo::{Bar, Baz, f}; + +fn qux(bar: Bar, baz: Baz) { + f(); +} +", + ) + } + + #[test] + fn expanding_glob_import_with_existing_uses_in_same_module() { + check_assist( + expand_glob_import, + r" +mod foo { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} +} + +use foo::Bar; +use foo::{*$0, f}; + +fn qux(bar: Bar, baz: Baz) { + f(); +} +", + r" +mod foo { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} +} + +use foo::Bar; +use foo::{Baz, f}; + +fn qux(bar: Bar, baz: Baz) { + f(); +} +", + ) + } + + #[test] + fn expanding_nested_glob_import() { + check_assist( + expand_glob_import, + r" +mod foo { + pub mod bar { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} + } + + pub mod baz { + pub fn g() {} + } +} + +use foo::{bar::{*$0, f}, baz::*}; + +fn qux(bar: Bar, baz: Baz) { + f(); + g(); +} +", + r" +mod foo { + pub mod bar { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} + } + + pub mod baz { + pub fn g() {} + } +} + +use foo::{bar::{Bar, Baz, f}, baz::*}; + +fn qux(bar: Bar, baz: Baz) { + f(); + g(); +} +", + ); + + check_assist( + expand_glob_import, + r" +mod foo { + pub mod bar { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} + } + + pub mod baz { + pub fn g() {} + } +} + +use foo::{bar::{Bar, Baz, f}, baz::*$0}; + +fn qux(bar: Bar, baz: Baz) { + f(); + g(); +} +", + r" +mod foo { + pub mod bar { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} + } + + pub mod baz { + pub fn g() {} + } +} + +use foo::{bar::{Bar, Baz, f}, baz::g}; + +fn qux(bar: Bar, baz: Baz) { + f(); + g(); +} +", + ); + + check_assist( + expand_glob_import, + r" +mod foo { + pub mod bar { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} + } + + pub mod baz { + pub fn g() {} + + pub mod qux { + pub fn h() {} + pub fn m() {} + + pub mod q { + pub fn j() {} + } + } + } +} + +use foo::{ + bar::{*, f}, + baz::{g, qux::*$0} +}; + +fn qux(bar: Bar, baz: Baz) { + f(); + g(); + h(); + q::j(); +} +", + r" +mod foo { + pub mod bar { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} + } + + pub mod baz { + pub fn g() {} + + pub mod qux { + pub fn h() {} + pub fn m() {} + + pub mod q { + pub fn j() {} + } + } + } +} + +use foo::{ + bar::{*, f}, + baz::{g, qux::{h, q}} +}; + +fn qux(bar: Bar, baz: Baz) { + f(); + g(); + h(); + q::j(); +} +", + ); + + check_assist( + expand_glob_import, + r" +mod foo { + pub mod bar { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} + } + + pub mod baz { + pub fn g() {} + + pub mod qux { + pub fn h() {} + pub fn m() {} + + pub mod q { + pub fn j() {} + } + } + } +} + +use foo::{ + bar::{*, f}, + baz::{g, qux::{h, q::*$0}} +}; + +fn qux(bar: Bar, baz: Baz) { + f(); + g(); + h(); + j(); +} +", + r" +mod foo { + pub mod bar { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} + } + + pub mod baz { + pub fn g() {} + + pub mod qux { + pub fn h() {} + pub fn m() {} + + pub mod q { + pub fn j() {} + } + } + } +} + +use foo::{ + bar::{*, f}, + baz::{g, qux::{h, q::j}} +}; + +fn qux(bar: Bar, baz: Baz) { + f(); + g(); + h(); + j(); +} +", + ); + + check_assist( + expand_glob_import, + r" +mod foo { + pub mod bar { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} + } + + pub mod baz { + pub fn g() {} + + pub mod qux { + pub fn h() {} + pub fn m() {} + + pub mod q { + pub fn j() {} + } + } + } +} + +use foo::{ + bar::{*, f}, + baz::{g, qux::{q::j, *$0}} +}; + +fn qux(bar: Bar, baz: Baz) { + f(); + g(); + h(); + j(); +} +", + r" +mod foo { + pub mod bar { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} + } + + pub mod baz { + pub fn g() {} + + pub mod qux { + pub fn h() {} + pub fn m() {} + + pub mod q { + pub fn j() {} + } + } + } +} + +use foo::{ + bar::{*, f}, + baz::{g, qux::{q::j, h}} +}; + +fn qux(bar: Bar, baz: Baz) { + f(); + g(); + h(); + j(); +} +", + ); + } + + #[test] + fn expanding_glob_import_with_macro_defs() { + check_assist( + expand_glob_import, + r#" +//- /lib.rs crate:foo +#[macro_export] +macro_rules! bar { + () => () +} + +pub fn baz() {} + +//- /main.rs crate:main deps:foo +use foo::*$0; + +fn main() { + bar!(); + baz(); +} +"#, + r#" +use foo::{bar, baz}; + +fn main() { + bar!(); + baz(); +} +"#, + ); + } + + #[test] + fn expanding_glob_import_with_trait_method_uses() { + check_assist( + expand_glob_import, + r" +//- /lib.rs crate:foo +pub trait Tr { + fn method(&self) {} +} +impl Tr for () {} + +//- /main.rs crate:main deps:foo +use foo::*$0; + +fn main() { + ().method(); +} +", + r" +use foo::Tr; + +fn main() { + ().method(); +} +", + ); + + check_assist( + expand_glob_import, + r" +//- /lib.rs crate:foo +pub trait Tr { + fn method(&self) {} +} +impl Tr for () {} + +pub trait Tr2 { + fn method2(&self) {} +} +impl Tr2 for () {} + +//- /main.rs crate:main deps:foo +use foo::*$0; + +fn main() { + ().method(); +} +", + r" +use foo::Tr; + +fn main() { + ().method(); +} +", + ); + } + + #[test] + fn expanding_is_not_applicable_if_target_module_is_not_accessible_from_current_scope() { + check_assist_not_applicable( + expand_glob_import, + r" +mod foo { + mod bar { + pub struct Bar; + } +} + +use foo::bar::*$0; + +fn baz(bar: Bar) {} +", + ); + + check_assist_not_applicable( + expand_glob_import, + r" +mod foo { + mod bar { + pub mod baz { + pub struct Baz; + } + } +} + +use foo::bar::baz::*$0; + +fn qux(baz: Baz) {} +", + ); + } + + #[test] + fn expanding_is_not_applicable_if_cursor_is_not_in_star_token() { + check_assist_not_applicable( + expand_glob_import, + r" + mod foo { + pub struct Bar; + pub struct Baz; + pub struct Qux; + } + + use foo::Bar$0; + + fn qux(bar: Bar, baz: Baz) {} + ", + ) + } + + #[test] + fn expanding_glob_import_single_nested_glob_only() { + check_assist( + expand_glob_import, + r" +mod foo { + pub struct Bar; +} + +use foo::{*$0}; + +struct Baz { + bar: Bar +} +", + r" +mod foo { + pub struct Bar; +} + +use foo::{Bar}; + +struct Baz { + bar: Bar +} +", + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_function.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_function.rs new file mode 100644 index 000000000..52a55ead3 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_function.rs @@ -0,0 +1,5333 @@ +use std::iter; + +use ast::make; +use either::Either; +use hir::{ + HasSource, HirDisplay, InFile, Local, ModuleDef, PathResolution, Semantics, TypeInfo, TypeParam, +}; +use ide_db::{ + defs::{Definition, NameRefClass}, + famous_defs::FamousDefs, + helpers::mod_path_to_ast, + imports::insert_use::{insert_use, ImportScope}, + search::{FileReference, ReferenceCategory, SearchScope}, + syntax_helpers::node_ext::{preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr}, + FxIndexSet, RootDatabase, +}; +use itertools::Itertools; +use stdx::format_to; +use syntax::{ + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + AstNode, HasGenericParams, + }, + match_ast, ted, SyntaxElement, + SyntaxKind::{self, COMMENT}, + SyntaxNode, SyntaxToken, TextRange, TextSize, TokenAtOffset, WalkEvent, T, +}; + +use crate::{ + assist_context::{AssistContext, Assists, TreeMutator}, + utils::generate_impl_text, + AssistId, +}; + +// Assist: extract_function +// +// Extracts selected statements and comments into new function. +// +// ``` +// fn main() { +// let n = 1; +// $0let m = n + 2; +// // calculate +// let k = m + n;$0 +// let g = 3; +// } +// ``` +// -> +// ``` +// fn main() { +// let n = 1; +// fun_name(n); +// let g = 3; +// } +// +// fn $0fun_name(n: i32) { +// let m = n + 2; +// // calculate +// let k = m + n; +// } +// ``` +pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let range = ctx.selection_trimmed(); + if range.is_empty() { + return None; + } + + let node = ctx.covering_element(); + if node.kind() == COMMENT { + cov_mark::hit!(extract_function_in_comment_is_not_applicable); + return None; + } + + let node = match node { + syntax::NodeOrToken::Node(n) => n, + syntax::NodeOrToken::Token(t) => t.parent()?, + }; + + let body = extraction_target(&node, range)?; + let container_info = body.analyze_container(&ctx.sema)?; + + let (locals_used, self_param) = body.analyze(&ctx.sema); + + let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding }; + let insert_after = node_to_insert_after(&body, anchor)?; + let semantics_scope = ctx.sema.scope(&insert_after)?; + let module = semantics_scope.module(); + + let ret_ty = body.return_ty(ctx)?; + let control_flow = body.external_control_flow(ctx, &container_info)?; + let ret_values = body.ret_values(ctx, node.parent().as_ref().unwrap_or(&node)); + + let target_range = body.text_range(); + + let scope = ImportScope::find_insert_use_container(&node, &ctx.sema)?; + + acc.add( + AssistId("extract_function", crate::AssistKind::RefactorExtract), + "Extract into function", + target_range, + move |builder| { + let outliving_locals: Vec<_> = ret_values.collect(); + if stdx::never!(!outliving_locals.is_empty() && !ret_ty.is_unit()) { + // We should not have variables that outlive body if we have expression block + return; + } + + let params = + body.extracted_function_params(ctx, &container_info, locals_used.iter().copied()); + + let extracted_from_trait_impl = body.extracted_from_trait_impl(); + + let name = make_function_name(&semantics_scope); + + let fun = Function { + name, + self_param, + params, + control_flow, + ret_ty, + body, + outliving_locals, + mods: container_info, + }; + + let new_indent = IndentLevel::from_node(&insert_after); + let old_indent = fun.body.indent_level(); + + builder.replace(target_range, make_call(ctx, &fun, old_indent)); + + let fn_def = match fun.self_param_adt(ctx) { + Some(adt) if extracted_from_trait_impl => { + let fn_def = format_function(ctx, module, &fun, old_indent, new_indent + 1); + generate_impl_text(&adt, &fn_def).replace("{\n\n", "{") + } + _ => format_function(ctx, module, &fun, old_indent, new_indent), + }; + + if fn_def.contains("ControlFlow") { + let scope = match scope { + ImportScope::File(it) => ImportScope::File(builder.make_mut(it)), + ImportScope::Module(it) => ImportScope::Module(builder.make_mut(it)), + ImportScope::Block(it) => ImportScope::Block(builder.make_mut(it)), + }; + + let control_flow_enum = + FamousDefs(&ctx.sema, module.krate()).core_ops_ControlFlow(); + + if let Some(control_flow_enum) = control_flow_enum { + let mod_path = module.find_use_path_prefixed( + ctx.sema.db, + ModuleDef::from(control_flow_enum), + ctx.config.insert_use.prefix_kind, + ); + + if let Some(mod_path) = mod_path { + insert_use(&scope, mod_path_to_ast(&mod_path), &ctx.config.insert_use); + } + } + } + + let insert_offset = insert_after.text_range().end(); + + match ctx.config.snippet_cap { + Some(cap) => builder.insert_snippet(cap, insert_offset, fn_def), + None => builder.insert(insert_offset, fn_def), + }; + }, + ) +} + +fn make_function_name(semantics_scope: &hir::SemanticsScope<'_>) -> ast::NameRef { + let mut names_in_scope = vec![]; + semantics_scope.process_all_names(&mut |name, _| names_in_scope.push(name.to_string())); + + let default_name = "fun_name"; + + let mut name = default_name.to_string(); + let mut counter = 0; + while names_in_scope.contains(&name) { + counter += 1; + name = format!("{}{}", &default_name, counter) + } + make::name_ref(&name) +} + +/// Try to guess what user wants to extract +/// +/// We have basically have two cases: +/// * We want whole node, like `loop {}`, `2 + 2`, `{ let n = 1; }` exprs. +/// Then we can use `ast::Expr` +/// * We want a few statements for a block. E.g. +/// ```rust,no_run +/// fn foo() -> i32 { +/// let m = 1; +/// $0 +/// let n = 2; +/// let k = 3; +/// k + n +/// $0 +/// } +/// ``` +/// +fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option { + if let Some(stmt) = ast::Stmt::cast(node.clone()) { + return match stmt { + ast::Stmt::Item(_) => None, + ast::Stmt::ExprStmt(_) | ast::Stmt::LetStmt(_) => Some(FunctionBody::from_range( + node.parent().and_then(ast::StmtList::cast)?, + node.text_range(), + )), + }; + } + + // Covering element returned the parent block of one or multiple statements that have been selected + if let Some(stmt_list) = ast::StmtList::cast(node.clone()) { + if let Some(block_expr) = stmt_list.syntax().parent().and_then(ast::BlockExpr::cast) { + if block_expr.syntax().text_range() == selection_range { + return FunctionBody::from_expr(block_expr.into()); + } + } + + // Extract the full statements. + return Some(FunctionBody::from_range(stmt_list, selection_range)); + } + + let expr = ast::Expr::cast(node.clone())?; + // A node got selected fully + if node.text_range() == selection_range { + return FunctionBody::from_expr(expr); + } + + node.ancestors().find_map(ast::Expr::cast).and_then(FunctionBody::from_expr) +} + +#[derive(Debug)] +struct Function { + name: ast::NameRef, + self_param: Option, + params: Vec, + control_flow: ControlFlow, + ret_ty: RetType, + body: FunctionBody, + outliving_locals: Vec, + mods: ContainerInfo, +} + +#[derive(Debug)] +struct Param { + var: Local, + ty: hir::Type, + move_local: bool, + requires_mut: bool, + is_copy: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ParamKind { + Value, + MutValue, + SharedRef, + MutRef, +} + +#[derive(Debug, Eq, PartialEq)] +enum FunType { + Unit, + Single(hir::Type), + Tuple(Vec), +} + +/// Where to put extracted function definition +#[derive(Debug)] +enum Anchor { + /// Extract free function and put right after current top-level function + Freestanding, + /// Extract method and put right after current function in the impl-block + Method, +} + +// FIXME: ControlFlow and ContainerInfo both track some function modifiers, feels like these two should +// probably be merged somehow. +#[derive(Debug)] +struct ControlFlow { + kind: Option, + is_async: bool, + is_unsafe: bool, +} + +/// The thing whose expression we are extracting from. Can be a function, const, static, const arg, ... +#[derive(Clone, Debug)] +struct ContainerInfo { + is_const: bool, + is_in_tail: bool, + parent_loop: Option, + /// The function's return type, const's type etc. + ret_type: Option, + generic_param_lists: Vec, + where_clauses: Vec, +} + +/// Control flow that is exported from extracted function +/// +/// E.g.: +/// ```rust,no_run +/// loop { +/// $0 +/// if 42 == 42 { +/// break; +/// } +/// $0 +/// } +/// ``` +#[derive(Debug, Clone)] +enum FlowKind { + /// Return with value (`return $expr;`) + Return(Option), + Try { + kind: TryKind, + }, + /// Break with label and value (`break 'label $expr;`) + Break(Option, Option), + /// Continue with label (`continue 'label;`) + Continue(Option), +} + +#[derive(Debug, Clone)] +enum TryKind { + Option, + Result { ty: hir::Type }, +} + +#[derive(Debug)] +enum RetType { + Expr(hir::Type), + Stmt, +} + +impl RetType { + fn is_unit(&self) -> bool { + match self { + RetType::Expr(ty) => ty.is_unit(), + RetType::Stmt => true, + } + } +} + +/// Semantically same as `ast::Expr`, but preserves identity when using only part of the Block +/// This is the future function body, the part that is being extracted. +#[derive(Debug)] +enum FunctionBody { + Expr(ast::Expr), + Span { parent: ast::StmtList, text_range: TextRange }, +} + +#[derive(Debug)] +struct OutlivedLocal { + local: Local, + mut_usage_outside_body: bool, +} + +/// Container of local variable usages +/// +/// Semanticall same as `UsageSearchResult`, but provides more convenient interface +struct LocalUsages(ide_db::search::UsageSearchResult); + +impl LocalUsages { + fn find_local_usages(ctx: &AssistContext<'_>, var: Local) -> Self { + Self( + Definition::Local(var) + .usages(&ctx.sema) + .in_scope(SearchScope::single_file(ctx.file_id())) + .all(), + ) + } + + fn iter(&self) -> impl Iterator + '_ { + self.0.iter().flat_map(|(_, rs)| rs) + } +} + +impl Function { + fn return_type(&self, ctx: &AssistContext<'_>) -> FunType { + match &self.ret_ty { + RetType::Expr(ty) if ty.is_unit() => FunType::Unit, + RetType::Expr(ty) => FunType::Single(ty.clone()), + RetType::Stmt => match self.outliving_locals.as_slice() { + [] => FunType::Unit, + [var] => FunType::Single(var.local.ty(ctx.db())), + vars => { + let types = vars.iter().map(|v| v.local.ty(ctx.db())).collect(); + FunType::Tuple(types) + } + }, + } + } + + fn self_param_adt(&self, ctx: &AssistContext<'_>) -> Option { + let self_param = self.self_param.as_ref()?; + let def = ctx.sema.to_def(self_param)?; + let adt = def.ty(ctx.db()).strip_references().as_adt()?; + let InFile { file_id: _, value } = adt.source(ctx.db())?; + Some(value) + } +} + +impl ParamKind { + fn is_ref(&self) -> bool { + matches!(self, ParamKind::SharedRef | ParamKind::MutRef) + } +} + +impl Param { + fn kind(&self) -> ParamKind { + match (self.move_local, self.requires_mut, self.is_copy) { + (false, true, _) => ParamKind::MutRef, + (false, false, false) => ParamKind::SharedRef, + (true, true, _) => ParamKind::MutValue, + (_, false, _) => ParamKind::Value, + } + } + + fn to_arg(&self, ctx: &AssistContext<'_>) -> ast::Expr { + let var = path_expr_from_local(ctx, self.var); + match self.kind() { + ParamKind::Value | ParamKind::MutValue => var, + ParamKind::SharedRef => make::expr_ref(var, false), + ParamKind::MutRef => make::expr_ref(var, true), + } + } + + fn to_param(&self, ctx: &AssistContext<'_>, module: hir::Module) -> ast::Param { + let var = self.var.name(ctx.db()).to_string(); + let var_name = make::name(&var); + let pat = match self.kind() { + ParamKind::MutValue => make::ident_pat(false, true, var_name), + ParamKind::Value | ParamKind::SharedRef | ParamKind::MutRef => { + make::ext::simple_ident_pat(var_name) + } + }; + + let ty = make_ty(&self.ty, ctx, module); + let ty = match self.kind() { + ParamKind::Value | ParamKind::MutValue => ty, + ParamKind::SharedRef => make::ty_ref(ty, false), + ParamKind::MutRef => make::ty_ref(ty, true), + }; + + make::param(pat.into(), ty) + } +} + +impl TryKind { + fn of_ty(ty: hir::Type, ctx: &AssistContext<'_>) -> Option { + if ty.is_unknown() { + // We favour Result for `expr?` + return Some(TryKind::Result { ty }); + } + let adt = ty.as_adt()?; + let name = adt.name(ctx.db()); + // FIXME: use lang items to determine if it is std type or user defined + // E.g. if user happens to define type named `Option`, we would have false positive + match name.to_string().as_str() { + "Option" => Some(TryKind::Option), + "Result" => Some(TryKind::Result { ty }), + _ => None, + } + } +} + +impl FlowKind { + fn make_result_handler(&self, expr: Option) -> ast::Expr { + match self { + FlowKind::Return(_) => make::expr_return(expr), + FlowKind::Break(label, _) => make::expr_break(label.clone(), expr), + FlowKind::Try { .. } => { + stdx::never!("cannot have result handler with try"); + expr.unwrap_or_else(|| make::expr_return(None)) + } + FlowKind::Continue(label) => { + stdx::always!(expr.is_none(), "continue with value is not possible"); + make::expr_continue(label.clone()) + } + } + } + + fn expr_ty(&self, ctx: &AssistContext<'_>) -> Option { + match self { + FlowKind::Return(Some(expr)) | FlowKind::Break(_, Some(expr)) => { + ctx.sema.type_of_expr(expr).map(TypeInfo::adjusted) + } + FlowKind::Try { .. } => { + stdx::never!("try does not have defined expr_ty"); + None + } + _ => None, + } + } +} + +impl FunctionBody { + fn parent(&self) -> Option { + match self { + FunctionBody::Expr(expr) => expr.syntax().parent(), + FunctionBody::Span { parent, .. } => Some(parent.syntax().clone()), + } + } + + fn node(&self) -> &SyntaxNode { + match self { + FunctionBody::Expr(e) => e.syntax(), + FunctionBody::Span { parent, .. } => parent.syntax(), + } + } + + fn extracted_from_trait_impl(&self) -> bool { + match self.node().ancestors().find_map(ast::Impl::cast) { + Some(c) => return c.trait_().is_some(), + None => false, + } + } + + fn descendants(&self) -> impl Iterator { + match self { + FunctionBody::Expr(expr) => expr.syntax().descendants(), + FunctionBody::Span { parent, .. } => parent.syntax().descendants(), + } + } + + fn descendant_paths(&self) -> impl Iterator { + self.descendants().filter_map(|node| { + match_ast! { + match node { + ast::Path(it) => Some(it), + _ => None + } + } + }) + } + + fn from_expr(expr: ast::Expr) -> Option { + match expr { + ast::Expr::BreakExpr(it) => it.expr().map(Self::Expr), + ast::Expr::ReturnExpr(it) => it.expr().map(Self::Expr), + ast::Expr::BlockExpr(it) if !it.is_standalone() => None, + expr => Some(Self::Expr(expr)), + } + } + + fn from_range(parent: ast::StmtList, selected: TextRange) -> FunctionBody { + let full_body = parent.syntax().children_with_tokens(); + + let mut text_range = full_body + .filter(|it| ast::Stmt::can_cast(it.kind()) || it.kind() == COMMENT) + .map(|element| element.text_range()) + .filter(|&range| selected.intersect(range).filter(|it| !it.is_empty()).is_some()) + .reduce(|acc, stmt| acc.cover(stmt)); + + if let Some(tail_range) = parent + .tail_expr() + .map(|it| it.syntax().text_range()) + .filter(|&it| selected.intersect(it).is_some()) + { + text_range = Some(match text_range { + Some(text_range) => text_range.cover(tail_range), + None => tail_range, + }); + } + Self::Span { parent, text_range: text_range.unwrap_or(selected) } + } + + fn indent_level(&self) -> IndentLevel { + match &self { + FunctionBody::Expr(expr) => IndentLevel::from_node(expr.syntax()), + FunctionBody::Span { parent, .. } => IndentLevel::from_node(parent.syntax()) + 1, + } + } + + fn tail_expr(&self) -> Option { + match &self { + FunctionBody::Expr(expr) => Some(expr.clone()), + FunctionBody::Span { parent, text_range } => { + let tail_expr = parent.tail_expr()?; + text_range.contains_range(tail_expr.syntax().text_range()).then(|| tail_expr) + } + } + } + + fn walk_expr(&self, cb: &mut dyn FnMut(ast::Expr)) { + match self { + FunctionBody::Expr(expr) => walk_expr(expr, cb), + FunctionBody::Span { parent, text_range } => { + parent + .statements() + .filter(|stmt| text_range.contains_range(stmt.syntax().text_range())) + .filter_map(|stmt| match stmt { + ast::Stmt::ExprStmt(expr_stmt) => expr_stmt.expr(), + ast::Stmt::Item(_) => None, + ast::Stmt::LetStmt(stmt) => stmt.initializer(), + }) + .for_each(|expr| walk_expr(&expr, cb)); + if let Some(expr) = parent + .tail_expr() + .filter(|it| text_range.contains_range(it.syntax().text_range())) + { + walk_expr(&expr, cb); + } + } + } + } + + fn preorder_expr(&self, cb: &mut dyn FnMut(WalkEvent) -> bool) { + match self { + FunctionBody::Expr(expr) => preorder_expr(expr, cb), + FunctionBody::Span { parent, text_range } => { + parent + .statements() + .filter(|stmt| text_range.contains_range(stmt.syntax().text_range())) + .filter_map(|stmt| match stmt { + ast::Stmt::ExprStmt(expr_stmt) => expr_stmt.expr(), + ast::Stmt::Item(_) => None, + ast::Stmt::LetStmt(stmt) => stmt.initializer(), + }) + .for_each(|expr| preorder_expr(&expr, cb)); + if let Some(expr) = parent + .tail_expr() + .filter(|it| text_range.contains_range(it.syntax().text_range())) + { + preorder_expr(&expr, cb); + } + } + } + } + + fn walk_pat(&self, cb: &mut dyn FnMut(ast::Pat)) { + match self { + FunctionBody::Expr(expr) => walk_patterns_in_expr(expr, cb), + FunctionBody::Span { parent, text_range } => { + parent + .statements() + .filter(|stmt| text_range.contains_range(stmt.syntax().text_range())) + .for_each(|stmt| match stmt { + ast::Stmt::ExprStmt(expr_stmt) => { + if let Some(expr) = expr_stmt.expr() { + walk_patterns_in_expr(&expr, cb) + } + } + ast::Stmt::Item(_) => (), + ast::Stmt::LetStmt(stmt) => { + if let Some(pat) = stmt.pat() { + walk_pat(&pat, cb); + } + if let Some(expr) = stmt.initializer() { + walk_patterns_in_expr(&expr, cb); + } + } + }); + if let Some(expr) = parent + .tail_expr() + .filter(|it| text_range.contains_range(it.syntax().text_range())) + { + walk_patterns_in_expr(&expr, cb); + } + } + } + } + + fn text_range(&self) -> TextRange { + match self { + FunctionBody::Expr(expr) => expr.syntax().text_range(), + &FunctionBody::Span { text_range, .. } => text_range, + } + } + + fn contains_range(&self, range: TextRange) -> bool { + self.text_range().contains_range(range) + } + + fn precedes_range(&self, range: TextRange) -> bool { + self.text_range().end() <= range.start() + } + + fn contains_node(&self, node: &SyntaxNode) -> bool { + self.contains_range(node.text_range()) + } +} + +impl FunctionBody { + /// Analyzes a function body, returning the used local variables that are referenced in it as well as + /// whether it contains an await expression. + fn analyze( + &self, + sema: &Semantics<'_, RootDatabase>, + ) -> (FxIndexSet, Option) { + let mut self_param = None; + let mut res = FxIndexSet::default(); + let mut cb = |name_ref: Option<_>| { + let local_ref = + match name_ref.and_then(|name_ref| NameRefClass::classify(sema, &name_ref)) { + Some( + NameRefClass::Definition(Definition::Local(local_ref)) + | NameRefClass::FieldShorthand { local_ref, field_ref: _ }, + ) => local_ref, + _ => return, + }; + let InFile { file_id, value } = local_ref.source(sema.db); + // locals defined inside macros are not relevant to us + if !file_id.is_macro() { + match value { + Either::Right(it) => { + self_param.replace(it); + } + Either::Left(_) => { + res.insert(local_ref); + } + } + } + }; + self.walk_expr(&mut |expr| match expr { + ast::Expr::PathExpr(path_expr) => { + cb(path_expr.path().and_then(|it| it.as_single_name_ref())) + } + ast::Expr::ClosureExpr(closure_expr) => { + if let Some(body) = closure_expr.body() { + body.syntax().descendants().map(ast::NameRef::cast).for_each(|it| cb(it)); + } + } + ast::Expr::MacroExpr(expr) => { + if let Some(tt) = expr.macro_call().and_then(|call| call.token_tree()) { + tt.syntax() + .children_with_tokens() + .flat_map(SyntaxElement::into_token) + .filter(|it| it.kind() == SyntaxKind::IDENT) + .flat_map(|t| sema.descend_into_macros(t)) + .for_each(|t| cb(t.parent().and_then(ast::NameRef::cast))); + } + } + _ => (), + }); + (res, self_param) + } + + fn analyze_container(&self, sema: &Semantics<'_, RootDatabase>) -> Option { + let mut ancestors = self.parent()?.ancestors(); + let infer_expr_opt = |expr| sema.type_of_expr(&expr?).map(TypeInfo::adjusted); + let mut parent_loop = None; + let mut set_parent_loop = |loop_: &dyn ast::HasLoopBody| { + if loop_ + .loop_body() + .map_or(false, |it| it.syntax().text_range().contains_range(self.text_range())) + { + parent_loop.get_or_insert(loop_.syntax().clone()); + } + }; + + let (is_const, expr, ty) = loop { + let anc = ancestors.next()?; + break match_ast! { + match anc { + ast::ClosureExpr(closure) => (false, closure.body(), infer_expr_opt(closure.body())), + ast::BlockExpr(block_expr) => { + let (constness, block) = match block_expr.modifier() { + Some(ast::BlockModifier::Const(_)) => (true, block_expr), + Some(ast::BlockModifier::Try(_)) => (false, block_expr), + Some(ast::BlockModifier::Label(label)) if label.lifetime().is_some() => (false, block_expr), + _ => continue, + }; + let expr = Some(ast::Expr::BlockExpr(block)); + (constness, expr.clone(), infer_expr_opt(expr)) + }, + ast::Fn(fn_) => { + let func = sema.to_def(&fn_)?; + let mut ret_ty = func.ret_type(sema.db); + if func.is_async(sema.db) { + if let Some(async_ret) = func.async_ret_type(sema.db) { + ret_ty = async_ret; + } + } + (fn_.const_token().is_some(), fn_.body().map(ast::Expr::BlockExpr), Some(ret_ty)) + }, + ast::Static(statik) => { + (true, statik.body(), Some(sema.to_def(&statik)?.ty(sema.db))) + }, + ast::ConstArg(ca) => { + (true, ca.expr(), infer_expr_opt(ca.expr())) + }, + ast::Const(konst) => { + (true, konst.body(), Some(sema.to_def(&konst)?.ty(sema.db))) + }, + ast::ConstParam(cp) => { + (true, cp.default_val(), Some(sema.to_def(&cp)?.ty(sema.db))) + }, + ast::ConstBlockPat(cbp) => { + let expr = cbp.block_expr().map(ast::Expr::BlockExpr); + (true, expr.clone(), infer_expr_opt(expr)) + }, + ast::Variant(__) => return None, + ast::Meta(__) => return None, + ast::LoopExpr(it) => { + set_parent_loop(&it); + continue; + }, + ast::ForExpr(it) => { + set_parent_loop(&it); + continue; + }, + ast::WhileExpr(it) => { + set_parent_loop(&it); + continue; + }, + _ => continue, + } + }; + }; + let container_tail = match expr? { + ast::Expr::BlockExpr(block) => block.tail_expr(), + expr => Some(expr), + }; + let is_in_tail = + container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| { + container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range()) + }); + + let parent = self.parent()?; + let parents = generic_parents(&parent); + let generic_param_lists = parents.iter().filter_map(|it| it.generic_param_list()).collect(); + let where_clauses = parents.iter().filter_map(|it| it.where_clause()).collect(); + + Some(ContainerInfo { + is_in_tail, + is_const, + parent_loop, + ret_type: ty, + generic_param_lists, + where_clauses, + }) + } + + fn return_ty(&self, ctx: &AssistContext<'_>) -> Option { + match self.tail_expr() { + Some(expr) => ctx.sema.type_of_expr(&expr).map(TypeInfo::original).map(RetType::Expr), + None => Some(RetType::Stmt), + } + } + + /// Local variables defined inside `body` that are accessed outside of it + fn ret_values<'a>( + &self, + ctx: &'a AssistContext<'_>, + parent: &SyntaxNode, + ) -> impl Iterator + 'a { + let parent = parent.clone(); + let range = self.text_range(); + locals_defined_in_body(&ctx.sema, self) + .into_iter() + .filter_map(move |local| local_outlives_body(ctx, range, local, &parent)) + } + + /// Analyses the function body for external control flow. + fn external_control_flow( + &self, + ctx: &AssistContext<'_>, + container_info: &ContainerInfo, + ) -> Option { + let mut ret_expr = None; + let mut try_expr = None; + let mut break_expr = None; + let mut continue_expr = None; + let mut is_async = false; + let mut _is_unsafe = false; + + let mut unsafe_depth = 0; + let mut loop_depth = 0; + + self.preorder_expr(&mut |expr| { + let expr = match expr { + WalkEvent::Enter(e) => e, + WalkEvent::Leave(expr) => { + match expr { + ast::Expr::LoopExpr(_) + | ast::Expr::ForExpr(_) + | ast::Expr::WhileExpr(_) => loop_depth -= 1, + ast::Expr::BlockExpr(block_expr) if block_expr.unsafe_token().is_some() => { + unsafe_depth -= 1 + } + _ => (), + } + return false; + } + }; + match expr { + ast::Expr::LoopExpr(_) | ast::Expr::ForExpr(_) | ast::Expr::WhileExpr(_) => { + loop_depth += 1; + } + ast::Expr::BlockExpr(block_expr) if block_expr.unsafe_token().is_some() => { + unsafe_depth += 1 + } + ast::Expr::ReturnExpr(it) => { + ret_expr = Some(it); + } + ast::Expr::TryExpr(it) => { + try_expr = Some(it); + } + ast::Expr::BreakExpr(it) if loop_depth == 0 => { + break_expr = Some(it); + } + ast::Expr::ContinueExpr(it) if loop_depth == 0 => { + continue_expr = Some(it); + } + ast::Expr::AwaitExpr(_) => is_async = true, + // FIXME: Do unsafe analysis on expression, sem highlighting knows this so we should be able + // to just lift that out of there + // expr if unsafe_depth ==0 && expr.is_unsafe => is_unsafe = true, + _ => {} + } + false + }); + + let kind = match (try_expr, ret_expr, break_expr, continue_expr) { + (Some(_), _, None, None) => { + let ret_ty = container_info.ret_type.clone()?; + let kind = TryKind::of_ty(ret_ty, ctx)?; + + Some(FlowKind::Try { kind }) + } + (Some(_), _, _, _) => { + cov_mark::hit!(external_control_flow_try_and_bc); + return None; + } + (None, Some(r), None, None) => Some(FlowKind::Return(r.expr())), + (None, Some(_), _, _) => { + cov_mark::hit!(external_control_flow_return_and_bc); + return None; + } + (None, None, Some(_), Some(_)) => { + cov_mark::hit!(external_control_flow_break_and_continue); + return None; + } + (None, None, Some(b), None) => Some(FlowKind::Break(b.lifetime(), b.expr())), + (None, None, None, Some(c)) => Some(FlowKind::Continue(c.lifetime())), + (None, None, None, None) => None, + }; + + Some(ControlFlow { kind, is_async, is_unsafe: _is_unsafe }) + } + + /// find variables that should be extracted as params + /// + /// Computes additional info that affects param type and mutability + fn extracted_function_params( + &self, + ctx: &AssistContext<'_>, + container_info: &ContainerInfo, + locals: impl Iterator, + ) -> Vec { + locals + .map(|local| (local, local.source(ctx.db()))) + .filter(|(_, src)| is_defined_outside_of_body(ctx, self, src)) + .filter_map(|(local, src)| match src.value { + Either::Left(src) => Some((local, src)), + Either::Right(_) => { + stdx::never!(false, "Local::is_self returned false, but source is SelfParam"); + None + } + }) + .map(|(var, src)| { + let usages = LocalUsages::find_local_usages(ctx, var); + let ty = var.ty(ctx.db()); + + let defined_outside_parent_loop = container_info + .parent_loop + .as_ref() + .map_or(true, |it| it.text_range().contains_range(src.syntax().text_range())); + + let is_copy = ty.is_copy(ctx.db()); + let has_usages = self.has_usages_after_body(&usages); + let requires_mut = + !ty.is_mutable_reference() && has_exclusive_usages(ctx, &usages, self); + // We can move the value into the function call if it's not used after the call, + // if the var is not used but defined outside a loop we are extracting from we can't move it either + // as the function will reuse it in the next iteration. + let move_local = (!has_usages && defined_outside_parent_loop) || ty.is_reference(); + Param { var, ty, move_local, requires_mut, is_copy } + }) + .collect() + } + + fn has_usages_after_body(&self, usages: &LocalUsages) -> bool { + usages.iter().any(|reference| self.precedes_range(reference.range)) + } +} + +enum GenericParent { + Fn(ast::Fn), + Impl(ast::Impl), + Trait(ast::Trait), +} + +impl GenericParent { + fn generic_param_list(&self) -> Option { + match self { + GenericParent::Fn(fn_) => fn_.generic_param_list(), + GenericParent::Impl(impl_) => impl_.generic_param_list(), + GenericParent::Trait(trait_) => trait_.generic_param_list(), + } + } + + fn where_clause(&self) -> Option { + match self { + GenericParent::Fn(fn_) => fn_.where_clause(), + GenericParent::Impl(impl_) => impl_.where_clause(), + GenericParent::Trait(trait_) => trait_.where_clause(), + } + } +} + +/// Search `parent`'s ancestors for items with potentially applicable generic parameters +fn generic_parents(parent: &SyntaxNode) -> Vec { + let mut list = Vec::new(); + if let Some(parent_item) = parent.ancestors().find_map(ast::Item::cast) { + match parent_item { + ast::Item::Fn(ref fn_) => { + if let Some(parent_parent) = parent_item + .syntax() + .parent() + .and_then(|it| it.parent()) + .and_then(ast::Item::cast) + { + match parent_parent { + ast::Item::Impl(impl_) => list.push(GenericParent::Impl(impl_)), + ast::Item::Trait(trait_) => list.push(GenericParent::Trait(trait_)), + _ => (), + } + } + list.push(GenericParent::Fn(fn_.clone())); + } + _ => (), + } + } + list +} + +/// checks if relevant var is used with `&mut` access inside body +fn has_exclusive_usages( + ctx: &AssistContext<'_>, + usages: &LocalUsages, + body: &FunctionBody, +) -> bool { + usages + .iter() + .filter(|reference| body.contains_range(reference.range)) + .any(|reference| reference_is_exclusive(reference, body, ctx)) +} + +/// checks if this reference requires `&mut` access inside node +fn reference_is_exclusive( + reference: &FileReference, + node: &dyn HasTokenAtOffset, + ctx: &AssistContext<'_>, +) -> bool { + // we directly modify variable with set: `n = 0`, `n += 1` + if reference.category == Some(ReferenceCategory::Write) { + return true; + } + + // we take `&mut` reference to variable: `&mut v` + let path = match path_element_of_reference(node, reference) { + Some(path) => path, + None => return false, + }; + + expr_require_exclusive_access(ctx, &path).unwrap_or(false) +} + +/// checks if this expr requires `&mut` access, recurses on field access +fn expr_require_exclusive_access(ctx: &AssistContext<'_>, expr: &ast::Expr) -> Option { + if let ast::Expr::MacroExpr(_) = expr { + // FIXME: expand macro and check output for mutable usages of the variable? + return None; + } + + let parent = expr.syntax().parent()?; + + if let Some(bin_expr) = ast::BinExpr::cast(parent.clone()) { + if matches!(bin_expr.op_kind()?, ast::BinaryOp::Assignment { .. }) { + return Some(bin_expr.lhs()?.syntax() == expr.syntax()); + } + return Some(false); + } + + if let Some(ref_expr) = ast::RefExpr::cast(parent.clone()) { + return Some(ref_expr.mut_token().is_some()); + } + + if let Some(method_call) = ast::MethodCallExpr::cast(parent.clone()) { + let func = ctx.sema.resolve_method_call(&method_call)?; + let self_param = func.self_param(ctx.db())?; + let access = self_param.access(ctx.db()); + + return Some(matches!(access, hir::Access::Exclusive)); + } + + if let Some(field) = ast::FieldExpr::cast(parent) { + return expr_require_exclusive_access(ctx, &field.into()); + } + + Some(false) +} + +trait HasTokenAtOffset { + fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset; +} + +impl HasTokenAtOffset for SyntaxNode { + fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset { + SyntaxNode::token_at_offset(self, offset) + } +} + +impl HasTokenAtOffset for FunctionBody { + fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset { + match self { + FunctionBody::Expr(expr) => expr.syntax().token_at_offset(offset), + FunctionBody::Span { parent, text_range } => { + match parent.syntax().token_at_offset(offset) { + TokenAtOffset::None => TokenAtOffset::None, + TokenAtOffset::Single(t) => { + if text_range.contains_range(t.text_range()) { + TokenAtOffset::Single(t) + } else { + TokenAtOffset::None + } + } + TokenAtOffset::Between(a, b) => { + match ( + text_range.contains_range(a.text_range()), + text_range.contains_range(b.text_range()), + ) { + (true, true) => TokenAtOffset::Between(a, b), + (true, false) => TokenAtOffset::Single(a), + (false, true) => TokenAtOffset::Single(b), + (false, false) => TokenAtOffset::None, + } + } + } + } + } + } +} + +/// find relevant `ast::Expr` for reference +/// +/// # Preconditions +/// +/// `node` must cover `reference`, that is `node.text_range().contains_range(reference.range)` +fn path_element_of_reference( + node: &dyn HasTokenAtOffset, + reference: &FileReference, +) -> Option { + let token = node.token_at_offset(reference.range.start()).right_biased().or_else(|| { + stdx::never!(false, "cannot find token at variable usage: {:?}", reference); + None + })?; + let path = token.parent_ancestors().find_map(ast::Expr::cast).or_else(|| { + stdx::never!(false, "cannot find path parent of variable usage: {:?}", token); + None + })?; + stdx::always!( + matches!(path, ast::Expr::PathExpr(_) | ast::Expr::MacroExpr(_)), + "unexpected expression type for variable usage: {:?}", + path + ); + Some(path) +} + +/// list local variables defined inside `body` +fn locals_defined_in_body( + sema: &Semantics<'_, RootDatabase>, + body: &FunctionBody, +) -> FxIndexSet { + // FIXME: this doesn't work well with macros + // see https://github.com/rust-lang/rust-analyzer/pull/7535#discussion_r570048550 + let mut res = FxIndexSet::default(); + body.walk_pat(&mut |pat| { + if let ast::Pat::IdentPat(pat) = pat { + if let Some(local) = sema.to_def(&pat) { + res.insert(local); + } + } + }); + res +} + +/// Returns usage details if local variable is used after(outside of) body +fn local_outlives_body( + ctx: &AssistContext<'_>, + body_range: TextRange, + local: Local, + parent: &SyntaxNode, +) -> Option { + let usages = LocalUsages::find_local_usages(ctx, local); + let mut has_mut_usages = false; + let mut any_outlives = false; + for usage in usages.iter() { + if body_range.end() <= usage.range.start() { + has_mut_usages |= reference_is_exclusive(usage, parent, ctx); + any_outlives |= true; + if has_mut_usages { + break; // no need to check more elements we have all the info we wanted + } + } + } + if !any_outlives { + return None; + } + Some(OutlivedLocal { local, mut_usage_outside_body: has_mut_usages }) +} + +/// checks if the relevant local was defined before(outside of) body +fn is_defined_outside_of_body( + ctx: &AssistContext<'_>, + body: &FunctionBody, + src: &hir::InFile>, +) -> bool { + src.file_id.original_file(ctx.db()) == ctx.file_id() + && !body.contains_node(either_syntax(&src.value)) +} + +fn either_syntax(value: &Either) -> &SyntaxNode { + match value { + Either::Left(pat) => pat.syntax(), + Either::Right(it) => it.syntax(), + } +} + +/// find where to put extracted function definition +/// +/// Function should be put right after returned node +fn node_to_insert_after(body: &FunctionBody, anchor: Anchor) -> Option { + let node = body.node(); + let mut ancestors = node.ancestors().peekable(); + let mut last_ancestor = None; + while let Some(next_ancestor) = ancestors.next() { + match next_ancestor.kind() { + SyntaxKind::SOURCE_FILE => break, + SyntaxKind::ITEM_LIST if !matches!(anchor, Anchor::Freestanding) => continue, + SyntaxKind::ITEM_LIST => { + if ancestors.peek().map(SyntaxNode::kind) == Some(SyntaxKind::MODULE) { + break; + } + } + SyntaxKind::ASSOC_ITEM_LIST if !matches!(anchor, Anchor::Method) => continue, + SyntaxKind::ASSOC_ITEM_LIST if body.extracted_from_trait_impl() => continue, + SyntaxKind::ASSOC_ITEM_LIST => { + if ancestors.peek().map(SyntaxNode::kind) == Some(SyntaxKind::IMPL) { + break; + } + } + _ => (), + } + last_ancestor = Some(next_ancestor); + } + last_ancestor +} + +fn make_call(ctx: &AssistContext<'_>, fun: &Function, indent: IndentLevel) -> String { + let ret_ty = fun.return_type(ctx); + + let args = make::arg_list(fun.params.iter().map(|param| param.to_arg(ctx))); + let name = fun.name.clone(); + let mut call_expr = if fun.self_param.is_some() { + let self_arg = make::expr_path(make::ext::ident_path("self")); + make::expr_method_call(self_arg, name, args) + } else { + let func = make::expr_path(make::path_unqualified(make::path_segment(name))); + make::expr_call(func, args) + }; + + let handler = FlowHandler::from_ret_ty(fun, &ret_ty); + + if fun.control_flow.is_async { + call_expr = make::expr_await(call_expr); + } + let expr = handler.make_call_expr(call_expr).indent(indent); + + let mut_modifier = |var: &OutlivedLocal| if var.mut_usage_outside_body { "mut " } else { "" }; + + let mut buf = String::new(); + match fun.outliving_locals.as_slice() { + [] => {} + [var] => { + format_to!(buf, "let {}{} = ", mut_modifier(var), var.local.name(ctx.db())) + } + vars => { + buf.push_str("let ("); + let bindings = vars.iter().format_with(", ", |local, f| { + f(&format_args!("{}{}", mut_modifier(local), local.local.name(ctx.db()))) + }); + format_to!(buf, "{}", bindings); + buf.push_str(") = "); + } + } + + format_to!(buf, "{}", expr); + let insert_comma = fun + .body + .parent() + .and_then(ast::MatchArm::cast) + .map_or(false, |it| it.comma_token().is_none()); + if insert_comma { + buf.push(','); + } else if fun.ret_ty.is_unit() && (!fun.outliving_locals.is_empty() || !expr.is_block_like()) { + buf.push(';'); + } + buf +} + +enum FlowHandler { + None, + Try { kind: TryKind }, + If { action: FlowKind }, + IfOption { action: FlowKind }, + MatchOption { none: FlowKind }, + MatchResult { err: FlowKind }, +} + +impl FlowHandler { + fn from_ret_ty(fun: &Function, ret_ty: &FunType) -> FlowHandler { + match &fun.control_flow.kind { + None => FlowHandler::None, + Some(flow_kind) => { + let action = flow_kind.clone(); + if *ret_ty == FunType::Unit { + match flow_kind { + FlowKind::Return(None) + | FlowKind::Break(_, None) + | FlowKind::Continue(_) => FlowHandler::If { action }, + FlowKind::Return(_) | FlowKind::Break(_, _) => { + FlowHandler::IfOption { action } + } + FlowKind::Try { kind } => FlowHandler::Try { kind: kind.clone() }, + } + } else { + match flow_kind { + FlowKind::Return(None) + | FlowKind::Break(_, None) + | FlowKind::Continue(_) => FlowHandler::MatchOption { none: action }, + FlowKind::Return(_) | FlowKind::Break(_, _) => { + FlowHandler::MatchResult { err: action } + } + FlowKind::Try { kind } => FlowHandler::Try { kind: kind.clone() }, + } + } + } + } + } + + fn make_call_expr(&self, call_expr: ast::Expr) -> ast::Expr { + match self { + FlowHandler::None => call_expr, + FlowHandler::Try { kind: _ } => make::expr_try(call_expr), + FlowHandler::If { action } => { + let action = action.make_result_handler(None); + let stmt = make::expr_stmt(action); + let block = make::block_expr(iter::once(stmt.into()), None); + let controlflow_break_path = make::path_from_text("ControlFlow::Break"); + let condition = make::expr_let( + make::tuple_struct_pat( + controlflow_break_path, + iter::once(make::wildcard_pat().into()), + ) + .into(), + call_expr, + ); + make::expr_if(condition.into(), block, None) + } + FlowHandler::IfOption { action } => { + let path = make::ext::ident_path("Some"); + let value_pat = make::ext::simple_ident_pat(make::name("value")); + let pattern = make::tuple_struct_pat(path, iter::once(value_pat.into())); + let cond = make::expr_let(pattern.into(), call_expr); + let value = make::expr_path(make::ext::ident_path("value")); + let action_expr = action.make_result_handler(Some(value)); + let action_stmt = make::expr_stmt(action_expr); + let then = make::block_expr(iter::once(action_stmt.into()), None); + make::expr_if(cond.into(), then, None) + } + FlowHandler::MatchOption { none } => { + let some_name = "value"; + + let some_arm = { + let path = make::ext::ident_path("Some"); + let value_pat = make::ext::simple_ident_pat(make::name(some_name)); + let pat = make::tuple_struct_pat(path, iter::once(value_pat.into())); + let value = make::expr_path(make::ext::ident_path(some_name)); + make::match_arm(iter::once(pat.into()), None, value) + }; + let none_arm = { + let path = make::ext::ident_path("None"); + let pat = make::path_pat(path); + make::match_arm(iter::once(pat), None, none.make_result_handler(None)) + }; + let arms = make::match_arm_list(vec![some_arm, none_arm]); + make::expr_match(call_expr, arms) + } + FlowHandler::MatchResult { err } => { + let ok_name = "value"; + let err_name = "value"; + + let ok_arm = { + let path = make::ext::ident_path("Ok"); + let value_pat = make::ext::simple_ident_pat(make::name(ok_name)); + let pat = make::tuple_struct_pat(path, iter::once(value_pat.into())); + let value = make::expr_path(make::ext::ident_path(ok_name)); + make::match_arm(iter::once(pat.into()), None, value) + }; + let err_arm = { + let path = make::ext::ident_path("Err"); + let value_pat = make::ext::simple_ident_pat(make::name(err_name)); + let pat = make::tuple_struct_pat(path, iter::once(value_pat.into())); + let value = make::expr_path(make::ext::ident_path(err_name)); + make::match_arm( + iter::once(pat.into()), + None, + err.make_result_handler(Some(value)), + ) + }; + let arms = make::match_arm_list(vec![ok_arm, err_arm]); + make::expr_match(call_expr, arms) + } + } + } +} + +fn path_expr_from_local(ctx: &AssistContext<'_>, var: Local) -> ast::Expr { + let name = var.name(ctx.db()).to_string(); + make::expr_path(make::ext::ident_path(&name)) +} + +fn format_function( + ctx: &AssistContext<'_>, + module: hir::Module, + fun: &Function, + old_indent: IndentLevel, + new_indent: IndentLevel, +) -> String { + let mut fn_def = String::new(); + let params = fun.make_param_list(ctx, module); + let ret_ty = fun.make_ret_ty(ctx, module); + let body = make_body(ctx, old_indent, new_indent, fun); + let const_kw = if fun.mods.is_const { "const " } else { "" }; + let async_kw = if fun.control_flow.is_async { "async " } else { "" }; + let unsafe_kw = if fun.control_flow.is_unsafe { "unsafe " } else { "" }; + let (generic_params, where_clause) = make_generic_params_and_where_clause(ctx, fun); + match ctx.config.snippet_cap { + Some(_) => format_to!( + fn_def, + "\n\n{}{}{}{}fn $0{}", + new_indent, + const_kw, + async_kw, + unsafe_kw, + fun.name, + ), + None => format_to!( + fn_def, + "\n\n{}{}{}{}fn {}", + new_indent, + const_kw, + async_kw, + unsafe_kw, + fun.name, + ), + } + + if let Some(generic_params) = generic_params { + format_to!(fn_def, "{}", generic_params); + } + + format_to!(fn_def, "{}", params); + + if let Some(ret_ty) = ret_ty { + format_to!(fn_def, " {}", ret_ty); + } + + if let Some(where_clause) = where_clause { + format_to!(fn_def, " {}", where_clause); + } + + format_to!(fn_def, " {}", body); + + fn_def +} + +fn make_generic_params_and_where_clause( + ctx: &AssistContext<'_>, + fun: &Function, +) -> (Option, Option) { + let used_type_params = fun.type_params(ctx); + + let generic_param_list = make_generic_param_list(ctx, fun, &used_type_params); + let where_clause = make_where_clause(ctx, fun, &used_type_params); + + (generic_param_list, where_clause) +} + +fn make_generic_param_list( + ctx: &AssistContext<'_>, + fun: &Function, + used_type_params: &[TypeParam], +) -> Option { + let mut generic_params = fun + .mods + .generic_param_lists + .iter() + .flat_map(|parent_params| { + parent_params + .generic_params() + .filter(|param| param_is_required(ctx, param, used_type_params)) + }) + .peekable(); + + if generic_params.peek().is_some() { + Some(make::generic_param_list(generic_params)) + } else { + None + } +} + +fn param_is_required( + ctx: &AssistContext<'_>, + param: &ast::GenericParam, + used_type_params: &[TypeParam], +) -> bool { + match param { + ast::GenericParam::ConstParam(_) | ast::GenericParam::LifetimeParam(_) => false, + ast::GenericParam::TypeParam(type_param) => match &ctx.sema.to_def(type_param) { + Some(def) => used_type_params.contains(def), + _ => false, + }, + } +} + +fn make_where_clause( + ctx: &AssistContext<'_>, + fun: &Function, + used_type_params: &[TypeParam], +) -> Option { + let mut predicates = fun + .mods + .where_clauses + .iter() + .flat_map(|parent_where_clause| { + parent_where_clause + .predicates() + .filter(|pred| pred_is_required(ctx, pred, used_type_params)) + }) + .peekable(); + + if predicates.peek().is_some() { + Some(make::where_clause(predicates)) + } else { + None + } +} + +fn pred_is_required( + ctx: &AssistContext<'_>, + pred: &ast::WherePred, + used_type_params: &[TypeParam], +) -> bool { + match resolved_type_param(ctx, pred) { + Some(it) => used_type_params.contains(&it), + None => false, + } +} + +fn resolved_type_param(ctx: &AssistContext<'_>, pred: &ast::WherePred) -> Option { + let path = match pred.ty()? { + ast::Type::PathType(path_type) => path_type.path(), + _ => None, + }?; + + match ctx.sema.resolve_path(&path)? { + PathResolution::TypeParam(type_param) => Some(type_param), + _ => None, + } +} + +impl Function { + /// Collect all the `TypeParam`s used in the `body` and `params`. + fn type_params(&self, ctx: &AssistContext<'_>) -> Vec { + let type_params_in_descendant_paths = + self.body.descendant_paths().filter_map(|it| match ctx.sema.resolve_path(&it) { + Some(PathResolution::TypeParam(type_param)) => Some(type_param), + _ => None, + }); + let type_params_in_params = self.params.iter().filter_map(|p| p.ty.as_type_param(ctx.db())); + type_params_in_descendant_paths.chain(type_params_in_params).collect() + } + + fn make_param_list(&self, ctx: &AssistContext<'_>, module: hir::Module) -> ast::ParamList { + let self_param = self.self_param.clone(); + let params = self.params.iter().map(|param| param.to_param(ctx, module)); + make::param_list(self_param, params) + } + + fn make_ret_ty(&self, ctx: &AssistContext<'_>, module: hir::Module) -> Option { + let fun_ty = self.return_type(ctx); + let handler = if self.mods.is_in_tail { + FlowHandler::None + } else { + FlowHandler::from_ret_ty(self, &fun_ty) + }; + let ret_ty = match &handler { + FlowHandler::None => { + if matches!(fun_ty, FunType::Unit) { + return None; + } + fun_ty.make_ty(ctx, module) + } + FlowHandler::Try { kind: TryKind::Option } => { + make::ext::ty_option(fun_ty.make_ty(ctx, module)) + } + FlowHandler::Try { kind: TryKind::Result { ty: parent_ret_ty } } => { + let handler_ty = parent_ret_ty + .type_arguments() + .nth(1) + .map(|ty| make_ty(&ty, ctx, module)) + .unwrap_or_else(make::ty_placeholder); + make::ext::ty_result(fun_ty.make_ty(ctx, module), handler_ty) + } + FlowHandler::If { .. } => make::ty("ControlFlow<()>"), + FlowHandler::IfOption { action } => { + let handler_ty = action + .expr_ty(ctx) + .map(|ty| make_ty(&ty, ctx, module)) + .unwrap_or_else(make::ty_placeholder); + make::ext::ty_option(handler_ty) + } + FlowHandler::MatchOption { .. } => make::ext::ty_option(fun_ty.make_ty(ctx, module)), + FlowHandler::MatchResult { err } => { + let handler_ty = err + .expr_ty(ctx) + .map(|ty| make_ty(&ty, ctx, module)) + .unwrap_or_else(make::ty_placeholder); + make::ext::ty_result(fun_ty.make_ty(ctx, module), handler_ty) + } + }; + Some(make::ret_type(ret_ty)) + } +} + +impl FunType { + fn make_ty(&self, ctx: &AssistContext<'_>, module: hir::Module) -> ast::Type { + match self { + FunType::Unit => make::ty_unit(), + FunType::Single(ty) => make_ty(ty, ctx, module), + FunType::Tuple(types) => match types.as_slice() { + [] => { + stdx::never!("tuple type with 0 elements"); + make::ty_unit() + } + [ty] => { + stdx::never!("tuple type with 1 element"); + make_ty(ty, ctx, module) + } + types => { + let types = types.iter().map(|ty| make_ty(ty, ctx, module)); + make::ty_tuple(types) + } + }, + } + } +} + +fn make_body( + ctx: &AssistContext<'_>, + old_indent: IndentLevel, + new_indent: IndentLevel, + fun: &Function, +) -> ast::BlockExpr { + let ret_ty = fun.return_type(ctx); + let handler = if fun.mods.is_in_tail { + FlowHandler::None + } else { + FlowHandler::from_ret_ty(fun, &ret_ty) + }; + + let block = match &fun.body { + FunctionBody::Expr(expr) => { + let expr = rewrite_body_segment(ctx, &fun.params, &handler, expr.syntax()); + let expr = ast::Expr::cast(expr).unwrap(); + match expr { + ast::Expr::BlockExpr(block) => { + // If the extracted expression is itself a block, there is no need to wrap it inside another block. + let block = block.dedent(old_indent); + // Recreate the block for formatting consistency with other extracted functions. + make::block_expr(block.statements(), block.tail_expr()) + } + _ => { + let expr = expr.dedent(old_indent).indent(IndentLevel(1)); + + make::block_expr(Vec::new(), Some(expr)) + } + } + } + FunctionBody::Span { parent, text_range } => { + let mut elements: Vec<_> = parent + .syntax() + .children_with_tokens() + .filter(|it| text_range.contains_range(it.text_range())) + .map(|it| match &it { + syntax::NodeOrToken::Node(n) => syntax::NodeOrToken::Node( + rewrite_body_segment(ctx, &fun.params, &handler, n), + ), + _ => it, + }) + .collect(); + + let mut tail_expr = match &elements.last() { + Some(syntax::NodeOrToken::Node(node)) if ast::Expr::can_cast(node.kind()) => { + ast::Expr::cast(node.clone()) + } + _ => None, + }; + + match tail_expr { + Some(_) => { + elements.pop(); + } + None => match fun.outliving_locals.as_slice() { + [] => {} + [var] => { + tail_expr = Some(path_expr_from_local(ctx, var.local)); + } + vars => { + let exprs = vars.iter().map(|var| path_expr_from_local(ctx, var.local)); + let expr = make::expr_tuple(exprs); + tail_expr = Some(expr); + } + }, + }; + + let body_indent = IndentLevel(1); + let elements = elements + .into_iter() + .map(|node_or_token| match &node_or_token { + syntax::NodeOrToken::Node(node) => match ast::Stmt::cast(node.clone()) { + Some(stmt) => { + let indented = stmt.dedent(old_indent).indent(body_indent); + let ast_node = indented.syntax().clone_subtree(); + syntax::NodeOrToken::Node(ast_node) + } + _ => node_or_token, + }, + _ => node_or_token, + }) + .collect::>(); + let tail_expr = tail_expr.map(|expr| expr.dedent(old_indent).indent(body_indent)); + + make::hacky_block_expr_with_comments(elements, tail_expr) + } + }; + + let block = match &handler { + FlowHandler::None => block, + FlowHandler::Try { kind } => { + let block = with_default_tail_expr(block, make::expr_unit()); + map_tail_expr(block, |tail_expr| { + let constructor = match kind { + TryKind::Option => "Some", + TryKind::Result { .. } => "Ok", + }; + let func = make::expr_path(make::ext::ident_path(constructor)); + let args = make::arg_list(iter::once(tail_expr)); + make::expr_call(func, args) + }) + } + FlowHandler::If { .. } => { + let controlflow_continue = make::expr_call( + make::expr_path(make::path_from_text("ControlFlow::Continue")), + make::arg_list(iter::once(make::expr_unit())), + ); + with_tail_expr(block, controlflow_continue) + } + FlowHandler::IfOption { .. } => { + let none = make::expr_path(make::ext::ident_path("None")); + with_tail_expr(block, none) + } + FlowHandler::MatchOption { .. } => map_tail_expr(block, |tail_expr| { + let some = make::expr_path(make::ext::ident_path("Some")); + let args = make::arg_list(iter::once(tail_expr)); + make::expr_call(some, args) + }), + FlowHandler::MatchResult { .. } => map_tail_expr(block, |tail_expr| { + let ok = make::expr_path(make::ext::ident_path("Ok")); + let args = make::arg_list(iter::once(tail_expr)); + make::expr_call(ok, args) + }), + }; + + block.indent(new_indent) +} + +fn map_tail_expr(block: ast::BlockExpr, f: impl FnOnce(ast::Expr) -> ast::Expr) -> ast::BlockExpr { + let tail_expr = match block.tail_expr() { + Some(tail_expr) => tail_expr, + None => return block, + }; + make::block_expr(block.statements(), Some(f(tail_expr))) +} + +fn with_default_tail_expr(block: ast::BlockExpr, tail_expr: ast::Expr) -> ast::BlockExpr { + match block.tail_expr() { + Some(_) => block, + None => make::block_expr(block.statements(), Some(tail_expr)), + } +} + +fn with_tail_expr(block: ast::BlockExpr, tail_expr: ast::Expr) -> ast::BlockExpr { + let stmt_tail = block.tail_expr().map(|expr| make::expr_stmt(expr).into()); + let stmts = block.statements().chain(stmt_tail); + make::block_expr(stmts, Some(tail_expr)) +} + +fn format_type(ty: &hir::Type, ctx: &AssistContext<'_>, module: hir::Module) -> String { + ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "_".to_string()) +} + +fn make_ty(ty: &hir::Type, ctx: &AssistContext<'_>, module: hir::Module) -> ast::Type { + let ty_str = format_type(ty, ctx, module); + make::ty(&ty_str) +} + +fn rewrite_body_segment( + ctx: &AssistContext<'_>, + params: &[Param], + handler: &FlowHandler, + syntax: &SyntaxNode, +) -> SyntaxNode { + let syntax = fix_param_usages(ctx, params, syntax); + update_external_control_flow(handler, &syntax); + syntax +} + +/// change all usages to account for added `&`/`&mut` for some params +fn fix_param_usages(ctx: &AssistContext<'_>, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode { + let mut usages_for_param: Vec<(&Param, Vec)> = Vec::new(); + + let tm = TreeMutator::new(syntax); + + for param in params { + if !param.kind().is_ref() { + continue; + } + + let usages = LocalUsages::find_local_usages(ctx, param.var); + let usages = usages + .iter() + .filter(|reference| syntax.text_range().contains_range(reference.range)) + .filter_map(|reference| path_element_of_reference(syntax, reference)) + .map(|expr| tm.make_mut(&expr)); + + usages_for_param.push((param, usages.collect())); + } + + let res = tm.make_syntax_mut(syntax); + + for (param, usages) in usages_for_param { + for usage in usages { + match usage.syntax().ancestors().skip(1).find_map(ast::Expr::cast) { + Some(ast::Expr::MethodCallExpr(_) | ast::Expr::FieldExpr(_)) => { + // do nothing + } + Some(ast::Expr::RefExpr(node)) + if param.kind() == ParamKind::MutRef && node.mut_token().is_some() => + { + ted::replace(node.syntax(), node.expr().unwrap().syntax()); + } + Some(ast::Expr::RefExpr(node)) + if param.kind() == ParamKind::SharedRef && node.mut_token().is_none() => + { + ted::replace(node.syntax(), node.expr().unwrap().syntax()); + } + Some(_) | None => { + let p = &make::expr_prefix(T![*], usage.clone()).clone_for_update(); + ted::replace(usage.syntax(), p.syntax()) + } + } + } + } + + res +} + +fn update_external_control_flow(handler: &FlowHandler, syntax: &SyntaxNode) { + let mut nested_loop = None; + let mut nested_scope = None; + for event in syntax.preorder() { + match event { + WalkEvent::Enter(e) => match e.kind() { + SyntaxKind::LOOP_EXPR | SyntaxKind::WHILE_EXPR | SyntaxKind::FOR_EXPR => { + if nested_loop.is_none() { + nested_loop = Some(e.clone()); + } + } + SyntaxKind::FN + | SyntaxKind::CONST + | SyntaxKind::STATIC + | SyntaxKind::IMPL + | SyntaxKind::MODULE => { + if nested_scope.is_none() { + nested_scope = Some(e.clone()); + } + } + _ => {} + }, + WalkEvent::Leave(e) => { + if nested_scope.is_none() { + if let Some(expr) = ast::Expr::cast(e.clone()) { + match expr { + ast::Expr::ReturnExpr(return_expr) if nested_scope.is_none() => { + let expr = return_expr.expr(); + if let Some(replacement) = make_rewritten_flow(handler, expr) { + ted::replace(return_expr.syntax(), replacement.syntax()) + } + } + ast::Expr::BreakExpr(break_expr) if nested_loop.is_none() => { + let expr = break_expr.expr(); + if let Some(replacement) = make_rewritten_flow(handler, expr) { + ted::replace(break_expr.syntax(), replacement.syntax()) + } + } + ast::Expr::ContinueExpr(continue_expr) if nested_loop.is_none() => { + if let Some(replacement) = make_rewritten_flow(handler, None) { + ted::replace(continue_expr.syntax(), replacement.syntax()) + } + } + _ => { + // do nothing + } + } + } + } + + if nested_loop.as_ref() == Some(&e) { + nested_loop = None; + } + if nested_scope.as_ref() == Some(&e) { + nested_scope = None; + } + } + }; + } +} + +fn make_rewritten_flow(handler: &FlowHandler, arg_expr: Option) -> Option { + let value = match handler { + FlowHandler::None | FlowHandler::Try { .. } => return None, + FlowHandler::If { .. } => make::expr_call( + make::expr_path(make::path_from_text("ControlFlow::Break")), + make::arg_list(iter::once(make::expr_unit())), + ), + FlowHandler::IfOption { .. } => { + let expr = arg_expr.unwrap_or_else(|| make::expr_tuple(Vec::new())); + let args = make::arg_list(iter::once(expr)); + make::expr_call(make::expr_path(make::ext::ident_path("Some")), args) + } + FlowHandler::MatchOption { .. } => make::expr_path(make::ext::ident_path("None")), + FlowHandler::MatchResult { .. } => { + let expr = arg_expr.unwrap_or_else(|| make::expr_tuple(Vec::new())); + let args = make::arg_list(iter::once(expr)); + make::expr_call(make::expr_path(make::ext::ident_path("Err")), args) + } + }; + Some(make::expr_return(Some(value)).clone_for_update()) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn no_args_from_binary_expr() { + check_assist( + extract_function, + r#" +fn foo() { + foo($01 + 1$0); +} +"#, + r#" +fn foo() { + foo(fun_name()); +} + +fn $0fun_name() -> i32 { + 1 + 1 +} +"#, + ); + } + + #[test] + fn no_args_from_binary_expr_in_module() { + check_assist( + extract_function, + r#" +mod bar { + fn foo() { + foo($01 + 1$0); + } +} +"#, + r#" +mod bar { + fn foo() { + foo(fun_name()); + } + + fn $0fun_name() -> i32 { + 1 + 1 + } +} +"#, + ); + } + + #[test] + fn no_args_from_binary_expr_indented() { + check_assist( + extract_function, + r#" +fn foo() { + $0{ 1 + 1 }$0; +} +"#, + r#" +fn foo() { + fun_name(); +} + +fn $0fun_name() -> i32 { + 1 + 1 +} +"#, + ); + } + + #[test] + fn no_args_from_stmt_with_last_expr() { + check_assist( + extract_function, + r#" +fn foo() -> i32 { + let k = 1; + $0let m = 1; + m + 1$0 +} +"#, + r#" +fn foo() -> i32 { + let k = 1; + fun_name() +} + +fn $0fun_name() -> i32 { + let m = 1; + m + 1 +} +"#, + ); + } + + #[test] + fn no_args_from_stmt_unit() { + check_assist( + extract_function, + r#" +fn foo() { + let k = 3; + $0let m = 1; + let n = m + 1;$0 + let g = 5; +} +"#, + r#" +fn foo() { + let k = 3; + fun_name(); + let g = 5; +} + +fn $0fun_name() { + let m = 1; + let n = m + 1; +} +"#, + ); + } + + #[test] + fn no_args_if() { + check_assist( + extract_function, + r#" +fn foo() { + $0if true { }$0 +} +"#, + r#" +fn foo() { + fun_name(); +} + +fn $0fun_name() { + if true { } +} +"#, + ); + } + + #[test] + fn no_args_if_else() { + check_assist( + extract_function, + r#" +fn foo() -> i32 { + $0if true { 1 } else { 2 }$0 +} +"#, + r#" +fn foo() -> i32 { + fun_name() +} + +fn $0fun_name() -> i32 { + if true { 1 } else { 2 } +} +"#, + ); + } + + #[test] + fn no_args_if_let_else() { + check_assist( + extract_function, + r#" +fn foo() -> i32 { + $0if let true = false { 1 } else { 2 }$0 +} +"#, + r#" +fn foo() -> i32 { + fun_name() +} + +fn $0fun_name() -> i32 { + if let true = false { 1 } else { 2 } +} +"#, + ); + } + + #[test] + fn no_args_match() { + check_assist( + extract_function, + r#" +fn foo() -> i32 { + $0match true { + true => 1, + false => 2, + }$0 +} +"#, + r#" +fn foo() -> i32 { + fun_name() +} + +fn $0fun_name() -> i32 { + match true { + true => 1, + false => 2, + } +} +"#, + ); + } + + #[test] + fn no_args_while() { + check_assist( + extract_function, + r#" +fn foo() { + $0while true { }$0 +} +"#, + r#" +fn foo() { + fun_name(); +} + +fn $0fun_name() { + while true { } +} +"#, + ); + } + + #[test] + fn no_args_for() { + check_assist( + extract_function, + r#" +fn foo() { + $0for v in &[0, 1] { }$0 +} +"#, + r#" +fn foo() { + fun_name(); +} + +fn $0fun_name() { + for v in &[0, 1] { } +} +"#, + ); + } + + #[test] + fn no_args_from_loop_unit() { + check_assist( + extract_function, + r#" +fn foo() { + $0loop { + let m = 1; + }$0 +} +"#, + r#" +fn foo() { + fun_name() +} + +fn $0fun_name() -> ! { + loop { + let m = 1; + } +} +"#, + ); + } + + #[test] + fn no_args_from_loop_with_return() { + check_assist( + extract_function, + r#" +fn foo() { + let v = $0loop { + let m = 1; + break m; + }$0; +} +"#, + r#" +fn foo() { + let v = fun_name(); +} + +fn $0fun_name() -> i32 { + loop { + let m = 1; + break m; + } +} +"#, + ); + } + + #[test] + fn no_args_from_match() { + check_assist( + extract_function, + r#" +fn foo() { + let v: i32 = $0match Some(1) { + Some(x) => x, + None => 0, + }$0; +} +"#, + r#" +fn foo() { + let v: i32 = fun_name(); +} + +fn $0fun_name() -> i32 { + match Some(1) { + Some(x) => x, + None => 0, + } +} +"#, + ); + } + + #[test] + fn extract_partial_block_single_line() { + check_assist( + extract_function, + r#" +fn foo() { + let n = 1; + let mut v = $0n * n;$0 + v += 1; +} +"#, + r#" +fn foo() { + let n = 1; + let mut v = fun_name(n); + v += 1; +} + +fn $0fun_name(n: i32) -> i32 { + let mut v = n * n; + v +} +"#, + ); + } + + #[test] + fn extract_partial_block() { + check_assist( + extract_function, + r#" +fn foo() { + let m = 2; + let n = 1; + let mut v = m $0* n; + let mut w = 3;$0 + v += 1; + w += 1; +} +"#, + r#" +fn foo() { + let m = 2; + let n = 1; + let (mut v, mut w) = fun_name(m, n); + v += 1; + w += 1; +} + +fn $0fun_name(m: i32, n: i32) -> (i32, i32) { + let mut v = m * n; + let mut w = 3; + (v, w) +} +"#, + ); + } + + #[test] + fn argument_form_expr() { + check_assist( + extract_function, + r#" +fn foo() -> u32 { + let n = 2; + $0n+2$0 +} +"#, + r#" +fn foo() -> u32 { + let n = 2; + fun_name(n) +} + +fn $0fun_name(n: u32) -> u32 { + n+2 +} +"#, + ) + } + + #[test] + fn argument_used_twice_form_expr() { + check_assist( + extract_function, + r#" +fn foo() -> u32 { + let n = 2; + $0n+n$0 +} +"#, + r#" +fn foo() -> u32 { + let n = 2; + fun_name(n) +} + +fn $0fun_name(n: u32) -> u32 { + n+n +} +"#, + ) + } + + #[test] + fn two_arguments_form_expr() { + check_assist( + extract_function, + r#" +fn foo() -> u32 { + let n = 2; + let m = 3; + $0n+n*m$0 +} +"#, + r#" +fn foo() -> u32 { + let n = 2; + let m = 3; + fun_name(n, m) +} + +fn $0fun_name(n: u32, m: u32) -> u32 { + n+n*m +} +"#, + ) + } + + #[test] + fn argument_and_locals() { + check_assist( + extract_function, + r#" +fn foo() -> u32 { + let n = 2; + $0let m = 1; + n + m$0 +} +"#, + r#" +fn foo() -> u32 { + let n = 2; + fun_name(n) +} + +fn $0fun_name(n: u32) -> u32 { + let m = 1; + n + m +} +"#, + ) + } + + #[test] + fn in_comment_is_not_applicable() { + cov_mark::check!(extract_function_in_comment_is_not_applicable); + check_assist_not_applicable(extract_function, r"fn main() { 1 + /* $0comment$0 */ 1; }"); + } + + #[test] + fn part_of_expr_stmt() { + check_assist( + extract_function, + r#" +fn foo() { + $01$0 + 1; +} +"#, + r#" +fn foo() { + fun_name() + 1; +} + +fn $0fun_name() -> i32 { + 1 +} +"#, + ); + } + + #[test] + fn function_expr() { + check_assist( + extract_function, + r#" +fn foo() { + $0bar(1 + 1)$0 +} +"#, + r#" +fn foo() { + fun_name(); +} + +fn $0fun_name() { + bar(1 + 1) +} +"#, + ) + } + + #[test] + fn extract_from_nested() { + check_assist( + extract_function, + r#" +fn main() { + let x = true; + let tuple = match x { + true => ($02 + 2$0, true) + _ => (0, false) + }; +} +"#, + r#" +fn main() { + let x = true; + let tuple = match x { + true => (fun_name(), true) + _ => (0, false) + }; +} + +fn $0fun_name() -> i32 { + 2 + 2 +} +"#, + ); + } + + #[test] + fn param_from_closure() { + check_assist( + extract_function, + r#" +fn main() { + let lambda = |x: u32| $0x * 2$0; +} +"#, + r#" +fn main() { + let lambda = |x: u32| fun_name(x); +} + +fn $0fun_name(x: u32) -> u32 { + x * 2 +} +"#, + ); + } + + #[test] + fn extract_return_stmt() { + check_assist( + extract_function, + r#" +fn foo() -> u32 { + $0return 2 + 2$0; +} +"#, + r#" +fn foo() -> u32 { + return fun_name(); +} + +fn $0fun_name() -> u32 { + 2 + 2 +} +"#, + ); + } + + #[test] + fn does_not_add_extra_whitespace() { + check_assist( + extract_function, + r#" +fn foo() -> u32 { + + + $0return 2 + 2$0; +} +"#, + r#" +fn foo() -> u32 { + + + return fun_name(); +} + +fn $0fun_name() -> u32 { + 2 + 2 +} +"#, + ); + } + + #[test] + fn break_stmt() { + check_assist( + extract_function, + r#" +fn main() { + let result = loop { + $0break 2 + 2$0; + }; +} +"#, + r#" +fn main() { + let result = loop { + break fun_name(); + }; +} + +fn $0fun_name() -> i32 { + 2 + 2 +} +"#, + ); + } + + #[test] + fn extract_cast() { + check_assist( + extract_function, + r#" +fn main() { + let v = $00f32 as u32$0; +} +"#, + r#" +fn main() { + let v = fun_name(); +} + +fn $0fun_name() -> u32 { + 0f32 as u32 +} +"#, + ); + } + + #[test] + fn return_not_applicable() { + check_assist_not_applicable(extract_function, r"fn foo() { $0return$0; } "); + } + + #[test] + fn method_to_freestanding() { + check_assist( + extract_function, + r#" +struct S; + +impl S { + fn foo(&self) -> i32 { + $01+1$0 + } +} +"#, + r#" +struct S; + +impl S { + fn foo(&self) -> i32 { + fun_name() + } +} + +fn $0fun_name() -> i32 { + 1+1 +} +"#, + ); + } + + #[test] + fn method_with_reference() { + check_assist( + extract_function, + r#" +struct S { f: i32 }; + +impl S { + fn foo(&self) -> i32 { + $0self.f+self.f$0 + } +} +"#, + r#" +struct S { f: i32 }; + +impl S { + fn foo(&self) -> i32 { + self.fun_name() + } + + fn $0fun_name(&self) -> i32 { + self.f+self.f + } +} +"#, + ); + } + + #[test] + fn method_with_mut() { + check_assist( + extract_function, + r#" +struct S { f: i32 }; + +impl S { + fn foo(&mut self) { + $0self.f += 1;$0 + } +} +"#, + r#" +struct S { f: i32 }; + +impl S { + fn foo(&mut self) { + self.fun_name(); + } + + fn $0fun_name(&mut self) { + self.f += 1; + } +} +"#, + ); + } + + #[test] + fn variable_defined_inside_and_used_after_no_ret() { + check_assist( + extract_function, + r#" +fn foo() { + let n = 1; + $0let k = n * n;$0 + let m = k + 1; +} +"#, + r#" +fn foo() { + let n = 1; + let k = fun_name(n); + let m = k + 1; +} + +fn $0fun_name(n: i32) -> i32 { + let k = n * n; + k +} +"#, + ); + } + + #[test] + fn variable_defined_inside_and_used_after_mutably_no_ret() { + check_assist( + extract_function, + r#" +fn foo() { + let n = 1; + $0let mut k = n * n;$0 + k += 1; +} +"#, + r#" +fn foo() { + let n = 1; + let mut k = fun_name(n); + k += 1; +} + +fn $0fun_name(n: i32) -> i32 { + let mut k = n * n; + k +} +"#, + ); + } + + #[test] + fn two_variables_defined_inside_and_used_after_no_ret() { + check_assist( + extract_function, + r#" +fn foo() { + let n = 1; + $0let k = n * n; + let m = k + 2;$0 + let h = k + m; +} +"#, + r#" +fn foo() { + let n = 1; + let (k, m) = fun_name(n); + let h = k + m; +} + +fn $0fun_name(n: i32) -> (i32, i32) { + let k = n * n; + let m = k + 2; + (k, m) +} +"#, + ); + } + + #[test] + fn multi_variables_defined_inside_and_used_after_mutably_no_ret() { + check_assist( + extract_function, + r#" +fn foo() { + let n = 1; + $0let mut k = n * n; + let mut m = k + 2; + let mut o = m + 3; + o += 1;$0 + k += o; + m = 1; +} +"#, + r#" +fn foo() { + let n = 1; + let (mut k, mut m, o) = fun_name(n); + k += o; + m = 1; +} + +fn $0fun_name(n: i32) -> (i32, i32, i32) { + let mut k = n * n; + let mut m = k + 2; + let mut o = m + 3; + o += 1; + (k, m, o) +} +"#, + ); + } + + #[test] + fn nontrivial_patterns_define_variables() { + check_assist( + extract_function, + r#" +struct Counter(i32); +fn foo() { + $0let Counter(n) = Counter(0);$0 + let m = n; +} +"#, + r#" +struct Counter(i32); +fn foo() { + let n = fun_name(); + let m = n; +} + +fn $0fun_name() -> i32 { + let Counter(n) = Counter(0); + n +} +"#, + ); + } + + #[test] + fn struct_with_two_fields_pattern_define_variables() { + check_assist( + extract_function, + r#" +struct Counter { n: i32, m: i32 }; +fn foo() { + $0let Counter { n, m: k } = Counter { n: 1, m: 2 };$0 + let h = n + k; +} +"#, + r#" +struct Counter { n: i32, m: i32 }; +fn foo() { + let (n, k) = fun_name(); + let h = n + k; +} + +fn $0fun_name() -> (i32, i32) { + let Counter { n, m: k } = Counter { n: 1, m: 2 }; + (n, k) +} +"#, + ); + } + + #[test] + fn mut_var_from_outer_scope() { + check_assist( + extract_function, + r#" +fn foo() { + let mut n = 1; + $0n += 1;$0 + let m = n + 1; +} +"#, + r#" +fn foo() { + let mut n = 1; + fun_name(&mut n); + let m = n + 1; +} + +fn $0fun_name(n: &mut i32) { + *n += 1; +} +"#, + ); + } + + #[test] + fn mut_field_from_outer_scope() { + check_assist( + extract_function, + r#" +struct C { n: i32 } +fn foo() { + let mut c = C { n: 0 }; + $0c.n += 1;$0 + let m = c.n + 1; +} +"#, + r#" +struct C { n: i32 } +fn foo() { + let mut c = C { n: 0 }; + fun_name(&mut c); + let m = c.n + 1; +} + +fn $0fun_name(c: &mut C) { + c.n += 1; +} +"#, + ); + } + + #[test] + fn mut_nested_field_from_outer_scope() { + check_assist( + extract_function, + r#" +struct P { n: i32} +struct C { p: P } +fn foo() { + let mut c = C { p: P { n: 0 } }; + let mut v = C { p: P { n: 0 } }; + let u = C { p: P { n: 0 } }; + $0c.p.n += u.p.n; + let r = &mut v.p.n;$0 + let m = c.p.n + v.p.n + u.p.n; +} +"#, + r#" +struct P { n: i32} +struct C { p: P } +fn foo() { + let mut c = C { p: P { n: 0 } }; + let mut v = C { p: P { n: 0 } }; + let u = C { p: P { n: 0 } }; + fun_name(&mut c, &u, &mut v); + let m = c.p.n + v.p.n + u.p.n; +} + +fn $0fun_name(c: &mut C, u: &C, v: &mut C) { + c.p.n += u.p.n; + let r = &mut v.p.n; +} +"#, + ); + } + + #[test] + fn mut_param_many_usages_stmt() { + check_assist( + extract_function, + r#" +fn bar(k: i32) {} +trait I: Copy { + fn succ(&self) -> Self; + fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v } +} +impl I for i32 { + fn succ(&self) -> Self { *self + 1 } +} +fn foo() { + let mut n = 1; + $0n += n; + bar(n); + bar(n+1); + bar(n*n); + bar(&n); + n.inc(); + let v = &mut n; + *v = v.succ(); + n.succ();$0 + let m = n + 1; +} +"#, + r#" +fn bar(k: i32) {} +trait I: Copy { + fn succ(&self) -> Self; + fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v } +} +impl I for i32 { + fn succ(&self) -> Self { *self + 1 } +} +fn foo() { + let mut n = 1; + fun_name(&mut n); + let m = n + 1; +} + +fn $0fun_name(n: &mut i32) { + *n += *n; + bar(*n); + bar(*n+1); + bar(*n**n); + bar(&*n); + n.inc(); + let v = n; + *v = v.succ(); + n.succ(); +} +"#, + ); + } + + #[test] + fn mut_param_many_usages_expr() { + check_assist( + extract_function, + r#" +fn bar(k: i32) {} +trait I: Copy { + fn succ(&self) -> Self; + fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v } +} +impl I for i32 { + fn succ(&self) -> Self { *self + 1 } +} +fn foo() { + let mut n = 1; + $0{ + n += n; + bar(n); + bar(n+1); + bar(n*n); + bar(&n); + n.inc(); + let v = &mut n; + *v = v.succ(); + n.succ(); + }$0 + let m = n + 1; +} +"#, + r#" +fn bar(k: i32) {} +trait I: Copy { + fn succ(&self) -> Self; + fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v } +} +impl I for i32 { + fn succ(&self) -> Self { *self + 1 } +} +fn foo() { + let mut n = 1; + fun_name(&mut n); + let m = n + 1; +} + +fn $0fun_name(n: &mut i32) { + *n += *n; + bar(*n); + bar(*n+1); + bar(*n**n); + bar(&*n); + n.inc(); + let v = n; + *v = v.succ(); + n.succ(); +} +"#, + ); + } + + #[test] + fn mut_param_by_value() { + check_assist( + extract_function, + r#" +fn foo() { + let mut n = 1; + $0n += 1;$0 +} +"#, + r" +fn foo() { + let mut n = 1; + fun_name(n); +} + +fn $0fun_name(mut n: i32) { + n += 1; +} +", + ); + } + + #[test] + fn mut_param_because_of_mut_ref() { + check_assist( + extract_function, + r#" +fn foo() { + let mut n = 1; + $0let v = &mut n; + *v += 1;$0 + let k = n; +} +"#, + r#" +fn foo() { + let mut n = 1; + fun_name(&mut n); + let k = n; +} + +fn $0fun_name(n: &mut i32) { + let v = n; + *v += 1; +} +"#, + ); + } + + #[test] + fn mut_param_by_value_because_of_mut_ref() { + check_assist( + extract_function, + r" +fn foo() { + let mut n = 1; + $0let v = &mut n; + *v += 1;$0 +} +", + r#" +fn foo() { + let mut n = 1; + fun_name(n); +} + +fn $0fun_name(mut n: i32) { + let v = &mut n; + *v += 1; +} +"#, + ); + } + + #[test] + fn mut_method_call() { + check_assist( + extract_function, + r#" +trait I { + fn inc(&mut self); +} +impl I for i32 { + fn inc(&mut self) { *self += 1 } +} +fn foo() { + let mut n = 1; + $0n.inc();$0 +} +"#, + r#" +trait I { + fn inc(&mut self); +} +impl I for i32 { + fn inc(&mut self) { *self += 1 } +} +fn foo() { + let mut n = 1; + fun_name(n); +} + +fn $0fun_name(mut n: i32) { + n.inc(); +} +"#, + ); + } + + #[test] + fn shared_method_call() { + check_assist( + extract_function, + r#" +trait I { + fn succ(&self); +} +impl I for i32 { + fn succ(&self) { *self + 1 } +} +fn foo() { + let mut n = 1; + $0n.succ();$0 +} +"#, + r" +trait I { + fn succ(&self); +} +impl I for i32 { + fn succ(&self) { *self + 1 } +} +fn foo() { + let mut n = 1; + fun_name(n); +} + +fn $0fun_name(n: i32) { + n.succ(); +} +", + ); + } + + #[test] + fn mut_method_call_with_other_receiver() { + check_assist( + extract_function, + r#" +trait I { + fn inc(&mut self, n: i32); +} +impl I for i32 { + fn inc(&mut self, n: i32) { *self += n } +} +fn foo() { + let mut n = 1; + $0let mut m = 2; + m.inc(n);$0 +} +"#, + r" +trait I { + fn inc(&mut self, n: i32); +} +impl I for i32 { + fn inc(&mut self, n: i32) { *self += n } +} +fn foo() { + let mut n = 1; + fun_name(n); +} + +fn $0fun_name(n: i32) { + let mut m = 2; + m.inc(n); +} +", + ); + } + + #[test] + fn non_copy_without_usages_after() { + check_assist( + extract_function, + r#" +struct Counter(i32); +fn foo() { + let c = Counter(0); + $0let n = c.0;$0 +} +"#, + r" +struct Counter(i32); +fn foo() { + let c = Counter(0); + fun_name(c); +} + +fn $0fun_name(c: Counter) { + let n = c.0; +} +", + ); + } + + #[test] + fn non_copy_used_after() { + check_assist( + extract_function, + r" +struct Counter(i32); +fn foo() { + let c = Counter(0); + $0let n = c.0;$0 + let m = c.0; +} +", + r#" +struct Counter(i32); +fn foo() { + let c = Counter(0); + fun_name(&c); + let m = c.0; +} + +fn $0fun_name(c: &Counter) { + let n = c.0; +} +"#, + ); + } + + #[test] + fn copy_used_after() { + check_assist( + extract_function, + r#" +//- minicore: copy +fn foo() { + let n = 0; + $0let m = n;$0 + let k = n; +} +"#, + r#" +fn foo() { + let n = 0; + fun_name(n); + let k = n; +} + +fn $0fun_name(n: i32) { + let m = n; +} +"#, + ) + } + + #[test] + fn copy_custom_used_after() { + check_assist( + extract_function, + r#" +//- minicore: copy, derive +#[derive(Clone, Copy)] +struct Counter(i32); +fn foo() { + let c = Counter(0); + $0let n = c.0;$0 + let m = c.0; +} +"#, + r#" +#[derive(Clone, Copy)] +struct Counter(i32); +fn foo() { + let c = Counter(0); + fun_name(c); + let m = c.0; +} + +fn $0fun_name(c: Counter) { + let n = c.0; +} +"#, + ); + } + + #[test] + fn indented_stmts() { + check_assist( + extract_function, + r#" +fn foo() { + if true { + loop { + $0let n = 1; + let m = 2;$0 + } + } +} +"#, + r#" +fn foo() { + if true { + loop { + fun_name(); + } + } +} + +fn $0fun_name() { + let n = 1; + let m = 2; +} +"#, + ); + } + + #[test] + fn indented_stmts_inside_mod() { + check_assist( + extract_function, + r#" +mod bar { + fn foo() { + if true { + loop { + $0let n = 1; + let m = 2;$0 + } + } + } +} +"#, + r#" +mod bar { + fn foo() { + if true { + loop { + fun_name(); + } + } + } + + fn $0fun_name() { + let n = 1; + let m = 2; + } +} +"#, + ); + } + + #[test] + fn break_loop() { + check_assist( + extract_function, + r#" +//- minicore: option +fn foo() { + loop { + let n = 1; + $0let m = n + 1; + break; + let k = 2;$0 + let h = 1 + k; + } +} +"#, + r#" +fn foo() { + loop { + let n = 1; + let k = match fun_name(n) { + Some(value) => value, + None => break, + }; + let h = 1 + k; + } +} + +fn $0fun_name(n: i32) -> Option { + let m = n + 1; + return None; + let k = 2; + Some(k) +} +"#, + ); + } + + #[test] + fn return_to_parent() { + check_assist( + extract_function, + r#" +//- minicore: copy, result +fn foo() -> i64 { + let n = 1; + $0let m = n + 1; + return 1; + let k = 2;$0 + (n + k) as i64 +} +"#, + r#" +fn foo() -> i64 { + let n = 1; + let k = match fun_name(n) { + Ok(value) => value, + Err(value) => return value, + }; + (n + k) as i64 +} + +fn $0fun_name(n: i32) -> Result { + let m = n + 1; + return Err(1); + let k = 2; + Ok(k) +} +"#, + ); + } + + #[test] + fn break_and_continue() { + cov_mark::check!(external_control_flow_break_and_continue); + check_assist_not_applicable( + extract_function, + r#" +fn foo() { + loop { + let n = 1; + $0let m = n + 1; + break; + let k = 2; + continue; + let k = k + 1;$0 + let r = n + k; + } +} +"#, + ); + } + + #[test] + fn return_and_break() { + cov_mark::check!(external_control_flow_return_and_bc); + check_assist_not_applicable( + extract_function, + r#" +fn foo() { + loop { + let n = 1; + $0let m = n + 1; + break; + let k = 2; + return; + let k = k + 1;$0 + let r = n + k; + } +} +"#, + ); + } + + #[test] + fn break_loop_with_if() { + check_assist( + extract_function, + r#" +//- minicore: try +fn foo() { + loop { + let mut n = 1; + $0let m = n + 1; + break; + n += m;$0 + let h = 1 + n; + } +} +"#, + r#" +use core::ops::ControlFlow; + +fn foo() { + loop { + let mut n = 1; + if let ControlFlow::Break(_) = fun_name(&mut n) { + break; + } + let h = 1 + n; + } +} + +fn $0fun_name(n: &mut i32) -> ControlFlow<()> { + let m = *n + 1; + return ControlFlow::Break(()); + *n += m; + ControlFlow::Continue(()) +} +"#, + ); + } + + #[test] + fn break_loop_nested() { + check_assist( + extract_function, + r#" +//- minicore: try +fn foo() { + loop { + let mut n = 1; + $0let m = n + 1; + if m == 42 { + break; + }$0 + let h = 1; + } +} +"#, + r#" +use core::ops::ControlFlow; + +fn foo() { + loop { + let mut n = 1; + if let ControlFlow::Break(_) = fun_name(n) { + break; + } + let h = 1; + } +} + +fn $0fun_name(n: i32) -> ControlFlow<()> { + let m = n + 1; + if m == 42 { + return ControlFlow::Break(()); + } + ControlFlow::Continue(()) +} +"#, + ); + } + + #[test] + fn break_loop_nested_labeled() { + check_assist( + extract_function, + r#" +//- minicore: try +fn foo() { + 'bar: loop { + loop { + $0break 'bar;$0 + } + } +} +"#, + r#" +use core::ops::ControlFlow; + +fn foo() { + 'bar: loop { + loop { + if let ControlFlow::Break(_) = fun_name() { + break 'bar; + } + } + } +} + +fn $0fun_name() -> ControlFlow<()> { + return ControlFlow::Break(()); + ControlFlow::Continue(()) +} +"#, + ); + } + + #[test] + fn continue_loop_nested_labeled() { + check_assist( + extract_function, + r#" +//- minicore: try +fn foo() { + 'bar: loop { + loop { + $0continue 'bar;$0 + } + } +} +"#, + r#" +use core::ops::ControlFlow; + +fn foo() { + 'bar: loop { + loop { + if let ControlFlow::Break(_) = fun_name() { + continue 'bar; + } + } + } +} + +fn $0fun_name() -> ControlFlow<()> { + return ControlFlow::Break(()); + ControlFlow::Continue(()) +} +"#, + ); + } + + #[test] + fn return_from_nested_loop() { + check_assist( + extract_function, + r#" +fn foo() { + loop { + let n = 1;$0 + let k = 1; + loop { + return; + } + let m = k + 1;$0 + let h = 1 + m; + } +} +"#, + r#" +fn foo() { + loop { + let n = 1; + let m = match fun_name() { + Some(value) => value, + None => return, + }; + let h = 1 + m; + } +} + +fn $0fun_name() -> Option { + let k = 1; + loop { + return None; + } + let m = k + 1; + Some(m) +} +"#, + ); + } + + #[test] + fn break_from_nested_loop() { + check_assist( + extract_function, + r#" +fn foo() { + loop { + let n = 1; + $0let k = 1; + loop { + break; + } + let m = k + 1;$0 + let h = 1 + m; + } +} +"#, + r#" +fn foo() { + loop { + let n = 1; + let m = fun_name(); + let h = 1 + m; + } +} + +fn $0fun_name() -> i32 { + let k = 1; + loop { + break; + } + let m = k + 1; + m +} +"#, + ); + } + + #[test] + fn break_from_nested_and_outer_loops() { + check_assist( + extract_function, + r#" +fn foo() { + loop { + let n = 1; + $0let k = 1; + loop { + break; + } + if k == 42 { + break; + } + let m = k + 1;$0 + let h = 1 + m; + } +} +"#, + r#" +fn foo() { + loop { + let n = 1; + let m = match fun_name() { + Some(value) => value, + None => break, + }; + let h = 1 + m; + } +} + +fn $0fun_name() -> Option { + let k = 1; + loop { + break; + } + if k == 42 { + return None; + } + let m = k + 1; + Some(m) +} +"#, + ); + } + + #[test] + fn return_from_nested_fn() { + check_assist( + extract_function, + r#" +fn foo() { + loop { + let n = 1; + $0let k = 1; + fn test() { + return; + } + let m = k + 1;$0 + let h = 1 + m; + } +} +"#, + r#" +fn foo() { + loop { + let n = 1; + let m = fun_name(); + let h = 1 + m; + } +} + +fn $0fun_name() -> i32 { + let k = 1; + fn test() { + return; + } + let m = k + 1; + m +} +"#, + ); + } + + #[test] + fn break_with_value() { + check_assist( + extract_function, + r#" +fn foo() -> i32 { + loop { + let n = 1; + $0let k = 1; + if k == 42 { + break 3; + } + let m = k + 1;$0 + let h = 1; + } +} +"#, + r#" +fn foo() -> i32 { + loop { + let n = 1; + if let Some(value) = fun_name() { + break value; + } + let h = 1; + } +} + +fn $0fun_name() -> Option { + let k = 1; + if k == 42 { + return Some(3); + } + let m = k + 1; + None +} +"#, + ); + } + + #[test] + fn break_with_value_and_label() { + check_assist( + extract_function, + r#" +fn foo() -> i32 { + 'bar: loop { + let n = 1; + $0let k = 1; + if k == 42 { + break 'bar 4; + } + let m = k + 1;$0 + let h = 1; + } +} +"#, + r#" +fn foo() -> i32 { + 'bar: loop { + let n = 1; + if let Some(value) = fun_name() { + break 'bar value; + } + let h = 1; + } +} + +fn $0fun_name() -> Option { + let k = 1; + if k == 42 { + return Some(4); + } + let m = k + 1; + None +} +"#, + ); + } + + #[test] + fn break_with_value_and_return() { + check_assist( + extract_function, + r#" +fn foo() -> i64 { + loop { + let n = 1;$0 + let k = 1; + if k == 42 { + break 3; + } + let m = k + 1;$0 + let h = 1 + m; + } +} +"#, + r#" +fn foo() -> i64 { + loop { + let n = 1; + let m = match fun_name() { + Ok(value) => value, + Err(value) => break value, + }; + let h = 1 + m; + } +} + +fn $0fun_name() -> Result { + let k = 1; + if k == 42 { + return Err(3); + } + let m = k + 1; + Ok(m) +} +"#, + ); + } + + #[test] + fn try_option() { + check_assist( + extract_function, + r#" +//- minicore: option +fn bar() -> Option { None } +fn foo() -> Option<()> { + let n = bar()?; + $0let k = foo()?; + let m = k + 1;$0 + let h = 1 + m; + Some(()) +} +"#, + r#" +fn bar() -> Option { None } +fn foo() -> Option<()> { + let n = bar()?; + let m = fun_name()?; + let h = 1 + m; + Some(()) +} + +fn $0fun_name() -> Option { + let k = foo()?; + let m = k + 1; + Some(m) +} +"#, + ); + } + + #[test] + fn try_option_unit() { + check_assist( + extract_function, + r#" +//- minicore: option +fn foo() -> Option<()> { + let n = 1; + $0let k = foo()?; + let m = k + 1;$0 + let h = 1 + n; + Some(()) +} +"#, + r#" +fn foo() -> Option<()> { + let n = 1; + fun_name()?; + let h = 1 + n; + Some(()) +} + +fn $0fun_name() -> Option<()> { + let k = foo()?; + let m = k + 1; + Some(()) +} +"#, + ); + } + + #[test] + fn try_result() { + check_assist( + extract_function, + r#" +//- minicore: result +fn foo() -> Result<(), i64> { + let n = 1; + $0let k = foo()?; + let m = k + 1;$0 + let h = 1 + m; + Ok(()) +} +"#, + r#" +fn foo() -> Result<(), i64> { + let n = 1; + let m = fun_name()?; + let h = 1 + m; + Ok(()) +} + +fn $0fun_name() -> Result { + let k = foo()?; + let m = k + 1; + Ok(m) +} +"#, + ); + } + + #[test] + fn try_option_with_return() { + check_assist( + extract_function, + r#" +//- minicore: option +fn foo() -> Option<()> { + let n = 1; + $0let k = foo()?; + if k == 42 { + return None; + } + let m = k + 1;$0 + let h = 1 + m; + Some(()) +} +"#, + r#" +fn foo() -> Option<()> { + let n = 1; + let m = fun_name()?; + let h = 1 + m; + Some(()) +} + +fn $0fun_name() -> Option { + let k = foo()?; + if k == 42 { + return None; + } + let m = k + 1; + Some(m) +} +"#, + ); + } + + #[test] + fn try_result_with_return() { + check_assist( + extract_function, + r#" +//- minicore: result +fn foo() -> Result<(), i64> { + let n = 1; + $0let k = foo()?; + if k == 42 { + return Err(1); + } + let m = k + 1;$0 + let h = 1 + m; + Ok(()) +} +"#, + r#" +fn foo() -> Result<(), i64> { + let n = 1; + let m = fun_name()?; + let h = 1 + m; + Ok(()) +} + +fn $0fun_name() -> Result { + let k = foo()?; + if k == 42 { + return Err(1); + } + let m = k + 1; + Ok(m) +} +"#, + ); + } + + #[test] + fn try_and_break() { + cov_mark::check!(external_control_flow_try_and_bc); + check_assist_not_applicable( + extract_function, + r#" +//- minicore: option +fn foo() -> Option<()> { + loop { + let n = Some(1); + $0let m = n? + 1; + break; + let k = 2; + let k = k + 1;$0 + let r = n + k; + } + Some(()) +} +"#, + ); + } + + #[test] + fn try_and_return_ok() { + check_assist( + extract_function, + r#" +//- minicore: result +fn foo() -> Result<(), i64> { + let n = 1; + $0let k = foo()?; + if k == 42 { + return Ok(1); + } + let m = k + 1;$0 + let h = 1 + m; + Ok(()) +} +"#, + r#" +fn foo() -> Result<(), i64> { + let n = 1; + let m = fun_name()?; + let h = 1 + m; + Ok(()) +} + +fn $0fun_name() -> Result { + let k = foo()?; + if k == 42 { + return Ok(1); + } + let m = k + 1; + Ok(m) +} +"#, + ); + } + + #[test] + fn param_usage_in_macro() { + check_assist( + extract_function, + r#" +macro_rules! m { + ($val:expr) => { $val }; +} + +fn foo() { + let n = 1; + $0let k = n * m!(n);$0 + let m = k + 1; +} +"#, + r#" +macro_rules! m { + ($val:expr) => { $val }; +} + +fn foo() { + let n = 1; + let k = fun_name(n); + let m = k + 1; +} + +fn $0fun_name(n: i32) -> i32 { + let k = n * m!(n); + k +} +"#, + ); + } + + #[test] + fn extract_with_await() { + check_assist( + extract_function, + r#" +//- minicore: future +fn main() { + $0some_function().await;$0 +} + +async fn some_function() { + +} +"#, + r#" +fn main() { + fun_name().await; +} + +async fn $0fun_name() { + some_function().await; +} + +async fn some_function() { + +} +"#, + ); + } + + #[test] + fn extract_with_await_and_result_not_producing_match_expr() { + check_assist( + extract_function, + r#" +//- minicore: future, result +async fn foo() -> Result<(), ()> { + $0async {}.await; + Err(())?$0 +} +"#, + r#" +async fn foo() -> Result<(), ()> { + fun_name().await? +} + +async fn $0fun_name() -> Result<(), ()> { + async {}.await; + Err(())? +} +"#, + ); + } + + #[test] + fn extract_with_await_and_result_producing_match_expr() { + check_assist( + extract_function, + r#" +//- minicore: future +async fn foo() -> i32 { + loop { + let n = 1;$0 + let k = async { 1 }.await; + if k == 42 { + break 3; + } + let m = k + 1;$0 + let h = 1 + m; + } +} +"#, + r#" +async fn foo() -> i32 { + loop { + let n = 1; + let m = match fun_name().await { + Ok(value) => value, + Err(value) => break value, + }; + let h = 1 + m; + } +} + +async fn $0fun_name() -> Result { + let k = async { 1 }.await; + if k == 42 { + return Err(3); + } + let m = k + 1; + Ok(m) +} +"#, + ); + } + + #[test] + fn extract_with_await_in_args() { + check_assist( + extract_function, + r#" +//- minicore: future +fn main() { + $0function_call("a", some_function().await);$0 +} + +async fn some_function() { + +} +"#, + r#" +fn main() { + fun_name().await; +} + +async fn $0fun_name() { + function_call("a", some_function().await); +} + +async fn some_function() { + +} +"#, + ); + } + + #[test] + fn extract_does_not_extract_standalone_blocks() { + check_assist_not_applicable( + extract_function, + r#" +fn main() $0{}$0 +"#, + ); + } + + #[test] + fn extract_adds_comma_for_match_arm() { + check_assist( + extract_function, + r#" +fn main() { + match 6 { + 100 => $0{ 100 }$0 + _ => 0, + }; +} +"#, + r#" +fn main() { + match 6 { + 100 => fun_name(), + _ => 0, + }; +} + +fn $0fun_name() -> i32 { + 100 +} +"#, + ); + check_assist( + extract_function, + r#" +fn main() { + match 6 { + 100 => $0{ 100 }$0, + _ => 0, + }; +} +"#, + r#" +fn main() { + match 6 { + 100 => fun_name(), + _ => 0, + }; +} + +fn $0fun_name() -> i32 { + 100 +} +"#, + ); + } + + #[test] + fn extract_does_not_tear_comments_apart() { + check_assist( + extract_function, + r#" +fn foo() { + /*$0*/ + foo(); + foo(); + /*$0*/ +} +"#, + r#" +fn foo() { + fun_name(); +} + +fn $0fun_name() { + /**/ + foo(); + foo(); + /**/ +} +"#, + ); + } + + #[test] + fn extract_does_not_tear_body_apart() { + check_assist( + extract_function, + r#" +fn foo() { + $0foo(); +}$0 +"#, + r#" +fn foo() { + fun_name(); +} + +fn $0fun_name() { + foo(); +} +"#, + ); + } + + #[test] + fn extract_does_not_wrap_res_in_res() { + check_assist( + extract_function, + r#" +//- minicore: result +fn foo() -> Result<(), i64> { + $0Result::::Ok(0)?; + Ok(())$0 +} +"#, + r#" +fn foo() -> Result<(), i64> { + fun_name()? +} + +fn $0fun_name() -> Result<(), i64> { + Result::::Ok(0)?; + Ok(()) +} +"#, + ); + } + + #[test] + fn extract_knows_const() { + check_assist( + extract_function, + r#" +const fn foo() { + $0()$0 +} +"#, + r#" +const fn foo() { + fun_name(); +} + +const fn $0fun_name() { + () +} +"#, + ); + check_assist( + extract_function, + r#" +const FOO: () = { + $0()$0 +}; +"#, + r#" +const FOO: () = { + fun_name(); +}; + +const fn $0fun_name() { + () +} +"#, + ); + } + + #[test] + fn extract_does_not_move_outer_loop_vars() { + check_assist( + extract_function, + r#" +fn foo() { + let mut x = 5; + for _ in 0..10 { + $0x += 1;$0 + } +} +"#, + r#" +fn foo() { + let mut x = 5; + for _ in 0..10 { + fun_name(&mut x); + } +} + +fn $0fun_name(x: &mut i32) { + *x += 1; +} +"#, + ); + check_assist( + extract_function, + r#" +fn foo() { + for _ in 0..10 { + let mut x = 5; + $0x += 1;$0 + } +} +"#, + r#" +fn foo() { + for _ in 0..10 { + let mut x = 5; + fun_name(x); + } +} + +fn $0fun_name(mut x: i32) { + x += 1; +} +"#, + ); + check_assist( + extract_function, + r#" +fn foo() { + loop { + let mut x = 5; + for _ in 0..10 { + $0x += 1;$0 + } + } +} +"#, + r#" +fn foo() { + loop { + let mut x = 5; + for _ in 0..10 { + fun_name(&mut x); + } + } +} + +fn $0fun_name(x: &mut i32) { + *x += 1; +} +"#, + ); + } + + // regression test for #9822 + #[test] + fn extract_mut_ref_param_has_no_mut_binding_in_loop() { + check_assist( + extract_function, + r#" +struct Foo; +impl Foo { + fn foo(&mut self) {} +} +fn foo() { + let mut x = Foo; + while false { + let y = &mut x; + $0y.foo();$0 + } + let z = x; +} +"#, + r#" +struct Foo; +impl Foo { + fn foo(&mut self) {} +} +fn foo() { + let mut x = Foo; + while false { + let y = &mut x; + fun_name(y); + } + let z = x; +} + +fn $0fun_name(y: &mut Foo) { + y.foo(); +} +"#, + ); + } + + #[test] + fn extract_with_macro_arg() { + check_assist( + extract_function, + r#" +macro_rules! m { + ($val:expr) => { $val }; +} +fn main() { + let bar = "bar"; + $0m!(bar);$0 +} +"#, + r#" +macro_rules! m { + ($val:expr) => { $val }; +} +fn main() { + let bar = "bar"; + fun_name(bar); +} + +fn $0fun_name(bar: &str) { + m!(bar); +} +"#, + ); + } + + #[test] + fn unresolveable_types_default_to_placeholder() { + check_assist( + extract_function, + r#" +fn foo() { + let a = __unresolved; + let _ = $0{a}$0; +} +"#, + r#" +fn foo() { + let a = __unresolved; + let _ = fun_name(a); +} + +fn $0fun_name(a: _) -> _ { + a +} +"#, + ); + } + + #[test] + fn reference_mutable_param_with_further_usages() { + check_assist( + extract_function, + r#" +pub struct Foo { + field: u32, +} + +pub fn testfn(arg: &mut Foo) { + $0arg.field = 8;$0 + // Simulating access after the extracted portion + arg.field = 16; +} +"#, + r#" +pub struct Foo { + field: u32, +} + +pub fn testfn(arg: &mut Foo) { + fun_name(arg); + // Simulating access after the extracted portion + arg.field = 16; +} + +fn $0fun_name(arg: &mut Foo) { + arg.field = 8; +} +"#, + ); + } + + #[test] + fn reference_mutable_param_without_further_usages() { + check_assist( + extract_function, + r#" +pub struct Foo { + field: u32, +} + +pub fn testfn(arg: &mut Foo) { + $0arg.field = 8;$0 +} +"#, + r#" +pub struct Foo { + field: u32, +} + +pub fn testfn(arg: &mut Foo) { + fun_name(arg); +} + +fn $0fun_name(arg: &mut Foo) { + arg.field = 8; +} +"#, + ); + } + + #[test] + fn extract_function_copies_comment_at_start() { + check_assist( + extract_function, + r#" +fn func() { + let i = 0; + $0// comment here! + let x = 0;$0 +} +"#, + r#" +fn func() { + let i = 0; + fun_name(); +} + +fn $0fun_name() { + // comment here! + let x = 0; +} +"#, + ); + } + + #[test] + fn extract_function_copies_comment_in_between() { + check_assist( + extract_function, + r#" +fn func() { + let i = 0;$0 + let a = 0; + // comment here! + let x = 0;$0 +} +"#, + r#" +fn func() { + let i = 0; + fun_name(); +} + +fn $0fun_name() { + let a = 0; + // comment here! + let x = 0; +} +"#, + ); + } + + #[test] + fn extract_function_copies_comment_at_end() { + check_assist( + extract_function, + r#" +fn func() { + let i = 0; + $0let x = 0; + // comment here!$0 +} +"#, + r#" +fn func() { + let i = 0; + fun_name(); +} + +fn $0fun_name() { + let x = 0; + // comment here! +} +"#, + ); + } + + #[test] + fn extract_function_copies_comment_indented() { + check_assist( + extract_function, + r#" +fn func() { + let i = 0; + $0let x = 0; + while(true) { + // comment here! + }$0 +} +"#, + r#" +fn func() { + let i = 0; + fun_name(); +} + +fn $0fun_name() { + let x = 0; + while(true) { + // comment here! + } +} +"#, + ); + } + + // FIXME: we do want to preserve whitespace + #[test] + fn extract_function_does_not_preserve_whitespace() { + check_assist( + extract_function, + r#" +fn func() { + let i = 0; + $0let a = 0; + + let x = 0;$0 +} +"#, + r#" +fn func() { + let i = 0; + fun_name(); +} + +fn $0fun_name() { + let a = 0; + let x = 0; +} +"#, + ); + } + + #[test] + fn extract_function_long_form_comment() { + check_assist( + extract_function, + r#" +fn func() { + let i = 0; + $0/* a comment */ + let x = 0;$0 +} +"#, + r#" +fn func() { + let i = 0; + fun_name(); +} + +fn $0fun_name() { + /* a comment */ + let x = 0; +} +"#, + ); + } + + #[test] + fn it_should_not_generate_duplicate_function_names() { + check_assist( + extract_function, + r#" +fn fun_name() { + $0let x = 0;$0 +} +"#, + r#" +fn fun_name() { + fun_name1(); +} + +fn $0fun_name1() { + let x = 0; +} +"#, + ); + } + + #[test] + fn should_increment_suffix_until_it_finds_space() { + check_assist( + extract_function, + r#" +fn fun_name1() { + let y = 0; +} + +fn fun_name() { + $0let x = 0;$0 +} +"#, + r#" +fn fun_name1() { + let y = 0; +} + +fn fun_name() { + fun_name2(); +} + +fn $0fun_name2() { + let x = 0; +} +"#, + ); + } + + #[test] + fn extract_method_from_trait_impl() { + check_assist( + extract_function, + r#" +struct Struct(i32); +trait Trait { + fn bar(&self) -> i32; +} + +impl Trait for Struct { + fn bar(&self) -> i32 { + $0self.0 + 2$0 + } +} +"#, + r#" +struct Struct(i32); +trait Trait { + fn bar(&self) -> i32; +} + +impl Trait for Struct { + fn bar(&self) -> i32 { + self.fun_name() + } +} + +impl Struct { + fn $0fun_name(&self) -> i32 { + self.0 + 2 + } +} +"#, + ); + } + + #[test] + fn closure_arguments() { + check_assist( + extract_function, + r#" +fn parent(factor: i32) { + let v = &[1, 2, 3]; + + $0v.iter().map(|it| it * factor);$0 +} +"#, + r#" +fn parent(factor: i32) { + let v = &[1, 2, 3]; + + fun_name(v, factor); +} + +fn $0fun_name(v: &[i32; 3], factor: i32) { + v.iter().map(|it| it * factor); +} +"#, + ); + } + + #[test] + fn preserve_generics() { + check_assist( + extract_function, + r#" +fn func(i: T) { + $0foo(i);$0 +} +"#, + r#" +fn func(i: T) { + fun_name(i); +} + +fn $0fun_name(i: T) { + foo(i); +} +"#, + ); + } + + #[test] + fn preserve_generics_from_body() { + check_assist( + extract_function, + r#" +fn func() -> T { + $0T::default()$0 +} +"#, + r#" +fn func() -> T { + fun_name() +} + +fn $0fun_name() -> T { + T::default() +} +"#, + ); + } + + #[test] + fn filter_unused_generics() { + check_assist( + extract_function, + r#" +fn func(i: T, u: U) { + bar(u); + $0foo(i);$0 +} +"#, + r#" +fn func(i: T, u: U) { + bar(u); + fun_name(i); +} + +fn $0fun_name(i: T) { + foo(i); +} +"#, + ); + } + + #[test] + fn empty_generic_param_list() { + check_assist( + extract_function, + r#" +fn func(t: T, i: u32) { + bar(t); + $0foo(i);$0 +} +"#, + r#" +fn func(t: T, i: u32) { + bar(t); + fun_name(i); +} + +fn $0fun_name(i: u32) { + foo(i); +} +"#, + ); + } + + #[test] + fn preserve_where_clause() { + check_assist( + extract_function, + r#" +fn func(i: T) where T: Debug { + $0foo(i);$0 +} +"#, + r#" +fn func(i: T) where T: Debug { + fun_name(i); +} + +fn $0fun_name(i: T) where T: Debug { + foo(i); +} +"#, + ); + } + + #[test] + fn filter_unused_where_clause() { + check_assist( + extract_function, + r#" +fn func(i: T, u: U) where T: Debug, U: Copy { + bar(u); + $0foo(i);$0 +} +"#, + r#" +fn func(i: T, u: U) where T: Debug, U: Copy { + bar(u); + fun_name(i); +} + +fn $0fun_name(i: T) where T: Debug { + foo(i); +} +"#, + ); + } + + #[test] + fn nested_generics() { + check_assist( + extract_function, + r#" +struct Struct>(T); +impl + Copy> Struct { + fn func>(&self, v: V) -> i32 { + let t = self.0; + $0t.into() + v.into()$0 + } +} +"#, + r#" +struct Struct>(T); +impl + Copy> Struct { + fn func>(&self, v: V) -> i32 { + let t = self.0; + fun_name(t, v) + } +} + +fn $0fun_name + Copy, V: Into>(t: T, v: V) -> i32 { + t.into() + v.into() +} +"#, + ); + } + + #[test] + fn filters_unused_nested_generics() { + check_assist( + extract_function, + r#" +struct Struct, U: Debug>(T, U); +impl + Copy, U: Debug> Struct { + fn func>(&self, v: V) -> i32 { + let t = self.0; + $0t.into() + v.into()$0 + } +} +"#, + r#" +struct Struct, U: Debug>(T, U); +impl + Copy, U: Debug> Struct { + fn func>(&self, v: V) -> i32 { + let t = self.0; + fun_name(t, v) + } +} + +fn $0fun_name + Copy, V: Into>(t: T, v: V) -> i32 { + t.into() + v.into() +} +"#, + ); + } + + #[test] + fn nested_where_clauses() { + check_assist( + extract_function, + r#" +struct Struct(T) where T: Into; +impl Struct where T: Into + Copy { + fn func(&self, v: V) -> i32 where V: Into { + let t = self.0; + $0t.into() + v.into()$0 + } +} +"#, + r#" +struct Struct(T) where T: Into; +impl Struct where T: Into + Copy { + fn func(&self, v: V) -> i32 where V: Into { + let t = self.0; + fun_name(t, v) + } +} + +fn $0fun_name(t: T, v: V) -> i32 where T: Into + Copy, V: Into { + t.into() + v.into() +} +"#, + ); + } + + #[test] + fn filters_unused_nested_where_clauses() { + check_assist( + extract_function, + r#" +struct Struct(T, U) where T: Into, U: Debug; +impl Struct where T: Into + Copy, U: Debug { + fn func(&self, v: V) -> i32 where V: Into { + let t = self.0; + $0t.into() + v.into()$0 + } +} +"#, + r#" +struct Struct(T, U) where T: Into, U: Debug; +impl Struct where T: Into + Copy, U: Debug { + fn func(&self, v: V) -> i32 where V: Into { + let t = self.0; + fun_name(t, v) + } +} + +fn $0fun_name(t: T, v: V) -> i32 where T: Into + Copy, V: Into { + t.into() + v.into() +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_module.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_module.rs new file mode 100644 index 000000000..b3c4d306a --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_module.rs @@ -0,0 +1,1770 @@ +use std::{ + collections::{HashMap, HashSet}, + iter, +}; + +use hir::{HasSource, ModuleSource}; +use ide_db::{ + assists::{AssistId, AssistKind}, + base_db::FileId, + defs::{Definition, NameClass, NameRefClass}, + search::{FileReference, SearchScope}, +}; +use stdx::format_to; +use syntax::{ + algo::find_node_at_range, + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + make, HasName, HasVisibility, + }, + match_ast, ted, AstNode, SourceFile, + SyntaxKind::{self, WHITESPACE}, + SyntaxNode, TextRange, +}; + +use crate::{AssistContext, Assists}; + +use super::remove_unused_param::range_to_remove; + +// Assist: extract_module +// +// Extracts a selected region as seperate module. All the references, visibility and imports are +// resolved. +// +// ``` +// $0fn foo(name: i32) -> i32 { +// name + 1 +// }$0 +// +// fn bar(name: i32) -> i32 { +// name + 2 +// } +// ``` +// -> +// ``` +// mod modname { +// pub(crate) fn foo(name: i32) -> i32 { +// name + 1 +// } +// } +// +// fn bar(name: i32) -> i32 { +// name + 2 +// } +// ``` +pub(crate) fn extract_module(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + if ctx.has_empty_selection() { + return None; + } + + let node = ctx.covering_element(); + let node = match node { + syntax::NodeOrToken::Node(n) => n, + syntax::NodeOrToken::Token(t) => t.parent()?, + }; + + //If the selection is inside impl block, we need to place new module outside impl block, + //as impl blocks cannot contain modules + + let mut impl_parent: Option = None; + let mut impl_child_count: usize = 0; + if let Some(parent_assoc_list) = node.parent() { + if let Some(parent_impl) = parent_assoc_list.parent() { + if let Some(impl_) = ast::Impl::cast(parent_impl) { + impl_child_count = parent_assoc_list.children().count(); + impl_parent = Some(impl_); + } + } + } + + let mut curr_parent_module: Option = None; + if let Some(mod_syn_opt) = node.ancestors().find(|it| ast::Module::can_cast(it.kind())) { + curr_parent_module = ast::Module::cast(mod_syn_opt); + } + + let mut module = extract_target(&node, ctx.selection_trimmed())?; + if module.body_items.is_empty() { + return None; + } + + let old_item_indent = module.body_items[0].indent_level(); + + acc.add( + AssistId("extract_module", AssistKind::RefactorExtract), + "Extract Module", + module.text_range, + |builder| { + //This takes place in three steps: + // + //- Firstly, we will update the references(usages) e.g. converting a + // function call bar() to modname::bar(), and similarly for other items + // + //- Secondly, changing the visibility of each item inside the newly selected module + // i.e. making a fn a() {} to pub(crate) fn a() {} + // + //- Thirdly, resolving all the imports this includes removing paths from imports + // outside the module, shifting/cloning them inside new module, or shifting the imports, or making + // new import statemnts + + //We are getting item usages and record_fields together, record_fields + //for change_visibility and usages for first point mentioned above in the process + let (usages_to_be_processed, record_fields) = module.get_usages_and_record_fields(ctx); + + let import_paths_to_be_removed = module.resolve_imports(curr_parent_module, ctx); + module.change_visibility(record_fields); + + let mut body_items: Vec = Vec::new(); + let mut items_to_be_processed: Vec = module.body_items.clone(); + let mut new_item_indent = old_item_indent + 1; + + if impl_parent.is_some() { + new_item_indent = old_item_indent + 2; + } else { + items_to_be_processed = [module.use_items.clone(), items_to_be_processed].concat(); + } + + for item in items_to_be_processed { + let item = item.indent(IndentLevel(1)); + let mut indented_item = String::new(); + format_to!(indented_item, "{}{}", new_item_indent, item.to_string()); + body_items.push(indented_item); + } + + let mut body = body_items.join("\n\n"); + + if let Some(impl_) = &impl_parent { + let mut impl_body_def = String::new(); + + if let Some(self_ty) = impl_.self_ty() { + format_to!( + impl_body_def, + "{}impl {} {{\n{}\n{}}}", + old_item_indent + 1, + self_ty.to_string(), + body, + old_item_indent + 1 + ); + + body = impl_body_def; + + // Add the import for enum/struct corresponding to given impl block + module.make_use_stmt_of_node_with_super(self_ty.syntax()); + for item in module.use_items { + let mut indented_item = String::new(); + format_to!(indented_item, "{}{}", old_item_indent + 1, item.to_string()); + body = format!("{}\n\n{}", indented_item, body); + } + } + } + + let mut module_def = String::new(); + + format_to!(module_def, "mod {} {{\n{}\n{}}}", module.name, body, old_item_indent); + + let mut usages_to_be_updated_for_curr_file = vec![]; + for usages_to_be_updated_for_file in usages_to_be_processed { + if usages_to_be_updated_for_file.0 == ctx.file_id() { + usages_to_be_updated_for_curr_file = usages_to_be_updated_for_file.1; + continue; + } + builder.edit_file(usages_to_be_updated_for_file.0); + for usage_to_be_processed in usages_to_be_updated_for_file.1 { + builder.replace(usage_to_be_processed.0, usage_to_be_processed.1) + } + } + + builder.edit_file(ctx.file_id()); + for usage_to_be_processed in usages_to_be_updated_for_curr_file { + builder.replace(usage_to_be_processed.0, usage_to_be_processed.1) + } + + for import_path_text_range in import_paths_to_be_removed { + builder.delete(import_path_text_range); + } + + if let Some(impl_) = impl_parent { + // Remove complete impl block if it has only one child (as such it will be empty + // after deleting that child) + let node_to_be_removed = if impl_child_count == 1 { + impl_.syntax() + } else { + //Remove selected node + &node + }; + + builder.delete(node_to_be_removed.text_range()); + // Remove preceding indentation from node + if let Some(range) = indent_range_before_given_node(node_to_be_removed) { + builder.delete(range); + } + + builder.insert(impl_.syntax().text_range().end(), format!("\n\n{}", module_def)); + } else { + builder.replace(module.text_range, module_def) + } + }, + ) +} + +#[derive(Debug)] +struct Module { + text_range: TextRange, + name: &'static str, + /// All items except use items. + body_items: Vec, + /// Use items are kept separately as they help when the selection is inside an impl block, + /// we can directly take these items and keep them outside generated impl block inside + /// generated module. + use_items: Vec, +} + +fn extract_target(node: &SyntaxNode, selection_range: TextRange) -> Option { + let selected_nodes = node + .children() + .filter(|node| selection_range.contains_range(node.text_range())) + .chain(iter::once(node.clone())); + let (use_items, body_items) = selected_nodes + .filter_map(ast::Item::cast) + .partition(|item| matches!(item, ast::Item::Use(..))); + + Some(Module { text_range: selection_range, name: "modname", body_items, use_items }) +} + +impl Module { + fn get_usages_and_record_fields( + &self, + ctx: &AssistContext<'_>, + ) -> (HashMap>, Vec) { + let mut adt_fields = Vec::new(); + let mut refs: HashMap> = HashMap::new(); + + //Here impl is not included as each item inside impl will be tied to the parent of + //implementing block(a struct, enum, etc), if the parent is in selected module, it will + //get updated by ADT section given below or if it is not, then we dont need to do any operation + for item in &self.body_items { + match_ast! { + match (item.syntax()) { + ast::Adt(it) => { + if let Some( nod ) = ctx.sema.to_def(&it) { + let node_def = Definition::Adt(nod); + self.expand_and_group_usages_file_wise(ctx, node_def, &mut refs); + + //Enum Fields are not allowed to explicitly specify pub, it is implied + match it { + ast::Adt::Struct(x) => { + if let Some(field_list) = x.field_list() { + match field_list { + ast::FieldList::RecordFieldList(record_field_list) => { + record_field_list.fields().for_each(|record_field| { + adt_fields.push(record_field.syntax().clone()); + }); + }, + ast::FieldList::TupleFieldList(tuple_field_list) => { + tuple_field_list.fields().for_each(|tuple_field| { + adt_fields.push(tuple_field.syntax().clone()); + }); + }, + } + } + }, + ast::Adt::Union(x) => { + if let Some(record_field_list) = x.record_field_list() { + record_field_list.fields().for_each(|record_field| { + adt_fields.push(record_field.syntax().clone()); + }); + } + }, + ast::Adt::Enum(_) => {}, + } + } + }, + ast::TypeAlias(it) => { + if let Some( nod ) = ctx.sema.to_def(&it) { + let node_def = Definition::TypeAlias(nod); + self.expand_and_group_usages_file_wise(ctx, node_def, &mut refs); + } + }, + ast::Const(it) => { + if let Some( nod ) = ctx.sema.to_def(&it) { + let node_def = Definition::Const(nod); + self.expand_and_group_usages_file_wise(ctx, node_def, &mut refs); + } + }, + ast::Static(it) => { + if let Some( nod ) = ctx.sema.to_def(&it) { + let node_def = Definition::Static(nod); + self.expand_and_group_usages_file_wise(ctx, node_def, &mut refs); + } + }, + ast::Fn(it) => { + if let Some( nod ) = ctx.sema.to_def(&it) { + let node_def = Definition::Function(nod); + self.expand_and_group_usages_file_wise(ctx, node_def, &mut refs); + } + }, + ast::Macro(it) => { + if let Some(nod) = ctx.sema.to_def(&it) { + self.expand_and_group_usages_file_wise(ctx, Definition::Macro(nod), &mut refs); + } + }, + _ => (), + } + } + } + + (refs, adt_fields) + } + + fn expand_and_group_usages_file_wise( + &self, + ctx: &AssistContext<'_>, + node_def: Definition, + refs_in_files: &mut HashMap>, + ) { + for (file_id, references) in node_def.usages(&ctx.sema).all() { + let source_file = ctx.sema.parse(file_id); + let usages_in_file = references + .into_iter() + .filter_map(|usage| self.get_usage_to_be_processed(&source_file, usage)); + refs_in_files.entry(file_id).or_default().extend(usages_in_file); + } + } + + fn get_usage_to_be_processed( + &self, + source_file: &SourceFile, + FileReference { range, name, .. }: FileReference, + ) -> Option<(TextRange, String)> { + let path: ast::Path = find_node_at_range(source_file.syntax(), range)?; + + for desc in path.syntax().descendants() { + if desc.to_string() == name.syntax().to_string() + && !self.text_range.contains_range(desc.text_range()) + { + if let Some(name_ref) = ast::NameRef::cast(desc) { + return Some(( + name_ref.syntax().text_range(), + format!("{}::{}", self.name, name_ref), + )); + } + } + } + + None + } + + fn change_visibility(&mut self, record_fields: Vec) { + let (mut replacements, record_field_parents, impls) = + get_replacements_for_visibilty_change(&mut self.body_items, false); + + let mut impl_items: Vec = impls + .into_iter() + .flat_map(|impl_| impl_.syntax().descendants()) + .filter_map(ast::Item::cast) + .collect(); + + let (mut impl_item_replacements, _, _) = + get_replacements_for_visibilty_change(&mut impl_items, true); + + replacements.append(&mut impl_item_replacements); + + for (_, field_owner) in record_field_parents { + for desc in field_owner.descendants().filter_map(ast::RecordField::cast) { + let is_record_field_present = + record_fields.clone().into_iter().any(|x| x.to_string() == desc.to_string()); + if is_record_field_present { + replacements.push((desc.visibility(), desc.syntax().clone())); + } + } + } + + for (vis, syntax) in replacements { + let item = syntax.children_with_tokens().find(|node_or_token| { + match node_or_token.kind() { + // We're skipping comments, doc comments, and attribute macros that may precede the keyword + // that the visibility should be placed before. + SyntaxKind::COMMENT | SyntaxKind::ATTR | SyntaxKind::WHITESPACE => false, + _ => true, + } + }); + + add_change_vis(vis, item); + } + } + + fn resolve_imports( + &mut self, + curr_parent_module: Option, + ctx: &AssistContext<'_>, + ) -> Vec { + let mut import_paths_to_be_removed: Vec = vec![]; + let mut node_set: HashSet = HashSet::new(); + + for item in self.body_items.clone() { + for x in item.syntax().descendants() { + if let Some(name) = ast::Name::cast(x.clone()) { + if let Some(name_classify) = NameClass::classify(&ctx.sema, &name) { + //Necessary to avoid two same names going through + if !node_set.contains(&name.syntax().to_string()) { + node_set.insert(name.syntax().to_string()); + let def_opt: Option = match name_classify { + NameClass::Definition(def) => Some(def), + _ => None, + }; + + if let Some(def) = def_opt { + if let Some(import_path) = self + .process_names_and_namerefs_for_import_resolve( + def, + name.syntax(), + &curr_parent_module, + ctx, + ) + { + check_intersection_and_push( + &mut import_paths_to_be_removed, + import_path, + ); + } + } + } + } + } + + if let Some(name_ref) = ast::NameRef::cast(x) { + if let Some(name_classify) = NameRefClass::classify(&ctx.sema, &name_ref) { + //Necessary to avoid two same names going through + if !node_set.contains(&name_ref.syntax().to_string()) { + node_set.insert(name_ref.syntax().to_string()); + let def_opt: Option = match name_classify { + NameRefClass::Definition(def) => Some(def), + _ => None, + }; + + if let Some(def) = def_opt { + if let Some(import_path) = self + .process_names_and_namerefs_for_import_resolve( + def, + name_ref.syntax(), + &curr_parent_module, + ctx, + ) + { + check_intersection_and_push( + &mut import_paths_to_be_removed, + import_path, + ); + } + } + } + } + } + } + } + + import_paths_to_be_removed + } + + fn process_names_and_namerefs_for_import_resolve( + &mut self, + def: Definition, + node_syntax: &SyntaxNode, + curr_parent_module: &Option, + ctx: &AssistContext<'_>, + ) -> Option { + //We only need to find in the current file + let selection_range = ctx.selection_trimmed(); + let curr_file_id = ctx.file_id(); + let search_scope = SearchScope::single_file(curr_file_id); + let usage_res = def.usages(&ctx.sema).in_scope(search_scope).all(); + let file = ctx.sema.parse(curr_file_id); + + let mut exists_inside_sel = false; + let mut exists_outside_sel = false; + for (_, refs) in usage_res.iter() { + let mut non_use_nodes_itr = refs.iter().filter_map(|x| { + if find_node_at_range::(file.syntax(), x.range).is_none() { + let path_opt = find_node_at_range::(file.syntax(), x.range); + return path_opt; + } + + None + }); + + if non_use_nodes_itr + .clone() + .any(|x| !selection_range.contains_range(x.syntax().text_range())) + { + exists_outside_sel = true; + } + if non_use_nodes_itr.any(|x| selection_range.contains_range(x.syntax().text_range())) { + exists_inside_sel = true; + } + } + + let source_exists_outside_sel_in_same_mod = does_source_exists_outside_sel_in_same_mod( + def, + ctx, + curr_parent_module, + selection_range, + curr_file_id, + ); + + let use_stmt_opt: Option = usage_res.into_iter().find_map(|(file_id, refs)| { + if file_id == curr_file_id { + refs.into_iter() + .rev() + .find_map(|fref| find_node_at_range(file.syntax(), fref.range)) + } else { + None + } + }); + + let mut use_tree_str_opt: Option> = None; + //Exists inside and outside selection + // - Use stmt for item is present -> get the use_tree_str and reconstruct the path in new + // module + // - Use stmt for item is not present -> + //If it is not found, the definition is either ported inside new module or it stays + //outside: + //- Def is inside: Nothing to import + //- Def is outside: Import it inside with super + + //Exists inside selection but not outside -> Check for the import of it in original module, + //get the use_tree_str, reconstruct the use stmt in new module + + let mut import_path_to_be_removed: Option = None; + if exists_inside_sel && exists_outside_sel { + //Changes to be made only inside new module + + //If use_stmt exists, find the use_tree_str, reconstruct it inside new module + //If not, insert a use stmt with super and the given nameref + if let Some((use_tree_str, _)) = + self.process_use_stmt_for_import_resolve(use_stmt_opt, node_syntax) + { + use_tree_str_opt = Some(use_tree_str); + } else if source_exists_outside_sel_in_same_mod { + //Considered only after use_stmt is not present + //source_exists_outside_sel_in_same_mod | exists_outside_sel(exists_inside_sel = + //true for all cases) + // false | false -> Do nothing + // false | true -> If source is in selection -> nothing to do, If source is outside + // mod -> ust_stmt transversal + // true | false -> super import insertion + // true | true -> super import insertion + self.make_use_stmt_of_node_with_super(node_syntax); + } + } else if exists_inside_sel && !exists_outside_sel { + //Changes to be made inside new module, and remove import from outside + + if let Some((mut use_tree_str, text_range_opt)) = + self.process_use_stmt_for_import_resolve(use_stmt_opt, node_syntax) + { + if let Some(text_range) = text_range_opt { + import_path_to_be_removed = Some(text_range); + } + + if source_exists_outside_sel_in_same_mod { + if let Some(first_path_in_use_tree) = use_tree_str.last() { + let first_path_in_use_tree_str = first_path_in_use_tree.to_string(); + if !first_path_in_use_tree_str.contains("super") + && !first_path_in_use_tree_str.contains("crate") + { + let super_path = make::ext::ident_path("super"); + use_tree_str.push(super_path); + } + } + } + + use_tree_str_opt = Some(use_tree_str); + } else if source_exists_outside_sel_in_same_mod { + self.make_use_stmt_of_node_with_super(node_syntax); + } + } + + if let Some(use_tree_str) = use_tree_str_opt { + let mut use_tree_str = use_tree_str; + use_tree_str.reverse(); + + if !(!exists_outside_sel && exists_inside_sel && source_exists_outside_sel_in_same_mod) + { + if let Some(first_path_in_use_tree) = use_tree_str.first() { + let first_path_in_use_tree_str = first_path_in_use_tree.to_string(); + if first_path_in_use_tree_str.contains("super") { + let super_path = make::ext::ident_path("super"); + use_tree_str.insert(0, super_path) + } + } + } + + let use_ = + make::use_(None, make::use_tree(make::join_paths(use_tree_str), None, None, false)); + let item = ast::Item::from(use_); + self.use_items.insert(0, item); + } + + import_path_to_be_removed + } + + fn make_use_stmt_of_node_with_super(&mut self, node_syntax: &SyntaxNode) -> ast::Item { + let super_path = make::ext::ident_path("super"); + let node_path = make::ext::ident_path(&node_syntax.to_string()); + let use_ = make::use_( + None, + make::use_tree(make::join_paths(vec![super_path, node_path]), None, None, false), + ); + + let item = ast::Item::from(use_); + self.use_items.insert(0, item.clone()); + item + } + + fn process_use_stmt_for_import_resolve( + &self, + use_stmt_opt: Option, + node_syntax: &SyntaxNode, + ) -> Option<(Vec, Option)> { + if let Some(use_stmt) = use_stmt_opt { + for desc in use_stmt.syntax().descendants() { + if let Some(path_seg) = ast::PathSegment::cast(desc) { + if path_seg.syntax().to_string() == node_syntax.to_string() { + let mut use_tree_str = vec![path_seg.parent_path()]; + get_use_tree_paths_from_path(path_seg.parent_path(), &mut use_tree_str); + for ancs in path_seg.syntax().ancestors() { + //Here we are looking for use_tree with same string value as node + //passed above as the range_to_remove function looks for a comma and + //then includes it in the text range to remove it. But the comma only + //appears at the use_tree level + if let Some(use_tree) = ast::UseTree::cast(ancs) { + if use_tree.syntax().to_string() == node_syntax.to_string() { + return Some(( + use_tree_str, + Some(range_to_remove(use_tree.syntax())), + )); + } + } + } + + return Some((use_tree_str, None)); + } + } + } + } + + None + } +} + +fn check_intersection_and_push( + import_paths_to_be_removed: &mut Vec, + import_path: TextRange, +) { + if import_paths_to_be_removed.len() > 0 { + // Text ranges recieved here for imports are extended to the + // next/previous comma which can cause intersections among them + // and later deletion of these can cause panics similar + // to reported in #11766. So to mitigate it, we + // check for intersection between all current members + // and if it exists we combine both text ranges into + // one + let r = import_paths_to_be_removed + .into_iter() + .position(|it| it.intersect(import_path).is_some()); + match r { + Some(it) => { + import_paths_to_be_removed[it] = import_paths_to_be_removed[it].cover(import_path) + } + None => import_paths_to_be_removed.push(import_path), + } + } else { + import_paths_to_be_removed.push(import_path); + } +} + +fn does_source_exists_outside_sel_in_same_mod( + def: Definition, + ctx: &AssistContext<'_>, + curr_parent_module: &Option, + selection_range: TextRange, + curr_file_id: FileId, +) -> bool { + let mut source_exists_outside_sel_in_same_mod = false; + match def { + Definition::Module(x) => { + let source = x.definition_source(ctx.db()); + let have_same_parent; + if let Some(ast_module) = &curr_parent_module { + if let Some(hir_module) = x.parent(ctx.db()) { + have_same_parent = + compare_hir_and_ast_module(ast_module, hir_module, ctx).is_some(); + } else { + let source_file_id = source.file_id.original_file(ctx.db()); + have_same_parent = source_file_id == curr_file_id; + } + } else { + let source_file_id = source.file_id.original_file(ctx.db()); + have_same_parent = source_file_id == curr_file_id; + } + + if have_same_parent { + match source.value { + ModuleSource::Module(module_) => { + source_exists_outside_sel_in_same_mod = + !selection_range.contains_range(module_.syntax().text_range()); + } + _ => {} + } + } + } + Definition::Function(x) => { + if let Some(source) = x.source(ctx.db()) { + let have_same_parent = if let Some(ast_module) = &curr_parent_module { + compare_hir_and_ast_module(ast_module, x.module(ctx.db()), ctx).is_some() + } else { + let source_file_id = source.file_id.original_file(ctx.db()); + source_file_id == curr_file_id + }; + + if have_same_parent { + source_exists_outside_sel_in_same_mod = + !selection_range.contains_range(source.value.syntax().text_range()); + } + } + } + Definition::Adt(x) => { + if let Some(source) = x.source(ctx.db()) { + let have_same_parent = if let Some(ast_module) = &curr_parent_module { + compare_hir_and_ast_module(ast_module, x.module(ctx.db()), ctx).is_some() + } else { + let source_file_id = source.file_id.original_file(ctx.db()); + source_file_id == curr_file_id + }; + + if have_same_parent { + source_exists_outside_sel_in_same_mod = + !selection_range.contains_range(source.value.syntax().text_range()); + } + } + } + Definition::Variant(x) => { + if let Some(source) = x.source(ctx.db()) { + let have_same_parent = if let Some(ast_module) = &curr_parent_module { + compare_hir_and_ast_module(ast_module, x.module(ctx.db()), ctx).is_some() + } else { + let source_file_id = source.file_id.original_file(ctx.db()); + source_file_id == curr_file_id + }; + + if have_same_parent { + source_exists_outside_sel_in_same_mod = + !selection_range.contains_range(source.value.syntax().text_range()); + } + } + } + Definition::Const(x) => { + if let Some(source) = x.source(ctx.db()) { + let have_same_parent = if let Some(ast_module) = &curr_parent_module { + compare_hir_and_ast_module(ast_module, x.module(ctx.db()), ctx).is_some() + } else { + let source_file_id = source.file_id.original_file(ctx.db()); + source_file_id == curr_file_id + }; + + if have_same_parent { + source_exists_outside_sel_in_same_mod = + !selection_range.contains_range(source.value.syntax().text_range()); + } + } + } + Definition::Static(x) => { + if let Some(source) = x.source(ctx.db()) { + let have_same_parent = if let Some(ast_module) = &curr_parent_module { + compare_hir_and_ast_module(ast_module, x.module(ctx.db()), ctx).is_some() + } else { + let source_file_id = source.file_id.original_file(ctx.db()); + source_file_id == curr_file_id + }; + + if have_same_parent { + source_exists_outside_sel_in_same_mod = + !selection_range.contains_range(source.value.syntax().text_range()); + } + } + } + Definition::Trait(x) => { + if let Some(source) = x.source(ctx.db()) { + let have_same_parent = if let Some(ast_module) = &curr_parent_module { + compare_hir_and_ast_module(ast_module, x.module(ctx.db()), ctx).is_some() + } else { + let source_file_id = source.file_id.original_file(ctx.db()); + source_file_id == curr_file_id + }; + + if have_same_parent { + source_exists_outside_sel_in_same_mod = + !selection_range.contains_range(source.value.syntax().text_range()); + } + } + } + Definition::TypeAlias(x) => { + if let Some(source) = x.source(ctx.db()) { + let have_same_parent = if let Some(ast_module) = &curr_parent_module { + compare_hir_and_ast_module(ast_module, x.module(ctx.db()), ctx).is_some() + } else { + let source_file_id = source.file_id.original_file(ctx.db()); + source_file_id == curr_file_id + }; + + if have_same_parent { + source_exists_outside_sel_in_same_mod = + !selection_range.contains_range(source.value.syntax().text_range()); + } + } + } + _ => {} + } + + source_exists_outside_sel_in_same_mod +} + +fn get_replacements_for_visibilty_change( + items: &mut [ast::Item], + is_clone_for_updated: bool, +) -> ( + Vec<(Option, SyntaxNode)>, + Vec<(Option, SyntaxNode)>, + Vec, +) { + let mut replacements = Vec::new(); + let mut record_field_parents = Vec::new(); + let mut impls = Vec::new(); + + for item in items { + if !is_clone_for_updated { + *item = item.clone_for_update(); + } + //Use stmts are ignored + match item { + ast::Item::Const(it) => replacements.push((it.visibility(), it.syntax().clone())), + ast::Item::Enum(it) => replacements.push((it.visibility(), it.syntax().clone())), + ast::Item::ExternCrate(it) => replacements.push((it.visibility(), it.syntax().clone())), + ast::Item::Fn(it) => replacements.push((it.visibility(), it.syntax().clone())), + //Associated item's visibility should not be changed + ast::Item::Impl(it) if it.for_token().is_none() => impls.push(it.clone()), + ast::Item::MacroDef(it) => replacements.push((it.visibility(), it.syntax().clone())), + ast::Item::Module(it) => replacements.push((it.visibility(), it.syntax().clone())), + ast::Item::Static(it) => replacements.push((it.visibility(), it.syntax().clone())), + ast::Item::Struct(it) => { + replacements.push((it.visibility(), it.syntax().clone())); + record_field_parents.push((it.visibility(), it.syntax().clone())); + } + ast::Item::Trait(it) => replacements.push((it.visibility(), it.syntax().clone())), + ast::Item::TypeAlias(it) => replacements.push((it.visibility(), it.syntax().clone())), + ast::Item::Union(it) => { + replacements.push((it.visibility(), it.syntax().clone())); + record_field_parents.push((it.visibility(), it.syntax().clone())); + } + _ => (), + } + } + + (replacements, record_field_parents, impls) +} + +fn get_use_tree_paths_from_path( + path: ast::Path, + use_tree_str: &mut Vec, +) -> Option<&mut Vec> { + path.syntax().ancestors().filter(|x| x.to_string() != path.to_string()).find_map(|x| { + if let Some(use_tree) = ast::UseTree::cast(x) { + if let Some(upper_tree_path) = use_tree.path() { + if upper_tree_path.to_string() != path.to_string() { + use_tree_str.push(upper_tree_path.clone()); + get_use_tree_paths_from_path(upper_tree_path, use_tree_str); + return Some(use_tree); + } + } + } + None + })?; + + Some(use_tree_str) +} + +fn add_change_vis(vis: Option, node_or_token_opt: Option) { + if vis.is_none() { + if let Some(node_or_token) = node_or_token_opt { + let pub_crate_vis = make::visibility_pub_crate().clone_for_update(); + ted::insert(ted::Position::before(node_or_token), pub_crate_vis.syntax()); + } + } +} + +fn compare_hir_and_ast_module( + ast_module: &ast::Module, + hir_module: hir::Module, + ctx: &AssistContext<'_>, +) -> Option<()> { + let hir_mod_name = hir_module.name(ctx.db())?; + let ast_mod_name = ast_module.name()?; + if hir_mod_name.to_string() != ast_mod_name.to_string() { + return None; + } + + Some(()) +} + +fn indent_range_before_given_node(node: &SyntaxNode) -> Option { + node.siblings_with_tokens(syntax::Direction::Prev) + .find(|x| x.kind() == WHITESPACE) + .map(|x| x.text_range()) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_not_applicable_without_selection() { + check_assist_not_applicable( + extract_module, + r" +$0pub struct PublicStruct { + field: i32, +} + ", + ) + } + + #[test] + fn test_extract_module() { + check_assist( + extract_module, + r" + mod thirdpartycrate { + pub mod nest { + pub struct SomeType; + pub struct SomeType2; + } + pub struct SomeType1; + } + + mod bar { + use crate::thirdpartycrate::{nest::{SomeType, SomeType2}, SomeType1}; + + pub struct PublicStruct { + field: PrivateStruct, + field1: SomeType1, + } + + impl PublicStruct { + pub fn new() -> Self { + Self { field: PrivateStruct::new(), field1: SomeType1 } + } + } + + fn foo() { + let _s = PrivateStruct::new(); + let _a = bar(); + } + +$0struct PrivateStruct { + inner: SomeType, +} + +pub struct PrivateStruct1 { + pub inner: i32, +} + +impl PrivateStruct { + fn new() -> Self { + PrivateStruct { inner: SomeType } + } +} + +fn bar() -> i32 { + 2 +}$0 + } + ", + r" + mod thirdpartycrate { + pub mod nest { + pub struct SomeType; + pub struct SomeType2; + } + pub struct SomeType1; + } + + mod bar { + use crate::thirdpartycrate::{nest::{SomeType2}, SomeType1}; + + pub struct PublicStruct { + field: modname::PrivateStruct, + field1: SomeType1, + } + + impl PublicStruct { + pub fn new() -> Self { + Self { field: modname::PrivateStruct::new(), field1: SomeType1 } + } + } + + fn foo() { + let _s = modname::PrivateStruct::new(); + let _a = modname::bar(); + } + +mod modname { + use crate::thirdpartycrate::nest::SomeType; + + pub(crate) struct PrivateStruct { + pub(crate) inner: SomeType, + } + + pub struct PrivateStruct1 { + pub inner: i32, + } + + impl PrivateStruct { + pub(crate) fn new() -> Self { + PrivateStruct { inner: SomeType } + } + } + + pub(crate) fn bar() -> i32 { + 2 + } +} + } + ", + ); + } + + #[test] + fn test_extract_module_for_function_only() { + check_assist( + extract_module, + r" +$0fn foo(name: i32) -> i32 { + name + 1 +}$0 + + fn bar(name: i32) -> i32 { + name + 2 + } + ", + r" +mod modname { + pub(crate) fn foo(name: i32) -> i32 { + name + 1 + } +} + + fn bar(name: i32) -> i32 { + name + 2 + } + ", + ) + } + + #[test] + fn test_extract_module_for_impl_having_corresponding_adt_in_selection() { + check_assist( + extract_module, + r" + mod impl_play { +$0struct A {} + +impl A { + pub fn new_a() -> i32 { + 2 + } +}$0 + + fn a() { + let _a = A::new_a(); + } + } + ", + r" + mod impl_play { +mod modname { + pub(crate) struct A {} + + impl A { + pub fn new_a() -> i32 { + 2 + } + } +} + + fn a() { + let _a = modname::A::new_a(); + } + } + ", + ) + } + + #[test] + fn test_import_resolve_when_its_only_inside_selection() { + check_assist( + extract_module, + r" + mod foo { + pub struct PrivateStruct; + pub struct PrivateStruct1; + } + + mod bar { + use super::foo::{PrivateStruct, PrivateStruct1}; + +$0struct Strukt { + field: PrivateStruct, +}$0 + + struct Strukt1 { + field: PrivateStruct1, + } + } + ", + r" + mod foo { + pub struct PrivateStruct; + pub struct PrivateStruct1; + } + + mod bar { + use super::foo::{PrivateStruct1}; + +mod modname { + use super::super::foo::PrivateStruct; + + pub(crate) struct Strukt { + pub(crate) field: PrivateStruct, + } +} + + struct Strukt1 { + field: PrivateStruct1, + } + } + ", + ) + } + + #[test] + fn test_import_resolve_when_its_inside_and_outside_selection_and_source_not_in_same_mod() { + check_assist( + extract_module, + r" + mod foo { + pub struct PrivateStruct; + } + + mod bar { + use super::foo::PrivateStruct; + +$0struct Strukt { + field: PrivateStruct, +}$0 + + struct Strukt1 { + field: PrivateStruct, + } + } + ", + r" + mod foo { + pub struct PrivateStruct; + } + + mod bar { + use super::foo::PrivateStruct; + +mod modname { + use super::super::foo::PrivateStruct; + + pub(crate) struct Strukt { + pub(crate) field: PrivateStruct, + } +} + + struct Strukt1 { + field: PrivateStruct, + } + } + ", + ) + } + + #[test] + fn test_import_resolve_when_its_inside_and_outside_selection_and_source_is_in_same_mod() { + check_assist( + extract_module, + r" + mod bar { + pub struct PrivateStruct; + +$0struct Strukt { + field: PrivateStruct, +}$0 + + struct Strukt1 { + field: PrivateStruct, + } + } + ", + r" + mod bar { + pub struct PrivateStruct; + +mod modname { + use super::PrivateStruct; + + pub(crate) struct Strukt { + pub(crate) field: PrivateStruct, + } +} + + struct Strukt1 { + field: PrivateStruct, + } + } + ", + ) + } + + #[test] + fn test_extract_module_for_correspoding_adt_of_impl_present_in_same_mod_but_not_in_selection() { + check_assist( + extract_module, + r" + mod impl_play { + struct A {} + +$0impl A { + pub fn new_a() -> i32 { + 2 + } +}$0 + + fn a() { + let _a = A::new_a(); + } + } + ", + r" + mod impl_play { + struct A {} + +mod modname { + use super::A; + + impl A { + pub fn new_a() -> i32 { + 2 + } + } +} + + fn a() { + let _a = A::new_a(); + } + } + ", + ) + } + + #[test] + fn test_extract_module_for_impl_not_having_corresponding_adt_in_selection_and_not_in_same_mod_but_with_super( + ) { + check_assist( + extract_module, + r" + mod foo { + pub struct A {} + } + mod impl_play { + use super::foo::A; + +$0impl A { + pub fn new_a() -> i32 { + 2 + } +}$0 + + fn a() { + let _a = A::new_a(); + } + } + ", + r" + mod foo { + pub struct A {} + } + mod impl_play { + use super::foo::A; + +mod modname { + use super::super::foo::A; + + impl A { + pub fn new_a() -> i32 { + 2 + } + } +} + + fn a() { + let _a = A::new_a(); + } + } + ", + ) + } + + #[test] + fn test_import_resolve_for_trait_bounds_on_function() { + check_assist( + extract_module, + r" + mod impl_play2 { + trait JustATrait {} + +$0struct A {} + +fn foo(arg: T) -> T { + arg +} + +impl JustATrait for A {} + +fn bar() { + let a = A {}; + foo(a); +}$0 + } + ", + r" + mod impl_play2 { + trait JustATrait {} + +mod modname { + use super::JustATrait; + + pub(crate) struct A {} + + pub(crate) fn foo(arg: T) -> T { + arg + } + + impl JustATrait for A {} + + pub(crate) fn bar() { + let a = A {}; + foo(a); + } +} + } + ", + ) + } + + #[test] + fn test_extract_module_for_module() { + check_assist( + extract_module, + r" + mod impl_play2 { +$0mod impl_play { + pub struct A {} +}$0 + } + ", + r" + mod impl_play2 { +mod modname { + pub(crate) mod impl_play { + pub struct A {} + } +} + } + ", + ) + } + + #[test] + fn test_extract_module_with_multiple_files() { + check_assist( + extract_module, + r" + //- /main.rs + mod foo; + + use foo::PrivateStruct; + + pub struct Strukt { + field: PrivateStruct, + } + + fn main() { + $0struct Strukt1 { + field: Strukt, + }$0 + } + //- /foo.rs + pub struct PrivateStruct; + ", + r" + mod foo; + + use foo::PrivateStruct; + + pub struct Strukt { + field: PrivateStruct, + } + + fn main() { + mod modname { + use super::Strukt; + + pub(crate) struct Strukt1 { + pub(crate) field: Strukt, + } + } + } + ", + ) + } + + #[test] + fn test_extract_module_macro_rules() { + check_assist( + extract_module, + r" +$0macro_rules! m { + () => {}; +}$0 +m! {} + ", + r" +mod modname { + macro_rules! m { + () => {}; + } +} +modname::m! {} + ", + ); + } + + #[test] + fn test_do_not_apply_visibility_modifier_to_trait_impl_items() { + check_assist( + extract_module, + r" + trait ATrait { + fn function(); + } + + struct A {} + +$0impl ATrait for A { + fn function() {} +}$0 + ", + r" + trait ATrait { + fn function(); + } + + struct A {} + +mod modname { + use super::A; + + use super::ATrait; + + impl ATrait for A { + fn function() {} + } +} + ", + ) + } + + #[test] + fn test_if_inside_impl_block_generate_module_outside() { + check_assist( + extract_module, + r" + struct A {} + + impl A { +$0fn foo() {}$0 + fn bar() {} + } + ", + r" + struct A {} + + impl A { + fn bar() {} + } + +mod modname { + use super::A; + + impl A { + pub(crate) fn foo() {} + } +} + ", + ) + } + + #[test] + fn test_if_inside_impl_block_generate_module_outside_but_impl_block_having_one_child() { + check_assist( + extract_module, + r" + struct A {} + struct B {} + + impl A { +$0fn foo(x: B) {}$0 + } + ", + r" + struct A {} + struct B {} + +mod modname { + use super::B; + + use super::A; + + impl A { + pub(crate) fn foo(x: B) {} + } +} + ", + ) + } + + #[test] + fn test_issue_11766() { + //https://github.com/rust-lang/rust-analyzer/issues/11766 + check_assist( + extract_module, + r" + mod x { + pub struct Foo; + pub struct Bar; + } + + use x::{Bar, Foo}; + + $0type A = (Foo, Bar);$0 + ", + r" + mod x { + pub struct Foo; + pub struct Bar; + } + + use x::{}; + + mod modname { + use super::x::Bar; + + use super::x::Foo; + + pub(crate) type A = (Foo, Bar); + } + ", + ) + } + + #[test] + fn test_issue_12790() { + check_assist( + extract_module, + r" + $0/// A documented function + fn documented_fn() {} + + // A commented function with a #[] attribute macro + #[cfg(test)] + fn attribute_fn() {} + + // A normally commented function + fn normal_fn() {} + + /// A documented Struct + struct DocumentedStruct { + // Normal field + x: i32, + + /// Documented field + y: i32, + + // Macroed field + #[cfg(test)] + z: i32, + } + + // A macroed Struct + #[cfg(test)] + struct MacroedStruct { + // Normal field + x: i32, + + /// Documented field + y: i32, + + // Macroed field + #[cfg(test)] + z: i32, + } + + // A normal Struct + struct NormalStruct { + // Normal field + x: i32, + + /// Documented field + y: i32, + + // Macroed field + #[cfg(test)] + z: i32, + } + + /// A documented type + type DocumentedType = i32; + + // A macroed type + #[cfg(test)] + type MacroedType = i32; + + /// A module to move + mod module {} + + /// An impl to move + impl NormalStruct { + /// A method + fn new() {} + } + + /// A documented trait + trait DocTrait { + /// Inner function + fn doc() {} + } + + /// An enum + enum DocumentedEnum { + /// A variant + A, + /// Another variant + B { x: i32, y: i32 } + } + + /// Documented const + const MY_CONST: i32 = 0;$0 + ", + r" + mod modname { + /// A documented function + pub(crate) fn documented_fn() {} + + // A commented function with a #[] attribute macro + #[cfg(test)] + pub(crate) fn attribute_fn() {} + + // A normally commented function + pub(crate) fn normal_fn() {} + + /// A documented Struct + pub(crate) struct DocumentedStruct { + // Normal field + pub(crate) x: i32, + + /// Documented field + pub(crate) y: i32, + + // Macroed field + #[cfg(test)] + pub(crate) z: i32, + } + + // A macroed Struct + #[cfg(test)] + pub(crate) struct MacroedStruct { + // Normal field + pub(crate) x: i32, + + /// Documented field + pub(crate) y: i32, + + // Macroed field + #[cfg(test)] + pub(crate) z: i32, + } + + // A normal Struct + pub(crate) struct NormalStruct { + // Normal field + pub(crate) x: i32, + + /// Documented field + pub(crate) y: i32, + + // Macroed field + #[cfg(test)] + pub(crate) z: i32, + } + + /// A documented type + pub(crate) type DocumentedType = i32; + + // A macroed type + #[cfg(test)] + pub(crate) type MacroedType = i32; + + /// A module to move + pub(crate) mod module {} + + /// An impl to move + impl NormalStruct { + /// A method + pub(crate) fn new() {} + } + + /// A documented trait + pub(crate) trait DocTrait { + /// Inner function + fn doc() {} + } + + /// An enum + pub(crate) enum DocumentedEnum { + /// A variant + A, + /// Another variant + B { x: i32, y: i32 } + } + + /// Documented const + pub(crate) const MY_CONST: i32 = 0; + } + ", + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_struct_from_enum_variant.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_struct_from_enum_variant.rs new file mode 100644 index 000000000..a93648f2d --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_struct_from_enum_variant.rs @@ -0,0 +1,1076 @@ +use std::iter; + +use either::Either; +use hir::{Module, ModuleDef, Name, Variant}; +use ide_db::{ + defs::Definition, + helpers::mod_path_to_ast, + imports::insert_use::{insert_use, ImportScope, InsertUseConfig}, + search::FileReference, + FxHashSet, RootDatabase, +}; +use itertools::{Itertools, Position}; +use syntax::{ + ast::{ + self, edit::IndentLevel, edit_in_place::Indent, make, AstNode, HasAttrs, HasGenericParams, + HasName, HasVisibility, + }, + match_ast, ted, SyntaxElement, + SyntaxKind::*, + SyntaxNode, T, +}; + +use crate::{assist_context::AssistBuilder, AssistContext, AssistId, AssistKind, Assists}; + +// Assist: extract_struct_from_enum_variant +// +// Extracts a struct from enum variant. +// +// ``` +// enum A { $0One(u32, u32) } +// ``` +// -> +// ``` +// struct One(u32, u32); +// +// enum A { One(One) } +// ``` +pub(crate) fn extract_struct_from_enum_variant( + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { + let variant = ctx.find_node_at_offset::()?; + let field_list = extract_field_list_if_applicable(&variant)?; + + let variant_name = variant.name()?; + let variant_hir = ctx.sema.to_def(&variant)?; + if existing_definition(ctx.db(), &variant_name, &variant_hir) { + cov_mark::hit!(test_extract_enum_not_applicable_if_struct_exists); + return None; + } + + let enum_ast = variant.parent_enum(); + let enum_hir = ctx.sema.to_def(&enum_ast)?; + let target = variant.syntax().text_range(); + acc.add( + AssistId("extract_struct_from_enum_variant", AssistKind::RefactorRewrite), + "Extract struct from enum variant", + target, + |builder| { + let variant_hir_name = variant_hir.name(ctx.db()); + let enum_module_def = ModuleDef::from(enum_hir); + let usages = Definition::Variant(variant_hir).usages(&ctx.sema).all(); + + let mut visited_modules_set = FxHashSet::default(); + let current_module = enum_hir.module(ctx.db()); + visited_modules_set.insert(current_module); + // record file references of the file the def resides in, we only want to swap to the edited file in the builder once + let mut def_file_references = None; + for (file_id, references) in usages { + if file_id == ctx.file_id() { + def_file_references = Some(references); + continue; + } + builder.edit_file(file_id); + let processed = process_references( + ctx, + builder, + &mut visited_modules_set, + &enum_module_def, + &variant_hir_name, + references, + ); + processed.into_iter().for_each(|(path, node, import)| { + apply_references(ctx.config.insert_use, path, node, import) + }); + } + builder.edit_file(ctx.file_id()); + + let variant = builder.make_mut(variant.clone()); + if let Some(references) = def_file_references { + let processed = process_references( + ctx, + builder, + &mut visited_modules_set, + &enum_module_def, + &variant_hir_name, + references, + ); + processed.into_iter().for_each(|(path, node, import)| { + apply_references(ctx.config.insert_use, path, node, import) + }); + } + + let indent = enum_ast.indent_level(); + let generic_params = enum_ast + .generic_param_list() + .and_then(|known_generics| extract_generic_params(&known_generics, &field_list)); + let generics = generic_params.as_ref().map(|generics| generics.clone_for_update()); + let def = + create_struct_def(variant_name.clone(), &variant, &field_list, generics, &enum_ast); + def.reindent_to(indent); + + let start_offset = &variant.parent_enum().syntax().clone(); + ted::insert_all_raw( + ted::Position::before(start_offset), + vec![ + def.syntax().clone().into(), + make::tokens::whitespace(&format!("\n\n{}", indent)).into(), + ], + ); + + update_variant(&variant, generic_params.map(|g| g.clone_for_update())); + }, + ) +} + +fn extract_field_list_if_applicable( + variant: &ast::Variant, +) -> Option> { + match variant.kind() { + ast::StructKind::Record(field_list) if field_list.fields().next().is_some() => { + Some(Either::Left(field_list)) + } + ast::StructKind::Tuple(field_list) if field_list.fields().count() > 1 => { + Some(Either::Right(field_list)) + } + _ => None, + } +} + +fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &Variant) -> bool { + variant + .parent_enum(db) + .module(db) + .scope(db, None) + .into_iter() + .filter(|(_, def)| match def { + // only check type-namespace + hir::ScopeDef::ModuleDef(def) => matches!( + def, + ModuleDef::Module(_) + | ModuleDef::Adt(_) + | ModuleDef::Variant(_) + | ModuleDef::Trait(_) + | ModuleDef::TypeAlias(_) + | ModuleDef::BuiltinType(_) + ), + _ => false, + }) + .any(|(name, _)| name.to_string() == variant_name.to_string()) +} + +fn extract_generic_params( + known_generics: &ast::GenericParamList, + field_list: &Either, +) -> Option { + let mut generics = known_generics.generic_params().map(|param| (param, false)).collect_vec(); + + let tagged_one = match field_list { + Either::Left(field_list) => field_list + .fields() + .filter_map(|f| f.ty()) + .fold(false, |tagged, ty| tag_generics_in_variant(&ty, &mut generics) || tagged), + Either::Right(field_list) => field_list + .fields() + .filter_map(|f| f.ty()) + .fold(false, |tagged, ty| tag_generics_in_variant(&ty, &mut generics) || tagged), + }; + + let generics = generics.into_iter().filter_map(|(param, tag)| tag.then(|| param)); + tagged_one.then(|| make::generic_param_list(generics)) +} + +fn tag_generics_in_variant(ty: &ast::Type, generics: &mut [(ast::GenericParam, bool)]) -> bool { + let mut tagged_one = false; + + for token in ty.syntax().descendants_with_tokens().filter_map(SyntaxElement::into_token) { + for (param, tag) in generics.iter_mut().filter(|(_, tag)| !tag) { + match param { + ast::GenericParam::LifetimeParam(lt) + if matches!(token.kind(), T![lifetime_ident]) => + { + if let Some(lt) = lt.lifetime() { + if lt.text().as_str() == token.text() { + *tag = true; + tagged_one = true; + break; + } + } + } + param if matches!(token.kind(), T![ident]) => { + if match param { + ast::GenericParam::ConstParam(konst) => konst + .name() + .map(|name| name.text().as_str() == token.text()) + .unwrap_or_default(), + ast::GenericParam::TypeParam(ty) => ty + .name() + .map(|name| name.text().as_str() == token.text()) + .unwrap_or_default(), + ast::GenericParam::LifetimeParam(lt) => lt + .lifetime() + .map(|lt| lt.text().as_str() == token.text()) + .unwrap_or_default(), + } { + *tag = true; + tagged_one = true; + break; + } + } + _ => (), + } + } + } + + tagged_one +} + +fn create_struct_def( + variant_name: ast::Name, + variant: &ast::Variant, + field_list: &Either, + generics: Option, + enum_: &ast::Enum, +) -> ast::Struct { + let enum_vis = enum_.visibility(); + + let insert_vis = |node: &'_ SyntaxNode, vis: &'_ SyntaxNode| { + let vis = vis.clone_for_update(); + ted::insert(ted::Position::before(node), vis); + }; + + // for fields without any existing visibility, use visibility of enum + let field_list: ast::FieldList = match field_list { + Either::Left(field_list) => { + let field_list = field_list.clone_for_update(); + + if let Some(vis) = &enum_vis { + field_list + .fields() + .filter(|field| field.visibility().is_none()) + .filter_map(|field| field.name()) + .for_each(|it| insert_vis(it.syntax(), vis.syntax())); + } + + field_list.into() + } + Either::Right(field_list) => { + let field_list = field_list.clone_for_update(); + + if let Some(vis) = &enum_vis { + field_list + .fields() + .filter(|field| field.visibility().is_none()) + .filter_map(|field| field.ty()) + .for_each(|it| insert_vis(it.syntax(), vis.syntax())); + } + + field_list.into() + } + }; + + field_list.reindent_to(IndentLevel::single()); + + let strukt = make::struct_(enum_vis, variant_name, generics, field_list).clone_for_update(); + + // FIXME: Consider making this an actual function somewhere (like in `AttrsOwnerEdit`) after some deliberation + let attrs_and_docs = |node: &SyntaxNode| { + let mut select_next_ws = false; + node.children_with_tokens().filter(move |child| { + let accept = match child.kind() { + ATTR | COMMENT => { + select_next_ws = true; + return true; + } + WHITESPACE if select_next_ws => true, + _ => false, + }; + select_next_ws = false; + + accept + }) + }; + + // copy attributes & comments from variant + let variant_attrs = attrs_and_docs(variant.syntax()) + .map(|tok| match tok.kind() { + WHITESPACE => make::tokens::single_newline().into(), + _ => tok, + }) + .collect(); + ted::insert_all(ted::Position::first_child_of(strukt.syntax()), variant_attrs); + + // copy attributes from enum + ted::insert_all( + ted::Position::first_child_of(strukt.syntax()), + enum_.attrs().map(|it| it.syntax().clone_for_update().into()).collect(), + ); + strukt +} + +fn update_variant(variant: &ast::Variant, generics: Option) -> Option<()> { + let name = variant.name()?; + let ty = generics + .filter(|generics| generics.generic_params().count() > 0) + .map(|generics| { + let mut generic_str = String::with_capacity(8); + + for (p, more) in generics.generic_params().with_position().map(|p| match p { + Position::First(p) | Position::Middle(p) => (p, true), + Position::Last(p) | Position::Only(p) => (p, false), + }) { + match p { + ast::GenericParam::ConstParam(konst) => { + if let Some(name) = konst.name() { + generic_str.push_str(name.text().as_str()); + } + } + ast::GenericParam::LifetimeParam(lt) => { + if let Some(lt) = lt.lifetime() { + generic_str.push_str(lt.text().as_str()); + } + } + ast::GenericParam::TypeParam(ty) => { + if let Some(name) = ty.name() { + generic_str.push_str(name.text().as_str()); + } + } + } + if more { + generic_str.push_str(", "); + } + } + + make::ty(&format!("{}<{}>", &name.text(), &generic_str)) + }) + .unwrap_or_else(|| make::ty(&name.text())); + + let tuple_field = make::tuple_field(None, ty); + let replacement = make::variant( + name, + Some(ast::FieldList::TupleFieldList(make::tuple_field_list(iter::once(tuple_field)))), + ) + .clone_for_update(); + ted::replace(variant.syntax(), replacement.syntax()); + Some(()) +} + +fn apply_references( + insert_use_cfg: InsertUseConfig, + segment: ast::PathSegment, + node: SyntaxNode, + import: Option<(ImportScope, hir::ModPath)>, +) { + if let Some((scope, path)) = import { + insert_use(&scope, mod_path_to_ast(&path), &insert_use_cfg); + } + // deep clone to prevent cycle + let path = make::path_from_segments(iter::once(segment.clone_subtree()), false); + ted::insert_raw(ted::Position::before(segment.syntax()), path.clone_for_update().syntax()); + ted::insert_raw(ted::Position::before(segment.syntax()), make::token(T!['('])); + ted::insert_raw(ted::Position::after(&node), make::token(T![')'])); +} + +fn process_references( + ctx: &AssistContext<'_>, + builder: &mut AssistBuilder, + visited_modules: &mut FxHashSet, + enum_module_def: &ModuleDef, + variant_hir_name: &Name, + refs: Vec, +) -> Vec<(ast::PathSegment, SyntaxNode, Option<(ImportScope, hir::ModPath)>)> { + // we have to recollect here eagerly as we are about to edit the tree we need to calculate the changes + // and corresponding nodes up front + refs.into_iter() + .flat_map(|reference| { + let (segment, scope_node, module) = reference_to_node(&ctx.sema, reference)?; + let segment = builder.make_mut(segment); + let scope_node = builder.make_syntax_mut(scope_node); + if !visited_modules.contains(&module) { + let mod_path = module.find_use_path_prefixed( + ctx.sema.db, + *enum_module_def, + ctx.config.insert_use.prefix_kind, + ); + if let Some(mut mod_path) = mod_path { + mod_path.pop_segment(); + mod_path.push_segment(variant_hir_name.clone()); + let scope = ImportScope::find_insert_use_container(&scope_node, &ctx.sema)?; + visited_modules.insert(module); + return Some((segment, scope_node, Some((scope, mod_path)))); + } + } + Some((segment, scope_node, None)) + }) + .collect() +} + +fn reference_to_node( + sema: &hir::Semantics<'_, RootDatabase>, + reference: FileReference, +) -> Option<(ast::PathSegment, SyntaxNode, hir::Module)> { + let segment = + reference.name.as_name_ref()?.syntax().parent().and_then(ast::PathSegment::cast)?; + let parent = segment.parent_path().syntax().parent()?; + let expr_or_pat = match_ast! { + match parent { + ast::PathExpr(_it) => parent.parent()?, + ast::RecordExpr(_it) => parent, + ast::TupleStructPat(_it) => parent, + ast::RecordPat(_it) => parent, + _ => return None, + } + }; + let module = sema.scope(&expr_or_pat)?.module(); + Some((segment, expr_or_pat, module)) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_extract_struct_several_fields_tuple() { + check_assist( + extract_struct_from_enum_variant, + "enum A { $0One(u32, u32) }", + r#"struct One(u32, u32); + +enum A { One(One) }"#, + ); + } + + #[test] + fn test_extract_struct_several_fields_named() { + check_assist( + extract_struct_from_enum_variant, + "enum A { $0One { foo: u32, bar: u32 } }", + r#"struct One{ foo: u32, bar: u32 } + +enum A { One(One) }"#, + ); + } + + #[test] + fn test_extract_struct_one_field_named() { + check_assist( + extract_struct_from_enum_variant, + "enum A { $0One { foo: u32 } }", + r#"struct One{ foo: u32 } + +enum A { One(One) }"#, + ); + } + + #[test] + fn test_extract_struct_carries_over_generics() { + check_assist( + extract_struct_from_enum_variant, + r"enum En { Var { a: T$0 } }", + r#"struct Var{ a: T } + +enum En { Var(Var) }"#, + ); + } + + #[test] + fn test_extract_struct_carries_over_attributes() { + check_assist( + extract_struct_from_enum_variant, + r#"#[derive(Debug)] +#[derive(Clone)] +enum Enum { Variant{ field: u32$0 } }"#, + r#"#[derive(Debug)]#[derive(Clone)] struct Variant{ field: u32 } + +#[derive(Debug)] +#[derive(Clone)] +enum Enum { Variant(Variant) }"#, + ); + } + + #[test] + fn test_extract_struct_indent_to_parent_enum() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum Enum { + Variant { + field: u32$0 + } +}"#, + r#" +struct Variant{ + field: u32 +} + +enum Enum { + Variant(Variant) +}"#, + ); + } + + #[test] + fn test_extract_struct_indent_to_parent_enum_in_mod() { + check_assist( + extract_struct_from_enum_variant, + r#" +mod indenting { + enum Enum { + Variant { + field: u32$0 + } + } +}"#, + r#" +mod indenting { + struct Variant{ + field: u32 + } + + enum Enum { + Variant(Variant) + } +}"#, + ); + } + + #[test] + fn test_extract_struct_keep_comments_and_attrs_one_field_named() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum A { + $0One { + // leading comment + /// doc comment + #[an_attr] + foo: u32 + // trailing comment + } +}"#, + r#" +struct One{ + // leading comment + /// doc comment + #[an_attr] + foo: u32 + // trailing comment +} + +enum A { + One(One) +}"#, + ); + } + + #[test] + fn test_extract_struct_keep_comments_and_attrs_several_fields_named() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum A { + $0One { + // comment + /// doc + #[attr] + foo: u32, + // comment + #[attr] + /// doc + bar: u32 + } +}"#, + r#" +struct One{ + // comment + /// doc + #[attr] + foo: u32, + // comment + #[attr] + /// doc + bar: u32 +} + +enum A { + One(One) +}"#, + ); + } + + #[test] + fn test_extract_struct_keep_comments_and_attrs_several_fields_tuple() { + check_assist( + extract_struct_from_enum_variant, + "enum A { $0One(/* comment */ #[attr] u32, /* another */ u32 /* tail */) }", + r#" +struct One(/* comment */ #[attr] u32, /* another */ u32 /* tail */); + +enum A { One(One) }"#, + ); + } + + #[test] + fn test_extract_struct_keep_comments_and_attrs_on_variant_struct() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum A { + /* comment */ + // other + /// comment + #[attr] + $0One { + a: u32 + } +}"#, + r#" +/* comment */ +// other +/// comment +#[attr] +struct One{ + a: u32 +} + +enum A { + One(One) +}"#, + ); + } + + #[test] + fn test_extract_struct_keep_comments_and_attrs_on_variant_tuple() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum A { + /* comment */ + // other + /// comment + #[attr] + $0One(u32, u32) +}"#, + r#" +/* comment */ +// other +/// comment +#[attr] +struct One(u32, u32); + +enum A { + One(One) +}"#, + ); + } + + #[test] + fn test_extract_struct_keep_existing_visibility_named() { + check_assist( + extract_struct_from_enum_variant, + "enum A { $0One{ a: u32, pub(crate) b: u32, pub(super) c: u32, d: u32 } }", + r#" +struct One{ a: u32, pub(crate) b: u32, pub(super) c: u32, d: u32 } + +enum A { One(One) }"#, + ); + } + + #[test] + fn test_extract_struct_keep_existing_visibility_tuple() { + check_assist( + extract_struct_from_enum_variant, + "enum A { $0One(u32, pub(crate) u32, pub(super) u32, u32) }", + r#" +struct One(u32, pub(crate) u32, pub(super) u32, u32); + +enum A { One(One) }"#, + ); + } + + #[test] + fn test_extract_enum_variant_name_value_namespace() { + check_assist( + extract_struct_from_enum_variant, + r#"const One: () = (); +enum A { $0One(u32, u32) }"#, + r#"const One: () = (); +struct One(u32, u32); + +enum A { One(One) }"#, + ); + } + + #[test] + fn test_extract_struct_no_visibility() { + check_assist( + extract_struct_from_enum_variant, + "enum A { $0One(u32, u32) }", + r#" +struct One(u32, u32); + +enum A { One(One) }"#, + ); + } + + #[test] + fn test_extract_struct_pub_visibility() { + check_assist( + extract_struct_from_enum_variant, + "pub enum A { $0One(u32, u32) }", + r#" +pub struct One(pub u32, pub u32); + +pub enum A { One(One) }"#, + ); + } + + #[test] + fn test_extract_struct_pub_in_mod_visibility() { + check_assist( + extract_struct_from_enum_variant, + "pub(in something) enum A { $0One{ a: u32, b: u32 } }", + r#" +pub(in something) struct One{ pub(in something) a: u32, pub(in something) b: u32 } + +pub(in something) enum A { One(One) }"#, + ); + } + + #[test] + fn test_extract_struct_pub_crate_visibility() { + check_assist( + extract_struct_from_enum_variant, + "pub(crate) enum A { $0One{ a: u32, b: u32, c: u32 } }", + r#" +pub(crate) struct One{ pub(crate) a: u32, pub(crate) b: u32, pub(crate) c: u32 } + +pub(crate) enum A { One(One) }"#, + ); + } + + #[test] + fn test_extract_struct_with_complex_imports() { + check_assist( + extract_struct_from_enum_variant, + r#"mod my_mod { + fn another_fn() { + let m = my_other_mod::MyEnum::MyField(1, 1); + } + + pub mod my_other_mod { + fn another_fn() { + let m = MyEnum::MyField(1, 1); + } + + pub enum MyEnum { + $0MyField(u8, u8), + } + } +} + +fn another_fn() { + let m = my_mod::my_other_mod::MyEnum::MyField(1, 1); +}"#, + r#"use my_mod::my_other_mod::MyField; + +mod my_mod { + use self::my_other_mod::MyField; + + fn another_fn() { + let m = my_other_mod::MyEnum::MyField(MyField(1, 1)); + } + + pub mod my_other_mod { + fn another_fn() { + let m = MyEnum::MyField(MyField(1, 1)); + } + + pub struct MyField(pub u8, pub u8); + + pub enum MyEnum { + MyField(MyField), + } + } +} + +fn another_fn() { + let m = my_mod::my_other_mod::MyEnum::MyField(MyField(1, 1)); +}"#, + ); + } + + #[test] + fn extract_record_fix_references() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum E { + $0V { i: i32, j: i32 } +} + +fn f() { + let E::V { i, j } = E::V { i: 9, j: 2 }; +} +"#, + r#" +struct V{ i: i32, j: i32 } + +enum E { + V(V) +} + +fn f() { + let E::V(V { i, j }) = E::V(V { i: 9, j: 2 }); +} +"#, + ) + } + + #[test] + fn extract_record_fix_references2() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum E { + $0V(i32, i32) +} + +fn f() { + let E::V(i, j) = E::V(9, 2); +} +"#, + r#" +struct V(i32, i32); + +enum E { + V(V) +} + +fn f() { + let E::V(V(i, j)) = E::V(V(9, 2)); +} +"#, + ) + } + + #[test] + fn test_several_files() { + check_assist( + extract_struct_from_enum_variant, + r#" +//- /main.rs +enum E { + $0V(i32, i32) +} +mod foo; + +//- /foo.rs +use crate::E; +fn f() { + let e = E::V(9, 2); +} +"#, + r#" +//- /main.rs +struct V(i32, i32); + +enum E { + V(V) +} +mod foo; + +//- /foo.rs +use crate::{E, V}; +fn f() { + let e = E::V(V(9, 2)); +} +"#, + ) + } + + #[test] + fn test_several_files_record() { + check_assist( + extract_struct_from_enum_variant, + r#" +//- /main.rs +enum E { + $0V { i: i32, j: i32 } +} +mod foo; + +//- /foo.rs +use crate::E; +fn f() { + let e = E::V { i: 9, j: 2 }; +} +"#, + r#" +//- /main.rs +struct V{ i: i32, j: i32 } + +enum E { + V(V) +} +mod foo; + +//- /foo.rs +use crate::{E, V}; +fn f() { + let e = E::V(V { i: 9, j: 2 }); +} +"#, + ) + } + + #[test] + fn test_extract_struct_record_nested_call_exp() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum A { $0One { a: u32, b: u32 } } + +struct B(A); + +fn foo() { + let _ = B(A::One { a: 1, b: 2 }); +} +"#, + r#" +struct One{ a: u32, b: u32 } + +enum A { One(One) } + +struct B(A); + +fn foo() { + let _ = B(A::One(One { a: 1, b: 2 })); +} +"#, + ); + } + + #[test] + fn test_extract_enum_not_applicable_for_element_with_no_fields() { + check_assist_not_applicable(extract_struct_from_enum_variant, r#"enum A { $0One }"#); + } + + #[test] + fn test_extract_enum_not_applicable_if_struct_exists() { + cov_mark::check!(test_extract_enum_not_applicable_if_struct_exists); + check_assist_not_applicable( + extract_struct_from_enum_variant, + r#" +struct One; +enum A { $0One(u8, u32) } +"#, + ); + } + + #[test] + fn test_extract_not_applicable_one_field() { + check_assist_not_applicable(extract_struct_from_enum_variant, r"enum A { $0One(u32) }"); + } + + #[test] + fn test_extract_not_applicable_no_field_tuple() { + check_assist_not_applicable(extract_struct_from_enum_variant, r"enum A { $0None() }"); + } + + #[test] + fn test_extract_not_applicable_no_field_named() { + check_assist_not_applicable(extract_struct_from_enum_variant, r"enum A { $0None {} }"); + } + + #[test] + fn test_extract_struct_only_copies_needed_generics() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum X<'a, 'b, 'x> { + $0A { a: &'a &'x mut () }, + B { b: &'b () }, + C { c: () }, +} +"#, + r#" +struct A<'a, 'x>{ a: &'a &'x mut () } + +enum X<'a, 'b, 'x> { + A(A<'a, 'x>), + B { b: &'b () }, + C { c: () }, +} +"#, + ); + } + + #[test] + fn test_extract_struct_with_liftime_type_const() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum X<'b, T, V, const C: usize> { + $0A { a: T, b: X<'b>, c: [u8; C] }, + D { d: V }, +} +"#, + r#" +struct A<'b, T, const C: usize>{ a: T, b: X<'b>, c: [u8; C] } + +enum X<'b, T, V, const C: usize> { + A(A<'b, T, C>), + D { d: V }, +} +"#, + ); + } + + #[test] + fn test_extract_struct_without_generics() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum X<'a, 'b> { + A { a: &'a () }, + B { b: &'b () }, + $0C { c: () }, +} +"#, + r#" +struct C{ c: () } + +enum X<'a, 'b> { + A { a: &'a () }, + B { b: &'b () }, + C(C), +} +"#, + ); + } + + #[test] + fn test_extract_struct_keeps_trait_bounds() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum En { + $0A { a: T }, + B { b: V }, +} +"#, + r#" +struct A{ a: T } + +enum En { + A(A), + B { b: V }, +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_type_alias.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_type_alias.rs new file mode 100644 index 000000000..af584cdb4 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_type_alias.rs @@ -0,0 +1,360 @@ +use either::Either; +use ide_db::syntax_helpers::node_ext::walk_ty; +use itertools::Itertools; +use syntax::{ + ast::{self, edit::IndentLevel, AstNode, HasGenericParams, HasName}, + match_ast, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: extract_type_alias +// +// Extracts the selected type as a type alias. +// +// ``` +// struct S { +// field: $0(u8, u8, u8)$0, +// } +// ``` +// -> +// ``` +// type $0Type = (u8, u8, u8); +// +// struct S { +// field: Type, +// } +// ``` +pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + if ctx.has_empty_selection() { + return None; + } + + let ty = ctx.find_node_at_range::()?; + let item = ty.syntax().ancestors().find_map(ast::Item::cast)?; + let assoc_owner = item.syntax().ancestors().nth(2).and_then(|it| { + match_ast! { + match it { + ast::Trait(tr) => Some(Either::Left(tr)), + ast::Impl(impl_) => Some(Either::Right(impl_)), + _ => None, + } + } + }); + let node = assoc_owner.as_ref().map_or_else( + || item.syntax(), + |impl_| impl_.as_ref().either(AstNode::syntax, AstNode::syntax), + ); + let insert_pos = node.text_range().start(); + let target = ty.syntax().text_range(); + + acc.add( + AssistId("extract_type_alias", AssistKind::RefactorExtract), + "Extract type as type alias", + target, + |builder| { + let mut known_generics = match item.generic_param_list() { + Some(it) => it.generic_params().collect(), + None => Vec::new(), + }; + if let Some(it) = assoc_owner.as_ref().and_then(|it| match it { + Either::Left(it) => it.generic_param_list(), + Either::Right(it) => it.generic_param_list(), + }) { + known_generics.extend(it.generic_params()); + } + let generics = collect_used_generics(&ty, &known_generics); + + let replacement = if !generics.is_empty() { + format!( + "Type<{}>", + generics.iter().format_with(", ", |generic, f| { + match generic { + ast::GenericParam::ConstParam(cp) => f(&cp.name().unwrap()), + ast::GenericParam::LifetimeParam(lp) => f(&lp.lifetime().unwrap()), + ast::GenericParam::TypeParam(tp) => f(&tp.name().unwrap()), + } + }) + ) + } else { + String::from("Type") + }; + builder.replace(target, replacement); + + let indent = IndentLevel::from_node(node); + let generics = if !generics.is_empty() { + format!("<{}>", generics.iter().format(", ")) + } else { + String::new() + }; + match ctx.config.snippet_cap { + Some(cap) => { + builder.insert_snippet( + cap, + insert_pos, + format!("type $0Type{} = {};\n\n{}", generics, ty, indent), + ); + } + None => { + builder.insert( + insert_pos, + format!("type Type{} = {};\n\n{}", generics, ty, indent), + ); + } + } + }, + ) +} + +fn collect_used_generics<'gp>( + ty: &ast::Type, + known_generics: &'gp [ast::GenericParam], +) -> Vec<&'gp ast::GenericParam> { + // can't use a closure -> closure here cause lifetime inference fails for that + fn find_lifetime(text: &str) -> impl Fn(&&ast::GenericParam) -> bool + '_ { + move |gp: &&ast::GenericParam| match gp { + ast::GenericParam::LifetimeParam(lp) => { + lp.lifetime().map_or(false, |lt| lt.text() == text) + } + _ => false, + } + } + + let mut generics = Vec::new(); + walk_ty(ty, &mut |ty| match ty { + ast::Type::PathType(ty) => { + if let Some(path) = ty.path() { + if let Some(name_ref) = path.as_single_name_ref() { + if let Some(param) = known_generics.iter().find(|gp| { + match gp { + ast::GenericParam::ConstParam(cp) => cp.name(), + ast::GenericParam::TypeParam(tp) => tp.name(), + _ => None, + } + .map_or(false, |n| n.text() == name_ref.text()) + }) { + generics.push(param); + } + } + generics.extend( + path.segments() + .filter_map(|seg| seg.generic_arg_list()) + .flat_map(|it| it.generic_args()) + .filter_map(|it| match it { + ast::GenericArg::LifetimeArg(lt) => { + let lt = lt.lifetime()?; + known_generics.iter().find(find_lifetime(<.text())) + } + _ => None, + }), + ); + } + } + ast::Type::ImplTraitType(impl_ty) => { + if let Some(it) = impl_ty.type_bound_list() { + generics.extend( + it.bounds() + .filter_map(|it| it.lifetime()) + .filter_map(|lt| known_generics.iter().find(find_lifetime(<.text()))), + ); + } + } + ast::Type::DynTraitType(dyn_ty) => { + if let Some(it) = dyn_ty.type_bound_list() { + generics.extend( + it.bounds() + .filter_map(|it| it.lifetime()) + .filter_map(|lt| known_generics.iter().find(find_lifetime(<.text()))), + ); + } + } + ast::Type::RefType(ref_) => generics.extend( + ref_.lifetime().and_then(|lt| known_generics.iter().find(find_lifetime(<.text()))), + ), + _ => (), + }); + // stable resort to lifetime, type, const + generics.sort_by_key(|gp| match gp { + ast::GenericParam::ConstParam(_) => 2, + ast::GenericParam::LifetimeParam(_) => 0, + ast::GenericParam::TypeParam(_) => 1, + }); + generics +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_not_applicable_without_selection() { + check_assist_not_applicable( + extract_type_alias, + r" +struct S { + field: $0(u8, u8, u8), +} + ", + ); + } + + #[test] + fn test_simple_types() { + check_assist( + extract_type_alias, + r" +struct S { + field: $0u8$0, +} + ", + r#" +type $0Type = u8; + +struct S { + field: Type, +} + "#, + ); + } + + #[test] + fn test_generic_type_arg() { + check_assist( + extract_type_alias, + r" +fn generic() {} + +fn f() { + generic::<$0()$0>(); +} + ", + r#" +fn generic() {} + +type $0Type = (); + +fn f() { + generic::(); +} + "#, + ); + } + + #[test] + fn test_inner_type_arg() { + check_assist( + extract_type_alias, + r" +struct Vec {} +struct S { + v: Vec$0>>, +} + ", + r#" +struct Vec {} +type $0Type = Vec; + +struct S { + v: Vec>, +} + "#, + ); + } + + #[test] + fn test_extract_inner_type() { + check_assist( + extract_type_alias, + r" +struct S { + field: ($0u8$0,), +} + ", + r#" +type $0Type = u8; + +struct S { + field: (Type,), +} + "#, + ); + } + + #[test] + fn extract_from_impl_or_trait() { + // When invoked in an impl/trait, extracted type alias should be placed next to the + // impl/trait, not inside. + check_assist( + extract_type_alias, + r#" +impl S { + fn f() -> $0(u8, u8)$0 {} +} + "#, + r#" +type $0Type = (u8, u8); + +impl S { + fn f() -> Type {} +} + "#, + ); + check_assist( + extract_type_alias, + r#" +trait Tr { + fn f() -> $0(u8, u8)$0 {} +} + "#, + r#" +type $0Type = (u8, u8); + +trait Tr { + fn f() -> Type {} +} + "#, + ); + } + + #[test] + fn indentation() { + check_assist( + extract_type_alias, + r#" +mod m { + fn f() -> $0u8$0 {} +} + "#, + r#" +mod m { + type $0Type = u8; + + fn f() -> Type {} +} + "#, + ); + } + + #[test] + fn generics() { + check_assist( + extract_type_alias, + r#" +struct Struct; +impl<'outer, Outer, const OUTER: usize> () { + fn func<'inner, Inner, const INNER: usize>(_: $0&(Struct, Struct, Outer, &'inner (), Inner, &'outer ())$0) {} +} +"#, + r#" +struct Struct; +type $0Type<'inner, 'outer, Outer, Inner, const INNER: usize, const OUTER: usize> = &(Struct, Struct, Outer, &'inner (), Inner, &'outer ()); + +impl<'outer, Outer, const OUTER: usize> () { + fn func<'inner, Inner, const INNER: usize>(_: Type<'inner, 'outer, Outer, Inner, INNER, OUTER>) {} +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_variable.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_variable.rs new file mode 100644 index 000000000..3596b6f82 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_variable.rs @@ -0,0 +1,1279 @@ +use stdx::format_to; +use syntax::{ + ast::{self, AstNode}, + NodeOrToken, + SyntaxKind::{ + BLOCK_EXPR, BREAK_EXPR, CLOSURE_EXPR, COMMENT, LOOP_EXPR, MATCH_ARM, MATCH_GUARD, + PATH_EXPR, RETURN_EXPR, + }, + SyntaxNode, +}; + +use crate::{utils::suggest_name, AssistContext, AssistId, AssistKind, Assists}; + +// Assist: extract_variable +// +// Extracts subexpression into a variable. +// +// ``` +// fn main() { +// $0(1 + 2)$0 * 4; +// } +// ``` +// -> +// ``` +// fn main() { +// let $0var_name = (1 + 2); +// var_name * 4; +// } +// ``` +pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + if ctx.has_empty_selection() { + return None; + } + + let node = match ctx.covering_element() { + NodeOrToken::Node(it) => it, + NodeOrToken::Token(it) if it.kind() == COMMENT => { + cov_mark::hit!(extract_var_in_comment_is_not_applicable); + return None; + } + NodeOrToken::Token(it) => it.parent()?, + }; + let node = node.ancestors().take_while(|anc| anc.text_range() == node.text_range()).last()?; + let to_extract = node + .descendants() + .take_while(|it| ctx.selection_trimmed().contains_range(it.text_range())) + .find_map(valid_target_expr)?; + + if let Some(ty_info) = ctx.sema.type_of_expr(&to_extract) { + if ty_info.adjusted().is_unit() { + return None; + } + } + + let reference_modifier = match get_receiver_type(ctx, &to_extract) { + Some(receiver_type) if receiver_type.is_mutable_reference() => "&mut ", + Some(receiver_type) if receiver_type.is_reference() => "&", + _ => "", + }; + + let parent_ref_expr = to_extract.syntax().parent().and_then(ast::RefExpr::cast); + let var_modifier = match parent_ref_expr { + Some(expr) if expr.mut_token().is_some() => "mut ", + _ => "", + }; + + let anchor = Anchor::from(&to_extract)?; + let indent = anchor.syntax().prev_sibling_or_token()?.as_token()?.clone(); + let target = to_extract.syntax().text_range(); + acc.add( + AssistId("extract_variable", AssistKind::RefactorExtract), + "Extract into variable", + target, + move |edit| { + let field_shorthand = + match to_extract.syntax().parent().and_then(ast::RecordExprField::cast) { + Some(field) => field.name_ref(), + None => None, + }; + + let mut buf = String::new(); + + let var_name = match &field_shorthand { + Some(it) => it.to_string(), + None => suggest_name::for_variable(&to_extract, &ctx.sema), + }; + let expr_range = match &field_shorthand { + Some(it) => it.syntax().text_range().cover(to_extract.syntax().text_range()), + None => to_extract.syntax().text_range(), + }; + + match anchor { + Anchor::Before(_) | Anchor::Replace(_) => { + format_to!(buf, "let {}{} = {}", var_modifier, var_name, reference_modifier) + } + Anchor::WrapInBlock(_) => { + format_to!(buf, "{{ let {} = {}", var_name, reference_modifier) + } + }; + format_to!(buf, "{}", to_extract.syntax()); + + if let Anchor::Replace(stmt) = anchor { + cov_mark::hit!(test_extract_var_expr_stmt); + if stmt.semicolon_token().is_none() { + buf.push(';'); + } + match ctx.config.snippet_cap { + Some(cap) => { + let snip = buf.replace( + &format!("let {}{}", var_modifier, var_name), + &format!("let {}$0{}", var_modifier, var_name), + ); + edit.replace_snippet(cap, expr_range, snip) + } + None => edit.replace(expr_range, buf), + } + return; + } + + buf.push(';'); + + // We want to maintain the indent level, + // but we do not want to duplicate possible + // extra newlines in the indent block + let text = indent.text(); + if text.starts_with('\n') { + buf.push('\n'); + buf.push_str(text.trim_start_matches('\n')); + } else { + buf.push_str(text); + } + + edit.replace(expr_range, var_name.clone()); + let offset = anchor.syntax().text_range().start(); + match ctx.config.snippet_cap { + Some(cap) => { + let snip = buf.replace( + &format!("let {}{}", var_modifier, var_name), + &format!("let {}$0{}", var_modifier, var_name), + ); + edit.insert_snippet(cap, offset, snip) + } + None => edit.insert(offset, buf), + } + + if let Anchor::WrapInBlock(_) = anchor { + edit.insert(anchor.syntax().text_range().end(), " }"); + } + }, + ) +} + +/// Check whether the node is a valid expression which can be extracted to a variable. +/// In general that's true for any expression, but in some cases that would produce invalid code. +fn valid_target_expr(node: SyntaxNode) -> Option { + match node.kind() { + PATH_EXPR | LOOP_EXPR => None, + BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()), + RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()), + BLOCK_EXPR => { + ast::BlockExpr::cast(node).filter(|it| it.is_standalone()).map(ast::Expr::from) + } + _ => ast::Expr::cast(node), + } +} + +fn get_receiver_type(ctx: &AssistContext<'_>, expression: &ast::Expr) -> Option { + let receiver = get_receiver(expression.clone())?; + Some(ctx.sema.type_of_expr(&receiver)?.original()) +} + +/// In the expression `a.b.c.x()`, find `a` +fn get_receiver(expression: ast::Expr) -> Option { + match expression { + ast::Expr::FieldExpr(field) if field.expr().is_some() => { + let nested_expression = &field.expr()?; + get_receiver(nested_expression.to_owned()) + } + _ => Some(expression), + } +} + +#[derive(Debug)] +enum Anchor { + Before(SyntaxNode), + Replace(ast::ExprStmt), + WrapInBlock(SyntaxNode), +} + +impl Anchor { + fn from(to_extract: &ast::Expr) -> Option { + to_extract + .syntax() + .ancestors() + .take_while(|it| !ast::Item::can_cast(it.kind()) || ast::MacroCall::can_cast(it.kind())) + .find_map(|node| { + if ast::MacroCall::can_cast(node.kind()) { + return None; + } + if let Some(expr) = + node.parent().and_then(ast::StmtList::cast).and_then(|it| it.tail_expr()) + { + if expr.syntax() == &node { + cov_mark::hit!(test_extract_var_last_expr); + return Some(Anchor::Before(node)); + } + } + + if let Some(parent) = node.parent() { + if parent.kind() == CLOSURE_EXPR { + cov_mark::hit!(test_extract_var_in_closure_no_block); + return Some(Anchor::WrapInBlock(node)); + } + if parent.kind() == MATCH_ARM { + if node.kind() == MATCH_GUARD { + cov_mark::hit!(test_extract_var_in_match_guard); + } else { + cov_mark::hit!(test_extract_var_in_match_arm_no_block); + return Some(Anchor::WrapInBlock(node)); + } + } + } + + if let Some(stmt) = ast::Stmt::cast(node.clone()) { + if let ast::Stmt::ExprStmt(stmt) = stmt { + if stmt.expr().as_ref() == Some(to_extract) { + return Some(Anchor::Replace(stmt)); + } + } + return Some(Anchor::Before(node)); + } + None + }) + } + + fn syntax(&self) -> &SyntaxNode { + match self { + Anchor::Before(it) | Anchor::WrapInBlock(it) => it, + Anchor::Replace(stmt) => stmt.syntax(), + } + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + use super::*; + + #[test] + fn test_extract_var_simple() { + check_assist( + extract_variable, + r#" +fn foo() { + foo($01 + 1$0); +}"#, + r#" +fn foo() { + let $0var_name = 1 + 1; + foo(var_name); +}"#, + ); + } + + #[test] + fn extract_var_in_comment_is_not_applicable() { + cov_mark::check!(extract_var_in_comment_is_not_applicable); + check_assist_not_applicable(extract_variable, "fn main() { 1 + /* $0comment$0 */ 1; }"); + } + + #[test] + fn test_extract_var_expr_stmt() { + cov_mark::check!(test_extract_var_expr_stmt); + check_assist( + extract_variable, + r#" +fn foo() { + $0 1 + 1$0; +}"#, + r#" +fn foo() { + let $0var_name = 1 + 1; +}"#, + ); + check_assist( + extract_variable, + r" +fn foo() { + $0{ let x = 0; x }$0 + something_else(); +}", + r" +fn foo() { + let $0var_name = { let x = 0; x }; + something_else(); +}", + ); + } + + #[test] + fn test_extract_var_part_of_expr_stmt() { + check_assist( + extract_variable, + r" +fn foo() { + $01$0 + 1; +}", + r" +fn foo() { + let $0var_name = 1; + var_name + 1; +}", + ); + } + + #[test] + fn test_extract_var_last_expr() { + cov_mark::check!(test_extract_var_last_expr); + check_assist( + extract_variable, + r#" +fn foo() { + bar($01 + 1$0) +} +"#, + r#" +fn foo() { + let $0var_name = 1 + 1; + bar(var_name) +} +"#, + ); + check_assist( + extract_variable, + r#" +fn foo() -> i32 { + $0bar(1 + 1)$0 +} + +fn bar(i: i32) -> i32 { + i +} +"#, + r#" +fn foo() -> i32 { + let $0bar = bar(1 + 1); + bar +} + +fn bar(i: i32) -> i32 { + i +} +"#, + ) + } + + #[test] + fn test_extract_var_in_match_arm_no_block() { + cov_mark::check!(test_extract_var_in_match_arm_no_block); + check_assist( + extract_variable, + r#" +fn main() { + let x = true; + let tuple = match x { + true => ($02 + 2$0, true) + _ => (0, false) + }; +} +"#, + r#" +fn main() { + let x = true; + let tuple = match x { + true => { let $0var_name = 2 + 2; (var_name, true) } + _ => (0, false) + }; +} +"#, + ); + } + + #[test] + fn test_extract_var_in_match_arm_with_block() { + check_assist( + extract_variable, + r#" +fn main() { + let x = true; + let tuple = match x { + true => { + let y = 1; + ($02 + y$0, true) + } + _ => (0, false) + }; +} +"#, + r#" +fn main() { + let x = true; + let tuple = match x { + true => { + let y = 1; + let $0var_name = 2 + y; + (var_name, true) + } + _ => (0, false) + }; +} +"#, + ); + } + + #[test] + fn test_extract_var_in_match_guard() { + cov_mark::check!(test_extract_var_in_match_guard); + check_assist( + extract_variable, + r#" +fn main() { + match () { + () if $010 > 0$0 => 1 + _ => 2 + }; +} +"#, + r#" +fn main() { + let $0var_name = 10 > 0; + match () { + () if var_name => 1 + _ => 2 + }; +} +"#, + ); + } + + #[test] + fn test_extract_var_in_closure_no_block() { + cov_mark::check!(test_extract_var_in_closure_no_block); + check_assist( + extract_variable, + r#" +fn main() { + let lambda = |x: u32| $0x * 2$0; +} +"#, + r#" +fn main() { + let lambda = |x: u32| { let $0var_name = x * 2; var_name }; +} +"#, + ); + } + + #[test] + fn test_extract_var_in_closure_with_block() { + check_assist( + extract_variable, + r#" +fn main() { + let lambda = |x: u32| { $0x * 2$0 }; +} +"#, + r#" +fn main() { + let lambda = |x: u32| { let $0var_name = x * 2; var_name }; +} +"#, + ); + } + + #[test] + fn test_extract_var_path_simple() { + check_assist( + extract_variable, + " +fn main() { + let o = $0Some(true)$0; +} +", + " +fn main() { + let $0var_name = Some(true); + let o = var_name; +} +", + ); + } + + #[test] + fn test_extract_var_path_method() { + check_assist( + extract_variable, + " +fn main() { + let v = $0bar.foo()$0; +} +", + " +fn main() { + let $0foo = bar.foo(); + let v = foo; +} +", + ); + } + + #[test] + fn test_extract_var_return() { + check_assist( + extract_variable, + " +fn foo() -> u32 { + $0return 2 + 2$0; +} +", + " +fn foo() -> u32 { + let $0var_name = 2 + 2; + return var_name; +} +", + ); + } + + #[test] + fn test_extract_var_does_not_add_extra_whitespace() { + check_assist( + extract_variable, + " +fn foo() -> u32 { + + + $0return 2 + 2$0; +} +", + " +fn foo() -> u32 { + + + let $0var_name = 2 + 2; + return var_name; +} +", + ); + + check_assist( + extract_variable, + " +fn foo() -> u32 { + + $0return 2 + 2$0; +} +", + " +fn foo() -> u32 { + + let $0var_name = 2 + 2; + return var_name; +} +", + ); + + check_assist( + extract_variable, + " +fn foo() -> u32 { + let foo = 1; + + // bar + + + $0return 2 + 2$0; +} +", + " +fn foo() -> u32 { + let foo = 1; + + // bar + + + let $0var_name = 2 + 2; + return var_name; +} +", + ); + } + + #[test] + fn test_extract_var_break() { + check_assist( + extract_variable, + " +fn main() { + let result = loop { + $0break 2 + 2$0; + }; +} +", + " +fn main() { + let result = loop { + let $0var_name = 2 + 2; + break var_name; + }; +} +", + ); + } + + #[test] + fn test_extract_var_for_cast() { + check_assist( + extract_variable, + " +fn main() { + let v = $00f32 as u32$0; +} +", + " +fn main() { + let $0var_name = 0f32 as u32; + let v = var_name; +} +", + ); + } + + #[test] + fn extract_var_field_shorthand() { + check_assist( + extract_variable, + r#" +struct S { + foo: i32 +} + +fn main() { + S { foo: $01 + 1$0 } +} +"#, + r#" +struct S { + foo: i32 +} + +fn main() { + let $0foo = 1 + 1; + S { foo } +} +"#, + ) + } + + #[test] + fn extract_var_name_from_type() { + check_assist( + extract_variable, + r#" +struct Test(i32); + +fn foo() -> Test { + $0{ Test(10) }$0 +} +"#, + r#" +struct Test(i32); + +fn foo() -> Test { + let $0test = { Test(10) }; + test +} +"#, + ) + } + + #[test] + fn extract_var_name_from_parameter() { + check_assist( + extract_variable, + r#" +fn bar(test: u32, size: u32) + +fn foo() { + bar(1, $01+1$0); +} +"#, + r#" +fn bar(test: u32, size: u32) + +fn foo() { + let $0size = 1+1; + bar(1, size); +} +"#, + ) + } + + #[test] + fn extract_var_parameter_name_has_precedence_over_type() { + check_assist( + extract_variable, + r#" +struct TextSize(u32); +fn bar(test: u32, size: TextSize) + +fn foo() { + bar(1, $0{ TextSize(1+1) }$0); +} +"#, + r#" +struct TextSize(u32); +fn bar(test: u32, size: TextSize) + +fn foo() { + let $0size = { TextSize(1+1) }; + bar(1, size); +} +"#, + ) + } + + #[test] + fn extract_var_name_from_function() { + check_assist( + extract_variable, + r#" +fn is_required(test: u32, size: u32) -> bool + +fn foo() -> bool { + $0is_required(1, 2)$0 +} +"#, + r#" +fn is_required(test: u32, size: u32) -> bool + +fn foo() -> bool { + let $0is_required = is_required(1, 2); + is_required +} +"#, + ) + } + + #[test] + fn extract_var_name_from_method() { + check_assist( + extract_variable, + r#" +struct S; +impl S { + fn bar(&self, n: u32) -> u32 { n } +} + +fn foo() -> u32 { + $0S.bar(1)$0 +} +"#, + r#" +struct S; +impl S { + fn bar(&self, n: u32) -> u32 { n } +} + +fn foo() -> u32 { + let $0bar = S.bar(1); + bar +} +"#, + ) + } + + #[test] + fn extract_var_name_from_method_param() { + check_assist( + extract_variable, + r#" +struct S; +impl S { + fn bar(&self, n: u32, size: u32) { n } +} + +fn foo() { + S.bar($01 + 1$0, 2) +} +"#, + r#" +struct S; +impl S { + fn bar(&self, n: u32, size: u32) { n } +} + +fn foo() { + let $0n = 1 + 1; + S.bar(n, 2) +} +"#, + ) + } + + #[test] + fn extract_var_name_from_ufcs_method_param() { + check_assist( + extract_variable, + r#" +struct S; +impl S { + fn bar(&self, n: u32, size: u32) { n } +} + +fn foo() { + S::bar(&S, $01 + 1$0, 2) +} +"#, + r#" +struct S; +impl S { + fn bar(&self, n: u32, size: u32) { n } +} + +fn foo() { + let $0n = 1 + 1; + S::bar(&S, n, 2) +} +"#, + ) + } + + #[test] + fn extract_var_parameter_name_has_precedence_over_function() { + check_assist( + extract_variable, + r#" +fn bar(test: u32, size: u32) + +fn foo() { + bar(1, $0symbol_size(1, 2)$0); +} +"#, + r#" +fn bar(test: u32, size: u32) + +fn foo() { + let $0size = symbol_size(1, 2); + bar(1, size); +} +"#, + ) + } + + #[test] + fn extract_macro_call() { + check_assist( + extract_variable, + r" +struct Vec; +macro_rules! vec { + () => {Vec} +} +fn main() { + let _ = $0vec![]$0; +} +", + r" +struct Vec; +macro_rules! vec { + () => {Vec} +} +fn main() { + let $0vec = vec![]; + let _ = vec; +} +", + ); + } + + #[test] + fn test_extract_var_for_return_not_applicable() { + check_assist_not_applicable(extract_variable, "fn foo() { $0return$0; } "); + } + + #[test] + fn test_extract_var_for_break_not_applicable() { + check_assist_not_applicable(extract_variable, "fn main() { loop { $0break$0; }; }"); + } + + #[test] + fn test_extract_var_unit_expr_not_applicable() { + check_assist_not_applicable( + extract_variable, + r#" +fn foo() { + let mut i = 3; + $0if i >= 0 { + i += 1; + } else { + i -= 1; + }$0 +}"#, + ); + } + + // FIXME: This is not quite correct, but good enough(tm) for the sorting heuristic + #[test] + fn extract_var_target() { + check_assist_target(extract_variable, "fn foo() -> u32 { $0return 2 + 2$0; }", "2 + 2"); + + check_assist_target( + extract_variable, + " +fn main() { + let x = true; + let tuple = match x { + true => ($02 + 2$0, true) + _ => (0, false) + }; +} +", + "2 + 2", + ); + } + + #[test] + fn extract_var_no_block_body() { + check_assist_not_applicable( + extract_variable, + r" +const X: usize = $0100$0; +", + ); + } + + #[test] + fn test_extract_var_mutable_reference_parameter() { + check_assist( + extract_variable, + r#" +struct S { + vec: Vec +} + +fn foo(s: &mut S) { + $0s.vec$0.push(0); +}"#, + r#" +struct S { + vec: Vec +} + +fn foo(s: &mut S) { + let $0vec = &mut s.vec; + vec.push(0); +}"#, + ); + } + + #[test] + fn test_extract_var_mutable_reference_parameter_deep_nesting() { + check_assist( + extract_variable, + r#" +struct Y { + field: X +} +struct X { + field: S +} +struct S { + vec: Vec +} + +fn foo(f: &mut Y) { + $0f.field.field.vec$0.push(0); +}"#, + r#" +struct Y { + field: X +} +struct X { + field: S +} +struct S { + vec: Vec +} + +fn foo(f: &mut Y) { + let $0vec = &mut f.field.field.vec; + vec.push(0); +}"#, + ); + } + + #[test] + fn test_extract_var_reference_parameter() { + check_assist( + extract_variable, + r#" +struct X; + +impl X { + fn do_thing(&self) { + + } +} + +struct S { + sub: X +} + +fn foo(s: &S) { + $0s.sub$0.do_thing(); +}"#, + r#" +struct X; + +impl X { + fn do_thing(&self) { + + } +} + +struct S { + sub: X +} + +fn foo(s: &S) { + let $0x = &s.sub; + x.do_thing(); +}"#, + ); + } + + #[test] + fn test_extract_var_reference_parameter_deep_nesting() { + check_assist( + extract_variable, + r#" +struct Z; +impl Z { + fn do_thing(&self) { + + } +} + +struct Y { + field: Z +} + +struct X { + field: Y +} + +struct S { + sub: X +} + +fn foo(s: &S) { + $0s.sub.field.field$0.do_thing(); +}"#, + r#" +struct Z; +impl Z { + fn do_thing(&self) { + + } +} + +struct Y { + field: Z +} + +struct X { + field: Y +} + +struct S { + sub: X +} + +fn foo(s: &S) { + let $0z = &s.sub.field.field; + z.do_thing(); +}"#, + ); + } + + #[test] + fn test_extract_var_regular_parameter() { + check_assist( + extract_variable, + r#" +struct X; + +impl X { + fn do_thing(&self) { + + } +} + +struct S { + sub: X +} + +fn foo(s: S) { + $0s.sub$0.do_thing(); +}"#, + r#" +struct X; + +impl X { + fn do_thing(&self) { + + } +} + +struct S { + sub: X +} + +fn foo(s: S) { + let $0x = s.sub; + x.do_thing(); +}"#, + ); + } + + #[test] + fn test_extract_var_mutable_reference_local() { + check_assist( + extract_variable, + r#" +struct X; + +struct S { + sub: X +} + +impl S { + fn new() -> S { + S { + sub: X::new() + } + } +} + +impl X { + fn new() -> X { + X { } + } + fn do_thing(&self) { + + } +} + + +fn foo() { + let local = &mut S::new(); + $0local.sub$0.do_thing(); +}"#, + r#" +struct X; + +struct S { + sub: X +} + +impl S { + fn new() -> S { + S { + sub: X::new() + } + } +} + +impl X { + fn new() -> X { + X { } + } + fn do_thing(&self) { + + } +} + + +fn foo() { + let local = &mut S::new(); + let $0x = &mut local.sub; + x.do_thing(); +}"#, + ); + } + + #[test] + fn test_extract_var_reference_local() { + check_assist( + extract_variable, + r#" +struct X; + +struct S { + sub: X +} + +impl S { + fn new() -> S { + S { + sub: X::new() + } + } +} + +impl X { + fn new() -> X { + X { } + } + fn do_thing(&self) { + + } +} + + +fn foo() { + let local = &S::new(); + $0local.sub$0.do_thing(); +}"#, + r#" +struct X; + +struct S { + sub: X +} + +impl S { + fn new() -> S { + S { + sub: X::new() + } + } +} + +impl X { + fn new() -> X { + X { } + } + fn do_thing(&self) { + + } +} + + +fn foo() { + let local = &S::new(); + let $0x = &local.sub; + x.do_thing(); +}"#, + ); + } + + #[test] + fn test_extract_var_for_mutable_borrow() { + check_assist( + extract_variable, + r#" +fn foo() { + let v = &mut $00$0; +}"#, + r#" +fn foo() { + let mut $0var_name = 0; + let v = &mut var_name; +}"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/fix_visibility.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/fix_visibility.rs new file mode 100644 index 000000000..b33846f54 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/fix_visibility.rs @@ -0,0 +1,606 @@ +use hir::{db::HirDatabase, HasSource, HasVisibility, PathResolution}; +use ide_db::base_db::FileId; +use syntax::{ + ast::{self, HasVisibility as _}, + AstNode, TextRange, TextSize, +}; + +use crate::{utils::vis_offset, AssistContext, AssistId, AssistKind, Assists}; + +// FIXME: this really should be a fix for diagnostic, rather than an assist. + +// Assist: fix_visibility +// +// Makes inaccessible item public. +// +// ``` +// mod m { +// fn frobnicate() {} +// } +// fn main() { +// m::frobnicate$0() {} +// } +// ``` +// -> +// ``` +// mod m { +// $0pub(crate) fn frobnicate() {} +// } +// fn main() { +// m::frobnicate() {} +// } +// ``` +pub(crate) fn fix_visibility(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + add_vis_to_referenced_module_def(acc, ctx) + .or_else(|| add_vis_to_referenced_record_field(acc, ctx)) +} + +fn add_vis_to_referenced_module_def(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let path: ast::Path = ctx.find_node_at_offset()?; + let path_res = ctx.sema.resolve_path(&path)?; + let def = match path_res { + PathResolution::Def(def) => def, + _ => return None, + }; + + let current_module = ctx.sema.scope(path.syntax())?.module(); + let target_module = def.module(ctx.db())?; + + if def.visibility(ctx.db()).is_visible_from(ctx.db(), current_module.into()) { + return None; + }; + + let (offset, current_visibility, target, target_file, target_name) = + target_data_for_def(ctx.db(), def)?; + + let missing_visibility = + if current_module.krate() == target_module.krate() { "pub(crate)" } else { "pub" }; + + let assist_label = match target_name { + None => format!("Change visibility to {}", missing_visibility), + Some(name) => format!("Change visibility of {} to {}", name, missing_visibility), + }; + + acc.add(AssistId("fix_visibility", AssistKind::QuickFix), assist_label, target, |builder| { + builder.edit_file(target_file); + match ctx.config.snippet_cap { + Some(cap) => match current_visibility { + Some(current_visibility) => builder.replace_snippet( + cap, + current_visibility.syntax().text_range(), + format!("$0{}", missing_visibility), + ), + None => builder.insert_snippet(cap, offset, format!("$0{} ", missing_visibility)), + }, + None => match current_visibility { + Some(current_visibility) => { + builder.replace(current_visibility.syntax().text_range(), missing_visibility) + } + None => builder.insert(offset, format!("{} ", missing_visibility)), + }, + } + }) +} + +fn add_vis_to_referenced_record_field(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let record_field: ast::RecordExprField = ctx.find_node_at_offset()?; + let (record_field_def, _, _) = ctx.sema.resolve_record_field(&record_field)?; + + let current_module = ctx.sema.scope(record_field.syntax())?.module(); + let visibility = record_field_def.visibility(ctx.db()); + if visibility.is_visible_from(ctx.db(), current_module.into()) { + return None; + } + + let parent = record_field_def.parent_def(ctx.db()); + let parent_name = parent.name(ctx.db()); + let target_module = parent.module(ctx.db()); + + let in_file_source = record_field_def.source(ctx.db())?; + let (offset, current_visibility, target) = match in_file_source.value { + hir::FieldSource::Named(it) => { + let s = it.syntax(); + (vis_offset(s), it.visibility(), s.text_range()) + } + hir::FieldSource::Pos(it) => { + let s = it.syntax(); + (vis_offset(s), it.visibility(), s.text_range()) + } + }; + + let missing_visibility = + if current_module.krate() == target_module.krate() { "pub(crate)" } else { "pub" }; + let target_file = in_file_source.file_id.original_file(ctx.db()); + + let target_name = record_field_def.name(ctx.db()); + let assist_label = + format!("Change visibility of {}.{} to {}", parent_name, target_name, missing_visibility); + + acc.add(AssistId("fix_visibility", AssistKind::QuickFix), assist_label, target, |builder| { + builder.edit_file(target_file); + match ctx.config.snippet_cap { + Some(cap) => match current_visibility { + Some(current_visibility) => builder.replace_snippet( + cap, + current_visibility.syntax().text_range(), + format!("$0{}", missing_visibility), + ), + None => builder.insert_snippet(cap, offset, format!("$0{} ", missing_visibility)), + }, + None => match current_visibility { + Some(current_visibility) => { + builder.replace(current_visibility.syntax().text_range(), missing_visibility) + } + None => builder.insert(offset, format!("{} ", missing_visibility)), + }, + } + }) +} + +fn target_data_for_def( + db: &dyn HirDatabase, + def: hir::ModuleDef, +) -> Option<(TextSize, Option, TextRange, FileId, Option)> { + fn offset_target_and_file_id( + db: &dyn HirDatabase, + x: S, + ) -> Option<(TextSize, Option, TextRange, FileId)> + where + S: HasSource, + Ast: AstNode + ast::HasVisibility, + { + let source = x.source(db)?; + let in_file_syntax = source.syntax(); + let file_id = in_file_syntax.file_id; + let syntax = in_file_syntax.value; + let current_visibility = source.value.visibility(); + Some(( + vis_offset(syntax), + current_visibility, + syntax.text_range(), + file_id.original_file(db.upcast()), + )) + } + + let target_name; + let (offset, current_visibility, target, target_file) = match def { + hir::ModuleDef::Function(f) => { + target_name = Some(f.name(db)); + offset_target_and_file_id(db, f)? + } + hir::ModuleDef::Adt(adt) => { + target_name = Some(adt.name(db)); + match adt { + hir::Adt::Struct(s) => offset_target_and_file_id(db, s)?, + hir::Adt::Union(u) => offset_target_and_file_id(db, u)?, + hir::Adt::Enum(e) => offset_target_and_file_id(db, e)?, + } + } + hir::ModuleDef::Const(c) => { + target_name = c.name(db); + offset_target_and_file_id(db, c)? + } + hir::ModuleDef::Static(s) => { + target_name = Some(s.name(db)); + offset_target_and_file_id(db, s)? + } + hir::ModuleDef::Trait(t) => { + target_name = Some(t.name(db)); + offset_target_and_file_id(db, t)? + } + hir::ModuleDef::TypeAlias(t) => { + target_name = Some(t.name(db)); + offset_target_and_file_id(db, t)? + } + hir::ModuleDef::Module(m) => { + target_name = m.name(db); + let in_file_source = m.declaration_source(db)?; + let file_id = in_file_source.file_id.original_file(db.upcast()); + let syntax = in_file_source.value.syntax(); + (vis_offset(syntax), in_file_source.value.visibility(), syntax.text_range(), file_id) + } + // FIXME + hir::ModuleDef::Macro(_) => return None, + // Enum variants can't be private, we can't modify builtin types + hir::ModuleDef::Variant(_) | hir::ModuleDef::BuiltinType(_) => return None, + }; + + Some((offset, current_visibility, target, target_file, target_name)) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn fix_visibility_of_fn() { + check_assist( + fix_visibility, + r"mod foo { fn foo() {} } + fn main() { foo::foo$0() } ", + r"mod foo { $0pub(crate) fn foo() {} } + fn main() { foo::foo() } ", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub fn foo() {} } + fn main() { foo::foo$0() } ", + ) + } + + #[test] + fn fix_visibility_of_adt_in_submodule() { + check_assist( + fix_visibility, + r"mod foo { struct Foo; } + fn main() { foo::Foo$0 } ", + r"mod foo { $0pub(crate) struct Foo; } + fn main() { foo::Foo } ", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub struct Foo; } + fn main() { foo::Foo$0 } ", + ); + check_assist( + fix_visibility, + r"mod foo { enum Foo; } + fn main() { foo::Foo$0 } ", + r"mod foo { $0pub(crate) enum Foo; } + fn main() { foo::Foo } ", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub enum Foo; } + fn main() { foo::Foo$0 } ", + ); + check_assist( + fix_visibility, + r"mod foo { union Foo; } + fn main() { foo::Foo$0 } ", + r"mod foo { $0pub(crate) union Foo; } + fn main() { foo::Foo } ", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub union Foo; } + fn main() { foo::Foo$0 } ", + ); + } + + #[test] + fn fix_visibility_of_adt_in_other_file() { + check_assist( + fix_visibility, + r" +//- /main.rs +mod foo; +fn main() { foo::Foo$0 } + +//- /foo.rs +struct Foo; +", + r"$0pub(crate) struct Foo; +", + ); + } + + #[test] + fn fix_visibility_of_struct_field() { + check_assist( + fix_visibility, + r"mod foo { pub struct Foo { bar: (), } } + fn main() { foo::Foo { $0bar: () }; } ", + r"mod foo { pub struct Foo { $0pub(crate) bar: (), } } + fn main() { foo::Foo { bar: () }; } ", + ); + check_assist( + fix_visibility, + r" +//- /lib.rs +mod foo; +fn main() { foo::Foo { $0bar: () }; } +//- /foo.rs +pub struct Foo { bar: () } +", + r"pub struct Foo { $0pub(crate) bar: () } +", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub struct Foo { pub bar: (), } } + fn main() { foo::Foo { $0bar: () }; } ", + ); + check_assist_not_applicable( + fix_visibility, + r" +//- /lib.rs +mod foo; +fn main() { foo::Foo { $0bar: () }; } +//- /foo.rs +pub struct Foo { pub bar: () } +", + ); + } + + #[test] + fn fix_visibility_of_enum_variant_field() { + // Enum variants, as well as their fields, always get the enum's visibility. In fact, rustc + // rejects any visibility specifiers on them, so this assist should never fire on them. + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub enum Foo { Bar { bar: () } } } + fn main() { foo::Foo::Bar { $0bar: () }; } ", + ); + check_assist_not_applicable( + fix_visibility, + r" +//- /lib.rs +mod foo; +fn main() { foo::Foo::Bar { $0bar: () }; } +//- /foo.rs +pub enum Foo { Bar { bar: () } } +", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub struct Foo { pub bar: (), } } + fn main() { foo::Foo { $0bar: () }; } ", + ); + check_assist_not_applicable( + fix_visibility, + r" +//- /lib.rs +mod foo; +fn main() { foo::Foo { $0bar: () }; } +//- /foo.rs +pub struct Foo { pub bar: () } +", + ); + } + + #[test] + fn fix_visibility_of_union_field() { + check_assist( + fix_visibility, + r"mod foo { pub union Foo { bar: (), } } + fn main() { foo::Foo { $0bar: () }; } ", + r"mod foo { pub union Foo { $0pub(crate) bar: (), } } + fn main() { foo::Foo { bar: () }; } ", + ); + check_assist( + fix_visibility, + r" +//- /lib.rs +mod foo; +fn main() { foo::Foo { $0bar: () }; } +//- /foo.rs +pub union Foo { bar: () } +", + r"pub union Foo { $0pub(crate) bar: () } +", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub union Foo { pub bar: (), } } + fn main() { foo::Foo { $0bar: () }; } ", + ); + check_assist_not_applicable( + fix_visibility, + r" +//- /lib.rs +mod foo; +fn main() { foo::Foo { $0bar: () }; } +//- /foo.rs +pub union Foo { pub bar: () } +", + ); + } + + #[test] + fn fix_visibility_of_const() { + check_assist( + fix_visibility, + r"mod foo { const FOO: () = (); } + fn main() { foo::FOO$0 } ", + r"mod foo { $0pub(crate) const FOO: () = (); } + fn main() { foo::FOO } ", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub const FOO: () = (); } + fn main() { foo::FOO$0 } ", + ); + } + + #[test] + fn fix_visibility_of_static() { + check_assist( + fix_visibility, + r"mod foo { static FOO: () = (); } + fn main() { foo::FOO$0 } ", + r"mod foo { $0pub(crate) static FOO: () = (); } + fn main() { foo::FOO } ", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub static FOO: () = (); } + fn main() { foo::FOO$0 } ", + ); + } + + #[test] + fn fix_visibility_of_trait() { + check_assist( + fix_visibility, + r"mod foo { trait Foo { fn foo(&self) {} } } + fn main() { let x: &dyn foo::$0Foo; } ", + r"mod foo { $0pub(crate) trait Foo { fn foo(&self) {} } } + fn main() { let x: &dyn foo::Foo; } ", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub trait Foo { fn foo(&self) {} } } + fn main() { let x: &dyn foo::Foo$0; } ", + ); + } + + #[test] + fn fix_visibility_of_type_alias() { + check_assist( + fix_visibility, + r"mod foo { type Foo = (); } + fn main() { let x: foo::Foo$0; } ", + r"mod foo { $0pub(crate) type Foo = (); } + fn main() { let x: foo::Foo; } ", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub type Foo = (); } + fn main() { let x: foo::Foo$0; } ", + ); + } + + #[test] + fn fix_visibility_of_module() { + check_assist( + fix_visibility, + r"mod foo { mod bar { fn bar() {} } } + fn main() { foo::bar$0::bar(); } ", + r"mod foo { $0pub(crate) mod bar { fn bar() {} } } + fn main() { foo::bar::bar(); } ", + ); + + check_assist( + fix_visibility, + r" +//- /main.rs +mod foo; +fn main() { foo::bar$0::baz(); } + +//- /foo.rs +mod bar { + pub fn baz() {} +} +", + r"$0pub(crate) mod bar { + pub fn baz() {} +} +", + ); + + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub mod bar { pub fn bar() {} } } + fn main() { foo::bar$0::bar(); } ", + ); + } + + #[test] + fn fix_visibility_of_inline_module_in_other_file() { + check_assist( + fix_visibility, + r" +//- /main.rs +mod foo; +fn main() { foo::bar$0::baz(); } + +//- /foo.rs +mod bar; +//- /foo/bar.rs +pub fn baz() {} +", + r"$0pub(crate) mod bar; +", + ); + } + + #[test] + fn fix_visibility_of_module_declaration_in_other_file() { + check_assist( + fix_visibility, + r" +//- /main.rs +mod foo; +fn main() { foo::bar$0>::baz(); } + +//- /foo.rs +mod bar { + pub fn baz() {} +} +", + r"$0pub(crate) mod bar { + pub fn baz() {} +} +", + ); + } + + #[test] + fn adds_pub_when_target_is_in_another_crate() { + check_assist( + fix_visibility, + r" +//- /main.rs crate:a deps:foo +foo::Bar$0 +//- /lib.rs crate:foo +struct Bar; +", + r"$0pub struct Bar; +", + ) + } + + #[test] + fn replaces_pub_crate_with_pub() { + check_assist( + fix_visibility, + r" +//- /main.rs crate:a deps:foo +foo::Bar$0 +//- /lib.rs crate:foo +pub(crate) struct Bar; +", + r"$0pub struct Bar; +", + ); + check_assist( + fix_visibility, + r" +//- /main.rs crate:a deps:foo +fn main() { + foo::Foo { $0bar: () }; +} +//- /lib.rs crate:foo +pub struct Foo { pub(crate) bar: () } +", + r"pub struct Foo { $0pub bar: () } +", + ); + } + + #[test] + fn fix_visibility_of_reexport() { + // FIXME: broken test, this should fix visibility of the re-export + // rather than the struct. + check_assist( + fix_visibility, + r#" +mod foo { + use bar::Baz; + mod bar { pub(super) struct Baz; } +} +foo::Baz$0 +"#, + r#" +mod foo { + use bar::Baz; + mod bar { $0pub(crate) struct Baz; } +} +foo::Baz +"#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_binexpr.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_binexpr.rs new file mode 100644 index 000000000..2ea6f58fa --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_binexpr.rs @@ -0,0 +1,139 @@ +use syntax::ast::{self, AstNode, BinExpr}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: flip_binexpr +// +// Flips operands of a binary expression. +// +// ``` +// fn main() { +// let _ = 90 +$0 2; +// } +// ``` +// -> +// ``` +// fn main() { +// let _ = 2 + 90; +// } +// ``` +pub(crate) fn flip_binexpr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let expr = ctx.find_node_at_offset::()?; + let lhs = expr.lhs()?.syntax().clone(); + let rhs = expr.rhs()?.syntax().clone(); + let op_range = expr.op_token()?.text_range(); + // The assist should be applied only if the cursor is on the operator + let cursor_in_range = op_range.contains_range(ctx.selection_trimmed()); + if !cursor_in_range { + return None; + } + let action: FlipAction = expr.op_kind()?.into(); + // The assist should not be applied for certain operators + if let FlipAction::DontFlip = action { + return None; + } + + acc.add( + AssistId("flip_binexpr", AssistKind::RefactorRewrite), + "Flip binary expression", + op_range, + |edit| { + if let FlipAction::FlipAndReplaceOp(new_op) = action { + edit.replace(op_range, new_op); + } + edit.replace(lhs.text_range(), rhs.text()); + edit.replace(rhs.text_range(), lhs.text()); + }, + ) +} + +enum FlipAction { + // Flip the expression + Flip, + // Flip the expression and replace the operator with this string + FlipAndReplaceOp(&'static str), + // Do not flip the expression + DontFlip, +} + +impl From for FlipAction { + fn from(op_kind: ast::BinaryOp) -> Self { + match op_kind { + ast::BinaryOp::Assignment { .. } => FlipAction::DontFlip, + ast::BinaryOp::CmpOp(ast::CmpOp::Ord { ordering, strict }) => { + let rev_op = match (ordering, strict) { + (ast::Ordering::Less, true) => ">", + (ast::Ordering::Less, false) => ">=", + (ast::Ordering::Greater, true) => "<", + (ast::Ordering::Greater, false) => "<=", + }; + FlipAction::FlipAndReplaceOp(rev_op) + } + _ => FlipAction::Flip, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + #[test] + fn flip_binexpr_target_is_the_op() { + check_assist_target(flip_binexpr, "fn f() { let res = 1 ==$0 2; }", "==") + } + + #[test] + fn flip_binexpr_not_applicable_for_assignment() { + check_assist_not_applicable(flip_binexpr, "fn f() { let mut _x = 1; _x +=$0 2 }") + } + + #[test] + fn flip_binexpr_works_for_eq() { + check_assist(flip_binexpr, "fn f() { let res = 1 ==$0 2; }", "fn f() { let res = 2 == 1; }") + } + + #[test] + fn flip_binexpr_works_for_gt() { + check_assist(flip_binexpr, "fn f() { let res = 1 >$0 2; }", "fn f() { let res = 2 < 1; }") + } + + #[test] + fn flip_binexpr_works_for_lteq() { + check_assist(flip_binexpr, "fn f() { let res = 1 <=$0 2; }", "fn f() { let res = 2 >= 1; }") + } + + #[test] + fn flip_binexpr_works_for_complex_expr() { + check_assist( + flip_binexpr, + "fn f() { let res = (1 + 1) ==$0 (2 + 2); }", + "fn f() { let res = (2 + 2) == (1 + 1); }", + ) + } + + #[test] + fn flip_binexpr_works_inside_match() { + check_assist( + flip_binexpr, + r#" + fn dyn_eq(&self, other: &dyn Diagnostic) -> bool { + match other.downcast_ref::() { + None => false, + Some(it) => it ==$0 self, + } + } + "#, + r#" + fn dyn_eq(&self, other: &dyn Diagnostic) -> bool { + match other.downcast_ref::() { + None => false, + Some(it) => self == it, + } + } + "#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_comma.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_comma.rs new file mode 100644 index 000000000..f40f2713a --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_comma.rs @@ -0,0 +1,92 @@ +use syntax::{algo::non_trivia_sibling, Direction, SyntaxKind, T}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: flip_comma +// +// Flips two comma-separated items. +// +// ``` +// fn main() { +// ((1, 2),$0 (3, 4)); +// } +// ``` +// -> +// ``` +// fn main() { +// ((3, 4), (1, 2)); +// } +// ``` +pub(crate) fn flip_comma(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let comma = ctx.find_token_syntax_at_offset(T![,])?; + let prev = non_trivia_sibling(comma.clone().into(), Direction::Prev)?; + let next = non_trivia_sibling(comma.clone().into(), Direction::Next)?; + + // Don't apply a "flip" in case of a last comma + // that typically comes before punctuation + if next.kind().is_punct() { + return None; + } + + // Don't apply a "flip" inside the macro call + // since macro input are just mere tokens + if comma.parent_ancestors().any(|it| it.kind() == SyntaxKind::MACRO_CALL) { + return None; + } + + acc.add( + AssistId("flip_comma", AssistKind::RefactorRewrite), + "Flip comma", + comma.text_range(), + |edit| { + edit.replace(prev.text_range(), next.to_string()); + edit.replace(next.text_range(), prev.to_string()); + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + #[test] + fn flip_comma_works_for_function_parameters() { + check_assist( + flip_comma, + r#"fn foo(x: i32,$0 y: Result<(), ()>) {}"#, + r#"fn foo(y: Result<(), ()>, x: i32) {}"#, + ) + } + + #[test] + fn flip_comma_target() { + check_assist_target(flip_comma, r#"fn foo(x: i32,$0 y: Result<(), ()>) {}"#, ",") + } + + #[test] + fn flip_comma_before_punct() { + // See https://github.com/rust-lang/rust-analyzer/issues/1619 + // "Flip comma" assist shouldn't be applicable to the last comma in enum or struct + // declaration body. + check_assist_not_applicable(flip_comma, "pub enum Test { A,$0 }"); + check_assist_not_applicable(flip_comma, "pub struct Test { foo: usize,$0 }"); + } + + #[test] + fn flip_comma_works() { + check_assist( + flip_comma, + r#"fn main() {((1, 2),$0 (3, 4));}"#, + r#"fn main() {((3, 4), (1, 2));}"#, + ) + } + + #[test] + fn flip_comma_not_applicable_for_macro_input() { + // "Flip comma" assist shouldn't be applicable inside the macro call + // See https://github.com/rust-lang/rust-analyzer/issues/7693 + check_assist_not_applicable(flip_comma, r#"bar!(a,$0 b)"#); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_trait_bound.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_trait_bound.rs new file mode 100644 index 000000000..e3ae4970b --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_trait_bound.rs @@ -0,0 +1,121 @@ +use syntax::{ + algo::non_trivia_sibling, + ast::{self, AstNode}, + Direction, T, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: flip_trait_bound +// +// Flips two trait bounds. +// +// ``` +// fn foo() { } +// ``` +// -> +// ``` +// fn foo() { } +// ``` +pub(crate) fn flip_trait_bound(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + // We want to replicate the behavior of `flip_binexpr` by only suggesting + // the assist when the cursor is on a `+` + let plus = ctx.find_token_syntax_at_offset(T![+])?; + + // Make sure we're in a `TypeBoundList` + if ast::TypeBoundList::cast(plus.parent()?).is_none() { + return None; + } + + let (before, after) = ( + non_trivia_sibling(plus.clone().into(), Direction::Prev)?, + non_trivia_sibling(plus.clone().into(), Direction::Next)?, + ); + + let target = plus.text_range(); + acc.add( + AssistId("flip_trait_bound", AssistKind::RefactorRewrite), + "Flip trait bounds", + target, + |edit| { + edit.replace(before.text_range(), after.to_string()); + edit.replace(after.text_range(), before.to_string()); + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + #[test] + fn flip_trait_bound_assist_available() { + check_assist_target(flip_trait_bound, "struct S where T: A $0+ B + C { }", "+") + } + + #[test] + fn flip_trait_bound_not_applicable_for_single_trait_bound() { + check_assist_not_applicable(flip_trait_bound, "struct S where T: $0A { }") + } + + #[test] + fn flip_trait_bound_works_for_struct() { + check_assist( + flip_trait_bound, + "struct S where T: A $0+ B { }", + "struct S where T: B + A { }", + ) + } + + #[test] + fn flip_trait_bound_works_for_trait_impl() { + check_assist( + flip_trait_bound, + "impl X for S where T: A +$0 B { }", + "impl X for S where T: B + A { }", + ) + } + + #[test] + fn flip_trait_bound_works_for_fn() { + check_assist(flip_trait_bound, "fn f(t: T) { }", "fn f(t: T) { }") + } + + #[test] + fn flip_trait_bound_works_for_fn_where_clause() { + check_assist( + flip_trait_bound, + "fn f(t: T) where T: A +$0 B { }", + "fn f(t: T) where T: B + A { }", + ) + } + + #[test] + fn flip_trait_bound_works_for_lifetime() { + check_assist( + flip_trait_bound, + "fn f(t: T) where T: A $0+ 'static { }", + "fn f(t: T) where T: 'static + A { }", + ) + } + + #[test] + fn flip_trait_bound_works_for_complex_bounds() { + check_assist( + flip_trait_bound, + "struct S where T: A $0+ b_mod::B + C { }", + "struct S where T: b_mod::B + A + C { }", + ) + } + + #[test] + fn flip_trait_bound_works_for_long_bounds() { + check_assist( + flip_trait_bound, + "struct S where T: A + B + C + D + E + F +$0 G + H + I + J { }", + "struct S where T: A + B + C + D + E + G + F + H + I + J { }", + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_constant.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_constant.rs new file mode 100644 index 000000000..eaa6de73e --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_constant.rs @@ -0,0 +1,255 @@ +use crate::assist_context::{AssistContext, Assists}; +use hir::{HasVisibility, HirDisplay, Module}; +use ide_db::{ + assists::{AssistId, AssistKind}, + base_db::{FileId, Upcast}, + defs::{Definition, NameRefClass}, +}; +use syntax::{ + ast::{self, edit::IndentLevel, NameRef}, + AstNode, Direction, SyntaxKind, TextSize, +}; + +// Assist: generate_constant +// +// Generate a named constant. +// +// ``` +// struct S { i: usize } +// impl S { pub fn new(n: usize) {} } +// fn main() { +// let v = S::new(CAPA$0CITY); +// } +// ``` +// -> +// ``` +// struct S { i: usize } +// impl S { pub fn new(n: usize) {} } +// fn main() { +// const CAPACITY: usize = $0; +// let v = S::new(CAPACITY); +// } +// ``` + +pub(crate) fn generate_constant(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let constant_token = ctx.find_node_at_offset::()?; + if constant_token.to_string().chars().any(|it| !(it.is_uppercase() || it == '_')) { + cov_mark::hit!(not_constant_name); + return None; + } + if NameRefClass::classify(&ctx.sema, &constant_token).is_some() { + cov_mark::hit!(already_defined); + return None; + } + let expr = constant_token.syntax().ancestors().find_map(ast::Expr::cast)?; + let statement = expr.syntax().ancestors().find_map(ast::Stmt::cast)?; + let ty = ctx.sema.type_of_expr(&expr)?; + let scope = ctx.sema.scope(statement.syntax())?; + let constant_module = scope.module(); + let type_name = ty.original().display_source_code(ctx.db(), constant_module.into()).ok()?; + let target = statement.syntax().parent()?.text_range(); + let path = constant_token.syntax().ancestors().find_map(ast::Path::cast)?; + + let name_refs = path.segments().map(|s| s.name_ref()); + let mut outer_exists = false; + let mut not_exist_name_ref = Vec::new(); + let mut current_module = constant_module; + for name_ref in name_refs { + let name_ref_value = name_ref?; + let name_ref_class = NameRefClass::classify(&ctx.sema, &name_ref_value); + match name_ref_class { + Some(NameRefClass::Definition(Definition::Module(m))) => { + if !m.visibility(ctx.sema.db).is_visible_from(ctx.sema.db, constant_module.into()) { + return None; + } + outer_exists = true; + current_module = m; + } + Some(_) => { + return None; + } + None => { + not_exist_name_ref.push(name_ref_value); + } + } + } + let (offset, indent, file_id, post_string) = + target_data_for_generate_constant(ctx, current_module, constant_module).unwrap_or_else( + || { + let indent = IndentLevel::from_node(statement.syntax()); + (statement.syntax().text_range().start(), indent, None, format!("\n{}", indent)) + }, + ); + + let text = get_text_for_generate_constant(not_exist_name_ref, indent, outer_exists, type_name)?; + acc.add( + AssistId("generate_constant", AssistKind::QuickFix), + "Generate constant", + target, + |builder| { + if let Some(file_id) = file_id { + builder.edit_file(file_id); + } + builder.insert(offset, format!("{}{}", text, post_string)); + }, + ) +} + +fn get_text_for_generate_constant( + mut not_exist_name_ref: Vec, + indent: IndentLevel, + outer_exists: bool, + type_name: String, +) -> Option { + let constant_token = not_exist_name_ref.pop()?; + let vis = if not_exist_name_ref.len() == 0 && !outer_exists { "" } else { "\npub " }; + let mut text = format!("{}const {}: {} = $0;", vis, constant_token, type_name); + while let Some(name_ref) = not_exist_name_ref.pop() { + let vis = if not_exist_name_ref.len() == 0 && !outer_exists { "" } else { "\npub " }; + text = text.replace("\n", "\n "); + text = format!("{}mod {} {{{}\n}}", vis, name_ref.to_string(), text); + } + Some(text.replace("\n", &format!("\n{}", indent))) +} + +fn target_data_for_generate_constant( + ctx: &AssistContext<'_>, + current_module: Module, + constant_module: Module, +) -> Option<(TextSize, IndentLevel, Option, String)> { + if current_module == constant_module { + // insert in current file + return None; + } + let in_file_source = current_module.definition_source(ctx.sema.db); + let file_id = in_file_source.file_id.original_file(ctx.sema.db.upcast()); + match in_file_source.value { + hir::ModuleSource::Module(module_node) => { + let indent = IndentLevel::from_node(module_node.syntax()); + let l_curly_token = module_node.item_list()?.l_curly_token()?; + let offset = l_curly_token.text_range().end(); + + let siblings_has_newline = l_curly_token + .siblings_with_tokens(Direction::Next) + .find(|it| it.kind() == SyntaxKind::WHITESPACE && it.to_string().contains("\n")) + .is_some(); + let post_string = + if siblings_has_newline { format!("{}", indent) } else { format!("\n{}", indent) }; + Some((offset, indent + 1, Some(file_id), post_string)) + } + _ => Some((TextSize::from(0), 0.into(), Some(file_id), "\n".into())), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::{check_assist, check_assist_not_applicable}; + + #[test] + fn test_trivial() { + check_assist( + generate_constant, + r#"struct S { i: usize } +impl S { + pub fn new(n: usize) {} +} +fn main() { + let v = S::new(CAPA$0CITY); +}"#, + r#"struct S { i: usize } +impl S { + pub fn new(n: usize) {} +} +fn main() { + const CAPACITY: usize = $0; + let v = S::new(CAPACITY); +}"#, + ); + } + #[test] + fn test_wont_apply_when_defined() { + cov_mark::check!(already_defined); + check_assist_not_applicable( + generate_constant, + r#"struct S { i: usize } +impl S { + pub fn new(n: usize) {} +} +fn main() { + const CAPACITY: usize = 10; + let v = S::new(CAPAC$0ITY); +}"#, + ); + } + #[test] + fn test_wont_apply_when_maybe_not_constant() { + cov_mark::check!(not_constant_name); + check_assist_not_applicable( + generate_constant, + r#"struct S { i: usize } +impl S { + pub fn new(n: usize) {} +} +fn main() { + let v = S::new(capa$0city); +}"#, + ); + } + + #[test] + fn test_constant_with_path() { + check_assist( + generate_constant, + r#"mod foo {} +fn bar() -> i32 { + foo::A_CON$0STANT +}"#, + r#"mod foo { + pub const A_CONSTANT: i32 = $0; +} +fn bar() -> i32 { + foo::A_CONSTANT +}"#, + ); + } + + #[test] + fn test_constant_with_longer_path() { + check_assist( + generate_constant, + r#"mod foo { + pub mod goo {} +} +fn bar() -> i32 { + foo::goo::A_CON$0STANT +}"#, + r#"mod foo { + pub mod goo { + pub const A_CONSTANT: i32 = $0; + } +} +fn bar() -> i32 { + foo::goo::A_CONSTANT +}"#, + ); + } + + #[test] + fn test_constant_with_not_exist_longer_path() { + check_assist( + generate_constant, + r#"fn bar() -> i32 { + foo::goo::A_CON$0STANT +}"#, + r#"mod foo { + pub mod goo { + pub const A_CONSTANT: i32 = $0; + } +} +fn bar() -> i32 { + foo::goo::A_CONSTANT +}"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_default_from_enum_variant.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_default_from_enum_variant.rs new file mode 100644 index 000000000..5e9995a98 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_default_from_enum_variant.rs @@ -0,0 +1,179 @@ +use ide_db::{famous_defs::FamousDefs, RootDatabase}; +use syntax::ast::{self, AstNode, HasName}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: generate_default_from_enum_variant +// +// Adds a Default impl for an enum using a variant. +// +// ``` +// enum Version { +// Undefined, +// Minor$0, +// Major, +// } +// ``` +// -> +// ``` +// enum Version { +// Undefined, +// Minor, +// Major, +// } +// +// impl Default for Version { +// fn default() -> Self { +// Self::Minor +// } +// } +// ``` +pub(crate) fn generate_default_from_enum_variant( + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { + let variant = ctx.find_node_at_offset::()?; + let variant_name = variant.name()?; + let enum_name = variant.parent_enum().name()?; + if !matches!(variant.kind(), ast::StructKind::Unit) { + cov_mark::hit!(test_gen_default_on_non_unit_variant_not_implemented); + return None; + } + + if existing_default_impl(&ctx.sema, &variant).is_some() { + cov_mark::hit!(test_gen_default_impl_already_exists); + return None; + } + + let target = variant.syntax().text_range(); + acc.add( + AssistId("generate_default_from_enum_variant", AssistKind::Generate), + "Generate `Default` impl from this enum variant", + target, + |edit| { + let start_offset = variant.parent_enum().syntax().text_range().end(); + let buf = format!( + r#" + +impl Default for {0} {{ + fn default() -> Self {{ + Self::{1} + }} +}}"#, + enum_name, variant_name + ); + edit.insert(start_offset, buf); + }, + ) +} + +fn existing_default_impl( + sema: &'_ hir::Semantics<'_, RootDatabase>, + variant: &ast::Variant, +) -> Option<()> { + let variant = sema.to_def(variant)?; + let enum_ = variant.parent_enum(sema.db); + let krate = enum_.module(sema.db).krate(); + + let default_trait = FamousDefs(sema, krate).core_default_Default()?; + let enum_type = enum_.ty(sema.db); + + if enum_type.impls_trait(sema.db, default_trait, &[]) { + Some(()) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_generate_default_from_variant() { + check_assist( + generate_default_from_enum_variant, + r#" +//- minicore: default +enum Variant { + Undefined, + Minor$0, + Major, +} +"#, + r#" +enum Variant { + Undefined, + Minor, + Major, +} + +impl Default for Variant { + fn default() -> Self { + Self::Minor + } +} +"#, + ); + } + + #[test] + fn test_generate_default_already_implemented() { + cov_mark::check!(test_gen_default_impl_already_exists); + check_assist_not_applicable( + generate_default_from_enum_variant, + r#" +//- minicore: default +enum Variant { + Undefined, + Minor$0, + Major, +} + +impl Default for Variant { + fn default() -> Self { + Self::Minor + } +} +"#, + ); + } + + #[test] + fn test_add_from_impl_no_element() { + cov_mark::check!(test_gen_default_on_non_unit_variant_not_implemented); + check_assist_not_applicable( + generate_default_from_enum_variant, + r#" +//- minicore: default +enum Variant { + Undefined, + Minor(u32)$0, + Major, +} +"#, + ); + } + + #[test] + fn test_generate_default_from_variant_with_one_variant() { + check_assist( + generate_default_from_enum_variant, + r#" +//- minicore: default +enum Variant { Undefi$0ned } +"#, + r#" +enum Variant { Undefined } + +impl Default for Variant { + fn default() -> Self { + Self::Undefined + } +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_default_from_new.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_default_from_new.rs new file mode 100644 index 000000000..cbd33de19 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_default_from_new.rs @@ -0,0 +1,657 @@ +use ide_db::famous_defs::FamousDefs; +use itertools::Itertools; +use stdx::format_to; +use syntax::{ + ast::{self, HasGenericParams, HasName, HasTypeBounds, Impl}, + AstNode, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, +}; + +// Assist: generate_default_from_new +// +// Generates default implementation from new method. +// +// ``` +// struct Example { _inner: () } +// +// impl Example { +// pub fn n$0ew() -> Self { +// Self { _inner: () } +// } +// } +// ``` +// -> +// ``` +// struct Example { _inner: () } +// +// impl Example { +// pub fn new() -> Self { +// Self { _inner: () } +// } +// } +// +// impl Default for Example { +// fn default() -> Self { +// Self::new() +// } +// } +// ``` +pub(crate) fn generate_default_from_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let fn_node = ctx.find_node_at_offset::()?; + let fn_name = fn_node.name()?; + + if fn_name.text() != "new" { + cov_mark::hit!(other_function_than_new); + return None; + } + + if fn_node.param_list()?.params().next().is_some() { + cov_mark::hit!(new_function_with_parameters); + return None; + } + + let impl_ = fn_node.syntax().ancestors().into_iter().find_map(ast::Impl::cast)?; + if is_default_implemented(ctx, &impl_) { + cov_mark::hit!(default_block_is_already_present); + cov_mark::hit!(struct_in_module_with_default); + return None; + } + + let insert_location = impl_.syntax().text_range(); + + acc.add( + AssistId("generate_default_from_new", crate::AssistKind::Generate), + "Generate a Default impl from a new fn", + insert_location, + move |builder| { + let default_code = " fn default() -> Self { + Self::new() + }"; + let code = generate_trait_impl_text_from_impl(&impl_, "Default", default_code); + builder.insert(insert_location.end(), code); + }, + ) +} + +fn generate_trait_impl_text_from_impl(impl_: &ast::Impl, trait_text: &str, code: &str) -> String { + let generic_params = impl_.generic_param_list(); + let mut buf = String::with_capacity(code.len()); + buf.push_str("\n\n"); + buf.push_str("impl"); + + if let Some(generic_params) = &generic_params { + let lifetimes = generic_params.lifetime_params().map(|lt| format!("{}", lt.syntax())); + let toc_params = generic_params.type_or_const_params().map(|toc_param| match toc_param { + ast::TypeOrConstParam::Type(type_param) => { + let mut buf = String::new(); + if let Some(it) = type_param.name() { + format_to!(buf, "{}", it.syntax()); + } + if let Some(it) = type_param.colon_token() { + format_to!(buf, "{} ", it); + } + if let Some(it) = type_param.type_bound_list() { + format_to!(buf, "{}", it.syntax()); + } + buf + } + ast::TypeOrConstParam::Const(const_param) => const_param.syntax().to_string(), + }); + let generics = lifetimes.chain(toc_params).format(", "); + format_to!(buf, "<{}>", generics); + } + + buf.push(' '); + buf.push_str(trait_text); + buf.push_str(" for "); + buf.push_str(&impl_.self_ty().unwrap().syntax().text().to_string()); + + match impl_.where_clause() { + Some(where_clause) => { + format_to!(buf, "\n{}\n{{\n{}\n}}", where_clause, code); + } + None => { + format_to!(buf, " {{\n{}\n}}", code); + } + } + + buf +} + +fn is_default_implemented(ctx: &AssistContext<'_>, impl_: &Impl) -> bool { + let db = ctx.sema.db; + let impl_ = ctx.sema.to_def(impl_); + let impl_def = match impl_ { + Some(value) => value, + None => return false, + }; + + let ty = impl_def.self_ty(db); + let krate = impl_def.module(db).krate(); + let default = FamousDefs(&ctx.sema, krate).core_default_Default(); + let default_trait = match default { + Some(value) => value, + None => return false, + }; + + ty.impls_trait(db, default_trait, &[]) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn generate_default() { + check_assist( + generate_default_from_new, + r#" +//- minicore: default +struct Example { _inner: () } + +impl Example { + pub fn ne$0w() -> Self { + Self { _inner: () } + } +} + +fn main() {} +"#, + r#" +struct Example { _inner: () } + +impl Example { + pub fn new() -> Self { + Self { _inner: () } + } +} + +impl Default for Example { + fn default() -> Self { + Self::new() + } +} + +fn main() {} +"#, + ); + } + + #[test] + fn generate_default2() { + check_assist( + generate_default_from_new, + r#" +//- minicore: default +struct Test { value: u32 } + +impl Test { + pub fn ne$0w() -> Self { + Self { value: 0 } + } +} +"#, + r#" +struct Test { value: u32 } + +impl Test { + pub fn new() -> Self { + Self { value: 0 } + } +} + +impl Default for Test { + fn default() -> Self { + Self::new() + } +} +"#, + ); + } + + #[test] + fn new_function_with_generic() { + check_assist( + generate_default_from_new, + r#" +//- minicore: default +pub struct Foo { + _bar: *mut T, +} + +impl Foo { + pub fn ne$0w() -> Self { + unimplemented!() + } +} +"#, + r#" +pub struct Foo { + _bar: *mut T, +} + +impl Foo { + pub fn new() -> Self { + unimplemented!() + } +} + +impl Default for Foo { + fn default() -> Self { + Self::new() + } +} +"#, + ); + } + + #[test] + fn new_function_with_generics() { + check_assist( + generate_default_from_new, + r#" +//- minicore: default +pub struct Foo { + _tars: *mut T, + _bar: *mut B, +} + +impl Foo { + pub fn ne$0w() -> Self { + unimplemented!() + } +} +"#, + r#" +pub struct Foo { + _tars: *mut T, + _bar: *mut B, +} + +impl Foo { + pub fn new() -> Self { + unimplemented!() + } +} + +impl Default for Foo { + fn default() -> Self { + Self::new() + } +} +"#, + ); + } + + #[test] + fn new_function_with_generic_and_bound() { + check_assist( + generate_default_from_new, + r#" +//- minicore: default +pub struct Foo { + t: T, +} + +impl> Foo { + pub fn ne$0w() -> Self { + Foo { t: 0.into() } + } +} +"#, + r#" +pub struct Foo { + t: T, +} + +impl> Foo { + pub fn new() -> Self { + Foo { t: 0.into() } + } +} + +impl> Default for Foo { + fn default() -> Self { + Self::new() + } +} +"#, + ); + } + + #[test] + fn new_function_with_generics_and_bounds() { + check_assist( + generate_default_from_new, + r#" +//- minicore: default +pub struct Foo { + _tars: T, + _bar: B, +} + +impl, B: From> Foo { + pub fn ne$0w() -> Self { + unimplemented!() + } +} +"#, + r#" +pub struct Foo { + _tars: T, + _bar: B, +} + +impl, B: From> Foo { + pub fn new() -> Self { + unimplemented!() + } +} + +impl, B: From> Default for Foo { + fn default() -> Self { + Self::new() + } +} +"#, + ); + } + + #[test] + fn new_function_with_generic_and_where() { + check_assist( + generate_default_from_new, + r#" +//- minicore: default +pub struct Foo { + t: T, +} + +impl> Foo +where + Option: Debug +{ + pub fn ne$0w() -> Self { + Foo { t: 0.into() } + } +} +"#, + r#" +pub struct Foo { + t: T, +} + +impl> Foo +where + Option: Debug +{ + pub fn new() -> Self { + Foo { t: 0.into() } + } +} + +impl> Default for Foo +where + Option: Debug +{ + fn default() -> Self { + Self::new() + } +} +"#, + ); + } + + #[test] + fn new_function_with_generics_and_wheres() { + check_assist( + generate_default_from_new, + r#" +//- minicore: default +pub struct Foo { + _tars: T, + _bar: B, +} + +impl, B: From> Foo +where + Option: Debug, Option: Debug, +{ + pub fn ne$0w() -> Self { + unimplemented!() + } +} +"#, + r#" +pub struct Foo { + _tars: T, + _bar: B, +} + +impl, B: From> Foo +where + Option: Debug, Option: Debug, +{ + pub fn new() -> Self { + unimplemented!() + } +} + +impl, B: From> Default for Foo +where + Option: Debug, Option: Debug, +{ + fn default() -> Self { + Self::new() + } +} +"#, + ); + } + + #[test] + fn new_function_with_parameters() { + cov_mark::check!(new_function_with_parameters); + check_assist_not_applicable( + generate_default_from_new, + r#" +//- minicore: default +struct Example { _inner: () } + +impl Example { + pub fn $0new(value: ()) -> Self { + Self { _inner: value } + } +} +"#, + ); + } + + #[test] + fn other_function_than_new() { + cov_mark::check!(other_function_than_new); + check_assist_not_applicable( + generate_default_from_new, + r#" +struct Example { _inner: () } + +impl Example { + pub fn a$0dd() -> Self { + Self { _inner: () } + } +} + +"#, + ); + } + + #[test] + fn default_block_is_already_present() { + cov_mark::check!(default_block_is_already_present); + check_assist_not_applicable( + generate_default_from_new, + r#" +//- minicore: default +struct Example { _inner: () } + +impl Example { + pub fn n$0ew() -> Self { + Self { _inner: () } + } +} + +impl Default for Example { + fn default() -> Self { + Self::new() + } +} +"#, + ); + } + + #[test] + fn standalone_new_function() { + check_assist_not_applicable( + generate_default_from_new, + r#" +fn n$0ew() -> u32 { + 0 +} +"#, + ); + } + + #[test] + fn multiple_struct_blocks() { + check_assist( + generate_default_from_new, + r#" +//- minicore: default +struct Example { _inner: () } +struct Test { value: u32 } + +impl Example { + pub fn new$0() -> Self { + Self { _inner: () } + } +} +"#, + r#" +struct Example { _inner: () } +struct Test { value: u32 } + +impl Example { + pub fn new() -> Self { + Self { _inner: () } + } +} + +impl Default for Example { + fn default() -> Self { + Self::new() + } +} +"#, + ); + } + + #[test] + fn when_struct_is_after_impl() { + check_assist( + generate_default_from_new, + r#" +//- minicore: default +impl Example { + pub fn $0new() -> Self { + Self { _inner: () } + } +} + +struct Example { _inner: () } +"#, + r#" +impl Example { + pub fn new() -> Self { + Self { _inner: () } + } +} + +impl Default for Example { + fn default() -> Self { + Self::new() + } +} + +struct Example { _inner: () } +"#, + ); + } + + #[test] + fn struct_in_module() { + check_assist( + generate_default_from_new, + r#" +//- minicore: default +mod test { + struct Example { _inner: () } + + impl Example { + pub fn n$0ew() -> Self { + Self { _inner: () } + } + } +} +"#, + r#" +mod test { + struct Example { _inner: () } + + impl Example { + pub fn new() -> Self { + Self { _inner: () } + } + } + +impl Default for Example { + fn default() -> Self { + Self::new() + } +} +} +"#, + ); + } + + #[test] + fn struct_in_module_with_default() { + cov_mark::check!(struct_in_module_with_default); + check_assist_not_applicable( + generate_default_from_new, + r#" +//- minicore: default +mod test { + struct Example { _inner: () } + + impl Example { + pub fn n$0ew() -> Self { + Self { _inner: () } + } + } + + impl Default for Example { + fn default() -> Self { + Self::new() + } + } +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_delegate_methods.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_delegate_methods.rs new file mode 100644 index 000000000..85b193663 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_delegate_methods.rs @@ -0,0 +1,334 @@ +use hir::{self, HasCrate, HasSource, HasVisibility}; +use syntax::ast::{self, make, AstNode, HasGenericParams, HasName, HasVisibility as _}; + +use crate::{ + utils::{convert_param_list_to_arg_list, find_struct_impl, render_snippet, Cursor}, + AssistContext, AssistId, AssistKind, Assists, GroupLabel, +}; +use syntax::ast::edit::AstNodeEdit; + +// Assist: generate_delegate_methods +// +// Generate delegate methods. +// +// ``` +// struct Age(u8); +// impl Age { +// fn age(&self) -> u8 { +// self.0 +// } +// } +// +// struct Person { +// ag$0e: Age, +// } +// ``` +// -> +// ``` +// struct Age(u8); +// impl Age { +// fn age(&self) -> u8 { +// self.0 +// } +// } +// +// struct Person { +// age: Age, +// } +// +// impl Person { +// $0fn age(&self) -> u8 { +// self.age.age() +// } +// } +// ``` +pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let strukt = ctx.find_node_at_offset::()?; + let strukt_name = strukt.name()?; + let current_module = ctx.sema.scope(strukt.syntax())?.module(); + + let (field_name, field_ty, target) = match ctx.find_node_at_offset::() { + Some(field) => { + let field_name = field.name()?; + let field_ty = field.ty()?; + (format!("{}", field_name), field_ty, field.syntax().text_range()) + } + None => { + let field = ctx.find_node_at_offset::()?; + let field_list = ctx.find_node_at_offset::()?; + let field_list_index = field_list.fields().position(|it| it == field)?; + let field_ty = field.ty()?; + (format!("{}", field_list_index), field_ty, field.syntax().text_range()) + } + }; + + let sema_field_ty = ctx.sema.resolve_type(&field_ty)?; + let krate = sema_field_ty.krate(ctx.db()); + let mut methods = vec![]; + sema_field_ty.iterate_assoc_items(ctx.db(), krate, |item| { + if let hir::AssocItem::Function(f) = item { + if f.self_param(ctx.db()).is_some() && f.is_visible_from(ctx.db(), current_module) { + methods.push(f) + } + } + Option::<()>::None + }); + + for method in methods { + let adt = ast::Adt::Struct(strukt.clone()); + let name = method.name(ctx.db()).to_string(); + let impl_def = find_struct_impl(ctx, &adt, &name).flatten(); + acc.add_group( + &GroupLabel("Generate delegate methods…".to_owned()), + AssistId("generate_delegate_methods", AssistKind::Generate), + format!("Generate delegate for `{}.{}()`", field_name, method.name(ctx.db())), + target, + |builder| { + // Create the function + let method_source = match method.source(ctx.db()) { + Some(source) => source.value, + None => return, + }; + let method_name = method.name(ctx.db()); + let vis = method_source.visibility(); + let name = make::name(&method.name(ctx.db()).to_string()); + let params = + method_source.param_list().unwrap_or_else(|| make::param_list(None, [])); + let type_params = method_source.generic_param_list(); + let arg_list = match method_source.param_list() { + Some(list) => convert_param_list_to_arg_list(list), + None => make::arg_list([]), + }; + let tail_expr = make::expr_method_call( + make::ext::field_from_idents(["self", &field_name]).unwrap(), // This unwrap is ok because we have at least 1 arg in the list + make::name_ref(&method_name.to_string()), + arg_list, + ); + let body = make::block_expr([], Some(tail_expr)); + let ret_type = method_source.ret_type(); + let is_async = method_source.async_token().is_some(); + let f = make::fn_(vis, name, type_params, params, body, ret_type, is_async) + .indent(ast::edit::IndentLevel(1)) + .clone_for_update(); + + let cursor = Cursor::Before(f.syntax()); + + // Create or update an impl block, attach the function to it, + // then insert into our code. + match impl_def { + Some(impl_def) => { + // Remember where in our source our `impl` block lives. + let impl_def = impl_def.clone_for_update(); + let old_range = impl_def.syntax().text_range(); + + // Attach the function to the impl block + let assoc_items = impl_def.get_or_create_assoc_item_list(); + assoc_items.add_item(f.clone().into()); + + // Update the impl block. + match ctx.config.snippet_cap { + Some(cap) => { + let snippet = render_snippet(cap, impl_def.syntax(), cursor); + builder.replace_snippet(cap, old_range, snippet); + } + None => { + builder.replace(old_range, impl_def.syntax().to_string()); + } + } + } + None => { + // Attach the function to the impl block + let name = &strukt_name.to_string(); + let params = strukt.generic_param_list(); + let ty_params = params.clone(); + let impl_def = make::impl_(make::ext::ident_path(name), params, ty_params) + .clone_for_update(); + let assoc_items = impl_def.get_or_create_assoc_item_list(); + assoc_items.add_item(f.clone().into()); + + // Insert the impl block. + match ctx.config.snippet_cap { + Some(cap) => { + let offset = strukt.syntax().text_range().end(); + let snippet = render_snippet(cap, impl_def.syntax(), cursor); + let snippet = format!("\n\n{}", snippet); + builder.insert_snippet(cap, offset, snippet); + } + None => { + let offset = strukt.syntax().text_range().end(); + let snippet = format!("\n\n{}", impl_def.syntax()); + builder.insert(offset, snippet); + } + } + } + } + }, + )?; + } + Some(()) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_generate_delegate_create_impl_block() { + check_assist( + generate_delegate_methods, + r#" +struct Age(u8); +impl Age { + fn age(&self) -> u8 { + self.0 + } +} + +struct Person { + ag$0e: Age, +}"#, + r#" +struct Age(u8); +impl Age { + fn age(&self) -> u8 { + self.0 + } +} + +struct Person { + age: Age, +} + +impl Person { + $0fn age(&self) -> u8 { + self.age.age() + } +}"#, + ); + } + + #[test] + fn test_generate_delegate_update_impl_block() { + check_assist( + generate_delegate_methods, + r#" +struct Age(u8); +impl Age { + fn age(&self) -> u8 { + self.0 + } +} + +struct Person { + ag$0e: Age, +} + +impl Person {}"#, + r#" +struct Age(u8); +impl Age { + fn age(&self) -> u8 { + self.0 + } +} + +struct Person { + age: Age, +} + +impl Person { + $0fn age(&self) -> u8 { + self.age.age() + } +}"#, + ); + } + + #[test] + fn test_generate_delegate_tuple_struct() { + check_assist( + generate_delegate_methods, + r#" +struct Age(u8); +impl Age { + fn age(&self) -> u8 { + self.0 + } +} + +struct Person(A$0ge);"#, + r#" +struct Age(u8); +impl Age { + fn age(&self) -> u8 { + self.0 + } +} + +struct Person(Age); + +impl Person { + $0fn age(&self) -> u8 { + self.0.age() + } +}"#, + ); + } + + #[test] + fn test_generate_delegate_enable_all_attributes() { + check_assist( + generate_delegate_methods, + r#" +struct Age(T); +impl Age { + pub(crate) async fn age(&'a mut self, ty: T, arg: J) -> T { + self.0 + } +} + +struct Person { + ag$0e: Age, +}"#, + r#" +struct Age(T); +impl Age { + pub(crate) async fn age(&'a mut self, ty: T, arg: J) -> T { + self.0 + } +} + +struct Person { + age: Age, +} + +impl Person { + $0pub(crate) async fn age(&'a mut self, ty: T, arg: J) -> T { + self.age.age(ty, arg) + } +}"#, + ); + } + + #[test] + fn test_generate_delegate_visibility() { + check_assist_not_applicable( + generate_delegate_methods, + r#" +mod m { + pub struct Age(u8); + impl Age { + fn age(&self) -> u8 { + self.0 + } + } +} + +struct Person { + ag$0e: m::Age, +}"#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_deref.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_deref.rs new file mode 100644 index 000000000..b9637ee8d --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_deref.rs @@ -0,0 +1,343 @@ +use std::fmt::Display; + +use hir::{ModPath, ModuleDef}; +use ide_db::{famous_defs::FamousDefs, RootDatabase}; +use syntax::{ + ast::{self, HasName}, + AstNode, SyntaxNode, +}; + +use crate::{ + assist_context::{AssistBuilder, AssistContext, Assists}, + utils::generate_trait_impl_text, + AssistId, AssistKind, +}; + +// Assist: generate_deref +// +// Generate `Deref` impl using the given struct field. +// +// ``` +// # //- minicore: deref, deref_mut +// struct A; +// struct B { +// $0a: A +// } +// ``` +// -> +// ``` +// struct A; +// struct B { +// a: A +// } +// +// impl core::ops::Deref for B { +// type Target = A; +// +// fn deref(&self) -> &Self::Target { +// &self.a +// } +// } +// ``` +pub(crate) fn generate_deref(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + generate_record_deref(acc, ctx).or_else(|| generate_tuple_deref(acc, ctx)) +} + +fn generate_record_deref(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let strukt = ctx.find_node_at_offset::()?; + let field = ctx.find_node_at_offset::()?; + + let deref_type_to_generate = match existing_deref_impl(&ctx.sema, &strukt) { + None => DerefType::Deref, + Some(DerefType::Deref) => DerefType::DerefMut, + Some(DerefType::DerefMut) => { + cov_mark::hit!(test_add_record_deref_impl_already_exists); + return None; + } + }; + + let module = ctx.sema.to_def(&strukt)?.module(ctx.db()); + let trait_ = deref_type_to_generate.to_trait(&ctx.sema, module.krate())?; + let trait_path = module.find_use_path(ctx.db(), ModuleDef::Trait(trait_))?; + + let field_type = field.ty()?; + let field_name = field.name()?; + let target = field.syntax().text_range(); + acc.add( + AssistId("generate_deref", AssistKind::Generate), + format!("Generate `{:?}` impl using `{}`", deref_type_to_generate, field_name), + target, + |edit| { + generate_edit( + edit, + strukt, + field_type.syntax(), + field_name.syntax(), + deref_type_to_generate, + trait_path, + ) + }, + ) +} + +fn generate_tuple_deref(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let strukt = ctx.find_node_at_offset::()?; + let field = ctx.find_node_at_offset::()?; + let field_list = ctx.find_node_at_offset::()?; + let field_list_index = + field_list.syntax().children().into_iter().position(|s| &s == field.syntax())?; + + let deref_type_to_generate = match existing_deref_impl(&ctx.sema, &strukt) { + None => DerefType::Deref, + Some(DerefType::Deref) => DerefType::DerefMut, + Some(DerefType::DerefMut) => { + cov_mark::hit!(test_add_field_deref_impl_already_exists); + return None; + } + }; + + let module = ctx.sema.to_def(&strukt)?.module(ctx.db()); + let trait_ = deref_type_to_generate.to_trait(&ctx.sema, module.krate())?; + let trait_path = module.find_use_path(ctx.db(), ModuleDef::Trait(trait_))?; + + let field_type = field.ty()?; + let target = field.syntax().text_range(); + acc.add( + AssistId("generate_deref", AssistKind::Generate), + format!("Generate `{:?}` impl using `{}`", deref_type_to_generate, field.syntax()), + target, + |edit| { + generate_edit( + edit, + strukt, + field_type.syntax(), + field_list_index, + deref_type_to_generate, + trait_path, + ) + }, + ) +} + +fn generate_edit( + edit: &mut AssistBuilder, + strukt: ast::Struct, + field_type_syntax: &SyntaxNode, + field_name: impl Display, + deref_type: DerefType, + trait_path: ModPath, +) { + let start_offset = strukt.syntax().text_range().end(); + let impl_code = match deref_type { + DerefType::Deref => format!( + r#" type Target = {0}; + + fn deref(&self) -> &Self::Target {{ + &self.{1} + }}"#, + field_type_syntax, field_name + ), + DerefType::DerefMut => format!( + r#" fn deref_mut(&mut self) -> &mut Self::Target {{ + &mut self.{} + }}"#, + field_name + ), + }; + let strukt_adt = ast::Adt::Struct(strukt); + let deref_impl = generate_trait_impl_text(&strukt_adt, &trait_path.to_string(), &impl_code); + edit.insert(start_offset, deref_impl); +} + +fn existing_deref_impl( + sema: &hir::Semantics<'_, RootDatabase>, + strukt: &ast::Struct, +) -> Option { + let strukt = sema.to_def(strukt)?; + let krate = strukt.module(sema.db).krate(); + + let deref_trait = FamousDefs(sema, krate).core_ops_Deref()?; + let deref_mut_trait = FamousDefs(sema, krate).core_ops_DerefMut()?; + let strukt_type = strukt.ty(sema.db); + + if strukt_type.impls_trait(sema.db, deref_trait, &[]) { + if strukt_type.impls_trait(sema.db, deref_mut_trait, &[]) { + Some(DerefType::DerefMut) + } else { + Some(DerefType::Deref) + } + } else { + None + } +} + +#[derive(Debug)] +enum DerefType { + Deref, + DerefMut, +} + +impl DerefType { + fn to_trait( + &self, + sema: &hir::Semantics<'_, RootDatabase>, + krate: hir::Crate, + ) -> Option { + match self { + DerefType::Deref => FamousDefs(sema, krate).core_ops_Deref(), + DerefType::DerefMut => FamousDefs(sema, krate).core_ops_DerefMut(), + } + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_generate_record_deref() { + check_assist( + generate_deref, + r#" +//- minicore: deref +struct A { } +struct B { $0a: A }"#, + r#" +struct A { } +struct B { a: A } + +impl core::ops::Deref for B { + type Target = A; + + fn deref(&self) -> &Self::Target { + &self.a + } +}"#, + ); + } + + #[test] + fn test_generate_record_deref_short_path() { + check_assist( + generate_deref, + r#" +//- minicore: deref +use core::ops::Deref; +struct A { } +struct B { $0a: A }"#, + r#" +use core::ops::Deref; +struct A { } +struct B { a: A } + +impl Deref for B { + type Target = A; + + fn deref(&self) -> &Self::Target { + &self.a + } +}"#, + ); + } + + #[test] + fn test_generate_field_deref_idx_0() { + check_assist( + generate_deref, + r#" +//- minicore: deref +struct A { } +struct B($0A);"#, + r#" +struct A { } +struct B(A); + +impl core::ops::Deref for B { + type Target = A; + + fn deref(&self) -> &Self::Target { + &self.0 + } +}"#, + ); + } + #[test] + fn test_generate_field_deref_idx_1() { + check_assist( + generate_deref, + r#" +//- minicore: deref +struct A { } +struct B(u8, $0A);"#, + r#" +struct A { } +struct B(u8, A); + +impl core::ops::Deref for B { + type Target = A; + + fn deref(&self) -> &Self::Target { + &self.1 + } +}"#, + ); + } + + #[test] + fn test_generates_derefmut_when_deref_present() { + check_assist( + generate_deref, + r#" +//- minicore: deref, deref_mut +struct B { $0a: u8 } + +impl core::ops::Deref for B {} +"#, + r#" +struct B { a: u8 } + +impl core::ops::DerefMut for B { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.a + } +} + +impl core::ops::Deref for B {} +"#, + ); + } + + #[test] + fn test_generate_record_deref_not_applicable_if_already_impl() { + cov_mark::check!(test_add_record_deref_impl_already_exists); + check_assist_not_applicable( + generate_deref, + r#" +//- minicore: deref, deref_mut +struct A { } +struct B { $0a: A } + +impl core::ops::Deref for B {} +impl core::ops::DerefMut for B {} +"#, + ) + } + + #[test] + fn test_generate_field_deref_not_applicable_if_already_impl() { + cov_mark::check!(test_add_field_deref_impl_already_exists); + check_assist_not_applicable( + generate_deref, + r#" +//- minicore: deref, deref_mut +struct A { } +struct B($0A) + +impl core::ops::Deref for B {} +impl core::ops::DerefMut for B {} +"#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_derive.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_derive.rs new file mode 100644 index 000000000..339245b94 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_derive.rs @@ -0,0 +1,132 @@ +use syntax::{ + ast::{self, AstNode, HasAttrs}, + SyntaxKind::{COMMENT, WHITESPACE}, + TextSize, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: generate_derive +// +// Adds a new `#[derive()]` clause to a struct or enum. +// +// ``` +// struct Point { +// x: u32, +// y: u32,$0 +// } +// ``` +// -> +// ``` +// #[derive($0)] +// struct Point { +// x: u32, +// y: u32, +// } +// ``` +pub(crate) fn generate_derive(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let cap = ctx.config.snippet_cap?; + let nominal = ctx.find_node_at_offset::()?; + let node_start = derive_insertion_offset(&nominal)?; + let target = nominal.syntax().text_range(); + acc.add( + AssistId("generate_derive", AssistKind::Generate), + "Add `#[derive]`", + target, + |builder| { + let derive_attr = nominal + .attrs() + .filter_map(|x| x.as_simple_call()) + .filter(|(name, _arg)| name == "derive") + .map(|(_name, arg)| arg) + .next(); + match derive_attr { + None => { + builder.insert_snippet(cap, node_start, "#[derive($0)]\n"); + } + Some(tt) => { + // Just move the cursor. + builder.insert_snippet( + cap, + tt.syntax().text_range().end() - TextSize::of(')'), + "$0", + ) + } + }; + }, + ) +} + +// Insert `derive` after doc comments. +fn derive_insertion_offset(nominal: &ast::Adt) -> Option { + let non_ws_child = nominal + .syntax() + .children_with_tokens() + .find(|it| it.kind() != COMMENT && it.kind() != WHITESPACE)?; + Some(non_ws_child.text_range().start()) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_target}; + + use super::*; + + #[test] + fn add_derive_new() { + check_assist( + generate_derive, + "struct Foo { a: i32, $0}", + "#[derive($0)]\nstruct Foo { a: i32, }", + ); + check_assist( + generate_derive, + "struct Foo { $0 a: i32, }", + "#[derive($0)]\nstruct Foo { a: i32, }", + ); + } + + #[test] + fn add_derive_existing() { + check_assist( + generate_derive, + "#[derive(Clone)]\nstruct Foo { a: i32$0, }", + "#[derive(Clone$0)]\nstruct Foo { a: i32, }", + ); + } + + #[test] + fn add_derive_new_with_doc_comment() { + check_assist( + generate_derive, + " +/// `Foo` is a pretty important struct. +/// It does stuff. +struct Foo { a: i32$0, } + ", + " +/// `Foo` is a pretty important struct. +/// It does stuff. +#[derive($0)] +struct Foo { a: i32, } + ", + ); + } + + #[test] + fn add_derive_target() { + check_assist_target( + generate_derive, + " +struct SomeThingIrrelevant; +/// `Foo` is a pretty important struct. +/// It does stuff. +struct Foo { a: i32$0, } +struct EvenMoreIrrelevant; + ", + "/// `Foo` is a pretty important struct. +/// It does stuff. +struct Foo { a: i32, }", + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_documentation_template.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_documentation_template.rs new file mode 100644 index 000000000..c91141f8e --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_documentation_template.rs @@ -0,0 +1,1328 @@ +use hir::{AsAssocItem, HasVisibility, ModuleDef, Visibility}; +use ide_db::assists::{AssistId, AssistKind}; +use itertools::Itertools; +use stdx::{format_to, to_lower_snake_case}; +use syntax::{ + algo::skip_whitespace_token, + ast::{self, edit::IndentLevel, HasDocComments, HasName}, + match_ast, AstNode, AstToken, +}; + +use crate::assist_context::{AssistContext, Assists}; + +// Assist: generate_documentation_template +// +// Adds a documentation template above a function definition / declaration. +// +// ``` +// pub struct S; +// impl S { +// pub unsafe fn set_len$0(&mut self, len: usize) -> Result<(), std::io::Error> { +// /* ... */ +// } +// } +// ``` +// -> +// ``` +// pub struct S; +// impl S { +// /// Sets the length of this [`S`]. +// /// +// /// # Errors +// /// +// /// This function will return an error if . +// /// +// /// # Safety +// /// +// /// . +// pub unsafe fn set_len(&mut self, len: usize) -> Result<(), std::io::Error> { +// /* ... */ +// } +// } +// ``` +pub(crate) fn generate_documentation_template( + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { + let name = ctx.find_node_at_offset::()?; + let ast_func = name.syntax().parent().and_then(ast::Fn::cast)?; + if is_in_trait_impl(&ast_func, ctx) || ast_func.doc_comments().next().is_some() { + return None; + } + + let parent_syntax = ast_func.syntax(); + let text_range = parent_syntax.text_range(); + let indent_level = IndentLevel::from_node(parent_syntax); + + acc.add( + AssistId("generate_documentation_template", AssistKind::Generate), + "Generate a documentation template", + text_range, + |builder| { + // Introduction / short function description before the sections + let mut doc_lines = vec![introduction_builder(&ast_func, ctx).unwrap_or(".".into())]; + // Then come the sections + for section_builder in [panics_builder, errors_builder, safety_builder] { + if let Some(mut lines) = section_builder(&ast_func) { + doc_lines.push("".into()); + doc_lines.append(&mut lines); + } + } + builder.insert(text_range.start(), documentation_from_lines(doc_lines, indent_level)); + }, + ) +} + +// Assist: generate_doc_example +// +// Generates a rustdoc example when editing an item's documentation. +// +// ``` +// /// Adds two numbers.$0 +// pub fn add(a: i32, b: i32) -> i32 { a + b } +// ``` +// -> +// ``` +// /// Adds two numbers. +// /// +// /// # Examples +// /// +// /// ``` +// /// use test::add; +// /// +// /// assert_eq!(add(a, b), ); +// /// ``` +// pub fn add(a: i32, b: i32) -> i32 { a + b } +// ``` +pub(crate) fn generate_doc_example(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let tok: ast::Comment = ctx.find_token_at_offset()?; + let node = tok.syntax().parent()?; + let last_doc_token = + ast::AnyHasDocComments::cast(node.clone())?.doc_comments().last()?.syntax().clone(); + let next_token = skip_whitespace_token(last_doc_token.next_token()?, syntax::Direction::Next)?; + + let example = match_ast! { + match node { + ast::Fn(it) => make_example_for_fn(&it, ctx)?, + _ => return None, + } + }; + + let mut lines = string_vec_from(&["", "# Examples", "", "```"]); + lines.extend(example.lines().map(String::from)); + lines.push("```".into()); + let indent_level = IndentLevel::from_node(&node); + + acc.add( + AssistId("generate_doc_example", AssistKind::Generate), + "Generate a documentation example", + node.text_range(), + |builder| { + builder.insert( + next_token.text_range().start(), + documentation_from_lines(lines, indent_level), + ); + }, + ) +} + +fn make_example_for_fn(ast_func: &ast::Fn, ctx: &AssistContext<'_>) -> Option { + if !is_public(ast_func, ctx)? { + // Doctests for private items can't actually name the item, so they're pretty useless. + return None; + } + + if is_in_trait_def(ast_func, ctx) { + // This is not yet implemented. + return None; + } + + let mut example = String::new(); + + let is_unsafe = ast_func.unsafe_token().is_some(); + let param_list = ast_func.param_list()?; + let ref_mut_params = ref_mut_params(¶m_list); + let self_name = self_name(ast_func); + + format_to!(example, "use {};\n\n", build_path(ast_func, ctx)?); + if let Some(self_name) = &self_name { + if let Some(mtbl) = is_ref_mut_self(ast_func) { + let mtbl = if mtbl == true { " mut" } else { "" }; + format_to!(example, "let{} {} = ;\n", mtbl, self_name); + } + } + for param_name in &ref_mut_params { + format_to!(example, "let mut {} = ;\n", param_name); + } + // Call the function, check result + let function_call = function_call(ast_func, ¶m_list, self_name.as_deref(), is_unsafe)?; + if returns_a_value(ast_func, ctx) { + if count_parameters(¶m_list) < 3 { + format_to!(example, "assert_eq!({}, );\n", function_call); + } else { + format_to!(example, "let result = {};\n", function_call); + example.push_str("assert_eq!(result, );\n"); + } + } else { + format_to!(example, "{};\n", function_call); + } + // Check the mutated values + if is_ref_mut_self(ast_func) == Some(true) { + format_to!(example, "assert_eq!({}, );", self_name?); + } + for param_name in &ref_mut_params { + format_to!(example, "assert_eq!({}, );", param_name); + } + Some(example) +} + +fn introduction_builder(ast_func: &ast::Fn, ctx: &AssistContext<'_>) -> Option { + let hir_func = ctx.sema.to_def(ast_func)?; + let container = hir_func.as_assoc_item(ctx.db())?.container(ctx.db()); + if let hir::AssocItemContainer::Impl(imp) = container { + let ret_ty = hir_func.ret_type(ctx.db()); + let self_ty = imp.self_ty(ctx.db()); + let name = ast_func.name()?.to_string(); + let linkable_self_ty = self_type_without_lifetimes(ast_func); + let linkable_self_ty = linkable_self_ty.as_deref(); + + let intro_for_new = || { + let is_new = name == "new"; + if is_new && ret_ty == self_ty { + Some(format!("Creates a new [`{}`].", linkable_self_ty?)) + } else { + None + } + }; + + let intro_for_getter = || match ( + hir_func.self_param(ctx.sema.db), + &*hir_func.params_without_self(ctx.sema.db), + ) { + (Some(self_param), []) if self_param.access(ctx.sema.db) != hir::Access::Owned => { + if name.starts_with("as_") || name.starts_with("to_") || name == "get" { + return None; + } + let mut what = name.trim_end_matches("_mut").replace('_', " "); + if what == "len" { + what = "length".into() + } + let reference = if ret_ty.is_mutable_reference() { + " a mutable reference to" + } else if ret_ty.is_reference() { + " a reference to" + } else { + "" + }; + Some(format!("Returns{reference} the {what} of this [`{}`].", linkable_self_ty?)) + } + _ => None, + }; + + let intro_for_setter = || { + if !name.starts_with("set_") { + return None; + } + + let mut what = name.trim_start_matches("set_").replace('_', " "); + if what == "len" { + what = "length".into() + }; + Some(format!("Sets the {what} of this [`{}`].", linkable_self_ty?)) + }; + + if let Some(intro) = intro_for_new() { + return Some(intro); + } + if let Some(intro) = intro_for_getter() { + return Some(intro); + } + if let Some(intro) = intro_for_setter() { + return Some(intro); + } + } + None +} + +/// Builds an optional `# Panics` section +fn panics_builder(ast_func: &ast::Fn) -> Option> { + match can_panic(ast_func) { + Some(true) => Some(string_vec_from(&["# Panics", "", "Panics if ."])), + _ => None, + } +} + +/// Builds an optional `# Errors` section +fn errors_builder(ast_func: &ast::Fn) -> Option> { + match return_type(ast_func)?.to_string().contains("Result") { + true => Some(string_vec_from(&["# Errors", "", "This function will return an error if ."])), + false => None, + } +} + +/// Builds an optional `# Safety` section +fn safety_builder(ast_func: &ast::Fn) -> Option> { + let is_unsafe = ast_func.unsafe_token().is_some(); + match is_unsafe { + true => Some(string_vec_from(&["# Safety", "", "."])), + false => None, + } +} + +/// Checks if the function is public / exported +fn is_public(ast_func: &ast::Fn, ctx: &AssistContext<'_>) -> Option { + let hir_func = ctx.sema.to_def(ast_func)?; + Some( + hir_func.visibility(ctx.db()) == Visibility::Public + && all_parent_mods_public(&hir_func, ctx), + ) +} + +/// Checks that all parent modules of the function are public / exported +fn all_parent_mods_public(hir_func: &hir::Function, ctx: &AssistContext<'_>) -> bool { + let mut module = hir_func.module(ctx.db()); + loop { + if let Some(parent) = module.parent(ctx.db()) { + match ModuleDef::from(module).visibility(ctx.db()) { + Visibility::Public => module = parent, + _ => break false, + } + } else { + break true; + } + } +} + +/// Returns the name of the current crate +fn crate_name(ast_func: &ast::Fn, ctx: &AssistContext<'_>) -> Option { + let krate = ctx.sema.scope(ast_func.syntax())?.krate(); + Some(krate.display_name(ctx.db())?.to_string()) +} + +/// `None` if function without a body; some bool to guess if function can panic +fn can_panic(ast_func: &ast::Fn) -> Option { + let body = ast_func.body()?.to_string(); + let can_panic = body.contains("panic!(") + // FIXME it would be better to not match `debug_assert*!` macro invocations + || body.contains("assert!(") + || body.contains(".unwrap()") + || body.contains(".expect("); + Some(can_panic) +} + +/// Helper function to get the name that should be given to `self` arguments +fn self_name(ast_func: &ast::Fn) -> Option { + self_partial_type(ast_func).map(|name| to_lower_snake_case(&name)) +} + +/// Heper function to get the name of the type of `self` +fn self_type(ast_func: &ast::Fn) -> Option { + ast_func.syntax().ancestors().find_map(ast::Impl::cast).and_then(|i| i.self_ty()) +} + +/// Output the real name of `Self` like `MyType`, without the lifetimes. +fn self_type_without_lifetimes(ast_func: &ast::Fn) -> Option { + let path_segment = match self_type(ast_func)? { + ast::Type::PathType(path_type) => path_type.path()?.segment()?, + _ => return None, + }; + let mut name = path_segment.name_ref()?.to_string(); + let generics = path_segment.generic_arg_list().into_iter().flat_map(|list| { + list.generic_args() + .filter(|generic| matches!(generic, ast::GenericArg::TypeArg(_))) + .map(|generic| generic.to_string()) + }); + let generics: String = generics.format(", ").to_string(); + if !generics.is_empty() { + name.push('<'); + name.push_str(&generics); + name.push('>'); + } + Some(name) +} + +/// Heper function to get the name of the type of `self` without generic arguments +fn self_partial_type(ast_func: &ast::Fn) -> Option { + let mut self_type = self_type(ast_func)?.to_string(); + if let Some(idx) = self_type.find(|c| ['<', ' '].contains(&c)) { + self_type.truncate(idx); + } + Some(self_type) +} + +/// Helper function to determine if the function is in a trait implementation +fn is_in_trait_impl(ast_func: &ast::Fn, ctx: &AssistContext<'_>) -> bool { + ctx.sema + .to_def(ast_func) + .and_then(|hir_func| hir_func.as_assoc_item(ctx.db())) + .and_then(|assoc_item| assoc_item.containing_trait_impl(ctx.db())) + .is_some() +} + +/// Helper function to determine if the function definition is in a trait definition +fn is_in_trait_def(ast_func: &ast::Fn, ctx: &AssistContext<'_>) -> bool { + ctx.sema + .to_def(ast_func) + .and_then(|hir_func| hir_func.as_assoc_item(ctx.db())) + .and_then(|assoc_item| assoc_item.containing_trait(ctx.db())) + .is_some() +} + +/// Returns `None` if no `self` at all, `Some(true)` if there is `&mut self` else `Some(false)` +fn is_ref_mut_self(ast_func: &ast::Fn) -> Option { + let self_param = ast_func.param_list()?.self_param()?; + Some(self_param.mut_token().is_some() && self_param.amp_token().is_some()) +} + +/// Helper function to determine if a parameter is `&mut` +fn is_a_ref_mut_param(param: &ast::Param) -> bool { + match param.ty() { + Some(ast::Type::RefType(param_ref)) => param_ref.mut_token().is_some(), + _ => false, + } +} + +/// Helper function to build the list of `&mut` parameters +fn ref_mut_params(param_list: &ast::ParamList) -> Vec { + param_list + .params() + .filter_map(|param| match is_a_ref_mut_param(¶m) { + // Maybe better filter the param name (to do this maybe extract a function from + // `arguments_from_params`?) in case of a `mut a: &mut T`. Anyway managing most (not + // all) cases might be enough, the goal is just to produce a template. + true => Some(param.pat()?.to_string()), + false => None, + }) + .collect() +} + +/// Helper function to build the comma-separated list of arguments of the function +fn arguments_from_params(param_list: &ast::ParamList) -> String { + let args_iter = param_list.params().map(|param| match param.pat() { + // To avoid `mut` in the function call (which would be a nonsense), `Pat` should not be + // written as is so its variants must be managed independently. Other variants (for + // instance `TuplePat`) could be managed later. + Some(ast::Pat::IdentPat(ident_pat)) => match ident_pat.name() { + Some(name) => match is_a_ref_mut_param(¶m) { + true => format!("&mut {}", name), + false => name.to_string(), + }, + None => "_".to_string(), + }, + _ => "_".to_string(), + }); + args_iter.format(", ").to_string() +} + +/// Helper function to build a function call. `None` if expected `self_name` was not provided +fn function_call( + ast_func: &ast::Fn, + param_list: &ast::ParamList, + self_name: Option<&str>, + is_unsafe: bool, +) -> Option { + let name = ast_func.name()?; + let arguments = arguments_from_params(param_list); + let function_call = if param_list.self_param().is_some() { + format!("{}.{}({})", self_name?, name, arguments) + } else if let Some(implementation) = self_partial_type(ast_func) { + format!("{}::{}({})", implementation, name, arguments) + } else { + format!("{}({})", name, arguments) + }; + match is_unsafe { + true => Some(format!("unsafe {{ {} }}", function_call)), + false => Some(function_call), + } +} + +/// Helper function to count the parameters including `self` +fn count_parameters(param_list: &ast::ParamList) -> usize { + param_list.params().count() + if param_list.self_param().is_some() { 1 } else { 0 } +} + +/// Helper function to transform lines of documentation into a Rust code documentation +fn documentation_from_lines(doc_lines: Vec, indent_level: IndentLevel) -> String { + let mut result = String::new(); + for doc_line in doc_lines { + result.push_str("///"); + if !doc_line.is_empty() { + result.push(' '); + result.push_str(&doc_line); + } + result.push('\n'); + result.push_str(&indent_level.to_string()); + } + result +} + +/// Helper function to transform an array of borrowed strings to an owned `Vec` +fn string_vec_from(string_array: &[&str]) -> Vec { + string_array.iter().map(|&s| s.to_owned()).collect() +} + +/// Helper function to build the path of the module in the which is the node +fn build_path(ast_func: &ast::Fn, ctx: &AssistContext<'_>) -> Option { + let crate_name = crate_name(ast_func, ctx)?; + let leaf = self_partial_type(ast_func) + .or_else(|| ast_func.name().map(|n| n.to_string())) + .unwrap_or_else(|| "*".into()); + let module_def: ModuleDef = ctx.sema.to_def(ast_func)?.module(ctx.db()).into(); + match module_def.canonical_path(ctx.db()) { + Some(path) => Some(format!("{}::{}::{}", crate_name, path, leaf)), + None => Some(format!("{}::{}", crate_name, leaf)), + } +} + +/// Helper function to get the return type of a function +fn return_type(ast_func: &ast::Fn) -> Option { + ast_func.ret_type()?.ty() +} + +/// Helper function to determine if the function returns some data +fn returns_a_value(ast_func: &ast::Fn, ctx: &AssistContext<'_>) -> bool { + ctx.sema + .to_def(ast_func) + .map(|hir_func| hir_func.ret_type(ctx.db())) + .map(|ret_ty| !ret_ty.is_unit() && !ret_ty.is_never()) + .unwrap_or(false) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn not_applicable_on_function_calls() { + check_assist_not_applicable( + generate_documentation_template, + r#" +fn hello_world() {} +fn calls_hello_world() { + hello_world$0(); +} +"#, + ) + } + + #[test] + fn not_applicable_in_trait_impl() { + check_assist_not_applicable( + generate_documentation_template, + r#" +trait MyTrait {} +struct MyStruct; +impl MyTrait for MyStruct { + fn hello_world$0(); +} +"#, + ) + } + + #[test] + fn not_applicable_if_function_already_documented() { + check_assist_not_applicable( + generate_documentation_template, + r#" +/// Some documentation here +pub fn $0documented_function() {} +"#, + ); + } + + #[test] + fn supports_noop_function() { + check_assist( + generate_documentation_template, + r#" +pub fn no$0op() {} +"#, + r#" +/// . +pub fn noop() {} +"#, + ); + } + + #[test] + fn is_applicable_if_function_is_private() { + check_assist( + generate_documentation_template, + r#" +fn priv$0ate() {} +"#, + r#" +/// . +fn private() {} +"#, + ); + } + + #[test] + fn no_doc_example_for_private_fn() { + check_assist_not_applicable( + generate_doc_example, + r#" +///$0 +fn private() {} +"#, + ); + } + + #[test] + fn supports_a_parameter() { + check_assist( + generate_doc_example, + r#" +/// $0. +pub fn noop_with_param(_a: i32) {} +"#, + r#" +/// . +/// +/// # Examples +/// +/// ``` +/// use test::noop_with_param; +/// +/// noop_with_param(_a); +/// ``` +pub fn noop_with_param(_a: i32) {} +"#, + ); + } + + #[test] + fn detects_unsafe_function() { + check_assist( + generate_documentation_template, + r#" +pub unsafe fn no$0op_unsafe() {} +"#, + r#" +/// . +/// +/// # Safety +/// +/// . +pub unsafe fn noop_unsafe() {} +"#, + ); + check_assist( + generate_doc_example, + r#" +/// . +/// +/// # Safety$0 +/// +/// . +pub unsafe fn noop_unsafe() {} +"#, + r#" +/// . +/// +/// # Safety +/// +/// . +/// +/// # Examples +/// +/// ``` +/// use test::noop_unsafe; +/// +/// unsafe { noop_unsafe() }; +/// ``` +pub unsafe fn noop_unsafe() {} +"#, + ); + } + + #[test] + fn guesses_panic_macro_can_panic() { + check_assist( + generate_documentation_template, + r#" +pub fn panic$0s_if(a: bool) { + if a { + panic!(); + } +} +"#, + r#" +/// . +/// +/// # Panics +/// +/// Panics if . +pub fn panics_if(a: bool) { + if a { + panic!(); + } +} +"#, + ); + } + + #[test] + fn guesses_assert_macro_can_panic() { + check_assist( + generate_documentation_template, + r#" +pub fn $0panics_if_not(a: bool) { + assert!(a == true); +} +"#, + r#" +/// . +/// +/// # Panics +/// +/// Panics if . +pub fn panics_if_not(a: bool) { + assert!(a == true); +} +"#, + ); + } + + #[test] + fn guesses_unwrap_can_panic() { + check_assist( + generate_documentation_template, + r#" +pub fn $0panics_if_none(a: Option<()>) { + a.unwrap(); +} +"#, + r#" +/// . +/// +/// # Panics +/// +/// Panics if . +pub fn panics_if_none(a: Option<()>) { + a.unwrap(); +} +"#, + ); + } + + #[test] + fn guesses_expect_can_panic() { + check_assist( + generate_documentation_template, + r#" +pub fn $0panics_if_none2(a: Option<()>) { + a.expect("Bouh!"); +} +"#, + r#" +/// . +/// +/// # Panics +/// +/// Panics if . +pub fn panics_if_none2(a: Option<()>) { + a.expect("Bouh!"); +} +"#, + ); + } + + #[test] + fn checks_output_in_example() { + check_assist( + generate_doc_example, + r#" +///$0 +pub fn returns_a_value$0() -> i32 { + 0 +} +"#, + r#" +/// +/// +/// # Examples +/// +/// ``` +/// use test::returns_a_value; +/// +/// assert_eq!(returns_a_value(), ); +/// ``` +pub fn returns_a_value() -> i32 { + 0 +} +"#, + ); + } + + #[test] + fn detects_result_output() { + check_assist( + generate_documentation_template, + r#" +pub fn returns_a_result$0() -> Result { + Ok(0) +} +"#, + r#" +/// . +/// +/// # Errors +/// +/// This function will return an error if . +pub fn returns_a_result() -> Result { + Ok(0) +} +"#, + ); + } + + #[test] + fn checks_ref_mut_in_example() { + check_assist( + generate_doc_example, + r#" +///$0 +pub fn modifies_a_value$0(a: &mut i32) { + *a = 0; +} +"#, + r#" +/// +/// +/// # Examples +/// +/// ``` +/// use test::modifies_a_value; +/// +/// let mut a = ; +/// modifies_a_value(&mut a); +/// assert_eq!(a, ); +/// ``` +pub fn modifies_a_value(a: &mut i32) { + *a = 0; +} +"#, + ); + } + + #[test] + fn stores_result_if_at_least_3_params() { + check_assist( + generate_doc_example, + r#" +///$0 +pub fn sum3$0(a: i32, b: i32, c: i32) -> i32 { + a + b + c +} +"#, + r#" +/// +/// +/// # Examples +/// +/// ``` +/// use test::sum3; +/// +/// let result = sum3(a, b, c); +/// assert_eq!(result, ); +/// ``` +pub fn sum3(a: i32, b: i32, c: i32) -> i32 { + a + b + c +} +"#, + ); + } + + #[test] + fn supports_fn_in_mods() { + check_assist( + generate_doc_example, + r#" +pub mod a { + pub mod b { + ///$0 + pub fn noop() {} + } +} +"#, + r#" +pub mod a { + pub mod b { + /// + /// + /// # Examples + /// + /// ``` + /// use test::a::b::noop; + /// + /// noop(); + /// ``` + pub fn noop() {} + } +} +"#, + ); + } + + #[test] + fn supports_fn_in_impl() { + check_assist( + generate_doc_example, + r#" +pub struct MyStruct; +impl MyStruct { + ///$0 + pub fn noop() {} +} +"#, + r#" +pub struct MyStruct; +impl MyStruct { + /// + /// + /// # Examples + /// + /// ``` + /// use test::MyStruct; + /// + /// MyStruct::noop(); + /// ``` + pub fn noop() {} +} +"#, + ); + } + + #[test] + fn supports_unsafe_fn_in_trait() { + check_assist( + generate_documentation_template, + r#" +pub trait MyTrait { + unsafe fn unsafe_funct$0ion_trait(); +} +"#, + r#" +pub trait MyTrait { + /// . + /// + /// # Safety + /// + /// . + unsafe fn unsafe_function_trait(); +} +"#, + ); + } + + #[test] + fn supports_fn_in_trait_with_default_panicking() { + check_assist( + generate_documentation_template, + r#" +pub trait MyTrait { + fn function_trait_with_$0default_panicking() { + panic!() + } +} +"#, + r#" +pub trait MyTrait { + /// . + /// + /// # Panics + /// + /// Panics if . + fn function_trait_with_default_panicking() { + panic!() + } +} +"#, + ); + } + + #[test] + fn supports_fn_in_trait_returning_result() { + check_assist( + generate_documentation_template, + r#" +pub trait MyTrait { + fn function_tr$0ait_returning_result() -> Result<(), std::io::Error>; +} +"#, + r#" +pub trait MyTrait { + /// . + /// + /// # Errors + /// + /// This function will return an error if . + fn function_trait_returning_result() -> Result<(), std::io::Error>; +} +"#, + ); + } + + #[test] + fn detects_new() { + check_assist( + generate_documentation_template, + r#" +pub struct String(u8); +impl String { + pub fn new$0(x: u8) -> String { + String(x) + } +} +"#, + r#" +pub struct String(u8); +impl String { + /// Creates a new [`String`]. + pub fn new(x: u8) -> String { + String(x) + } +} +"#, + ); + check_assist( + generate_documentation_template, + r#" +#[derive(Debug, PartialEq)] +pub struct MyGenericStruct { + pub x: T, +} +impl MyGenericStruct { + pub fn new$0(x: T) -> MyGenericStruct { + MyGenericStruct { x } + } +} +"#, + r#" +#[derive(Debug, PartialEq)] +pub struct MyGenericStruct { + pub x: T, +} +impl MyGenericStruct { + /// Creates a new [`MyGenericStruct`]. + pub fn new(x: T) -> MyGenericStruct { + MyGenericStruct { x } + } +} +"#, + ); + } + + #[test] + fn removes_one_lifetime_from_description() { + check_assist( + generate_documentation_template, + r#" +#[derive(Debug, PartialEq)] +pub struct MyGenericStruct<'a, T> { + pub x: &'a T, +} +impl<'a, T> MyGenericStruct<'a, T> { + pub fn new$0(x: &'a T) -> Self { + MyGenericStruct { x } + } +} +"#, + r#" +#[derive(Debug, PartialEq)] +pub struct MyGenericStruct<'a, T> { + pub x: &'a T, +} +impl<'a, T> MyGenericStruct<'a, T> { + /// Creates a new [`MyGenericStruct`]. + pub fn new(x: &'a T) -> Self { + MyGenericStruct { x } + } +} +"#, + ); + } + + #[test] + fn removes_all_lifetimes_from_description() { + check_assist( + generate_documentation_template, + r#" +#[derive(Debug, PartialEq)] +pub struct MyGenericStruct<'a, 'b, T> { + pub x: &'a T, + pub y: &'b T, +} +impl<'a, 'b, T> MyGenericStruct<'a, 'b, T> { + pub fn new$0(x: &'a T, y: &'b T) -> Self { + MyGenericStruct { x, y } + } +} +"#, + r#" +#[derive(Debug, PartialEq)] +pub struct MyGenericStruct<'a, 'b, T> { + pub x: &'a T, + pub y: &'b T, +} +impl<'a, 'b, T> MyGenericStruct<'a, 'b, T> { + /// Creates a new [`MyGenericStruct`]. + pub fn new(x: &'a T, y: &'b T) -> Self { + MyGenericStruct { x, y } + } +} +"#, + ); + } + + #[test] + fn removes_all_lifetimes_and_brackets_from_description() { + check_assist( + generate_documentation_template, + r#" +#[derive(Debug, PartialEq)] +pub struct MyGenericStruct<'a, 'b> { + pub x: &'a usize, + pub y: &'b usize, +} +impl<'a, 'b> MyGenericStruct<'a, 'b> { + pub fn new$0(x: &'a usize, y: &'b usize) -> Self { + MyGenericStruct { x, y } + } +} +"#, + r#" +#[derive(Debug, PartialEq)] +pub struct MyGenericStruct<'a, 'b> { + pub x: &'a usize, + pub y: &'b usize, +} +impl<'a, 'b> MyGenericStruct<'a, 'b> { + /// Creates a new [`MyGenericStruct`]. + pub fn new(x: &'a usize, y: &'b usize) -> Self { + MyGenericStruct { x, y } + } +} +"#, + ); + } + + #[test] + fn detects_new_with_self() { + check_assist( + generate_documentation_template, + r#" +#[derive(Debug, PartialEq)] +pub struct MyGenericStruct2 { + pub x: T, +} +impl MyGenericStruct2 { + pub fn new$0(x: T) -> Self { + MyGenericStruct2 { x } + } +} +"#, + r#" +#[derive(Debug, PartialEq)] +pub struct MyGenericStruct2 { + pub x: T, +} +impl MyGenericStruct2 { + /// Creates a new [`MyGenericStruct2`]. + pub fn new(x: T) -> Self { + MyGenericStruct2 { x } + } +} +"#, + ); + } + + #[test] + fn supports_method_call() { + check_assist( + generate_doc_example, + r#" +impl MyGenericStruct { + ///$0 + pub fn consume(self) {} +} +"#, + r#" +impl MyGenericStruct { + /// + /// + /// # Examples + /// + /// ``` + /// use test::MyGenericStruct; + /// + /// let my_generic_struct = ; + /// my_generic_struct.consume(); + /// ``` + pub fn consume(self) {} +} +"#, + ); + } + + #[test] + fn checks_modified_self_param() { + check_assist( + generate_doc_example, + r#" +impl MyGenericStruct { + ///$0 + pub fn modify(&mut self, new_value: T) { + self.x = new_value; + } +} +"#, + r#" +impl MyGenericStruct { + /// + /// + /// # Examples + /// + /// ``` + /// use test::MyGenericStruct; + /// + /// let mut my_generic_struct = ; + /// my_generic_struct.modify(new_value); + /// assert_eq!(my_generic_struct, ); + /// ``` + pub fn modify(&mut self, new_value: T) { + self.x = new_value; + } +} +"#, + ); + } + + #[test] + fn generates_intro_for_getters() { + check_assist( + generate_documentation_template, + r#" +pub struct S; +impl S { + pub fn speed$0(&self) -> f32 { 0.0 } +} +"#, + r#" +pub struct S; +impl S { + /// Returns the speed of this [`S`]. + pub fn speed(&self) -> f32 { 0.0 } +} +"#, + ); + check_assist( + generate_documentation_template, + r#" +pub struct S; +impl S { + pub fn data$0(&self) -> &[u8] { &[] } +} +"#, + r#" +pub struct S; +impl S { + /// Returns a reference to the data of this [`S`]. + pub fn data(&self) -> &[u8] { &[] } +} +"#, + ); + check_assist( + generate_documentation_template, + r#" +pub struct S; +impl S { + pub fn data$0(&mut self) -> &mut [u8] { &mut [] } +} +"#, + r#" +pub struct S; +impl S { + /// Returns a mutable reference to the data of this [`S`]. + pub fn data(&mut self) -> &mut [u8] { &mut [] } +} +"#, + ); + check_assist( + generate_documentation_template, + r#" +pub struct S; +impl S { + pub fn data_mut$0(&mut self) -> &mut [u8] { &mut [] } +} +"#, + r#" +pub struct S; +impl S { + /// Returns a mutable reference to the data of this [`S`]. + pub fn data_mut(&mut self) -> &mut [u8] { &mut [] } +} +"#, + ); + } + + #[test] + fn no_getter_intro_for_prefixed_methods() { + check_assist( + generate_documentation_template, + r#" +pub struct S; +impl S { + pub fn as_bytes$0(&self) -> &[u8] { &[] } +} +"#, + r#" +pub struct S; +impl S { + /// . + pub fn as_bytes(&self) -> &[u8] { &[] } +} +"#, + ); + } + + #[test] + fn generates_intro_for_setters() { + check_assist( + generate_documentation_template, + r#" +pub struct S; +impl S { + pub fn set_data$0(&mut self, data: Vec) {} +} +"#, + r#" +pub struct S; +impl S { + /// Sets the data of this [`S`]. + pub fn set_data(&mut self, data: Vec) {} +} +"#, + ); + check_assist( + generate_documentation_template, + r#" +pub struct S; +impl S { + pub fn set_domain_name$0(&mut self, name: String) {} +} +"#, + r#" +pub struct S; +impl S { + /// Sets the domain name of this [`S`]. + pub fn set_domain_name(&mut self, name: String) {} +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_is_method.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_is_method.rs new file mode 100644 index 000000000..52d27d8a7 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_is_method.rs @@ -0,0 +1,316 @@ +use ide_db::assists::GroupLabel; +use stdx::to_lower_snake_case; +use syntax::ast::HasVisibility; +use syntax::ast::{self, AstNode, HasName}; + +use crate::{ + utils::{add_method_to_adt, find_struct_impl}, + AssistContext, AssistId, AssistKind, Assists, +}; + +// Assist: generate_enum_is_method +// +// Generate an `is_` method for this enum variant. +// +// ``` +// enum Version { +// Undefined, +// Minor$0, +// Major, +// } +// ``` +// -> +// ``` +// enum Version { +// Undefined, +// Minor, +// Major, +// } +// +// impl Version { +// /// Returns `true` if the version is [`Minor`]. +// /// +// /// [`Minor`]: Version::Minor +// #[must_use] +// fn is_minor(&self) -> bool { +// matches!(self, Self::Minor) +// } +// } +// ``` +pub(crate) fn generate_enum_is_method(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let variant = ctx.find_node_at_offset::()?; + let variant_name = variant.name()?; + let parent_enum = ast::Adt::Enum(variant.parent_enum()); + let pattern_suffix = match variant.kind() { + ast::StructKind::Record(_) => " { .. }", + ast::StructKind::Tuple(_) => "(..)", + ast::StructKind::Unit => "", + }; + + let enum_name = parent_enum.name()?; + let enum_lowercase_name = to_lower_snake_case(&enum_name.to_string()).replace('_', " "); + let fn_name = format!("is_{}", &to_lower_snake_case(&variant_name.text())); + + // Return early if we've found an existing new fn + let impl_def = find_struct_impl(ctx, &parent_enum, &fn_name)?; + + let target = variant.syntax().text_range(); + acc.add_group( + &GroupLabel("Generate an `is_`,`as_`, or `try_into_` for this enum variant".to_owned()), + AssistId("generate_enum_is_method", AssistKind::Generate), + "Generate an `is_` method for this enum variant", + target, + |builder| { + let vis = parent_enum.visibility().map_or(String::new(), |v| format!("{} ", v)); + let method = format!( + " /// Returns `true` if the {} is [`{variant}`]. + /// + /// [`{variant}`]: {}::{variant} + #[must_use] + {}fn {}(&self) -> bool {{ + matches!(self, Self::{variant}{}) + }}", + enum_lowercase_name, + enum_name, + vis, + fn_name, + pattern_suffix, + variant = variant_name + ); + + add_method_to_adt(builder, &parent_enum, impl_def, &method); + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_generate_enum_is_from_variant() { + check_assist( + generate_enum_is_method, + r#" +enum Variant { + Undefined, + Minor$0, + Major, +}"#, + r#"enum Variant { + Undefined, + Minor, + Major, +} + +impl Variant { + /// Returns `true` if the variant is [`Minor`]. + /// + /// [`Minor`]: Variant::Minor + #[must_use] + fn is_minor(&self) -> bool { + matches!(self, Self::Minor) + } +}"#, + ); + } + + #[test] + fn test_generate_enum_is_already_implemented() { + check_assist_not_applicable( + generate_enum_is_method, + r#" +enum Variant { + Undefined, + Minor$0, + Major, +} + +impl Variant { + fn is_minor(&self) -> bool { + matches!(self, Self::Minor) + } +}"#, + ); + } + + #[test] + fn test_generate_enum_is_from_tuple_variant() { + check_assist( + generate_enum_is_method, + r#" +enum Variant { + Undefined, + Minor(u32)$0, + Major, +}"#, + r#"enum Variant { + Undefined, + Minor(u32), + Major, +} + +impl Variant { + /// Returns `true` if the variant is [`Minor`]. + /// + /// [`Minor`]: Variant::Minor + #[must_use] + fn is_minor(&self) -> bool { + matches!(self, Self::Minor(..)) + } +}"#, + ); + } + + #[test] + fn test_generate_enum_is_from_record_variant() { + check_assist( + generate_enum_is_method, + r#" +enum Variant { + Undefined, + Minor { foo: i32 }$0, + Major, +}"#, + r#"enum Variant { + Undefined, + Minor { foo: i32 }, + Major, +} + +impl Variant { + /// Returns `true` if the variant is [`Minor`]. + /// + /// [`Minor`]: Variant::Minor + #[must_use] + fn is_minor(&self) -> bool { + matches!(self, Self::Minor { .. }) + } +}"#, + ); + } + + #[test] + fn test_generate_enum_is_from_variant_with_one_variant() { + check_assist( + generate_enum_is_method, + r#"enum Variant { Undefi$0ned }"#, + r#" +enum Variant { Undefined } + +impl Variant { + /// Returns `true` if the variant is [`Undefined`]. + /// + /// [`Undefined`]: Variant::Undefined + #[must_use] + fn is_undefined(&self) -> bool { + matches!(self, Self::Undefined) + } +}"#, + ); + } + + #[test] + fn test_generate_enum_is_from_variant_with_visibility_marker() { + check_assist( + generate_enum_is_method, + r#" +pub(crate) enum Variant { + Undefined, + Minor$0, + Major, +}"#, + r#"pub(crate) enum Variant { + Undefined, + Minor, + Major, +} + +impl Variant { + /// Returns `true` if the variant is [`Minor`]. + /// + /// [`Minor`]: Variant::Minor + #[must_use] + pub(crate) fn is_minor(&self) -> bool { + matches!(self, Self::Minor) + } +}"#, + ); + } + + #[test] + fn test_multiple_generate_enum_is_from_variant() { + check_assist( + generate_enum_is_method, + r#" +enum Variant { + Undefined, + Minor, + Major$0, +} + +impl Variant { + /// Returns `true` if the variant is [`Minor`]. + /// + /// [`Minor`]: Variant::Minor + #[must_use] + fn is_minor(&self) -> bool { + matches!(self, Self::Minor) + } +}"#, + r#"enum Variant { + Undefined, + Minor, + Major, +} + +impl Variant { + /// Returns `true` if the variant is [`Minor`]. + /// + /// [`Minor`]: Variant::Minor + #[must_use] + fn is_minor(&self) -> bool { + matches!(self, Self::Minor) + } + + /// Returns `true` if the variant is [`Major`]. + /// + /// [`Major`]: Variant::Major + #[must_use] + fn is_major(&self) -> bool { + matches!(self, Self::Major) + } +}"#, + ); + } + + #[test] + fn test_generate_enum_is_variant_names() { + check_assist( + generate_enum_is_method, + r#" +enum GeneratorState { + Yielded, + Complete$0, + Major, +}"#, + r#"enum GeneratorState { + Yielded, + Complete, + Major, +} + +impl GeneratorState { + /// Returns `true` if the generator state is [`Complete`]. + /// + /// [`Complete`]: GeneratorState::Complete + #[must_use] + fn is_complete(&self) -> bool { + matches!(self, Self::Complete) + } +}"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_projection_method.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_projection_method.rs new file mode 100644 index 000000000..b19aa0f65 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_projection_method.rs @@ -0,0 +1,342 @@ +use ide_db::assists::GroupLabel; +use itertools::Itertools; +use stdx::to_lower_snake_case; +use syntax::ast::HasVisibility; +use syntax::ast::{self, AstNode, HasName}; + +use crate::{ + utils::{add_method_to_adt, find_struct_impl}, + AssistContext, AssistId, AssistKind, Assists, +}; + +// Assist: generate_enum_try_into_method +// +// Generate a `try_into_` method for this enum variant. +// +// ``` +// enum Value { +// Number(i32), +// Text(String)$0, +// } +// ``` +// -> +// ``` +// enum Value { +// Number(i32), +// Text(String), +// } +// +// impl Value { +// fn try_into_text(self) -> Result { +// if let Self::Text(v) = self { +// Ok(v) +// } else { +// Err(self) +// } +// } +// } +// ``` +pub(crate) fn generate_enum_try_into_method( + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { + generate_enum_projection_method( + acc, + ctx, + "generate_enum_try_into_method", + "Generate a `try_into_` method for this enum variant", + ProjectionProps { + fn_name_prefix: "try_into", + self_param: "self", + return_prefix: "Result<", + return_suffix: ", Self>", + happy_case: "Ok", + sad_case: "Err(self)", + }, + ) +} + +// Assist: generate_enum_as_method +// +// Generate an `as_` method for this enum variant. +// +// ``` +// enum Value { +// Number(i32), +// Text(String)$0, +// } +// ``` +// -> +// ``` +// enum Value { +// Number(i32), +// Text(String), +// } +// +// impl Value { +// fn as_text(&self) -> Option<&String> { +// if let Self::Text(v) = self { +// Some(v) +// } else { +// None +// } +// } +// } +// ``` +pub(crate) fn generate_enum_as_method(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + generate_enum_projection_method( + acc, + ctx, + "generate_enum_as_method", + "Generate an `as_` method for this enum variant", + ProjectionProps { + fn_name_prefix: "as", + self_param: "&self", + return_prefix: "Option<&", + return_suffix: ">", + happy_case: "Some", + sad_case: "None", + }, + ) +} + +struct ProjectionProps { + fn_name_prefix: &'static str, + self_param: &'static str, + return_prefix: &'static str, + return_suffix: &'static str, + happy_case: &'static str, + sad_case: &'static str, +} + +fn generate_enum_projection_method( + acc: &mut Assists, + ctx: &AssistContext<'_>, + assist_id: &'static str, + assist_description: &str, + props: ProjectionProps, +) -> Option<()> { + let variant = ctx.find_node_at_offset::()?; + let variant_name = variant.name()?; + let parent_enum = ast::Adt::Enum(variant.parent_enum()); + + let (pattern_suffix, field_type, bound_name) = match variant.kind() { + ast::StructKind::Record(record) => { + let (field,) = record.fields().collect_tuple()?; + let name = field.name()?.to_string(); + let ty = field.ty()?; + let pattern_suffix = format!(" {{ {} }}", name); + (pattern_suffix, ty, name) + } + ast::StructKind::Tuple(tuple) => { + let (field,) = tuple.fields().collect_tuple()?; + let ty = field.ty()?; + ("(v)".to_owned(), ty, "v".to_owned()) + } + ast::StructKind::Unit => return None, + }; + + let fn_name = + format!("{}_{}", props.fn_name_prefix, &to_lower_snake_case(&variant_name.text())); + + // Return early if we've found an existing new fn + let impl_def = find_struct_impl(ctx, &parent_enum, &fn_name)?; + + let target = variant.syntax().text_range(); + acc.add_group( + &GroupLabel("Generate an `is_`,`as_`, or `try_into_` for this enum variant".to_owned()), + AssistId(assist_id, AssistKind::Generate), + assist_description, + target, + |builder| { + let vis = parent_enum.visibility().map_or(String::new(), |v| format!("{} ", v)); + let method = format!( + " {0}fn {1}({2}) -> {3}{4}{5} {{ + if let Self::{6}{7} = self {{ + {8}({9}) + }} else {{ + {10} + }} + }}", + vis, + fn_name, + props.self_param, + props.return_prefix, + field_type.syntax(), + props.return_suffix, + variant_name, + pattern_suffix, + props.happy_case, + bound_name, + props.sad_case, + ); + + add_method_to_adt(builder, &parent_enum, impl_def, &method); + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_generate_enum_try_into_tuple_variant() { + check_assist( + generate_enum_try_into_method, + r#" +enum Value { + Number(i32), + Text(String)$0, +}"#, + r#"enum Value { + Number(i32), + Text(String), +} + +impl Value { + fn try_into_text(self) -> Result { + if let Self::Text(v) = self { + Ok(v) + } else { + Err(self) + } + } +}"#, + ); + } + + #[test] + fn test_generate_enum_try_into_already_implemented() { + check_assist_not_applicable( + generate_enum_try_into_method, + r#"enum Value { + Number(i32), + Text(String)$0, +} + +impl Value { + fn try_into_text(self) -> Result { + if let Self::Text(v) = self { + Ok(v) + } else { + Err(self) + } + } +}"#, + ); + } + + #[test] + fn test_generate_enum_try_into_unit_variant() { + check_assist_not_applicable( + generate_enum_try_into_method, + r#"enum Value { + Number(i32), + Text(String), + Unit$0, +}"#, + ); + } + + #[test] + fn test_generate_enum_try_into_record_with_multiple_fields() { + check_assist_not_applicable( + generate_enum_try_into_method, + r#"enum Value { + Number(i32), + Text(String), + Both { first: i32, second: String }$0, +}"#, + ); + } + + #[test] + fn test_generate_enum_try_into_tuple_with_multiple_fields() { + check_assist_not_applicable( + generate_enum_try_into_method, + r#"enum Value { + Number(i32), + Text(String, String)$0, +}"#, + ); + } + + #[test] + fn test_generate_enum_try_into_record_variant() { + check_assist( + generate_enum_try_into_method, + r#"enum Value { + Number(i32), + Text { text: String }$0, +}"#, + r#"enum Value { + Number(i32), + Text { text: String }, +} + +impl Value { + fn try_into_text(self) -> Result { + if let Self::Text { text } = self { + Ok(text) + } else { + Err(self) + } + } +}"#, + ); + } + + #[test] + fn test_generate_enum_as_tuple_variant() { + check_assist( + generate_enum_as_method, + r#" +enum Value { + Number(i32), + Text(String)$0, +}"#, + r#"enum Value { + Number(i32), + Text(String), +} + +impl Value { + fn as_text(&self) -> Option<&String> { + if let Self::Text(v) = self { + Some(v) + } else { + None + } + } +}"#, + ); + } + + #[test] + fn test_generate_enum_as_record_variant() { + check_assist( + generate_enum_as_method, + r#"enum Value { + Number(i32), + Text { text: String }$0, +}"#, + r#"enum Value { + Number(i32), + Text { text: String }, +} + +impl Value { + fn as_text(&self) -> Option<&String> { + if let Self::Text { text } = self { + Some(text) + } else { + None + } + } +}"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_variant.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_variant.rs new file mode 100644 index 000000000..4461fbd5a --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_variant.rs @@ -0,0 +1,227 @@ +use hir::{HasSource, InFile}; +use ide_db::assists::{AssistId, AssistKind}; +use syntax::{ + ast::{self, edit::IndentLevel}, + AstNode, TextSize, +}; + +use crate::assist_context::{AssistContext, Assists}; + +// Assist: generate_enum_variant +// +// Adds a variant to an enum. +// +// ``` +// enum Countries { +// Ghana, +// } +// +// fn main() { +// let country = Countries::Lesotho$0; +// } +// ``` +// -> +// ``` +// enum Countries { +// Ghana, +// Lesotho, +// } +// +// fn main() { +// let country = Countries::Lesotho; +// } +// ``` +pub(crate) fn generate_enum_variant(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let path_expr: ast::PathExpr = ctx.find_node_at_offset()?; + let path = path_expr.path()?; + + if ctx.sema.resolve_path(&path).is_some() { + // No need to generate anything if the path resolves + return None; + } + + let name_ref = path.segment()?.name_ref()?; + if name_ref.text().starts_with(char::is_lowercase) { + // Don't suggest generating variant if the name starts with a lowercase letter + return None; + } + + if let Some(hir::PathResolution::Def(hir::ModuleDef::Adt(hir::Adt::Enum(e)))) = + ctx.sema.resolve_path(&path.qualifier()?) + { + let target = path.syntax().text_range(); + return add_variant_to_accumulator(acc, ctx, target, e, &name_ref); + } + + None +} + +fn add_variant_to_accumulator( + acc: &mut Assists, + ctx: &AssistContext<'_>, + target: syntax::TextRange, + adt: hir::Enum, + name_ref: &ast::NameRef, +) -> Option<()> { + let db = ctx.db(); + let InFile { file_id, value: enum_node } = adt.source(db)?.original_ast_node(db)?; + let enum_indent = IndentLevel::from_node(&enum_node.syntax()); + + let variant_list = enum_node.variant_list()?; + let offset = variant_list.syntax().text_range().end() - TextSize::of('}'); + let empty_enum = variant_list.variants().next().is_none(); + + acc.add( + AssistId("generate_enum_variant", AssistKind::Generate), + "Generate variant", + target, + |builder| { + builder.edit_file(file_id.original_file(db)); + let text = format!( + "{maybe_newline}{indent_1}{name},\n{enum_indent}", + maybe_newline = if empty_enum { "\n" } else { "" }, + indent_1 = IndentLevel(1), + name = name_ref, + enum_indent = enum_indent + ); + builder.insert(offset, text) + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn generate_basic_enum_variant_in_empty_enum() { + check_assist( + generate_enum_variant, + r" +enum Foo {} +fn main() { + Foo::Bar$0 +} +", + r" +enum Foo { + Bar, +} +fn main() { + Foo::Bar +} +", + ) + } + + #[test] + fn generate_basic_enum_variant_in_non_empty_enum() { + check_assist( + generate_enum_variant, + r" +enum Foo { + Bar, +} +fn main() { + Foo::Baz$0 +} +", + r" +enum Foo { + Bar, + Baz, +} +fn main() { + Foo::Baz +} +", + ) + } + + #[test] + fn generate_basic_enum_variant_in_different_file() { + check_assist( + generate_enum_variant, + r" +//- /main.rs +mod foo; +use foo::Foo; + +fn main() { + Foo::Baz$0 +} + +//- /foo.rs +enum Foo { + Bar, +} +", + r" +enum Foo { + Bar, + Baz, +} +", + ) + } + + #[test] + fn not_applicable_for_existing_variant() { + check_assist_not_applicable( + generate_enum_variant, + r" +enum Foo { + Bar, +} +fn main() { + Foo::Bar$0 +} +", + ) + } + + #[test] + fn not_applicable_for_lowercase() { + check_assist_not_applicable( + generate_enum_variant, + r" +enum Foo { + Bar, +} +fn main() { + Foo::new$0 +} +", + ) + } + + #[test] + fn indentation_level_is_correct() { + check_assist( + generate_enum_variant, + r" +mod m { + enum Foo { + Bar, + } +} +fn main() { + m::Foo::Baz$0 +} +", + r" +mod m { + enum Foo { + Bar, + Baz, + } +} +fn main() { + m::Foo::Baz +} +", + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs new file mode 100644 index 000000000..507ea012b --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs @@ -0,0 +1,310 @@ +use ide_db::{famous_defs::FamousDefs, RootDatabase}; +use syntax::ast::{self, AstNode, HasName}; + +use crate::{utils::generate_trait_impl_text, AssistContext, AssistId, AssistKind, Assists}; + +// Assist: generate_from_impl_for_enum +// +// Adds a From impl for this enum variant with one tuple field. +// +// ``` +// enum A { $0One(u32) } +// ``` +// -> +// ``` +// enum A { One(u32) } +// +// impl From for A { +// fn from(v: u32) -> Self { +// Self::One(v) +// } +// } +// ``` +pub(crate) fn generate_from_impl_for_enum( + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { + let variant = ctx.find_node_at_offset::()?; + let variant_name = variant.name()?; + let enum_ = ast::Adt::Enum(variant.parent_enum()); + let (field_name, field_type) = match variant.kind() { + ast::StructKind::Tuple(field_list) => { + if field_list.fields().count() != 1 { + return None; + } + (None, field_list.fields().next()?.ty()?) + } + ast::StructKind::Record(field_list) => { + if field_list.fields().count() != 1 { + return None; + } + let field = field_list.fields().next()?; + (Some(field.name()?), field.ty()?) + } + ast::StructKind::Unit => return None, + }; + + if existing_from_impl(&ctx.sema, &variant).is_some() { + cov_mark::hit!(test_add_from_impl_already_exists); + return None; + } + + let target = variant.syntax().text_range(); + acc.add( + AssistId("generate_from_impl_for_enum", AssistKind::Generate), + "Generate `From` impl for this enum variant", + target, + |edit| { + let start_offset = variant.parent_enum().syntax().text_range().end(); + let from_trait = format!("From<{}>", field_type.syntax()); + let impl_code = if let Some(name) = field_name { + format!( + r#" fn from({0}: {1}) -> Self {{ + Self::{2} {{ {0} }} + }}"#, + name.text(), + field_type.syntax(), + variant_name, + ) + } else { + format!( + r#" fn from(v: {}) -> Self {{ + Self::{}(v) + }}"#, + field_type.syntax(), + variant_name, + ) + }; + let from_impl = generate_trait_impl_text(&enum_, &from_trait, &impl_code); + edit.insert(start_offset, from_impl); + }, + ) +} + +fn existing_from_impl( + sema: &'_ hir::Semantics<'_, RootDatabase>, + variant: &ast::Variant, +) -> Option<()> { + let variant = sema.to_def(variant)?; + let enum_ = variant.parent_enum(sema.db); + let krate = enum_.module(sema.db).krate(); + + let from_trait = FamousDefs(sema, krate).core_convert_From()?; + + let enum_type = enum_.ty(sema.db); + + let wrapped_type = variant.fields(sema.db).get(0)?.ty(sema.db); + + if enum_type.impls_trait(sema.db, from_trait, &[wrapped_type]) { + Some(()) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_generate_from_impl_for_enum() { + check_assist( + generate_from_impl_for_enum, + r#" +//- minicore: from +enum A { $0One(u32) } +"#, + r#" +enum A { One(u32) } + +impl From for A { + fn from(v: u32) -> Self { + Self::One(v) + } +} +"#, + ); + } + + #[test] + fn test_generate_from_impl_for_enum_complicated_path() { + check_assist( + generate_from_impl_for_enum, + r#" +//- minicore: from +enum A { $0One(foo::bar::baz::Boo) } +"#, + r#" +enum A { One(foo::bar::baz::Boo) } + +impl From for A { + fn from(v: foo::bar::baz::Boo) -> Self { + Self::One(v) + } +} +"#, + ); + } + + #[test] + fn test_add_from_impl_no_element() { + check_assist_not_applicable( + generate_from_impl_for_enum, + r#" +//- minicore: from +enum A { $0One } +"#, + ); + } + + #[test] + fn test_add_from_impl_more_than_one_element_in_tuple() { + check_assist_not_applicable( + generate_from_impl_for_enum, + r#" +//- minicore: from +enum A { $0One(u32, String) } +"#, + ); + } + + #[test] + fn test_add_from_impl_struct_variant() { + check_assist( + generate_from_impl_for_enum, + r#" +//- minicore: from +enum A { $0One { x: u32 } } +"#, + r#" +enum A { One { x: u32 } } + +impl From for A { + fn from(x: u32) -> Self { + Self::One { x } + } +} +"#, + ); + } + + #[test] + fn test_add_from_impl_already_exists() { + cov_mark::check!(test_add_from_impl_already_exists); + check_assist_not_applicable( + generate_from_impl_for_enum, + r#" +//- minicore: from +enum A { $0One(u32), } + +impl From for A { + fn from(v: u32) -> Self { + Self::One(v) + } +} +"#, + ); + } + + #[test] + fn test_add_from_impl_different_variant_impl_exists() { + check_assist( + generate_from_impl_for_enum, + r#" +//- minicore: from +enum A { $0One(u32), Two(String), } + +impl From for A { + fn from(v: String) -> Self { + A::Two(v) + } +} + +pub trait From { + fn from(T) -> Self; +} +"#, + r#" +enum A { One(u32), Two(String), } + +impl From for A { + fn from(v: u32) -> Self { + Self::One(v) + } +} + +impl From for A { + fn from(v: String) -> Self { + A::Two(v) + } +} + +pub trait From { + fn from(T) -> Self; +} +"#, + ); + } + + #[test] + fn test_add_from_impl_static_str() { + check_assist( + generate_from_impl_for_enum, + r#" +//- minicore: from +enum A { $0One(&'static str) } +"#, + r#" +enum A { One(&'static str) } + +impl From<&'static str> for A { + fn from(v: &'static str) -> Self { + Self::One(v) + } +} +"#, + ); + } + + #[test] + fn test_add_from_impl_generic_enum() { + check_assist( + generate_from_impl_for_enum, + r#" +//- minicore: from +enum Generic { $0One(T), Two(U) } +"#, + r#" +enum Generic { One(T), Two(U) } + +impl From for Generic { + fn from(v: T) -> Self { + Self::One(v) + } +} +"#, + ); + } + + #[test] + fn test_add_from_impl_with_lifetime() { + check_assist( + generate_from_impl_for_enum, + r#" +//- minicore: from +enum Generic<'a> { $0One(&'a i32) } +"#, + r#" +enum Generic<'a> { One(&'a i32) } + +impl<'a> From<&'a i32> for Generic<'a> { + fn from(v: &'a i32) -> Self { + Self::One(v) + } +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_function.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_function.rs new file mode 100644 index 000000000..d564a0540 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_function.rs @@ -0,0 +1,1787 @@ +use hir::{HasSource, HirDisplay, Module, Semantics, TypeInfo}; +use ide_db::{ + base_db::FileId, + defs::{Definition, NameRefClass}, + famous_defs::FamousDefs, + FxHashMap, FxHashSet, RootDatabase, SnippetCap, +}; +use stdx::to_lower_snake_case; +use syntax::{ + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + make, AstNode, CallExpr, HasArgList, HasModuleItem, + }, + SyntaxKind, SyntaxNode, TextRange, TextSize, +}; + +use crate::{ + utils::convert_reference_type, + utils::{find_struct_impl, render_snippet, Cursor}, + AssistContext, AssistId, AssistKind, Assists, +}; + +// Assist: generate_function +// +// Adds a stub function with a signature matching the function under the cursor. +// +// ``` +// struct Baz; +// fn baz() -> Baz { Baz } +// fn foo() { +// bar$0("", baz()); +// } +// +// ``` +// -> +// ``` +// struct Baz; +// fn baz() -> Baz { Baz } +// fn foo() { +// bar("", baz()); +// } +// +// fn bar(arg: &str, baz: Baz) ${0:-> _} { +// todo!() +// } +// +// ``` +pub(crate) fn generate_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + gen_fn(acc, ctx).or_else(|| gen_method(acc, ctx)) +} + +fn gen_fn(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let path_expr: ast::PathExpr = ctx.find_node_at_offset()?; + let call = path_expr.syntax().parent().and_then(ast::CallExpr::cast)?; + let path = path_expr.path()?; + let name_ref = path.segment()?.name_ref()?; + if ctx.sema.resolve_path(&path).is_some() { + // The function call already resolves, no need to add a function + return None; + } + + let fn_name = &*name_ref.text(); + let target_module; + let mut adt_name = None; + + let (target, file, insert_offset) = match path.qualifier() { + Some(qualifier) => match ctx.sema.resolve_path(&qualifier) { + Some(hir::PathResolution::Def(hir::ModuleDef::Module(module))) => { + target_module = Some(module); + get_fn_target(ctx, &target_module, call.clone())? + } + Some(hir::PathResolution::Def(hir::ModuleDef::Adt(adt))) => { + if let hir::Adt::Enum(_) = adt { + // Don't suggest generating function if the name starts with an uppercase letter + if name_ref.text().starts_with(char::is_uppercase) { + return None; + } + } + + let current_module = ctx.sema.scope(call.syntax())?.module(); + let module = adt.module(ctx.sema.db); + target_module = if current_module == module { None } else { Some(module) }; + if current_module.krate() != module.krate() { + return None; + } + let (impl_, file) = get_adt_source(ctx, &adt, fn_name)?; + let (target, insert_offset) = get_method_target(ctx, &module, &impl_)?; + adt_name = if impl_.is_none() { Some(adt.name(ctx.sema.db)) } else { None }; + (target, file, insert_offset) + } + _ => { + return None; + } + }, + _ => { + target_module = None; + get_fn_target(ctx, &target_module, call.clone())? + } + }; + let function_builder = FunctionBuilder::from_call(ctx, &call, fn_name, target_module, target)?; + let text_range = call.syntax().text_range(); + let label = format!("Generate {} function", function_builder.fn_name); + add_func_to_accumulator( + acc, + ctx, + text_range, + function_builder, + insert_offset, + file, + adt_name, + label, + ) +} + +fn gen_method(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let call: ast::MethodCallExpr = ctx.find_node_at_offset()?; + if ctx.sema.resolve_method_call(&call).is_some() { + return None; + } + + let fn_name = call.name_ref()?; + let adt = ctx.sema.type_of_expr(&call.receiver()?)?.original().strip_references().as_adt()?; + + let current_module = ctx.sema.scope(call.syntax())?.module(); + let target_module = adt.module(ctx.sema.db); + + if current_module.krate() != target_module.krate() { + return None; + } + let (impl_, file) = get_adt_source(ctx, &adt, fn_name.text().as_str())?; + let (target, insert_offset) = get_method_target(ctx, &target_module, &impl_)?; + let function_builder = + FunctionBuilder::from_method_call(ctx, &call, &fn_name, target_module, target)?; + let text_range = call.syntax().text_range(); + let adt_name = if impl_.is_none() { Some(adt.name(ctx.sema.db)) } else { None }; + let label = format!("Generate {} method", function_builder.fn_name); + add_func_to_accumulator( + acc, + ctx, + text_range, + function_builder, + insert_offset, + file, + adt_name, + label, + ) +} + +fn add_func_to_accumulator( + acc: &mut Assists, + ctx: &AssistContext<'_>, + text_range: TextRange, + function_builder: FunctionBuilder, + insert_offset: TextSize, + file: FileId, + adt_name: Option, + label: String, +) -> Option<()> { + acc.add(AssistId("generate_function", AssistKind::Generate), label, text_range, |builder| { + let function_template = function_builder.render(); + let mut func = function_template.to_string(ctx.config.snippet_cap); + if let Some(name) = adt_name { + func = format!("\nimpl {} {{\n{}\n}}", name, func); + } + builder.edit_file(file); + match ctx.config.snippet_cap { + Some(cap) => builder.insert_snippet(cap, insert_offset, func), + None => builder.insert(insert_offset, func), + } + }) +} + +fn get_adt_source( + ctx: &AssistContext<'_>, + adt: &hir::Adt, + fn_name: &str, +) -> Option<(Option, FileId)> { + let range = adt.source(ctx.sema.db)?.syntax().original_file_range(ctx.sema.db); + let file = ctx.sema.parse(range.file_id); + let adt_source = + ctx.sema.find_node_at_offset_with_macros(file.syntax(), range.range.start())?; + find_struct_impl(ctx, &adt_source, fn_name).map(|impl_| (impl_, range.file_id)) +} + +struct FunctionTemplate { + leading_ws: String, + fn_def: ast::Fn, + ret_type: Option, + should_focus_return_type: bool, + trailing_ws: String, + tail_expr: ast::Expr, +} + +impl FunctionTemplate { + fn to_string(&self, cap: Option) -> String { + let f = match cap { + Some(cap) => { + let cursor = if self.should_focus_return_type { + // Focus the return type if there is one + match self.ret_type { + Some(ref ret_type) => ret_type.syntax(), + None => self.tail_expr.syntax(), + } + } else { + self.tail_expr.syntax() + }; + render_snippet(cap, self.fn_def.syntax(), Cursor::Replace(cursor)) + } + None => self.fn_def.to_string(), + }; + + format!("{}{}{}", self.leading_ws, f, self.trailing_ws) + } +} + +struct FunctionBuilder { + target: GeneratedFunctionTarget, + fn_name: ast::Name, + type_params: Option, + params: ast::ParamList, + ret_type: Option, + should_focus_return_type: bool, + needs_pub: bool, + is_async: bool, +} + +impl FunctionBuilder { + /// Prepares a generated function that matches `call`. + /// The function is generated in `target_module` or next to `call` + fn from_call( + ctx: &AssistContext<'_>, + call: &ast::CallExpr, + fn_name: &str, + target_module: Option, + target: GeneratedFunctionTarget, + ) -> Option { + let needs_pub = target_module.is_some(); + let target_module = + target_module.or_else(|| ctx.sema.scope(target.syntax()).map(|it| it.module()))?; + let fn_name = make::name(fn_name); + let (type_params, params) = + fn_args(ctx, target_module, ast::CallableExpr::Call(call.clone()))?; + + let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast); + let is_async = await_expr.is_some(); + + let (ret_type, should_focus_return_type) = + make_return_type(ctx, &ast::Expr::CallExpr(call.clone()), target_module); + + Some(Self { + target, + fn_name, + type_params, + params, + ret_type, + should_focus_return_type, + needs_pub, + is_async, + }) + } + + fn from_method_call( + ctx: &AssistContext<'_>, + call: &ast::MethodCallExpr, + name: &ast::NameRef, + target_module: Module, + target: GeneratedFunctionTarget, + ) -> Option { + let needs_pub = + !module_is_descendant(&ctx.sema.scope(call.syntax())?.module(), &target_module, ctx); + let fn_name = make::name(&name.text()); + let (type_params, params) = + fn_args(ctx, target_module, ast::CallableExpr::MethodCall(call.clone()))?; + + let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast); + let is_async = await_expr.is_some(); + + let (ret_type, should_focus_return_type) = + make_return_type(ctx, &ast::Expr::MethodCallExpr(call.clone()), target_module); + + Some(Self { + target, + fn_name, + type_params, + params, + ret_type, + should_focus_return_type, + needs_pub, + is_async, + }) + } + + fn render(self) -> FunctionTemplate { + let placeholder_expr = make::ext::expr_todo(); + let fn_body = make::block_expr(vec![], Some(placeholder_expr)); + let visibility = if self.needs_pub { Some(make::visibility_pub_crate()) } else { None }; + let mut fn_def = make::fn_( + visibility, + self.fn_name, + self.type_params, + self.params, + fn_body, + self.ret_type, + self.is_async, + ); + let leading_ws; + let trailing_ws; + + match self.target { + GeneratedFunctionTarget::BehindItem(it) => { + let indent = IndentLevel::from_node(&it); + leading_ws = format!("\n\n{}", indent); + fn_def = fn_def.indent(indent); + trailing_ws = String::new(); + } + GeneratedFunctionTarget::InEmptyItemList(it) => { + let indent = IndentLevel::from_node(&it); + leading_ws = format!("\n{}", indent + 1); + fn_def = fn_def.indent(indent + 1); + trailing_ws = format!("\n{}", indent); + } + }; + + FunctionTemplate { + leading_ws, + ret_type: fn_def.ret_type(), + // PANIC: we guarantee we always create a function body with a tail expr + tail_expr: fn_def.body().unwrap().tail_expr().unwrap(), + should_focus_return_type: self.should_focus_return_type, + fn_def, + trailing_ws, + } + } +} + +/// Makes an optional return type along with whether the return type should be focused by the cursor. +/// If we cannot infer what the return type should be, we create a placeholder type. +/// +/// The rule for whether we focus a return type or not (and thus focus the function body), +/// is rather simple: +/// * If we could *not* infer what the return type should be, focus it (so the user can fill-in +/// the correct return type). +/// * If we could infer the return type, don't focus it (and thus focus the function body) so the +/// user can change the `todo!` function body. +fn make_return_type( + ctx: &AssistContext<'_>, + call: &ast::Expr, + target_module: Module, +) -> (Option, bool) { + let (ret_ty, should_focus_return_type) = { + match ctx.sema.type_of_expr(call).map(TypeInfo::original) { + Some(ty) if ty.is_unknown() => (Some(make::ty_placeholder()), true), + None => (Some(make::ty_placeholder()), true), + Some(ty) if ty.is_unit() => (None, false), + Some(ty) => { + let rendered = ty.display_source_code(ctx.db(), target_module.into()); + match rendered { + Ok(rendered) => (Some(make::ty(&rendered)), false), + Err(_) => (Some(make::ty_placeholder()), true), + } + } + } + }; + let ret_type = ret_ty.map(make::ret_type); + (ret_type, should_focus_return_type) +} + +fn get_fn_target( + ctx: &AssistContext<'_>, + target_module: &Option, + call: CallExpr, +) -> Option<(GeneratedFunctionTarget, FileId, TextSize)> { + let mut file = ctx.file_id(); + let target = match target_module { + Some(target_module) => { + let module_source = target_module.definition_source(ctx.db()); + let (in_file, target) = next_space_for_fn_in_module(ctx.sema.db, &module_source)?; + file = in_file; + target + } + None => next_space_for_fn_after_call_site(ast::CallableExpr::Call(call))?, + }; + Some((target.clone(), file, get_insert_offset(&target))) +} + +fn get_method_target( + ctx: &AssistContext<'_>, + target_module: &Module, + impl_: &Option, +) -> Option<(GeneratedFunctionTarget, TextSize)> { + let target = match impl_ { + Some(impl_) => next_space_for_fn_in_impl(impl_)?, + None => { + next_space_for_fn_in_module(ctx.sema.db, &target_module.definition_source(ctx.sema.db))? + .1 + } + }; + Some((target.clone(), get_insert_offset(&target))) +} + +fn get_insert_offset(target: &GeneratedFunctionTarget) -> TextSize { + match &target { + GeneratedFunctionTarget::BehindItem(it) => it.text_range().end(), + GeneratedFunctionTarget::InEmptyItemList(it) => it.text_range().start() + TextSize::of('{'), + } +} + +#[derive(Clone)] +enum GeneratedFunctionTarget { + BehindItem(SyntaxNode), + InEmptyItemList(SyntaxNode), +} + +impl GeneratedFunctionTarget { + fn syntax(&self) -> &SyntaxNode { + match self { + GeneratedFunctionTarget::BehindItem(it) => it, + GeneratedFunctionTarget::InEmptyItemList(it) => it, + } + } +} + +/// Computes the type variables and arguments required for the generated function +fn fn_args( + ctx: &AssistContext<'_>, + target_module: hir::Module, + call: ast::CallableExpr, +) -> Option<(Option, ast::ParamList)> { + let mut arg_names = Vec::new(); + let mut arg_types = Vec::new(); + for arg in call.arg_list()?.args() { + arg_names.push(fn_arg_name(&ctx.sema, &arg)); + arg_types.push(fn_arg_type(ctx, target_module, &arg)); + } + deduplicate_arg_names(&mut arg_names); + let params = arg_names.into_iter().zip(arg_types).map(|(name, ty)| { + make::param(make::ext::simple_ident_pat(make::name(&name)).into(), make::ty(&ty)) + }); + + Some(( + None, + make::param_list( + match call { + ast::CallableExpr::Call(_) => None, + ast::CallableExpr::MethodCall(_) => Some(make::self_param()), + }, + params, + ), + )) +} + +/// Makes duplicate argument names unique by appending incrementing numbers. +/// +/// ``` +/// let mut names: Vec = +/// vec!["foo".into(), "foo".into(), "bar".into(), "baz".into(), "bar".into()]; +/// deduplicate_arg_names(&mut names); +/// let expected: Vec = +/// vec!["foo_1".into(), "foo_2".into(), "bar_1".into(), "baz".into(), "bar_2".into()]; +/// assert_eq!(names, expected); +/// ``` +fn deduplicate_arg_names(arg_names: &mut Vec) { + let mut arg_name_counts = FxHashMap::default(); + for name in arg_names.iter() { + *arg_name_counts.entry(name).or_insert(0) += 1; + } + let duplicate_arg_names: FxHashSet = arg_name_counts + .into_iter() + .filter(|(_, count)| *count >= 2) + .map(|(name, _)| name.clone()) + .collect(); + + let mut counter_per_name = FxHashMap::default(); + for arg_name in arg_names.iter_mut() { + if duplicate_arg_names.contains(arg_name) { + let counter = counter_per_name.entry(arg_name.clone()).or_insert(1); + arg_name.push('_'); + arg_name.push_str(&counter.to_string()); + *counter += 1; + } + } +} + +fn fn_arg_name(sema: &Semantics<'_, RootDatabase>, arg_expr: &ast::Expr) -> String { + let name = (|| match arg_expr { + ast::Expr::CastExpr(cast_expr) => Some(fn_arg_name(sema, &cast_expr.expr()?)), + expr => { + let name_ref = expr + .syntax() + .descendants() + .filter_map(ast::NameRef::cast) + .filter(|name| name.ident_token().is_some()) + .last()?; + if let Some(NameRefClass::Definition(Definition::Const(_) | Definition::Static(_))) = + NameRefClass::classify(sema, &name_ref) + { + return Some(name_ref.to_string().to_lowercase()); + }; + Some(to_lower_snake_case(&name_ref.to_string())) + } + })(); + match name { + Some(mut name) if name.starts_with(|c: char| c.is_ascii_digit()) => { + name.insert_str(0, "arg"); + name + } + Some(name) => name, + None => "arg".to_string(), + } +} + +fn fn_arg_type(ctx: &AssistContext<'_>, target_module: hir::Module, fn_arg: &ast::Expr) -> String { + fn maybe_displayed_type( + ctx: &AssistContext<'_>, + target_module: hir::Module, + fn_arg: &ast::Expr, + ) -> Option { + let ty = ctx.sema.type_of_expr(fn_arg)?.adjusted(); + if ty.is_unknown() { + return None; + } + + if ty.is_reference() || ty.is_mutable_reference() { + let famous_defs = &FamousDefs(&ctx.sema, ctx.sema.scope(fn_arg.syntax())?.krate()); + convert_reference_type(ty.strip_references(), ctx.db(), famous_defs) + .map(|conversion| conversion.convert_type(ctx.db())) + .or_else(|| ty.display_source_code(ctx.db(), target_module.into()).ok()) + } else { + ty.display_source_code(ctx.db(), target_module.into()).ok() + } + } + + maybe_displayed_type(ctx, target_module, fn_arg).unwrap_or_else(|| String::from("_")) +} + +/// Returns the position inside the current mod or file +/// directly after the current block +/// We want to write the generated function directly after +/// fns, impls or macro calls, but inside mods +fn next_space_for_fn_after_call_site(expr: ast::CallableExpr) -> Option { + let mut ancestors = expr.syntax().ancestors().peekable(); + let mut last_ancestor: Option = None; + while let Some(next_ancestor) = ancestors.next() { + match next_ancestor.kind() { + SyntaxKind::SOURCE_FILE => { + break; + } + SyntaxKind::ITEM_LIST => { + if ancestors.peek().map(|a| a.kind()) == Some(SyntaxKind::MODULE) { + break; + } + } + _ => {} + } + last_ancestor = Some(next_ancestor); + } + last_ancestor.map(GeneratedFunctionTarget::BehindItem) +} + +fn next_space_for_fn_in_module( + db: &dyn hir::db::AstDatabase, + module_source: &hir::InFile, +) -> Option<(FileId, GeneratedFunctionTarget)> { + let file = module_source.file_id.original_file(db); + let assist_item = match &module_source.value { + hir::ModuleSource::SourceFile(it) => match it.items().last() { + Some(last_item) => GeneratedFunctionTarget::BehindItem(last_item.syntax().clone()), + None => GeneratedFunctionTarget::BehindItem(it.syntax().clone()), + }, + hir::ModuleSource::Module(it) => match it.item_list().and_then(|it| it.items().last()) { + Some(last_item) => GeneratedFunctionTarget::BehindItem(last_item.syntax().clone()), + None => GeneratedFunctionTarget::InEmptyItemList(it.item_list()?.syntax().clone()), + }, + hir::ModuleSource::BlockExpr(it) => { + if let Some(last_item) = + it.statements().take_while(|stmt| matches!(stmt, ast::Stmt::Item(_))).last() + { + GeneratedFunctionTarget::BehindItem(last_item.syntax().clone()) + } else { + GeneratedFunctionTarget::InEmptyItemList(it.syntax().clone()) + } + } + }; + Some((file, assist_item)) +} + +fn next_space_for_fn_in_impl(impl_: &ast::Impl) -> Option { + if let Some(last_item) = impl_.assoc_item_list().and_then(|it| it.assoc_items().last()) { + Some(GeneratedFunctionTarget::BehindItem(last_item.syntax().clone())) + } else { + Some(GeneratedFunctionTarget::InEmptyItemList(impl_.assoc_item_list()?.syntax().clone())) + } +} + +fn module_is_descendant(module: &hir::Module, ans: &hir::Module, ctx: &AssistContext<'_>) -> bool { + if module == ans { + return true; + } + for c in ans.children(ctx.sema.db) { + if module_is_descendant(module, &c, ctx) { + return true; + } + } + false +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn add_function_with_no_args() { + check_assist( + generate_function, + r" +fn foo() { + bar$0(); +} +", + r" +fn foo() { + bar(); +} + +fn bar() ${0:-> _} { + todo!() +} +", + ) + } + + #[test] + fn add_function_from_method() { + // This ensures that the function is correctly generated + // in the next outer mod or file + check_assist( + generate_function, + r" +impl Foo { + fn foo() { + bar$0(); + } +} +", + r" +impl Foo { + fn foo() { + bar(); + } +} + +fn bar() ${0:-> _} { + todo!() +} +", + ) + } + + #[test] + fn add_function_directly_after_current_block() { + // The new fn should not be created at the end of the file or module + check_assist( + generate_function, + r" +fn foo1() { + bar$0(); +} + +fn foo2() {} +", + r" +fn foo1() { + bar(); +} + +fn bar() ${0:-> _} { + todo!() +} + +fn foo2() {} +", + ) + } + + #[test] + fn add_function_with_no_args_in_same_module() { + check_assist( + generate_function, + r" +mod baz { + fn foo() { + bar$0(); + } +} +", + r" +mod baz { + fn foo() { + bar(); + } + + fn bar() ${0:-> _} { + todo!() + } +} +", + ) + } + + #[test] + fn add_function_with_upper_camel_case_arg() { + check_assist( + generate_function, + r" +struct BazBaz; +fn foo() { + bar$0(BazBaz); +} +", + r" +struct BazBaz; +fn foo() { + bar(BazBaz); +} + +fn bar(baz_baz: BazBaz) ${0:-> _} { + todo!() +} +", + ); + } + + #[test] + fn add_function_with_upper_camel_case_arg_as_cast() { + check_assist( + generate_function, + r" +struct BazBaz; +fn foo() { + bar$0(&BazBaz as *const BazBaz); +} +", + r" +struct BazBaz; +fn foo() { + bar(&BazBaz as *const BazBaz); +} + +fn bar(baz_baz: *const BazBaz) ${0:-> _} { + todo!() +} +", + ); + } + + #[test] + fn add_function_with_function_call_arg() { + check_assist( + generate_function, + r" +struct Baz; +fn baz() -> Baz { todo!() } +fn foo() { + bar$0(baz()); +} +", + r" +struct Baz; +fn baz() -> Baz { todo!() } +fn foo() { + bar(baz()); +} + +fn bar(baz: Baz) ${0:-> _} { + todo!() +} +", + ); + } + + #[test] + fn add_function_with_method_call_arg() { + check_assist( + generate_function, + r" +struct Baz; +impl Baz { + fn foo(&self) -> Baz { + ba$0r(self.baz()) + } + fn baz(&self) -> Baz { + Baz + } +} +", + r" +struct Baz; +impl Baz { + fn foo(&self) -> Baz { + bar(self.baz()) + } + fn baz(&self) -> Baz { + Baz + } +} + +fn bar(baz: Baz) -> Baz { + ${0:todo!()} +} +", + ) + } + + #[test] + fn add_function_with_string_literal_arg() { + check_assist( + generate_function, + r#" +fn foo() { + $0bar("bar") +} +"#, + r#" +fn foo() { + bar("bar") +} + +fn bar(arg: &str) { + ${0:todo!()} +} +"#, + ) + } + + #[test] + fn add_function_with_char_literal_arg() { + check_assist( + generate_function, + r#" +fn foo() { + $0bar('x') +} +"#, + r#" +fn foo() { + bar('x') +} + +fn bar(arg: char) { + ${0:todo!()} +} +"#, + ) + } + + #[test] + fn add_function_with_int_literal_arg() { + check_assist( + generate_function, + r" +fn foo() { + $0bar(42) +} +", + r" +fn foo() { + bar(42) +} + +fn bar(arg: i32) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn add_function_with_cast_int_literal_arg() { + check_assist( + generate_function, + r" +fn foo() { + $0bar(42 as u8) +} +", + r" +fn foo() { + bar(42 as u8) +} + +fn bar(arg: u8) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn name_of_cast_variable_is_used() { + // Ensures that the name of the cast type isn't used + // in the generated function signature. + check_assist( + generate_function, + r" +fn foo() { + let x = 42; + bar$0(x as u8) +} +", + r" +fn foo() { + let x = 42; + bar(x as u8) +} + +fn bar(x: u8) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn add_function_with_variable_arg() { + check_assist( + generate_function, + r" +fn foo() { + let worble = (); + $0bar(worble) +} +", + r" +fn foo() { + let worble = (); + bar(worble) +} + +fn bar(worble: ()) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn add_function_with_impl_trait_arg() { + check_assist( + generate_function, + r#" +//- minicore: sized +trait Foo {} +fn foo() -> impl Foo { + todo!() +} +fn baz() { + $0bar(foo()) +} +"#, + r#" +trait Foo {} +fn foo() -> impl Foo { + todo!() +} +fn baz() { + bar(foo()) +} + +fn bar(foo: impl Foo) { + ${0:todo!()} +} +"#, + ) + } + + #[test] + fn borrowed_arg() { + check_assist( + generate_function, + r" +struct Baz; +fn baz() -> Baz { todo!() } + +fn foo() { + bar$0(&baz()) +} +", + r" +struct Baz; +fn baz() -> Baz { todo!() } + +fn foo() { + bar(&baz()) +} + +fn bar(baz: &Baz) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn add_function_with_qualified_path_arg() { + check_assist( + generate_function, + r" +mod Baz { + pub struct Bof; + pub fn baz() -> Bof { Bof } +} +fn foo() { + $0bar(Baz::baz()) +} +", + r" +mod Baz { + pub struct Bof; + pub fn baz() -> Bof { Bof } +} +fn foo() { + bar(Baz::baz()) +} + +fn bar(baz: Baz::Bof) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn add_function_with_generic_arg() { + // FIXME: This is wrong, generated `bar` should include generic parameter. + check_assist( + generate_function, + r" +fn foo(t: T) { + $0bar(t) +} +", + r" +fn foo(t: T) { + bar(t) +} + +fn bar(t: T) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn add_function_with_fn_arg() { + // FIXME: The argument in `bar` is wrong. + check_assist( + generate_function, + r" +struct Baz; +impl Baz { + fn new() -> Self { Baz } +} +fn foo() { + $0bar(Baz::new); +} +", + r" +struct Baz; +impl Baz { + fn new() -> Self { Baz } +} +fn foo() { + bar(Baz::new); +} + +fn bar(new: fn) ${0:-> _} { + todo!() +} +", + ) + } + + #[test] + fn add_function_with_closure_arg() { + // FIXME: The argument in `bar` is wrong. + check_assist( + generate_function, + r" +fn foo() { + let closure = |x: i64| x - 1; + $0bar(closure) +} +", + r" +fn foo() { + let closure = |x: i64| x - 1; + bar(closure) +} + +fn bar(closure: _) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn unresolveable_types_default_to_placeholder() { + check_assist( + generate_function, + r" +fn foo() { + $0bar(baz) +} +", + r" +fn foo() { + bar(baz) +} + +fn bar(baz: _) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn arg_names_dont_overlap() { + check_assist( + generate_function, + r" +struct Baz; +fn baz() -> Baz { Baz } +fn foo() { + $0bar(baz(), baz()) +} +", + r" +struct Baz; +fn baz() -> Baz { Baz } +fn foo() { + bar(baz(), baz()) +} + +fn bar(baz_1: Baz, baz_2: Baz) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn arg_name_counters_start_at_1_per_name() { + check_assist( + generate_function, + r#" +struct Baz; +fn baz() -> Baz { Baz } +fn foo() { + $0bar(baz(), baz(), "foo", "bar") +} +"#, + r#" +struct Baz; +fn baz() -> Baz { Baz } +fn foo() { + bar(baz(), baz(), "foo", "bar") +} + +fn bar(baz_1: Baz, baz_2: Baz, arg_1: &str, arg_2: &str) { + ${0:todo!()} +} +"#, + ) + } + + #[test] + fn add_function_in_module() { + check_assist( + generate_function, + r" +mod bar {} + +fn foo() { + bar::my_fn$0() +} +", + r" +mod bar { + pub(crate) fn my_fn() { + ${0:todo!()} + } +} + +fn foo() { + bar::my_fn() +} +", + ) + } + + #[test] + fn qualified_path_uses_correct_scope() { + check_assist( + generate_function, + r#" +mod foo { + pub struct Foo; +} +fn bar() { + use foo::Foo; + let foo = Foo; + baz$0(foo) +} +"#, + r#" +mod foo { + pub struct Foo; +} +fn bar() { + use foo::Foo; + let foo = Foo; + baz(foo) +} + +fn baz(foo: foo::Foo) { + ${0:todo!()} +} +"#, + ) + } + + #[test] + fn add_function_in_module_containing_other_items() { + check_assist( + generate_function, + r" +mod bar { + fn something_else() {} +} + +fn foo() { + bar::my_fn$0() +} +", + r" +mod bar { + fn something_else() {} + + pub(crate) fn my_fn() { + ${0:todo!()} + } +} + +fn foo() { + bar::my_fn() +} +", + ) + } + + #[test] + fn add_function_in_nested_module() { + check_assist( + generate_function, + r" +mod bar { + mod baz {} +} + +fn foo() { + bar::baz::my_fn$0() +} +", + r" +mod bar { + mod baz { + pub(crate) fn my_fn() { + ${0:todo!()} + } + } +} + +fn foo() { + bar::baz::my_fn() +} +", + ) + } + + #[test] + fn add_function_in_another_file() { + check_assist( + generate_function, + r" +//- /main.rs +mod foo; + +fn main() { + foo::bar$0() +} +//- /foo.rs +", + r" + + +pub(crate) fn bar() { + ${0:todo!()} +}", + ) + } + + #[test] + fn add_function_with_return_type() { + check_assist( + generate_function, + r" +fn main() { + let x: u32 = foo$0(); +} +", + r" +fn main() { + let x: u32 = foo(); +} + +fn foo() -> u32 { + ${0:todo!()} +} +", + ) + } + + #[test] + fn add_function_not_applicable_if_function_already_exists() { + check_assist_not_applicable( + generate_function, + r" +fn foo() { + bar$0(); +} + +fn bar() {} +", + ) + } + + #[test] + fn add_function_not_applicable_if_unresolved_variable_in_call_is_selected() { + check_assist_not_applicable( + // bar is resolved, but baz isn't. + // The assist is only active if the cursor is on an unresolved path, + // but the assist should only be offered if the path is a function call. + generate_function, + r#" +fn foo() { + bar(b$0az); +} + +fn bar(baz: ()) {} +"#, + ) + } + + #[test] + fn create_method_with_no_args() { + check_assist( + generate_function, + r#" +struct Foo; +impl Foo { + fn foo(&self) { + self.bar()$0; + } +} +"#, + r#" +struct Foo; +impl Foo { + fn foo(&self) { + self.bar(); + } + + fn bar(&self) ${0:-> _} { + todo!() + } +} +"#, + ) + } + + #[test] + fn create_function_with_async() { + check_assist( + generate_function, + r" +fn foo() { + $0bar(42).await(); +} +", + r" +fn foo() { + bar(42).await(); +} + +async fn bar(arg: i32) ${0:-> _} { + todo!() +} +", + ) + } + + #[test] + fn create_method() { + check_assist( + generate_function, + r" +struct S; +fn foo() {S.bar$0();} +", + r" +struct S; +fn foo() {S.bar();} +impl S { + + +fn bar(&self) ${0:-> _} { + todo!() +} +} +", + ) + } + + #[test] + fn create_method_within_an_impl() { + check_assist( + generate_function, + r" +struct S; +fn foo() {S.bar$0();} +impl S {} + +", + r" +struct S; +fn foo() {S.bar();} +impl S { + fn bar(&self) ${0:-> _} { + todo!() + } +} + +", + ) + } + + #[test] + fn create_method_from_different_module() { + check_assist( + generate_function, + r" +mod s { + pub struct S; +} +fn foo() {s::S.bar$0();} +", + r" +mod s { + pub struct S; +impl S { + + + pub(crate) fn bar(&self) ${0:-> _} { + todo!() + } +} +} +fn foo() {s::S.bar();} +", + ) + } + + #[test] + fn create_method_from_descendant_module() { + check_assist( + generate_function, + r" +struct S; +mod s { + fn foo() { + super::S.bar$0(); + } +} + +", + r" +struct S; +mod s { + fn foo() { + super::S.bar(); + } +} +impl S { + + +fn bar(&self) ${0:-> _} { + todo!() +} +} + +", + ) + } + + #[test] + fn create_method_with_cursor_anywhere_on_call_expresion() { + check_assist( + generate_function, + r" +struct S; +fn foo() {$0S.bar();} +", + r" +struct S; +fn foo() {S.bar();} +impl S { + + +fn bar(&self) ${0:-> _} { + todo!() +} +} +", + ) + } + + #[test] + fn create_static_method() { + check_assist( + generate_function, + r" +struct S; +fn foo() {S::bar$0();} +", + r" +struct S; +fn foo() {S::bar();} +impl S { + + +fn bar() ${0:-> _} { + todo!() +} +} +", + ) + } + + #[test] + fn create_static_method_within_an_impl() { + check_assist( + generate_function, + r" +struct S; +fn foo() {S::bar$0();} +impl S {} + +", + r" +struct S; +fn foo() {S::bar();} +impl S { + fn bar() ${0:-> _} { + todo!() + } +} + +", + ) + } + + #[test] + fn create_static_method_from_different_module() { + check_assist( + generate_function, + r" +mod s { + pub struct S; +} +fn foo() {s::S::bar$0();} +", + r" +mod s { + pub struct S; +impl S { + + + pub(crate) fn bar() ${0:-> _} { + todo!() + } +} +} +fn foo() {s::S::bar();} +", + ) + } + + #[test] + fn create_static_method_with_cursor_anywhere_on_call_expresion() { + check_assist( + generate_function, + r" +struct S; +fn foo() {$0S::bar();} +", + r" +struct S; +fn foo() {S::bar();} +impl S { + + +fn bar() ${0:-> _} { + todo!() +} +} +", + ) + } + + #[test] + fn no_panic_on_invalid_global_path() { + check_assist( + generate_function, + r" +fn main() { + ::foo$0(); +} +", + r" +fn main() { + ::foo(); +} + +fn foo() ${0:-> _} { + todo!() +} +", + ) + } + + #[test] + fn handle_tuple_indexing() { + check_assist( + generate_function, + r" +fn main() { + let a = ((),); + foo$0(a.0); +} +", + r" +fn main() { + let a = ((),); + foo(a.0); +} + +fn foo(a: ()) ${0:-> _} { + todo!() +} +", + ) + } + + #[test] + fn add_function_with_const_arg() { + check_assist( + generate_function, + r" +const VALUE: usize = 0; +fn main() { + foo$0(VALUE); +} +", + r" +const VALUE: usize = 0; +fn main() { + foo(VALUE); +} + +fn foo(value: usize) ${0:-> _} { + todo!() +} +", + ) + } + + #[test] + fn add_function_with_static_arg() { + check_assist( + generate_function, + r" +static VALUE: usize = 0; +fn main() { + foo$0(VALUE); +} +", + r" +static VALUE: usize = 0; +fn main() { + foo(VALUE); +} + +fn foo(value: usize) ${0:-> _} { + todo!() +} +", + ) + } + + #[test] + fn add_function_with_static_mut_arg() { + check_assist( + generate_function, + r" +static mut VALUE: usize = 0; +fn main() { + foo$0(VALUE); +} +", + r" +static mut VALUE: usize = 0; +fn main() { + foo(VALUE); +} + +fn foo(value: usize) ${0:-> _} { + todo!() +} +", + ) + } + + #[test] + fn not_applicable_for_enum_variant() { + check_assist_not_applicable( + generate_function, + r" +enum Foo {} +fn main() { + Foo::Bar$0(true) +} +", + ); + } + + #[test] + fn applicable_for_enum_method() { + check_assist( + generate_function, + r" +enum Foo {} +fn main() { + Foo::new$0(); +} +", + r" +enum Foo {} +fn main() { + Foo::new(); +} +impl Foo { + + +fn new() ${0:-> _} { + todo!() +} +} +", + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_getter.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_getter.rs new file mode 100644 index 000000000..76fcef0ca --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_getter.rs @@ -0,0 +1,492 @@ +use ide_db::famous_defs::FamousDefs; +use stdx::{format_to, to_lower_snake_case}; +use syntax::ast::{self, AstNode, HasName, HasVisibility}; + +use crate::{ + utils::{convert_reference_type, find_impl_block_end, find_struct_impl, generate_impl_text}, + AssistContext, AssistId, AssistKind, Assists, GroupLabel, +}; + +// Assist: generate_getter +// +// Generate a getter method. +// +// ``` +// # //- minicore: as_ref +// # pub struct String; +// # impl AsRef for String { +// # fn as_ref(&self) -> &str { +// # "" +// # } +// # } +// # +// struct Person { +// nam$0e: String, +// } +// ``` +// -> +// ``` +// # pub struct String; +// # impl AsRef for String { +// # fn as_ref(&self) -> &str { +// # "" +// # } +// # } +// # +// struct Person { +// name: String, +// } +// +// impl Person { +// fn $0name(&self) -> &str { +// self.name.as_ref() +// } +// } +// ``` +pub(crate) fn generate_getter(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + generate_getter_impl(acc, ctx, false) +} + +// Assist: generate_getter_mut +// +// Generate a mut getter method. +// +// ``` +// struct Person { +// nam$0e: String, +// } +// ``` +// -> +// ``` +// struct Person { +// name: String, +// } +// +// impl Person { +// fn $0name_mut(&mut self) -> &mut String { +// &mut self.name +// } +// } +// ``` +pub(crate) fn generate_getter_mut(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + generate_getter_impl(acc, ctx, true) +} + +pub(crate) fn generate_getter_impl( + acc: &mut Assists, + ctx: &AssistContext<'_>, + mutable: bool, +) -> Option<()> { + let strukt = ctx.find_node_at_offset::()?; + let field = ctx.find_node_at_offset::()?; + + let field_name = field.name()?; + let field_ty = field.ty()?; + + // Return early if we've found an existing fn + let mut fn_name = to_lower_snake_case(&field_name.to_string()); + if mutable { + format_to!(fn_name, "_mut"); + } + let impl_def = find_struct_impl(ctx, &ast::Adt::Struct(strukt.clone()), fn_name.as_str())?; + + let (id, label) = if mutable { + ("generate_getter_mut", "Generate a mut getter method") + } else { + ("generate_getter", "Generate a getter method") + }; + let target = field.syntax().text_range(); + acc.add_group( + &GroupLabel("Generate getter/setter".to_owned()), + AssistId(id, AssistKind::Generate), + label, + target, + |builder| { + let mut buf = String::with_capacity(512); + + if impl_def.is_some() { + buf.push('\n'); + } + + let vis = strukt.visibility().map_or(String::new(), |v| format!("{} ", v)); + let (ty, body) = if mutable { + (format!("&mut {}", field_ty), format!("&mut self.{}", field_name)) + } else { + (|| { + let krate = ctx.sema.scope(field_ty.syntax())?.krate(); + let famous_defs = &FamousDefs(&ctx.sema, krate); + ctx.sema + .resolve_type(&field_ty) + .and_then(|ty| convert_reference_type(ty, ctx.db(), famous_defs)) + .map(|conversion| { + cov_mark::hit!(convert_reference_type); + ( + conversion.convert_type(ctx.db()), + conversion.getter(field_name.to_string()), + ) + }) + })() + .unwrap_or_else(|| (format!("&{}", field_ty), format!("&self.{}", field_name))) + }; + + format_to!( + buf, + " {}fn {}(&{}self) -> {} {{ + {} + }}", + vis, + fn_name, + mutable.then(|| "mut ").unwrap_or_default(), + ty, + body, + ); + + let start_offset = impl_def + .and_then(|impl_def| find_impl_block_end(impl_def, &mut buf)) + .unwrap_or_else(|| { + buf = generate_impl_text(&ast::Adt::Struct(strukt.clone()), &buf); + strukt.syntax().text_range().end() + }); + + match ctx.config.snippet_cap { + Some(cap) => { + builder.insert_snippet(cap, start_offset, buf.replacen("fn ", "fn $0", 1)) + } + None => builder.insert(start_offset, buf), + } + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_generate_getter_from_field() { + check_assist( + generate_getter, + r#" +struct Context { + dat$0a: Data, +} +"#, + r#" +struct Context { + data: Data, +} + +impl Context { + fn $0data(&self) -> &Data { + &self.data + } +} +"#, + ); + + check_assist( + generate_getter_mut, + r#" +struct Context { + dat$0a: Data, +} +"#, + r#" +struct Context { + data: Data, +} + +impl Context { + fn $0data_mut(&mut self) -> &mut Data { + &mut self.data + } +} +"#, + ); + } + + #[test] + fn test_generate_getter_already_implemented() { + check_assist_not_applicable( + generate_getter, + r#" +struct Context { + dat$0a: Data, +} + +impl Context { + fn data(&self) -> &Data { + &self.data + } +} +"#, + ); + + check_assist_not_applicable( + generate_getter_mut, + r#" +struct Context { + dat$0a: Data, +} + +impl Context { + fn data_mut(&mut self) -> &mut Data { + &mut self.data + } +} +"#, + ); + } + + #[test] + fn test_generate_getter_from_field_with_visibility_marker() { + check_assist( + generate_getter, + r#" +pub(crate) struct Context { + dat$0a: Data, +} +"#, + r#" +pub(crate) struct Context { + data: Data, +} + +impl Context { + pub(crate) fn $0data(&self) -> &Data { + &self.data + } +} +"#, + ); + } + + #[test] + fn test_multiple_generate_getter() { + check_assist( + generate_getter, + r#" +struct Context { + data: Data, + cou$0nt: usize, +} + +impl Context { + fn data(&self) -> &Data { + &self.data + } +} +"#, + r#" +struct Context { + data: Data, + count: usize, +} + +impl Context { + fn data(&self) -> &Data { + &self.data + } + + fn $0count(&self) -> &usize { + &self.count + } +} +"#, + ); + } + + #[test] + fn test_not_a_special_case() { + cov_mark::check_count!(convert_reference_type, 0); + // Fake string which doesn't implement AsRef + check_assist( + generate_getter, + r#" +pub struct String; + +struct S { foo: $0String } +"#, + r#" +pub struct String; + +struct S { foo: String } + +impl S { + fn $0foo(&self) -> &String { + &self.foo + } +} +"#, + ); + } + + #[test] + fn test_convert_reference_type() { + cov_mark::check_count!(convert_reference_type, 6); + + // Copy + check_assist( + generate_getter, + r#" +//- minicore: copy +struct S { foo: $0bool } +"#, + r#" +struct S { foo: bool } + +impl S { + fn $0foo(&self) -> bool { + self.foo + } +} +"#, + ); + + // AsRef + check_assist( + generate_getter, + r#" +//- minicore: as_ref +pub struct String; +impl AsRef for String { + fn as_ref(&self) -> &str { + "" + } +} + +struct S { foo: $0String } +"#, + r#" +pub struct String; +impl AsRef for String { + fn as_ref(&self) -> &str { + "" + } +} + +struct S { foo: String } + +impl S { + fn $0foo(&self) -> &str { + self.foo.as_ref() + } +} +"#, + ); + + // AsRef + check_assist( + generate_getter, + r#" +//- minicore: as_ref +struct Sweets; + +pub struct Box(T); +impl AsRef for Box { + fn as_ref(&self) -> &T { + &self.0 + } +} + +struct S { foo: $0Box } +"#, + r#" +struct Sweets; + +pub struct Box(T); +impl AsRef for Box { + fn as_ref(&self) -> &T { + &self.0 + } +} + +struct S { foo: Box } + +impl S { + fn $0foo(&self) -> &Sweets { + self.foo.as_ref() + } +} +"#, + ); + + // AsRef<[T]> + check_assist( + generate_getter, + r#" +//- minicore: as_ref +pub struct Vec; +impl AsRef<[T]> for Vec { + fn as_ref(&self) -> &[T] { + &[] + } +} + +struct S { foo: $0Vec<()> } +"#, + r#" +pub struct Vec; +impl AsRef<[T]> for Vec { + fn as_ref(&self) -> &[T] { + &[] + } +} + +struct S { foo: Vec<()> } + +impl S { + fn $0foo(&self) -> &[()] { + self.foo.as_ref() + } +} +"#, + ); + + // Option + check_assist( + generate_getter, + r#" +//- minicore: option +struct Failure; + +struct S { foo: $0Option } +"#, + r#" +struct Failure; + +struct S { foo: Option } + +impl S { + fn $0foo(&self) -> Option<&Failure> { + self.foo.as_ref() + } +} +"#, + ); + + // Result + check_assist( + generate_getter, + r#" +//- minicore: result +struct Context { + dat$0a: Result, +} +"#, + r#" +struct Context { + data: Result, +} + +impl Context { + fn $0data(&self) -> Result<&bool, &i32> { + self.data.as_ref() + } +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_impl.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_impl.rs new file mode 100644 index 000000000..68287a20b --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_impl.rs @@ -0,0 +1,177 @@ +use syntax::ast::{self, AstNode, HasName}; + +use crate::{utils::generate_impl_text, AssistContext, AssistId, AssistKind, Assists}; + +// Assist: generate_impl +// +// Adds a new inherent impl for a type. +// +// ``` +// struct Ctx { +// data: T,$0 +// } +// ``` +// -> +// ``` +// struct Ctx { +// data: T, +// } +// +// impl Ctx { +// $0 +// } +// ``` +pub(crate) fn generate_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let nominal = ctx.find_node_at_offset::()?; + let name = nominal.name()?; + let target = nominal.syntax().text_range(); + + acc.add( + AssistId("generate_impl", AssistKind::Generate), + format!("Generate impl for `{}`", name), + target, + |edit| { + let start_offset = nominal.syntax().text_range().end(); + match ctx.config.snippet_cap { + Some(cap) => { + let snippet = generate_impl_text(&nominal, " $0"); + edit.insert_snippet(cap, start_offset, snippet); + } + None => { + let snippet = generate_impl_text(&nominal, ""); + edit.insert(start_offset, snippet); + } + } + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_target}; + + use super::*; + + #[test] + fn test_add_impl() { + check_assist( + generate_impl, + "struct Foo {$0}\n", + "struct Foo {}\n\nimpl Foo {\n $0\n}\n", + ); + check_assist( + generate_impl, + "struct Foo {$0}", + "struct Foo {}\n\nimpl Foo {\n $0\n}", + ); + check_assist( + generate_impl, + "struct Foo<'a, T: Foo<'a>> {$0}", + "struct Foo<'a, T: Foo<'a>> {}\n\nimpl<'a, T: Foo<'a>> Foo<'a, T> {\n $0\n}", + ); + check_assist( + generate_impl, + r#" + struct MyOwnArray {}$0"#, + r#" + struct MyOwnArray {} + + impl MyOwnArray { + $0 + }"#, + ); + check_assist( + generate_impl, + r#" + #[cfg(feature = "foo")] + struct Foo<'a, T: Foo<'a>> {$0}"#, + r#" + #[cfg(feature = "foo")] + struct Foo<'a, T: Foo<'a>> {} + + #[cfg(feature = "foo")] + impl<'a, T: Foo<'a>> Foo<'a, T> { + $0 + }"#, + ); + + check_assist( + generate_impl, + r#" + #[cfg(not(feature = "foo"))] + struct Foo<'a, T: Foo<'a>> {$0}"#, + r#" + #[cfg(not(feature = "foo"))] + struct Foo<'a, T: Foo<'a>> {} + + #[cfg(not(feature = "foo"))] + impl<'a, T: Foo<'a>> Foo<'a, T> { + $0 + }"#, + ); + + check_assist( + generate_impl, + r#" + struct Defaulted {}$0"#, + r#" + struct Defaulted {} + + impl Defaulted { + $0 + }"#, + ); + + check_assist( + generate_impl, + r#" + struct Defaulted<'a, 'b: 'a, T: Debug + Clone + 'a + 'b = String, const S: usize> {}$0"#, + r#" + struct Defaulted<'a, 'b: 'a, T: Debug + Clone + 'a + 'b = String, const S: usize> {} + + impl<'a, 'b: 'a, T: Debug + Clone + 'a + 'b, const S: usize> Defaulted<'a, 'b, T, S> { + $0 + }"#, + ); + + check_assist( + generate_impl, + r#"pub trait Trait {} +struct Struct$0 +where + T: Trait, +{ + inner: T, +}"#, + r#"pub trait Trait {} +struct Struct +where + T: Trait, +{ + inner: T, +} + +impl Struct +where + T: Trait, +{ + $0 +}"#, + ); + } + + #[test] + fn add_impl_target() { + check_assist_target( + generate_impl, + " +struct SomeThingIrrelevant; +/// Has a lifetime parameter +struct Foo<'a, T: Foo<'a>> {$0} +struct EvenMoreIrrelevant; +", + "/// Has a lifetime parameter +struct Foo<'a, T: Foo<'a>> {}", + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_is_empty_from_len.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_is_empty_from_len.rs new file mode 100644 index 000000000..9ce525ca3 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_is_empty_from_len.rs @@ -0,0 +1,295 @@ +use hir::{known, HasSource, Name}; +use syntax::{ + ast::{self, HasName}, + AstNode, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, AssistKind, +}; + +// Assist: generate_is_empty_from_len +// +// Generates is_empty implementation from the len method. +// +// ``` +// struct MyStruct { data: Vec } +// +// impl MyStruct { +// #[must_use] +// p$0ub fn len(&self) -> usize { +// self.data.len() +// } +// } +// ``` +// -> +// ``` +// struct MyStruct { data: Vec } +// +// impl MyStruct { +// #[must_use] +// pub fn len(&self) -> usize { +// self.data.len() +// } +// +// #[must_use] +// pub fn is_empty(&self) -> bool { +// self.len() == 0 +// } +// } +// ``` +pub(crate) fn generate_is_empty_from_len(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let fn_node = ctx.find_node_at_offset::()?; + let fn_name = fn_node.name()?; + + if fn_name.text() != "len" { + cov_mark::hit!(len_function_not_present); + return None; + } + + if fn_node.param_list()?.params().next().is_some() { + cov_mark::hit!(len_function_with_parameters); + return None; + } + + let impl_ = fn_node.syntax().ancestors().find_map(ast::Impl::cast)?; + let len_fn = get_impl_method(ctx, &impl_, &known::len)?; + if !len_fn.ret_type(ctx.sema.db).is_usize() { + cov_mark::hit!(len_fn_different_return_type); + return None; + } + + if get_impl_method(ctx, &impl_, &known::is_empty).is_some() { + cov_mark::hit!(is_empty_already_implemented); + return None; + } + + let node = len_fn.source(ctx.sema.db)?; + let range = node.syntax().value.text_range(); + + acc.add( + AssistId("generate_is_empty_from_len", AssistKind::Generate), + "Generate a is_empty impl from a len function", + range, + |builder| { + let code = r#" + + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + }"# + .to_string(); + builder.insert(range.end(), code) + }, + ) +} + +fn get_impl_method( + ctx: &AssistContext<'_>, + impl_: &ast::Impl, + fn_name: &Name, +) -> Option { + let db = ctx.sema.db; + let impl_def: hir::Impl = ctx.sema.to_def(impl_)?; + + let scope = ctx.sema.scope(impl_.syntax())?; + let ty = impl_def.self_ty(db); + ty.iterate_method_candidates( + db, + &scope, + &scope.visible_traits().0, + None, + Some(fn_name), + |func| Some(func), + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn len_function_not_present() { + cov_mark::check!(len_function_not_present); + check_assist_not_applicable( + generate_is_empty_from_len, + r#" +struct MyStruct { data: Vec } + +impl MyStruct { + p$0ub fn test(&self) -> usize { + self.data.len() + } + } +"#, + ); + } + + #[test] + fn len_function_with_parameters() { + cov_mark::check!(len_function_with_parameters); + check_assist_not_applicable( + generate_is_empty_from_len, + r#" +struct MyStruct { data: Vec } + +impl MyStruct { + #[must_use] + p$0ub fn len(&self, _i: bool) -> usize { + self.data.len() + } +} +"#, + ); + } + + #[test] + fn is_empty_already_implemented() { + cov_mark::check!(is_empty_already_implemented); + check_assist_not_applicable( + generate_is_empty_from_len, + r#" +struct MyStruct { data: Vec } + +impl MyStruct { + #[must_use] + p$0ub fn len(&self) -> usize { + self.data.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} +"#, + ); + } + + #[test] + fn len_fn_different_return_type() { + cov_mark::check!(len_fn_different_return_type); + check_assist_not_applicable( + generate_is_empty_from_len, + r#" +struct MyStruct { data: Vec } + +impl MyStruct { + #[must_use] + p$0ub fn len(&self) -> u32 { + self.data.len() + } +} +"#, + ); + } + + #[test] + fn generate_is_empty() { + check_assist( + generate_is_empty_from_len, + r#" +struct MyStruct { data: Vec } + +impl MyStruct { + #[must_use] + p$0ub fn len(&self) -> usize { + self.data.len() + } +} +"#, + r#" +struct MyStruct { data: Vec } + +impl MyStruct { + #[must_use] + pub fn len(&self) -> usize { + self.data.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} +"#, + ); + } + + #[test] + fn multiple_functions_in_impl() { + check_assist( + generate_is_empty_from_len, + r#" +struct MyStruct { data: Vec } + +impl MyStruct { + #[must_use] + pub fn new() -> Self { + Self { data: 0 } + } + + #[must_use] + p$0ub fn len(&self) -> usize { + self.data.len() + } + + pub fn work(&self) -> Option { + + } +} +"#, + r#" +struct MyStruct { data: Vec } + +impl MyStruct { + #[must_use] + pub fn new() -> Self { + Self { data: 0 } + } + + #[must_use] + pub fn len(&self) -> usize { + self.data.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn work(&self) -> Option { + + } +} +"#, + ); + } + + #[test] + fn multiple_impls() { + check_assist_not_applicable( + generate_is_empty_from_len, + r#" +struct MyStruct { data: Vec } + +impl MyStruct { + #[must_use] + p$0ub fn len(&self) -> usize { + self.data.len() + } +} + +impl MyStruct { + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_new.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_new.rs new file mode 100644 index 000000000..6c93875e9 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_new.rs @@ -0,0 +1,495 @@ +use ide_db::{ + imports::import_assets::item_for_path_search, use_trivial_contructor::use_trivial_constructor, +}; +use itertools::Itertools; +use stdx::format_to; +use syntax::ast::{self, AstNode, HasName, HasVisibility, StructKind}; + +use crate::{ + utils::{find_impl_block_start, find_struct_impl, generate_impl_text}, + AssistContext, AssistId, AssistKind, Assists, +}; + +// Assist: generate_new +// +// Adds a `fn new` for a type. +// +// ``` +// struct Ctx { +// data: T,$0 +// } +// ``` +// -> +// ``` +// struct Ctx { +// data: T, +// } +// +// impl Ctx { +// fn $0new(data: T) -> Self { Self { data } } +// } +// ``` +pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let strukt = ctx.find_node_at_offset::()?; + + // We want to only apply this to non-union structs with named fields + let field_list = match strukt.kind() { + StructKind::Record(named) => named, + _ => return None, + }; + + // Return early if we've found an existing new fn + let impl_def = find_struct_impl(ctx, &ast::Adt::Struct(strukt.clone()), "new")?; + + let current_module = ctx.sema.scope(strukt.syntax())?.module(); + + let target = strukt.syntax().text_range(); + acc.add(AssistId("generate_new", AssistKind::Generate), "Generate `new`", target, |builder| { + let mut buf = String::with_capacity(512); + + if impl_def.is_some() { + buf.push('\n'); + } + + let vis = strukt.visibility().map_or(String::new(), |v| format!("{} ", v)); + + let trivial_constructors = field_list + .fields() + .map(|f| { + let ty = ctx.sema.resolve_type(&f.ty()?)?; + + let item_in_ns = hir::ItemInNs::from(hir::ModuleDef::from(ty.as_adt()?)); + + let type_path = current_module + .find_use_path(ctx.sema.db, item_for_path_search(ctx.sema.db, item_in_ns)?)?; + + let expr = use_trivial_constructor( + &ctx.sema.db, + ide_db::helpers::mod_path_to_ast(&type_path), + &ty, + )?; + + Some(format!("{}: {}", f.name()?.syntax(), expr)) + }) + .collect::>(); + + let params = field_list + .fields() + .enumerate() + .filter_map(|(i, f)| { + if trivial_constructors[i].is_none() { + Some(format!("{}: {}", f.name()?.syntax(), f.ty()?.syntax())) + } else { + None + } + }) + .format(", "); + + let fields = field_list + .fields() + .enumerate() + .filter_map(|(i, f)| { + let contructor = trivial_constructors[i].clone(); + if contructor.is_some() { + contructor + } else { + Some(f.name()?.to_string()) + } + }) + .format(", "); + + format_to!(buf, " {}fn new({}) -> Self {{ Self {{ {} }} }}", vis, params, fields); + + let start_offset = impl_def + .and_then(|impl_def| find_impl_block_start(impl_def, &mut buf)) + .unwrap_or_else(|| { + buf = generate_impl_text(&ast::Adt::Struct(strukt.clone()), &buf); + strukt.syntax().text_range().end() + }); + + match ctx.config.snippet_cap { + None => builder.insert(start_offset, buf), + Some(cap) => { + buf = buf.replace("fn new", "fn $0new"); + builder.insert_snippet(cap, start_offset, buf); + } + } + }) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + use super::*; + + #[test] + fn test_generate_new_with_zst_fields() { + check_assist( + generate_new, + r#" +struct Empty; + +struct Foo { empty: Empty $0} +"#, + r#" +struct Empty; + +struct Foo { empty: Empty } + +impl Foo { + fn $0new() -> Self { Self { empty: Empty } } +} +"#, + ); + check_assist( + generate_new, + r#" +struct Empty; + +struct Foo { baz: String, empty: Empty $0} +"#, + r#" +struct Empty; + +struct Foo { baz: String, empty: Empty } + +impl Foo { + fn $0new(baz: String) -> Self { Self { baz, empty: Empty } } +} +"#, + ); + check_assist( + generate_new, + r#" +enum Empty { Bar } + +struct Foo { empty: Empty $0} +"#, + r#" +enum Empty { Bar } + +struct Foo { empty: Empty } + +impl Foo { + fn $0new() -> Self { Self { empty: Empty::Bar } } +} +"#, + ); + + // make sure the assist only works on unit variants + check_assist( + generate_new, + r#" +struct Empty {} + +struct Foo { empty: Empty $0} +"#, + r#" +struct Empty {} + +struct Foo { empty: Empty } + +impl Foo { + fn $0new(empty: Empty) -> Self { Self { empty } } +} +"#, + ); + check_assist( + generate_new, + r#" +enum Empty { Bar {} } + +struct Foo { empty: Empty $0} +"#, + r#" +enum Empty { Bar {} } + +struct Foo { empty: Empty } + +impl Foo { + fn $0new(empty: Empty) -> Self { Self { empty } } +} +"#, + ); + } + + #[test] + fn test_generate_new() { + check_assist( + generate_new, + r#" +struct Foo {$0} +"#, + r#" +struct Foo {} + +impl Foo { + fn $0new() -> Self { Self { } } +} +"#, + ); + check_assist( + generate_new, + r#" +struct Foo {$0} +"#, + r#" +struct Foo {} + +impl Foo { + fn $0new() -> Self { Self { } } +} +"#, + ); + check_assist( + generate_new, + r#" +struct Foo<'a, T: Foo<'a>> {$0} +"#, + r#" +struct Foo<'a, T: Foo<'a>> {} + +impl<'a, T: Foo<'a>> Foo<'a, T> { + fn $0new() -> Self { Self { } } +} +"#, + ); + check_assist( + generate_new, + r#" +struct Foo { baz: String $0} +"#, + r#" +struct Foo { baz: String } + +impl Foo { + fn $0new(baz: String) -> Self { Self { baz } } +} +"#, + ); + check_assist( + generate_new, + r#" +struct Foo { baz: String, qux: Vec $0} +"#, + r#" +struct Foo { baz: String, qux: Vec } + +impl Foo { + fn $0new(baz: String, qux: Vec) -> Self { Self { baz, qux } } +} +"#, + ); + } + + #[test] + fn check_that_visibility_modifiers_dont_get_brought_in() { + check_assist( + generate_new, + r#" +struct Foo { pub baz: String, pub qux: Vec $0} +"#, + r#" +struct Foo { pub baz: String, pub qux: Vec } + +impl Foo { + fn $0new(baz: String, qux: Vec) -> Self { Self { baz, qux } } +} +"#, + ); + } + + #[test] + fn check_it_reuses_existing_impls() { + check_assist( + generate_new, + r#" +struct Foo {$0} + +impl Foo {} +"#, + r#" +struct Foo {} + +impl Foo { + fn $0new() -> Self { Self { } } +} +"#, + ); + check_assist( + generate_new, + r#" +struct Foo {$0} + +impl Foo { + fn qux(&self) {} +} +"#, + r#" +struct Foo {} + +impl Foo { + fn $0new() -> Self { Self { } } + + fn qux(&self) {} +} +"#, + ); + + check_assist( + generate_new, + r#" +struct Foo {$0} + +impl Foo { + fn qux(&self) {} + fn baz() -> i32 { + 5 + } +} +"#, + r#" +struct Foo {} + +impl Foo { + fn $0new() -> Self { Self { } } + + fn qux(&self) {} + fn baz() -> i32 { + 5 + } +} +"#, + ); + } + + #[test] + fn check_visibility_of_new_fn_based_on_struct() { + check_assist( + generate_new, + r#" +pub struct Foo {$0} +"#, + r#" +pub struct Foo {} + +impl Foo { + pub fn $0new() -> Self { Self { } } +} +"#, + ); + check_assist( + generate_new, + r#" +pub(crate) struct Foo {$0} +"#, + r#" +pub(crate) struct Foo {} + +impl Foo { + pub(crate) fn $0new() -> Self { Self { } } +} +"#, + ); + } + + #[test] + fn generate_new_not_applicable_if_fn_exists() { + check_assist_not_applicable( + generate_new, + r#" +struct Foo {$0} + +impl Foo { + fn new() -> Self { + Self + } +} +"#, + ); + + check_assist_not_applicable( + generate_new, + r#" +struct Foo {$0} + +impl Foo { + fn New() -> Self { + Self + } +} +"#, + ); + } + + #[test] + fn generate_new_target() { + check_assist_target( + generate_new, + r#" +struct SomeThingIrrelevant; +/// Has a lifetime parameter +struct Foo<'a, T: Foo<'a>> {$0} +struct EvenMoreIrrelevant; +"#, + "/// Has a lifetime parameter +struct Foo<'a, T: Foo<'a>> {}", + ); + } + + #[test] + fn test_unrelated_new() { + check_assist( + generate_new, + r#" +pub struct AstId { + file_id: HirFileId, + file_ast_id: FileAstId, +} + +impl AstId { + pub fn new(file_id: HirFileId, file_ast_id: FileAstId) -> AstId { + AstId { file_id, file_ast_id } + } +} + +pub struct Source { + pub file_id: HirFileId,$0 + pub ast: T, +} + +impl Source { + pub fn map U, U>(self, f: F) -> Source { + Source { file_id: self.file_id, ast: f(self.ast) } + } +} +"#, + r#" +pub struct AstId { + file_id: HirFileId, + file_ast_id: FileAstId, +} + +impl AstId { + pub fn new(file_id: HirFileId, file_ast_id: FileAstId) -> AstId { + AstId { file_id, file_ast_id } + } +} + +pub struct Source { + pub file_id: HirFileId, + pub ast: T, +} + +impl Source { + pub fn $0new(file_id: HirFileId, ast: T) -> Self { Self { file_id, ast } } + + pub fn map U, U>(self, f: F) -> Source { + Source { file_id: self.file_id, ast: f(self.ast) } + } +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_setter.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_setter.rs new file mode 100644 index 000000000..2a7ad6ce3 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_setter.rs @@ -0,0 +1,184 @@ +use stdx::{format_to, to_lower_snake_case}; +use syntax::ast::{self, AstNode, HasName, HasVisibility}; + +use crate::{ + utils::{find_impl_block_end, find_struct_impl, generate_impl_text}, + AssistContext, AssistId, AssistKind, Assists, GroupLabel, +}; + +// Assist: generate_setter +// +// Generate a setter method. +// +// ``` +// struct Person { +// nam$0e: String, +// } +// ``` +// -> +// ``` +// struct Person { +// name: String, +// } +// +// impl Person { +// fn set_name(&mut self, name: String) { +// self.name = name; +// } +// } +// ``` +pub(crate) fn generate_setter(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let strukt = ctx.find_node_at_offset::()?; + let field = ctx.find_node_at_offset::()?; + + let field_name = field.name()?; + let field_ty = field.ty()?; + + // Return early if we've found an existing fn + let fn_name = to_lower_snake_case(&field_name.to_string()); + let impl_def = find_struct_impl( + ctx, + &ast::Adt::Struct(strukt.clone()), + format!("set_{}", fn_name).as_str(), + )?; + + let target = field.syntax().text_range(); + acc.add_group( + &GroupLabel("Generate getter/setter".to_owned()), + AssistId("generate_setter", AssistKind::Generate), + "Generate a setter method", + target, + |builder| { + let mut buf = String::with_capacity(512); + + if impl_def.is_some() { + buf.push('\n'); + } + + let vis = strukt.visibility().map_or(String::new(), |v| format!("{} ", v)); + format_to!( + buf, + " {}fn set_{}(&mut self, {}: {}) {{ + self.{} = {}; + }}", + vis, + fn_name, + fn_name, + field_ty, + fn_name, + fn_name, + ); + + let start_offset = impl_def + .and_then(|impl_def| find_impl_block_end(impl_def, &mut buf)) + .unwrap_or_else(|| { + buf = generate_impl_text(&ast::Adt::Struct(strukt.clone()), &buf); + strukt.syntax().text_range().end() + }); + + builder.insert(start_offset, buf); + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + fn check_not_applicable(ra_fixture: &str) { + check_assist_not_applicable(generate_setter, ra_fixture) + } + + #[test] + fn test_generate_setter_from_field() { + check_assist( + generate_setter, + r#" +struct Person { + dat$0a: T, +}"#, + r#" +struct Person { + data: T, +} + +impl Person { + fn set_data(&mut self, data: T) { + self.data = data; + } +}"#, + ); + } + + #[test] + fn test_generate_setter_already_implemented() { + check_not_applicable( + r#" +struct Person { + dat$0a: T, +} + +impl Person { + fn set_data(&mut self, data: T) { + self.data = data; + } +}"#, + ); + } + + #[test] + fn test_generate_setter_from_field_with_visibility_marker() { + check_assist( + generate_setter, + r#" +pub(crate) struct Person { + dat$0a: T, +}"#, + r#" +pub(crate) struct Person { + data: T, +} + +impl Person { + pub(crate) fn set_data(&mut self, data: T) { + self.data = data; + } +}"#, + ); + } + + #[test] + fn test_multiple_generate_setter() { + check_assist( + generate_setter, + r#" +struct Context { + data: T, + cou$0nt: usize, +} + +impl Context { + fn set_data(&mut self, data: T) { + self.data = data; + } +}"#, + r#" +struct Context { + data: T, + count: usize, +} + +impl Context { + fn set_data(&mut self, data: T) { + self.data = data; + } + + fn set_count(&mut self, count: usize) { + self.count = count; + } +}"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_call.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_call.rs new file mode 100644 index 000000000..80d3b9255 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_call.rs @@ -0,0 +1,1194 @@ +use ast::make; +use either::Either; +use hir::{db::HirDatabase, PathResolution, Semantics, TypeInfo}; +use ide_db::{ + base_db::{FileId, FileRange}, + defs::Definition, + imports::insert_use::remove_path_if_in_use_stmt, + path_transform::PathTransform, + search::{FileReference, SearchScope}, + syntax_helpers::{insert_whitespace_into_node::insert_ws_into, node_ext::expr_as_name_ref}, + RootDatabase, +}; +use itertools::{izip, Itertools}; +use syntax::{ + ast::{self, edit_in_place::Indent, HasArgList, PathExpr}, + ted, AstNode, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, AssistKind, +}; + +// Assist: inline_into_callers +// +// Inline a function or method body into all of its callers where possible, creating a `let` statement per parameter +// unless the parameter can be inlined. The parameter will be inlined either if it the supplied argument is a simple local +// or if the parameter is only accessed inside the function body once. +// If all calls can be inlined the function will be removed. +// +// ``` +// fn print(_: &str) {} +// fn foo$0(word: &str) { +// if !word.is_empty() { +// print(word); +// } +// } +// fn bar() { +// foo("안녕하세요"); +// foo("여러분"); +// } +// ``` +// -> +// ``` +// fn print(_: &str) {} +// +// fn bar() { +// { +// let word = "안녕하세요"; +// if !word.is_empty() { +// print(word); +// } +// }; +// { +// let word = "여러분"; +// if !word.is_empty() { +// print(word); +// } +// }; +// } +// ``` +pub(crate) fn inline_into_callers(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let def_file = ctx.file_id(); + let name = ctx.find_node_at_offset::()?; + let ast_func = name.syntax().parent().and_then(ast::Fn::cast)?; + let func_body = ast_func.body()?; + let param_list = ast_func.param_list()?; + + let function = ctx.sema.to_def(&ast_func)?; + + let params = get_fn_params(ctx.sema.db, function, ¶m_list)?; + + let usages = Definition::Function(function).usages(&ctx.sema); + if !usages.at_least_one() { + return None; + } + + let is_recursive_fn = usages + .clone() + .in_scope(SearchScope::file_range(FileRange { + file_id: def_file, + range: func_body.syntax().text_range(), + })) + .at_least_one(); + if is_recursive_fn { + cov_mark::hit!(inline_into_callers_recursive); + return None; + } + + acc.add( + AssistId("inline_into_callers", AssistKind::RefactorInline), + "Inline into all callers", + name.syntax().text_range(), + |builder| { + let mut usages = usages.all(); + let current_file_usage = usages.references.remove(&def_file); + + let mut remove_def = true; + let mut inline_refs_for_file = |file_id, refs: Vec| { + builder.edit_file(file_id); + let count = refs.len(); + // The collects are required as we are otherwise iterating while mutating 🙅‍♀️🙅‍♂️ + let (name_refs, name_refs_use): (Vec<_>, Vec<_>) = refs + .into_iter() + .filter_map(|file_ref| match file_ref.name { + ast::NameLike::NameRef(name_ref) => Some(name_ref), + _ => None, + }) + .partition_map(|name_ref| { + match name_ref.syntax().ancestors().find_map(ast::UseTree::cast) { + Some(use_tree) => Either::Right(builder.make_mut(use_tree)), + None => Either::Left(name_ref), + } + }); + let call_infos: Vec<_> = name_refs + .into_iter() + .filter_map(CallInfo::from_name_ref) + .map(|call_info| { + let mut_node = builder.make_syntax_mut(call_info.node.syntax().clone()); + (call_info, mut_node) + }) + .collect(); + let replaced = call_infos + .into_iter() + .map(|(call_info, mut_node)| { + let replacement = + inline(&ctx.sema, def_file, function, &func_body, ¶ms, &call_info); + ted::replace(mut_node, replacement.syntax()); + }) + .count(); + if replaced + name_refs_use.len() == count { + // we replaced all usages in this file, so we can remove the imports + name_refs_use.into_iter().for_each(|use_tree| { + if let Some(path) = use_tree.path() { + remove_path_if_in_use_stmt(&path); + } + }) + } else { + remove_def = false; + } + }; + for (file_id, refs) in usages.into_iter() { + inline_refs_for_file(file_id, refs); + } + match current_file_usage { + Some(refs) => inline_refs_for_file(def_file, refs), + None => builder.edit_file(def_file), + } + if remove_def { + builder.delete(ast_func.syntax().text_range()); + } + }, + ) +} + +// Assist: inline_call +// +// Inlines a function or method body creating a `let` statement per parameter unless the parameter +// can be inlined. The parameter will be inlined either if it the supplied argument is a simple local +// or if the parameter is only accessed inside the function body once. +// +// ``` +// # //- minicore: option +// fn foo(name: Option<&str>) { +// let name = name.unwrap$0(); +// } +// ``` +// -> +// ``` +// fn foo(name: Option<&str>) { +// let name = match name { +// Some(val) => val, +// None => panic!("called `Option::unwrap()` on a `None` value"), +// }; +// } +// ``` +pub(crate) fn inline_call(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let name_ref: ast::NameRef = ctx.find_node_at_offset()?; + let call_info = CallInfo::from_name_ref(name_ref.clone())?; + let (function, label) = match &call_info.node { + ast::CallableExpr::Call(call) => { + let path = match call.expr()? { + ast::Expr::PathExpr(path) => path.path(), + _ => None, + }?; + let function = match ctx.sema.resolve_path(&path)? { + PathResolution::Def(hir::ModuleDef::Function(f)) => f, + _ => return None, + }; + (function, format!("Inline `{}`", path)) + } + ast::CallableExpr::MethodCall(call) => { + (ctx.sema.resolve_method_call(call)?, format!("Inline `{}`", name_ref)) + } + }; + + let fn_source = ctx.sema.source(function)?; + let fn_body = fn_source.value.body()?; + let param_list = fn_source.value.param_list()?; + + let FileRange { file_id, range } = fn_source.syntax().original_file_range(ctx.sema.db); + if file_id == ctx.file_id() && range.contains(ctx.offset()) { + cov_mark::hit!(inline_call_recursive); + return None; + } + let params = get_fn_params(ctx.sema.db, function, ¶m_list)?; + + if call_info.arguments.len() != params.len() { + // Can't inline the function because they've passed the wrong number of + // arguments to this function + cov_mark::hit!(inline_call_incorrect_number_of_arguments); + return None; + } + + let syntax = call_info.node.syntax().clone(); + acc.add( + AssistId("inline_call", AssistKind::RefactorInline), + label, + syntax.text_range(), + |builder| { + let replacement = inline(&ctx.sema, file_id, function, &fn_body, ¶ms, &call_info); + + builder.replace_ast( + match call_info.node { + ast::CallableExpr::Call(it) => ast::Expr::CallExpr(it), + ast::CallableExpr::MethodCall(it) => ast::Expr::MethodCallExpr(it), + }, + replacement, + ); + }, + ) +} + +struct CallInfo { + node: ast::CallableExpr, + arguments: Vec, + generic_arg_list: Option, +} + +impl CallInfo { + fn from_name_ref(name_ref: ast::NameRef) -> Option { + let parent = name_ref.syntax().parent()?; + if let Some(call) = ast::MethodCallExpr::cast(parent.clone()) { + let receiver = call.receiver()?; + let mut arguments = vec![receiver]; + arguments.extend(call.arg_list()?.args()); + Some(CallInfo { + generic_arg_list: call.generic_arg_list(), + node: ast::CallableExpr::MethodCall(call), + arguments, + }) + } else if let Some(segment) = ast::PathSegment::cast(parent) { + let path = segment.syntax().parent().and_then(ast::Path::cast)?; + let path = path.syntax().parent().and_then(ast::PathExpr::cast)?; + let call = path.syntax().parent().and_then(ast::CallExpr::cast)?; + + Some(CallInfo { + arguments: call.arg_list()?.args().collect(), + node: ast::CallableExpr::Call(call), + generic_arg_list: segment.generic_arg_list(), + }) + } else { + None + } + } +} + +fn get_fn_params( + db: &dyn HirDatabase, + function: hir::Function, + param_list: &ast::ParamList, +) -> Option, hir::Param)>> { + let mut assoc_fn_params = function.assoc_fn_params(db).into_iter(); + + let mut params = Vec::new(); + if let Some(self_param) = param_list.self_param() { + // FIXME this should depend on the receiver as well as the self_param + params.push(( + make::ident_pat( + self_param.amp_token().is_some(), + self_param.mut_token().is_some(), + make::name("this"), + ) + .into(), + None, + assoc_fn_params.next()?, + )); + } + for param in param_list.params() { + params.push((param.pat()?, param.ty(), assoc_fn_params.next()?)); + } + + Some(params) +} + +fn inline( + sema: &Semantics<'_, RootDatabase>, + function_def_file_id: FileId, + function: hir::Function, + fn_body: &ast::BlockExpr, + params: &[(ast::Pat, Option, hir::Param)], + CallInfo { node, arguments, generic_arg_list }: &CallInfo, +) -> ast::Expr { + let body = if sema.hir_file_for(fn_body.syntax()).is_macro() { + cov_mark::hit!(inline_call_defined_in_macro); + if let Some(body) = ast::BlockExpr::cast(insert_ws_into(fn_body.syntax().clone())) { + body + } else { + fn_body.clone_for_update() + } + } else { + fn_body.clone_for_update() + }; + let usages_for_locals = |local| { + Definition::Local(local) + .usages(sema) + .all() + .references + .remove(&function_def_file_id) + .unwrap_or_default() + .into_iter() + }; + let param_use_nodes: Vec> = params + .iter() + .map(|(pat, _, param)| { + if !matches!(pat, ast::Pat::IdentPat(pat) if pat.is_simple_ident()) { + return Vec::new(); + } + // FIXME: we need to fetch all locals declared in the parameter here + // not only the local if it is a simple binding + match param.as_local(sema.db) { + Some(l) => usages_for_locals(l) + .map(|FileReference { name, range, .. }| match name { + ast::NameLike::NameRef(_) => body + .syntax() + .covering_element(range) + .ancestors() + .nth(3) + .and_then(ast::PathExpr::cast), + _ => None, + }) + .collect::>>() + .unwrap_or_default(), + None => Vec::new(), + } + }) + .collect(); + if function.self_param(sema.db).is_some() { + let this = || make::name_ref("this").syntax().clone_for_update(); + if let Some(self_local) = params[0].2.as_local(sema.db) { + usages_for_locals(self_local) + .flat_map(|FileReference { name, range, .. }| match name { + ast::NameLike::NameRef(_) => Some(body.syntax().covering_element(range)), + _ => None, + }) + .for_each(|it| { + ted::replace(it, &this()); + }) + } + } + // Inline parameter expressions or generate `let` statements depending on whether inlining works or not. + for ((pat, param_ty, _), usages, expr) in izip!(params, param_use_nodes, arguments).rev() { + let inline_direct = |usage, replacement: &ast::Expr| { + if let Some(field) = path_expr_as_record_field(usage) { + cov_mark::hit!(inline_call_inline_direct_field); + field.replace_expr(replacement.clone_for_update()); + } else { + ted::replace(usage.syntax(), &replacement.syntax().clone_for_update()); + } + }; + // izip confuses RA due to our lack of hygiene info currently losing us type info causing incorrect errors + let usages: &[ast::PathExpr] = &*usages; + let expr: &ast::Expr = expr; + match usages { + // inline single use closure arguments + [usage] + if matches!(expr, ast::Expr::ClosureExpr(_)) + && usage.syntax().parent().and_then(ast::Expr::cast).is_some() => + { + cov_mark::hit!(inline_call_inline_closure); + let expr = make::expr_paren(expr.clone()); + inline_direct(usage, &expr); + } + // inline single use literals + [usage] if matches!(expr, ast::Expr::Literal(_)) => { + cov_mark::hit!(inline_call_inline_literal); + inline_direct(usage, expr); + } + // inline direct local arguments + [_, ..] if expr_as_name_ref(expr).is_some() => { + cov_mark::hit!(inline_call_inline_locals); + usages.iter().for_each(|usage| inline_direct(usage, expr)); + } + // can't inline, emit a let statement + _ => { + let ty = + sema.type_of_expr(expr).filter(TypeInfo::has_adjustment).and(param_ty.clone()); + if let Some(stmt_list) = body.stmt_list() { + stmt_list.push_front( + make::let_stmt(pat.clone(), ty, Some(expr.clone())) + .clone_for_update() + .into(), + ) + } + } + } + } + if let Some(generic_arg_list) = generic_arg_list.clone() { + if let Some((target, source)) = &sema.scope(node.syntax()).zip(sema.scope(fn_body.syntax())) + { + PathTransform::function_call(target, source, function, generic_arg_list) + .apply(body.syntax()); + } + } + + let original_indentation = match node { + ast::CallableExpr::Call(it) => it.indent_level(), + ast::CallableExpr::MethodCall(it) => it.indent_level(), + }; + body.reindent_to(original_indentation); + + match body.tail_expr() { + Some(expr) if body.statements().next().is_none() => expr, + _ => match node + .syntax() + .parent() + .and_then(ast::BinExpr::cast) + .and_then(|bin_expr| bin_expr.lhs()) + { + Some(lhs) if lhs.syntax() == node.syntax() => { + make::expr_paren(ast::Expr::BlockExpr(body)).clone_for_update() + } + _ => ast::Expr::BlockExpr(body), + }, + } +} + +fn path_expr_as_record_field(usage: &PathExpr) -> Option { + let path = usage.path()?; + let name_ref = path.as_single_name_ref()?; + ast::RecordExprField::for_name_ref(&name_ref) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn no_args_or_return_value_gets_inlined_without_block() { + check_assist( + inline_call, + r#" +fn foo() { println!("Hello, World!"); } +fn main() { + fo$0o(); +} +"#, + r#" +fn foo() { println!("Hello, World!"); } +fn main() { + { println!("Hello, World!"); }; +} +"#, + ); + } + + #[test] + fn not_applicable_when_incorrect_number_of_parameters_are_provided() { + cov_mark::check!(inline_call_incorrect_number_of_arguments); + check_assist_not_applicable( + inline_call, + r#" +fn add(a: u32, b: u32) -> u32 { a + b } +fn main() { let x = add$0(42); } +"#, + ); + } + + #[test] + fn args_with_side_effects() { + check_assist( + inline_call, + r#" +fn foo(name: String) { + println!("Hello, {}!", name); +} +fn main() { + foo$0(String::from("Michael")); +} +"#, + r#" +fn foo(name: String) { + println!("Hello, {}!", name); +} +fn main() { + { + let name = String::from("Michael"); + println!("Hello, {}!", name); + }; +} +"#, + ); + } + + #[test] + fn function_with_multiple_statements() { + check_assist( + inline_call, + r#" +fn foo(a: u32, b: u32) -> u32 { + let x = a + b; + let y = x - b; + x * y +} + +fn main() { + let x = foo$0(1, 2); +} +"#, + r#" +fn foo(a: u32, b: u32) -> u32 { + let x = a + b; + let y = x - b; + x * y +} + +fn main() { + let x = { + let b = 2; + let x = 1 + b; + let y = x - b; + x * y + }; +} +"#, + ); + } + + #[test] + fn function_with_self_param() { + check_assist( + inline_call, + r#" +struct Foo(u32); + +impl Foo { + fn add(self, a: u32) -> Self { + Foo(self.0 + a) + } +} + +fn main() { + let x = Foo::add$0(Foo(3), 2); +} +"#, + r#" +struct Foo(u32); + +impl Foo { + fn add(self, a: u32) -> Self { + Foo(self.0 + a) + } +} + +fn main() { + let x = { + let this = Foo(3); + Foo(this.0 + 2) + }; +} +"#, + ); + } + + #[test] + fn method_by_val() { + check_assist( + inline_call, + r#" +struct Foo(u32); + +impl Foo { + fn add(self, a: u32) -> Self { + Foo(self.0 + a) + } +} + +fn main() { + let x = Foo(3).add$0(2); +} +"#, + r#" +struct Foo(u32); + +impl Foo { + fn add(self, a: u32) -> Self { + Foo(self.0 + a) + } +} + +fn main() { + let x = { + let this = Foo(3); + Foo(this.0 + 2) + }; +} +"#, + ); + } + + #[test] + fn method_by_ref() { + check_assist( + inline_call, + r#" +struct Foo(u32); + +impl Foo { + fn add(&self, a: u32) -> Self { + Foo(self.0 + a) + } +} + +fn main() { + let x = Foo(3).add$0(2); +} +"#, + r#" +struct Foo(u32); + +impl Foo { + fn add(&self, a: u32) -> Self { + Foo(self.0 + a) + } +} + +fn main() { + let x = { + let ref this = Foo(3); + Foo(this.0 + 2) + }; +} +"#, + ); + } + + #[test] + fn method_by_ref_mut() { + check_assist( + inline_call, + r#" +struct Foo(u32); + +impl Foo { + fn clear(&mut self) { + self.0 = 0; + } +} + +fn main() { + let mut foo = Foo(3); + foo.clear$0(); +} +"#, + r#" +struct Foo(u32); + +impl Foo { + fn clear(&mut self) { + self.0 = 0; + } +} + +fn main() { + let mut foo = Foo(3); + { + let ref mut this = foo; + this.0 = 0; + }; +} +"#, + ); + } + + #[test] + fn function_multi_use_expr_in_param() { + check_assist( + inline_call, + r#" +fn square(x: u32) -> u32 { + x * x +} +fn main() { + let x = 51; + let y = square$0(10 + x); +} +"#, + r#" +fn square(x: u32) -> u32 { + x * x +} +fn main() { + let x = 51; + let y = { + let x = 10 + x; + x * x + }; +} +"#, + ); + } + + #[test] + fn function_use_local_in_param() { + cov_mark::check!(inline_call_inline_locals); + check_assist( + inline_call, + r#" +fn square(x: u32) -> u32 { + x * x +} +fn main() { + let local = 51; + let y = square$0(local); +} +"#, + r#" +fn square(x: u32) -> u32 { + x * x +} +fn main() { + let local = 51; + let y = local * local; +} +"#, + ); + } + + #[test] + fn method_in_impl() { + check_assist( + inline_call, + r#" +struct Foo; +impl Foo { + fn foo(&self) { + self; + self; + } + fn bar(&self) { + self.foo$0(); + } +} +"#, + r#" +struct Foo; +impl Foo { + fn foo(&self) { + self; + self; + } + fn bar(&self) { + { + let ref this = self; + this; + this; + }; + } +} +"#, + ); + } + + #[test] + fn wraps_closure_in_paren() { + cov_mark::check!(inline_call_inline_closure); + check_assist( + inline_call, + r#" +fn foo(x: fn()) { + x(); +} + +fn main() { + foo$0(|| {}) +} +"#, + r#" +fn foo(x: fn()) { + x(); +} + +fn main() { + { + (|| {})(); + } +} +"#, + ); + check_assist( + inline_call, + r#" +fn foo(x: fn()) { + x(); +} + +fn main() { + foo$0(main) +} +"#, + r#" +fn foo(x: fn()) { + x(); +} + +fn main() { + { + main(); + } +} +"#, + ); + } + + #[test] + fn inline_single_literal_expr() { + cov_mark::check!(inline_call_inline_literal); + check_assist( + inline_call, + r#" +fn foo(x: u32) -> u32{ + x +} + +fn main() { + foo$0(222); +} +"#, + r#" +fn foo(x: u32) -> u32{ + x +} + +fn main() { + 222; +} +"#, + ); + } + + #[test] + fn inline_emits_type_for_coercion() { + check_assist( + inline_call, + r#" +fn foo(x: *const u32) -> u32 { + x as u32 +} + +fn main() { + foo$0(&222); +} +"#, + r#" +fn foo(x: *const u32) -> u32 { + x as u32 +} + +fn main() { + { + let x: *const u32 = &222; + x as u32 + }; +} +"#, + ); + } + + // FIXME: const generics aren't being substituted, this is blocked on better support for them + #[test] + fn inline_substitutes_generics() { + check_assist( + inline_call, + r#" +fn foo() { + bar::() +} + +fn bar() {} + +fn main() { + foo$0::(); +} +"#, + r#" +fn foo() { + bar::() +} + +fn bar() {} + +fn main() { + bar::(); +} +"#, + ); + } + + #[test] + fn inline_callers() { + check_assist( + inline_into_callers, + r#" +fn do_the_math$0(b: u32) -> u32 { + let foo = 10; + foo * b + foo +} +fn foo() { + do_the_math(0); + let bar = 10; + do_the_math(bar); +} +"#, + r#" + +fn foo() { + { + let foo = 10; + foo * 0 + foo + }; + let bar = 10; + { + let foo = 10; + foo * bar + foo + }; +} +"#, + ); + } + + #[test] + fn inline_callers_across_files() { + check_assist( + inline_into_callers, + r#" +//- /lib.rs +mod foo; +fn do_the_math$0(b: u32) -> u32 { + let foo = 10; + foo * b + foo +} +//- /foo.rs +use super::do_the_math; +fn foo() { + do_the_math(0); + let bar = 10; + do_the_math(bar); +} +"#, + r#" +//- /lib.rs +mod foo; + +//- /foo.rs +fn foo() { + { + let foo = 10; + foo * 0 + foo + }; + let bar = 10; + { + let foo = 10; + foo * bar + foo + }; +} +"#, + ); + } + + #[test] + fn inline_callers_across_files_with_def_file() { + check_assist( + inline_into_callers, + r#" +//- /lib.rs +mod foo; +fn do_the_math$0(b: u32) -> u32 { + let foo = 10; + foo * b + foo +} +fn bar(a: u32, b: u32) -> u32 { + do_the_math(0); +} +//- /foo.rs +use super::do_the_math; +fn foo() { + do_the_math(0); +} +"#, + r#" +//- /lib.rs +mod foo; + +fn bar(a: u32, b: u32) -> u32 { + { + let foo = 10; + foo * 0 + foo + }; +} +//- /foo.rs +fn foo() { + { + let foo = 10; + foo * 0 + foo + }; +} +"#, + ); + } + + #[test] + fn inline_callers_recursive() { + cov_mark::check!(inline_into_callers_recursive); + check_assist_not_applicable( + inline_into_callers, + r#" +fn foo$0() { + foo(); +} +"#, + ); + } + + #[test] + fn inline_call_recursive() { + cov_mark::check!(inline_call_recursive); + check_assist_not_applicable( + inline_call, + r#" +fn foo() { + foo$0(); +} +"#, + ); + } + + #[test] + fn inline_call_field_shorthand() { + cov_mark::check!(inline_call_inline_direct_field); + check_assist( + inline_call, + r#" +struct Foo { + field: u32, + field1: u32, + field2: u32, + field3: u32, +} +fn foo(field: u32, field1: u32, val2: u32, val3: u32) -> Foo { + Foo { + field, + field1, + field2: val2, + field3: val3, + } +} +fn main() { + let bar = 0; + let baz = 0; + foo$0(bar, 0, baz, 0); +} +"#, + r#" +struct Foo { + field: u32, + field1: u32, + field2: u32, + field3: u32, +} +fn foo(field: u32, field1: u32, val2: u32, val3: u32) -> Foo { + Foo { + field, + field1, + field2: val2, + field3: val3, + } +} +fn main() { + let bar = 0; + let baz = 0; + Foo { + field: bar, + field1: 0, + field2: baz, + field3: 0, + }; +} +"#, + ); + } + + #[test] + fn inline_callers_wrapped_in_parentheses() { + check_assist( + inline_into_callers, + r#" +fn foo$0() -> u32 { + let x = 0; + x +} +fn bar() -> u32 { + foo() + foo() +} +"#, + r#" + +fn bar() -> u32 { + ({ + let x = 0; + x + }) + { + let x = 0; + x + } +} +"#, + ) + } + + #[test] + fn inline_call_wrapped_in_parentheses() { + check_assist( + inline_call, + r#" +fn foo() -> u32 { + let x = 0; + x +} +fn bar() -> u32 { + foo$0() + foo() +} +"#, + r#" +fn foo() -> u32 { + let x = 0; + x +} +fn bar() -> u32 { + ({ + let x = 0; + x + }) + foo() +} +"#, + ) + } + + #[test] + fn inline_call_defined_in_macro() { + cov_mark::check!(inline_call_defined_in_macro); + check_assist( + inline_call, + r#" +macro_rules! define_foo { + () => { fn foo() -> u32 { + let x = 0; + x + } }; +} +define_foo!(); +fn bar() -> u32 { + foo$0() +} +"#, + r#" +macro_rules! define_foo { + () => { fn foo() -> u32 { + let x = 0; + x + } }; +} +define_foo!(); +fn bar() -> u32 { + { + let x = 0; + x + } +} +"#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_local_variable.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_local_variable.rs new file mode 100644 index 000000000..7259d6781 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_local_variable.rs @@ -0,0 +1,954 @@ +use either::Either; +use hir::{PathResolution, Semantics}; +use ide_db::{ + base_db::FileId, + defs::Definition, + search::{FileReference, UsageSearchResult}, + RootDatabase, +}; +use syntax::{ + ast::{self, AstNode, AstToken, HasName}, + SyntaxElement, TextRange, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, AssistKind, +}; + +// Assist: inline_local_variable +// +// Inlines a local variable. +// +// ``` +// fn main() { +// let x$0 = 1 + 2; +// x * 4; +// } +// ``` +// -> +// ``` +// fn main() { +// (1 + 2) * 4; +// } +// ``` +pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let file_id = ctx.file_id(); + let range = ctx.selection_trimmed(); + let InlineData { let_stmt, delete_let, references, target } = + if let Some(path_expr) = ctx.find_node_at_offset::() { + inline_usage(&ctx.sema, path_expr, range, file_id) + } else if let Some(let_stmt) = ctx.find_node_at_offset::() { + inline_let(&ctx.sema, let_stmt, range, file_id) + } else { + None + }?; + let initializer_expr = let_stmt.initializer()?; + + let delete_range = delete_let.then(|| { + if let Some(whitespace) = let_stmt + .syntax() + .next_sibling_or_token() + .and_then(SyntaxElement::into_token) + .and_then(ast::Whitespace::cast) + { + TextRange::new( + let_stmt.syntax().text_range().start(), + whitespace.syntax().text_range().end(), + ) + } else { + let_stmt.syntax().text_range() + } + }); + + let wrap_in_parens = references + .into_iter() + .filter_map(|FileReference { range, name, .. }| match name { + ast::NameLike::NameRef(name) => Some((range, name)), + _ => None, + }) + .map(|(range, name_ref)| { + if range != name_ref.syntax().text_range() { + // Do not rename inside macros + // FIXME: This feels like a bad heuristic for macros + return None; + } + let usage_node = + name_ref.syntax().ancestors().find(|it| ast::PathExpr::can_cast(it.kind())); + let usage_parent_option = + usage_node.and_then(|it| it.parent()).and_then(ast::Expr::cast); + let usage_parent = match usage_parent_option { + Some(u) => u, + None => return Some((range, name_ref, false)), + }; + let initializer = matches!( + initializer_expr, + ast::Expr::CallExpr(_) + | ast::Expr::IndexExpr(_) + | ast::Expr::MethodCallExpr(_) + | ast::Expr::FieldExpr(_) + | ast::Expr::TryExpr(_) + | ast::Expr::Literal(_) + | ast::Expr::TupleExpr(_) + | ast::Expr::ArrayExpr(_) + | ast::Expr::ParenExpr(_) + | ast::Expr::PathExpr(_) + | ast::Expr::BlockExpr(_), + ); + let parent = matches!( + usage_parent, + ast::Expr::CallExpr(_) + | ast::Expr::TupleExpr(_) + | ast::Expr::ArrayExpr(_) + | ast::Expr::ParenExpr(_) + | ast::Expr::ForExpr(_) + | ast::Expr::WhileExpr(_) + | ast::Expr::BreakExpr(_) + | ast::Expr::ReturnExpr(_) + | ast::Expr::MatchExpr(_) + | ast::Expr::BlockExpr(_) + ); + Some((range, name_ref, !(initializer || parent))) + }) + .collect::>>()?; + + let init_str = initializer_expr.syntax().text().to_string(); + let init_in_paren = format!("({})", &init_str); + + let target = match target { + ast::NameOrNameRef::Name(it) => it.syntax().text_range(), + ast::NameOrNameRef::NameRef(it) => it.syntax().text_range(), + }; + + acc.add( + AssistId("inline_local_variable", AssistKind::RefactorInline), + "Inline variable", + target, + move |builder| { + if let Some(range) = delete_range { + builder.delete(range); + } + for (range, name, should_wrap) in wrap_in_parens { + let replacement = if should_wrap { &init_in_paren } else { &init_str }; + if ast::RecordExprField::for_field_name(&name).is_some() { + cov_mark::hit!(inline_field_shorthand); + builder.insert(range.end(), format!(": {}", replacement)); + } else { + builder.replace(range, replacement.clone()) + } + } + }, + ) +} + +struct InlineData { + let_stmt: ast::LetStmt, + delete_let: bool, + target: ast::NameOrNameRef, + references: Vec, +} + +fn inline_let( + sema: &Semantics<'_, RootDatabase>, + let_stmt: ast::LetStmt, + range: TextRange, + file_id: FileId, +) -> Option { + let bind_pat = match let_stmt.pat()? { + ast::Pat::IdentPat(pat) => pat, + _ => return None, + }; + if bind_pat.mut_token().is_some() { + cov_mark::hit!(test_not_inline_mut_variable); + return None; + } + if !bind_pat.syntax().text_range().contains_range(range) { + cov_mark::hit!(not_applicable_outside_of_bind_pat); + return None; + } + + let local = sema.to_def(&bind_pat)?; + let UsageSearchResult { mut references } = Definition::Local(local).usages(sema).all(); + match references.remove(&file_id) { + Some(references) => Some(InlineData { + let_stmt, + delete_let: true, + target: ast::NameOrNameRef::Name(bind_pat.name()?), + references, + }), + None => { + cov_mark::hit!(test_not_applicable_if_variable_unused); + None + } + } +} + +fn inline_usage( + sema: &Semantics<'_, RootDatabase>, + path_expr: ast::PathExpr, + range: TextRange, + file_id: FileId, +) -> Option { + let path = path_expr.path()?; + let name = path.as_single_name_ref()?; + if !name.syntax().text_range().contains_range(range) { + cov_mark::hit!(test_not_inline_selection_too_broad); + return None; + } + + let local = match sema.resolve_path(&path)? { + PathResolution::Local(local) => local, + _ => return None, + }; + if local.is_mut(sema.db) { + cov_mark::hit!(test_not_inline_mut_variable_use); + return None; + } + + // FIXME: Handle multiple local definitions + let bind_pat = match local.source(sema.db).value { + Either::Left(ident) => ident, + _ => return None, + }; + + let let_stmt = ast::LetStmt::cast(bind_pat.syntax().parent()?)?; + + let UsageSearchResult { mut references } = Definition::Local(local).usages(sema).all(); + let mut references = references.remove(&file_id)?; + let delete_let = references.len() == 1; + references.retain(|fref| fref.name.as_name_ref() == Some(&name)); + + Some(InlineData { let_stmt, delete_let, target: ast::NameOrNameRef::NameRef(name), references }) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_inline_let_bind_literal_expr() { + check_assist( + inline_local_variable, + r" +fn bar(a: usize) {} +fn foo() { + let a$0 = 1; + a + 1; + if a > 10 { + } + + while a > 10 { + + } + let b = a * 10; + bar(a); +}", + r" +fn bar(a: usize) {} +fn foo() { + 1 + 1; + if 1 > 10 { + } + + while 1 > 10 { + + } + let b = 1 * 10; + bar(1); +}", + ); + } + + #[test] + fn test_inline_let_bind_bin_expr() { + check_assist( + inline_local_variable, + r" +fn bar(a: usize) {} +fn foo() { + let a$0 = 1 + 1; + a + 1; + if a > 10 { + } + + while a > 10 { + + } + let b = a * 10; + bar(a); +}", + r" +fn bar(a: usize) {} +fn foo() { + (1 + 1) + 1; + if (1 + 1) > 10 { + } + + while (1 + 1) > 10 { + + } + let b = (1 + 1) * 10; + bar(1 + 1); +}", + ); + } + + #[test] + fn test_inline_let_bind_function_call_expr() { + check_assist( + inline_local_variable, + r" +fn bar(a: usize) {} +fn foo() { + let a$0 = bar(1); + a + 1; + if a > 10 { + } + + while a > 10 { + + } + let b = a * 10; + bar(a); +}", + r" +fn bar(a: usize) {} +fn foo() { + bar(1) + 1; + if bar(1) > 10 { + } + + while bar(1) > 10 { + + } + let b = bar(1) * 10; + bar(bar(1)); +}", + ); + } + + #[test] + fn test_inline_let_bind_cast_expr() { + check_assist( + inline_local_variable, + r" +fn bar(a: usize): usize { a } +fn foo() { + let a$0 = bar(1) as u64; + a + 1; + if a > 10 { + } + + while a > 10 { + + } + let b = a * 10; + bar(a); +}", + r" +fn bar(a: usize): usize { a } +fn foo() { + (bar(1) as u64) + 1; + if (bar(1) as u64) > 10 { + } + + while (bar(1) as u64) > 10 { + + } + let b = (bar(1) as u64) * 10; + bar(bar(1) as u64); +}", + ); + } + + #[test] + fn test_inline_let_bind_block_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a$0 = { 10 + 1 }; + a + 1; + if a > 10 { + } + + while a > 10 { + + } + let b = a * 10; + bar(a); +}", + r" +fn foo() { + { 10 + 1 } + 1; + if { 10 + 1 } > 10 { + } + + while { 10 + 1 } > 10 { + + } + let b = { 10 + 1 } * 10; + bar({ 10 + 1 }); +}", + ); + } + + #[test] + fn test_inline_let_bind_paren_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a$0 = ( 10 + 1 ); + a + 1; + if a > 10 { + } + + while a > 10 { + + } + let b = a * 10; + bar(a); +}", + r" +fn foo() { + ( 10 + 1 ) + 1; + if ( 10 + 1 ) > 10 { + } + + while ( 10 + 1 ) > 10 { + + } + let b = ( 10 + 1 ) * 10; + bar(( 10 + 1 )); +}", + ); + } + + #[test] + fn test_not_inline_mut_variable() { + cov_mark::check!(test_not_inline_mut_variable); + check_assist_not_applicable( + inline_local_variable, + r" +fn foo() { + let mut a$0 = 1 + 1; + a + 1; +}", + ); + } + + #[test] + fn test_not_inline_mut_variable_use() { + cov_mark::check!(test_not_inline_mut_variable_use); + check_assist_not_applicable( + inline_local_variable, + r" +fn foo() { + let mut a = 1 + 1; + a$0 + 1; +}", + ); + } + + #[test] + fn test_call_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a$0 = bar(10 + 1); + let b = a * 10; + let c = a as usize; +}", + r" +fn foo() { + let b = bar(10 + 1) * 10; + let c = bar(10 + 1) as usize; +}", + ); + } + + #[test] + fn test_index_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let x = vec![1, 2, 3]; + let a$0 = x[0]; + let b = a * 10; + let c = a as usize; +}", + r" +fn foo() { + let x = vec![1, 2, 3]; + let b = x[0] * 10; + let c = x[0] as usize; +}", + ); + } + + #[test] + fn test_method_call_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let bar = vec![1]; + let a$0 = bar.len(); + let b = a * 10; + let c = a as usize; +}", + r" +fn foo() { + let bar = vec![1]; + let b = bar.len() * 10; + let c = bar.len() as usize; +}", + ); + } + + #[test] + fn test_field_expr() { + check_assist( + inline_local_variable, + r" +struct Bar { + foo: usize +} + +fn foo() { + let bar = Bar { foo: 1 }; + let a$0 = bar.foo; + let b = a * 10; + let c = a as usize; +}", + r" +struct Bar { + foo: usize +} + +fn foo() { + let bar = Bar { foo: 1 }; + let b = bar.foo * 10; + let c = bar.foo as usize; +}", + ); + } + + #[test] + fn test_try_expr() { + check_assist( + inline_local_variable, + r" +fn foo() -> Option { + let bar = Some(1); + let a$0 = bar?; + let b = a * 10; + let c = a as usize; + None +}", + r" +fn foo() -> Option { + let bar = Some(1); + let b = bar? * 10; + let c = bar? as usize; + None +}", + ); + } + + #[test] + fn test_ref_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let bar = 10; + let a$0 = &bar; + let b = a * 10; +}", + r" +fn foo() { + let bar = 10; + let b = (&bar) * 10; +}", + ); + } + + #[test] + fn test_tuple_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a$0 = (10, 20); + let b = a[0]; +}", + r" +fn foo() { + let b = (10, 20)[0]; +}", + ); + } + + #[test] + fn test_array_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a$0 = [1, 2, 3]; + let b = a.len(); +}", + r" +fn foo() { + let b = [1, 2, 3].len(); +}", + ); + } + + #[test] + fn test_paren() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a$0 = (10 + 20); + let b = a * 10; + let c = a as usize; +}", + r" +fn foo() { + let b = (10 + 20) * 10; + let c = (10 + 20) as usize; +}", + ); + } + + #[test] + fn test_path_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let d = 10; + let a$0 = d; + let b = a * 10; + let c = a as usize; +}", + r" +fn foo() { + let d = 10; + let b = d * 10; + let c = d as usize; +}", + ); + } + + #[test] + fn test_block_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a$0 = { 10 }; + let b = a * 10; + let c = a as usize; +}", + r" +fn foo() { + let b = { 10 } * 10; + let c = { 10 } as usize; +}", + ); + } + + #[test] + fn test_used_in_different_expr1() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a$0 = 10 + 20; + let b = a * 10; + let c = (a, 20); + let d = [a, 10]; + let e = (a); +}", + r" +fn foo() { + let b = (10 + 20) * 10; + let c = (10 + 20, 20); + let d = [10 + 20, 10]; + let e = (10 + 20); +}", + ); + } + + #[test] + fn test_used_in_for_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a$0 = vec![10, 20]; + for i in a {} +}", + r" +fn foo() { + for i in vec![10, 20] {} +}", + ); + } + + #[test] + fn test_used_in_while_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a$0 = 1 > 0; + while a {} +}", + r" +fn foo() { + while 1 > 0 {} +}", + ); + } + + #[test] + fn test_used_in_break_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a$0 = 1 + 1; + loop { + break a; + } +}", + r" +fn foo() { + loop { + break 1 + 1; + } +}", + ); + } + + #[test] + fn test_used_in_return_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a$0 = 1 > 0; + return a; +}", + r" +fn foo() { + return 1 > 0; +}", + ); + } + + #[test] + fn test_used_in_match_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a$0 = 1 > 0; + match a {} +}", + r" +fn foo() { + match 1 > 0 {} +}", + ); + } + + #[test] + fn inline_field_shorthand() { + cov_mark::check!(inline_field_shorthand); + check_assist( + inline_local_variable, + r" +struct S { foo: i32} +fn main() { + let $0foo = 92; + S { foo } +} +", + r" +struct S { foo: i32} +fn main() { + S { foo: 92 } +} +", + ); + } + + #[test] + fn test_not_applicable_if_variable_unused() { + cov_mark::check!(test_not_applicable_if_variable_unused); + check_assist_not_applicable( + inline_local_variable, + r" +fn foo() { + let $0a = 0; +} + ", + ) + } + + #[test] + fn not_applicable_outside_of_bind_pat() { + cov_mark::check!(not_applicable_outside_of_bind_pat); + check_assist_not_applicable( + inline_local_variable, + r" +fn main() { + let x = $01 + 2; + x * 4; +} +", + ) + } + + #[test] + fn works_on_local_usage() { + check_assist( + inline_local_variable, + r#" +fn f() { + let xyz = 0; + xyz$0; +} +"#, + r#" +fn f() { + 0; +} +"#, + ); + } + + #[test] + fn does_not_remove_let_when_multiple_usages() { + check_assist( + inline_local_variable, + r#" +fn f() { + let xyz = 0; + xyz$0; + xyz; +} +"#, + r#" +fn f() { + let xyz = 0; + 0; + xyz; +} +"#, + ); + } + + #[test] + fn not_applicable_with_non_ident_pattern() { + check_assist_not_applicable( + inline_local_variable, + r#" +fn main() { + let (x, y) = (0, 1); + x$0; +} +"#, + ); + } + + #[test] + fn not_applicable_on_local_usage_in_macro() { + check_assist_not_applicable( + inline_local_variable, + r#" +macro_rules! m { + ($i:ident) => { $i } +} +fn f() { + let xyz = 0; + m!(xyz$0); // replacing it would break the macro +} +"#, + ); + check_assist_not_applicable( + inline_local_variable, + r#" +macro_rules! m { + ($i:ident) => { $i } +} +fn f() { + let xyz$0 = 0; + m!(xyz); // replacing it would break the macro +} +"#, + ); + } + + #[test] + fn test_not_inline_selection_too_broad() { + cov_mark::check!(test_not_inline_selection_too_broad); + check_assist_not_applicable( + inline_local_variable, + r#" +fn f() { + let foo = 0; + let bar = 0; + $0foo + bar$0; +} +"#, + ); + } + + #[test] + fn test_inline_ref_in_let() { + check_assist( + inline_local_variable, + r#" +fn f() { + let x = { + let y = 0; + y$0 + }; +} +"#, + r#" +fn f() { + let x = { + 0 + }; +} +"#, + ); + } + + #[test] + fn test_inline_let_unit_struct() { + check_assist_not_applicable( + inline_local_variable, + r#" +struct S; +fn f() { + let S$0 = S; + S; +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_type_alias.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_type_alias.rs new file mode 100644 index 000000000..054663a06 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_type_alias.rs @@ -0,0 +1,838 @@ +// Some ideas for future improvements: +// - Support replacing aliases which are used in expressions, e.g. `A::new()`. +// - "inline_alias_to_users" assist #10881. +// - Remove unused aliases if there are no longer any users, see inline_call.rs. + +use hir::{HasSource, PathResolution}; +use itertools::Itertools; +use std::collections::HashMap; +use syntax::{ + ast::{self, make, HasGenericParams, HasName}, + ted, AstNode, NodeOrToken, SyntaxNode, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, AssistKind, +}; + +// Assist: inline_type_alias +// +// Replace a type alias with its concrete type. +// +// ``` +// type A = Vec; +// +// fn main() { +// let a: $0A; +// } +// ``` +// -> +// ``` +// type A = Vec; +// +// fn main() { +// let a: Vec; +// } +// ``` +pub(crate) fn inline_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + enum Replacement { + Generic { lifetime_map: LifetimeMap, const_and_type_map: ConstAndTypeMap }, + Plain, + } + + let alias_instance = ctx.find_node_at_offset::()?; + let concrete_type; + let replacement; + match alias_instance.path()?.as_single_name_ref() { + Some(nameref) if nameref.Self_token().is_some() => { + match ctx.sema.resolve_path(&alias_instance.path()?)? { + PathResolution::SelfType(imp) => { + concrete_type = imp.source(ctx.db())?.value.self_ty()?; + } + // FIXME: should also work in ADT definitions + _ => return None, + } + + replacement = Replacement::Plain; + } + _ => { + let alias = get_type_alias(&ctx, &alias_instance)?; + concrete_type = alias.ty()?; + + replacement = if let Some(alias_generics) = alias.generic_param_list() { + if alias_generics.generic_params().next().is_none() { + cov_mark::hit!(no_generics_params); + return None; + } + + let instance_args = + alias_instance.syntax().descendants().find_map(ast::GenericArgList::cast); + + Replacement::Generic { + lifetime_map: LifetimeMap::new(&instance_args, &alias_generics)?, + const_and_type_map: ConstAndTypeMap::new(&instance_args, &alias_generics)?, + } + } else { + Replacement::Plain + }; + } + } + + let target = alias_instance.syntax().text_range(); + + acc.add( + AssistId("inline_type_alias", AssistKind::RefactorInline), + "Inline type alias", + target, + |builder| { + let replacement_text = match replacement { + Replacement::Generic { lifetime_map, const_and_type_map } => { + create_replacement(&lifetime_map, &const_and_type_map, &concrete_type) + } + Replacement::Plain => concrete_type.to_string(), + }; + + builder.replace(target, replacement_text); + }, + ) +} + +struct LifetimeMap(HashMap); + +impl LifetimeMap { + fn new( + instance_args: &Option, + alias_generics: &ast::GenericParamList, + ) -> Option { + let mut inner = HashMap::new(); + + let wildcard_lifetime = make::lifetime("'_"); + let lifetimes = alias_generics + .lifetime_params() + .filter_map(|lp| lp.lifetime()) + .map(|l| l.to_string()) + .collect_vec(); + + for lifetime in &lifetimes { + inner.insert(lifetime.to_string(), wildcard_lifetime.clone()); + } + + if let Some(instance_generic_args_list) = &instance_args { + for (index, lifetime) in instance_generic_args_list + .lifetime_args() + .filter_map(|arg| arg.lifetime()) + .enumerate() + { + let key = match lifetimes.get(index) { + Some(key) => key, + None => { + cov_mark::hit!(too_many_lifetimes); + return None; + } + }; + + inner.insert(key.clone(), lifetime); + } + } + + Some(Self(inner)) + } +} + +struct ConstAndTypeMap(HashMap); + +impl ConstAndTypeMap { + fn new( + instance_args: &Option, + alias_generics: &ast::GenericParamList, + ) -> Option { + let mut inner = HashMap::new(); + let instance_generics = generic_args_to_const_and_type_generics(instance_args); + let alias_generics = generic_param_list_to_const_and_type_generics(&alias_generics); + + if instance_generics.len() > alias_generics.len() { + cov_mark::hit!(too_many_generic_args); + return None; + } + + // Any declaration generics that don't have a default value must have one + // provided by the instance. + for (i, declaration_generic) in alias_generics.iter().enumerate() { + let key = declaration_generic.replacement_key()?; + + if let Some(instance_generic) = instance_generics.get(i) { + inner.insert(key, instance_generic.replacement_value()?); + } else if let Some(value) = declaration_generic.replacement_value() { + inner.insert(key, value); + } else { + cov_mark::hit!(missing_replacement_param); + return None; + } + } + + Some(Self(inner)) + } +} + +/// This doesn't attempt to ensure specified generics are compatible with those +/// required by the type alias, other than lifetimes which must either all be +/// specified or all omitted. It will replace TypeArgs with ConstArgs and vice +/// versa if they're in the wrong position. It supports partially specified +/// generics. +/// +/// 1. Map the provided instance's generic args to the type alias's generic +/// params: +/// +/// ``` +/// type A<'a, const N: usize, T = u64> = &'a [T; N]; +/// ^ alias generic params +/// let a: A<100>; +/// ^ instance generic args +/// ``` +/// +/// generic['a] = '_ due to omission +/// generic[N] = 100 due to the instance arg +/// generic[T] = u64 due to the default param +/// +/// 2. Copy the concrete type and substitute in each found mapping: +/// +/// &'_ [u64; 100] +/// +/// 3. Remove wildcard lifetimes entirely: +/// +/// &[u64; 100] +fn create_replacement( + lifetime_map: &LifetimeMap, + const_and_type_map: &ConstAndTypeMap, + concrete_type: &ast::Type, +) -> String { + let updated_concrete_type = concrete_type.clone_for_update(); + let mut replacements = Vec::new(); + let mut removals = Vec::new(); + + for syntax in updated_concrete_type.syntax().descendants() { + let syntax_string = syntax.to_string(); + let syntax_str = syntax_string.as_str(); + + if let Some(old_lifetime) = ast::Lifetime::cast(syntax.clone()) { + if let Some(new_lifetime) = lifetime_map.0.get(&old_lifetime.to_string()) { + if new_lifetime.text() == "'_" { + removals.push(NodeOrToken::Node(syntax.clone())); + + if let Some(ws) = syntax.next_sibling_or_token() { + removals.push(ws.clone()); + } + + continue; + } + + replacements.push((syntax.clone(), new_lifetime.syntax().clone_for_update())); + } + } else if let Some(replacement_syntax) = const_and_type_map.0.get(syntax_str) { + let new_string = replacement_syntax.to_string(); + let new = if new_string == "_" { + make::wildcard_pat().syntax().clone_for_update() + } else { + replacement_syntax.clone_for_update() + }; + + replacements.push((syntax.clone(), new)); + } + } + + for (old, new) in replacements { + ted::replace(old, new); + } + + for syntax in removals { + ted::remove(syntax); + } + + updated_concrete_type.to_string() +} + +fn get_type_alias(ctx: &AssistContext<'_>, path: &ast::PathType) -> Option { + let resolved_path = ctx.sema.resolve_path(&path.path()?)?; + + // We need the generics in the correct order to be able to map any provided + // instance generics to declaration generics. The `hir::TypeAlias` doesn't + // keep the order, so we must get the `ast::TypeAlias` from the hir + // definition. + if let PathResolution::Def(hir::ModuleDef::TypeAlias(ta)) = resolved_path { + Some(ctx.sema.source(ta)?.value) + } else { + None + } +} + +enum ConstOrTypeGeneric { + ConstArg(ast::ConstArg), + TypeArg(ast::TypeArg), + ConstParam(ast::ConstParam), + TypeParam(ast::TypeParam), +} + +impl ConstOrTypeGeneric { + fn replacement_key(&self) -> Option { + // Only params are used as replacement keys. + match self { + ConstOrTypeGeneric::ConstParam(cp) => Some(cp.name()?.to_string()), + ConstOrTypeGeneric::TypeParam(tp) => Some(tp.name()?.to_string()), + _ => None, + } + } + + fn replacement_value(&self) -> Option { + Some(match self { + ConstOrTypeGeneric::ConstArg(ca) => ca.expr()?.syntax().clone(), + ConstOrTypeGeneric::TypeArg(ta) => ta.syntax().clone(), + ConstOrTypeGeneric::ConstParam(cp) => cp.default_val()?.syntax().clone(), + ConstOrTypeGeneric::TypeParam(tp) => tp.default_type()?.syntax().clone(), + }) + } +} + +fn generic_param_list_to_const_and_type_generics( + generics: &ast::GenericParamList, +) -> Vec { + let mut others = Vec::new(); + + for param in generics.generic_params() { + match param { + ast::GenericParam::LifetimeParam(_) => {} + ast::GenericParam::ConstParam(cp) => { + others.push(ConstOrTypeGeneric::ConstParam(cp)); + } + ast::GenericParam::TypeParam(tp) => others.push(ConstOrTypeGeneric::TypeParam(tp)), + } + } + + others +} + +fn generic_args_to_const_and_type_generics( + generics: &Option, +) -> Vec { + let mut others = Vec::new(); + + // It's fine for there to be no instance generics because the declaration + // might have default values or they might be inferred. + if let Some(generics) = generics { + for arg in generics.generic_args() { + match arg { + ast::GenericArg::TypeArg(ta) => { + others.push(ConstOrTypeGeneric::TypeArg(ta)); + } + ast::GenericArg::ConstArg(ca) => { + others.push(ConstOrTypeGeneric::ConstArg(ca)); + } + _ => {} + } + } + } + + others +} + +#[cfg(test)] +mod test { + use super::*; + use crate::tests::{check_assist, check_assist_not_applicable}; + + #[test] + fn empty_generic_params() { + cov_mark::check!(no_generics_params); + check_assist_not_applicable( + inline_type_alias, + r#" +type A<> = T; +fn main() { + let a: $0A; +} + "#, + ); + } + + #[test] + fn too_many_generic_args() { + cov_mark::check!(too_many_generic_args); + check_assist_not_applicable( + inline_type_alias, + r#" +type A = T; +fn main() { + let a: $0A; +} + "#, + ); + } + + #[test] + fn too_many_lifetimes() { + cov_mark::check!(too_many_lifetimes); + check_assist_not_applicable( + inline_type_alias, + r#" +type A<'a> = &'a &'b u32; +fn f<'a>() { + let a: $0A<'a, 'b> = 0; +} +"#, + ); + } + + // This must be supported in order to support "inline_alias_to_users" or + // whatever it will be called. + #[test] + fn alias_as_expression_ignored() { + check_assist_not_applicable( + inline_type_alias, + r#" +type A = Vec; +fn main() { + let a: A = $0A::new(); +} +"#, + ); + } + + #[test] + fn primitive_arg() { + check_assist( + inline_type_alias, + r#" +type A = T; +fn main() { + let a: $0A = 0; +} +"#, + r#" +type A = T; +fn main() { + let a: u32 = 0; +} +"#, + ); + } + + #[test] + fn no_generic_replacements() { + check_assist( + inline_type_alias, + r#" +type A = Vec; +fn main() { + let a: $0A; +} +"#, + r#" +type A = Vec; +fn main() { + let a: Vec; +} +"#, + ); + } + + #[test] + fn param_expression() { + check_assist( + inline_type_alias, + r#" +type A = [u32; N]; +fn main() { + let a: $0A; +} +"#, + r#" +type A = [u32; N]; +fn main() { + let a: [u32; { 1 }]; +} +"#, + ); + } + + #[test] + fn param_default_value() { + check_assist( + inline_type_alias, + r#" +type A = [u32; N]; +fn main() { + let a: $0A; +} +"#, + r#" +type A = [u32; N]; +fn main() { + let a: [u32; 1]; +} +"#, + ); + } + + #[test] + fn all_param_types() { + check_assist( + inline_type_alias, + r#" +struct Struct; +type A<'inner1, 'outer1, Outer1, const INNER1: usize, Inner1: Clone, const OUTER1: usize> = (Struct, Struct, Outer1, &'inner1 (), Inner1, &'outer1 ()); +fn foo<'inner2, 'outer2, Outer2, const INNER2: usize, Inner2, const OUTER2: usize>() { + let a: $0A<'inner2, 'outer2, Outer2, INNER2, Inner2, OUTER2>; +} +"#, + r#" +struct Struct; +type A<'inner1, 'outer1, Outer1, const INNER1: usize, Inner1: Clone, const OUTER1: usize> = (Struct, Struct, Outer1, &'inner1 (), Inner1, &'outer1 ()); +fn foo<'inner2, 'outer2, Outer2, const INNER2: usize, Inner2, const OUTER2: usize>() { + let a: (Struct, Struct, Outer2, &'inner2 (), Inner2, &'outer2 ()); +} +"#, + ); + } + + #[test] + fn omitted_lifetimes() { + check_assist( + inline_type_alias, + r#" +type A<'l, 'r> = &'l &'r u32; +fn main() { + let a: $0A; +} +"#, + r#" +type A<'l, 'r> = &'l &'r u32; +fn main() { + let a: &&u32; +} +"#, + ); + } + + #[test] + fn omitted_type() { + check_assist( + inline_type_alias, + r#" +type A<'r, 'l, T = u32> = &'l std::collections::HashMap<&'r str, T>; +fn main() { + let a: $0A<'_, '_>; +} +"#, + r#" +type A<'r, 'l, T = u32> = &'l std::collections::HashMap<&'r str, T>; +fn main() { + let a: &std::collections::HashMap<&str, u32>; +} +"#, + ); + } + + #[test] + fn omitted_everything() { + check_assist( + inline_type_alias, + r#" +type A<'r, 'l, T = u32> = &'l std::collections::HashMap<&'r str, T>; +fn main() { + let v = std::collections::HashMap<&str, u32>; + let a: $0A = &v; +} +"#, + r#" +type A<'r, 'l, T = u32> = &'l std::collections::HashMap<&'r str, T>; +fn main() { + let v = std::collections::HashMap<&str, u32>; + let a: &std::collections::HashMap<&str, u32> = &v; +} +"#, + ); + } + + // This doesn't actually cause the GenericArgsList to contain a AssocTypeArg. + #[test] + fn arg_associated_type() { + check_assist( + inline_type_alias, + r#" +trait Tra { type Assoc; fn a(); } +struct Str {} +impl Tra for Str { + type Assoc = u32; + fn a() { + type A = Vec; + let a: $0A; + } +} +"#, + r#" +trait Tra { type Assoc; fn a(); } +struct Str {} +impl Tra for Str { + type Assoc = u32; + fn a() { + type A = Vec; + let a: Vec; + } +} +"#, + ); + } + + #[test] + fn param_default_associated_type() { + check_assist( + inline_type_alias, + r#" +trait Tra { type Assoc; fn a() } +struct Str {} +impl Tra for Str { + type Assoc = u32; + fn a() { + type A = Vec; + let a: $0A; + } +} +"#, + r#" +trait Tra { type Assoc; fn a() } +struct Str {} +impl Tra for Str { + type Assoc = u32; + fn a() { + type A = Vec; + let a: Vec; + } +} +"#, + ); + } + + #[test] + fn function_pointer() { + check_assist( + inline_type_alias, + r#" +type A = fn(u32); +fn foo(a: u32) {} +fn main() { + let a: $0A = foo; +} +"#, + r#" +type A = fn(u32); +fn foo(a: u32) {} +fn main() { + let a: fn(u32) = foo; +} +"#, + ); + } + + #[test] + fn closure() { + check_assist( + inline_type_alias, + r#" +type A = Box u32>; +fn main() { + let a: $0A = Box::new(|_| 0); +} +"#, + r#" +type A = Box u32>; +fn main() { + let a: Box u32> = Box::new(|_| 0); +} +"#, + ); + } + + // Type aliases can't be used in traits, but someone might use the assist to + // fix the error. + #[test] + fn bounds() { + check_assist( + inline_type_alias, + r#"type A = std::io::Write; fn f() where T: $0A {}"#, + r#"type A = std::io::Write; fn f() where T: std::io::Write {}"#, + ); + } + + #[test] + fn function_parameter() { + check_assist( + inline_type_alias, + r#" +type A = std::io::Write; +fn f(a: impl $0A) {} +"#, + r#" +type A = std::io::Write; +fn f(a: impl std::io::Write) {} +"#, + ); + } + + #[test] + fn arg_expression() { + check_assist( + inline_type_alias, + r#" +type A = [u32; N]; +fn main() { + let a: $0A<{ 1 + 1 }>; +} +"#, + r#" +type A = [u32; N]; +fn main() { + let a: [u32; { 1 + 1 }]; +} +"#, + ) + } + + #[test] + fn alias_instance_generic_path() { + check_assist( + inline_type_alias, + r#" +type A = [u32; N]; +fn main() { + let a: $0A; +} +"#, + r#" +type A = [u32; N]; +fn main() { + let a: [u32; u32::MAX]; +} +"#, + ) + } + + #[test] + fn generic_type() { + check_assist( + inline_type_alias, + r#" +type A = String; +fn f(a: Vec<$0A>) {} +"#, + r#" +type A = String; +fn f(a: Vec) {} +"#, + ); + } + + #[test] + fn missing_replacement_param() { + cov_mark::check!(missing_replacement_param); + check_assist_not_applicable( + inline_type_alias, + r#" +type A = Vec; +fn main() { + let a: $0A; +} +"#, + ); + } + + #[test] + fn full_path_type_is_replaced() { + check_assist( + inline_type_alias, + r#" +mod foo { + pub type A = String; +} +fn main() { + let a: foo::$0A; +} +"#, + r#" +mod foo { + pub type A = String; +} +fn main() { + let a: String; +} +"#, + ); + } + + #[test] + fn inline_self_type() { + check_assist( + inline_type_alias, + r#" +struct Strukt; + +impl Strukt { + fn new() -> Self$0 {} +} +"#, + r#" +struct Strukt; + +impl Strukt { + fn new() -> Strukt {} +} +"#, + ); + check_assist( + inline_type_alias, + r#" +struct Strukt<'a, T, const C: usize>(&'a [T; C]); + +impl Strukt<'_, T, C> { + fn new() -> Self$0 {} +} +"#, + r#" +struct Strukt<'a, T, const C: usize>(&'a [T; C]); + +impl Strukt<'_, T, C> { + fn new() -> Strukt<'_, T, C> {} +} +"#, + ); + check_assist( + inline_type_alias, + r#" +struct Strukt<'a, T, const C: usize>(&'a [T; C]); + +trait Tr<'b, T> {} + +impl Tr<'static, u8> for Strukt<'_, T, C> { + fn new() -> Self$0 {} +} +"#, + r#" +struct Strukt<'a, T, const C: usize>(&'a [T; C]); + +trait Tr<'b, T> {} + +impl Tr<'static, u8> for Strukt<'_, T, C> { + fn new() -> Strukt<'_, T, C> {} +} +"#, + ); + + check_assist_not_applicable( + inline_type_alias, + r#" +trait Tr { + fn new() -> Self$0; +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/introduce_named_generic.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/introduce_named_generic.rs new file mode 100644 index 000000000..062c816ae --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/introduce_named_generic.rs @@ -0,0 +1,144 @@ +use syntax::{ + ast::{self, edit_in_place::GenericParamsOwnerEdit, make, AstNode}, + ted, +}; + +use crate::{utils::suggest_name, AssistContext, AssistId, AssistKind, Assists}; + +// Assist: introduce_named_generic +// +// Replaces `impl Trait` function argument with the named generic. +// +// ``` +// fn foo(bar: $0impl Bar) {} +// ``` +// -> +// ``` +// fn foo(bar: B) {} +// ``` +pub(crate) fn introduce_named_generic(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let impl_trait_type = ctx.find_node_at_offset::()?; + let param = impl_trait_type.syntax().parent().and_then(ast::Param::cast)?; + let fn_ = param.syntax().ancestors().find_map(ast::Fn::cast)?; + + let type_bound_list = impl_trait_type.type_bound_list()?; + + let target = fn_.syntax().text_range(); + acc.add( + AssistId("introduce_named_generic", AssistKind::RefactorRewrite), + "Replace impl trait with generic", + target, + |edit| { + let impl_trait_type = edit.make_mut(impl_trait_type); + let fn_ = edit.make_mut(fn_); + + let type_param_name = suggest_name::for_generic_parameter(&impl_trait_type); + + let type_param = make::type_param(make::name(&type_param_name), Some(type_bound_list)) + .clone_for_update(); + let new_ty = make::ty(&type_param_name).clone_for_update(); + + ted::replace(impl_trait_type.syntax(), new_ty.syntax()); + fn_.get_or_create_generic_param_list().add_generic_param(type_param.into()) + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::check_assist; + + #[test] + fn introduce_named_generic_params() { + check_assist( + introduce_named_generic, + r#"fn foo(bar: $0impl Bar) {}"#, + r#"fn foo(bar: B) {}"#, + ); + } + + #[test] + fn replace_impl_trait_without_generic_params() { + check_assist( + introduce_named_generic, + r#"fn foo(bar: $0impl Bar) {}"#, + r#"fn foo(bar: B) {}"#, + ); + } + + #[test] + fn replace_two_impl_trait_with_generic_params() { + check_assist( + introduce_named_generic, + r#"fn foo(foo: impl Foo, bar: $0impl Bar) {}"#, + r#"fn foo(foo: impl Foo, bar: B) {}"#, + ); + } + + #[test] + fn replace_impl_trait_with_empty_generic_params() { + check_assist( + introduce_named_generic, + r#"fn foo<>(bar: $0impl Bar) {}"#, + r#"fn foo(bar: B) {}"#, + ); + } + + #[test] + fn replace_impl_trait_with_empty_multiline_generic_params() { + check_assist( + introduce_named_generic, + r#" +fn foo< +>(bar: $0impl Bar) {} +"#, + r#" +fn foo(bar: B) {} +"#, + ); + } + + #[test] + fn replace_impl_trait_with_exist_generic_letter() { + // FIXME: This is wrong, we should pick a different name if the one we + // want is already bound. + check_assist( + introduce_named_generic, + r#"fn foo(bar: $0impl Bar) {}"#, + r#"fn foo(bar: B) {}"#, + ); + } + + #[test] + fn replace_impl_trait_with_multiline_generic_params() { + check_assist( + introduce_named_generic, + r#" +fn foo< + G: Foo, + F, + H, +>(bar: $0impl Bar) {} +"#, + r#" +fn foo< + G: Foo, + F, + H, B: Bar, +>(bar: B) {} +"#, + ); + } + + #[test] + fn replace_impl_trait_multiple() { + check_assist( + introduce_named_generic, + r#"fn foo(bar: $0impl Foo + Bar) {}"#, + r#"fn foo(bar: F) {}"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/introduce_named_lifetime.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/introduce_named_lifetime.rs new file mode 100644 index 000000000..ce91dd237 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/introduce_named_lifetime.rs @@ -0,0 +1,338 @@ +use ide_db::FxHashSet; +use syntax::{ + ast::{self, edit_in_place::GenericParamsOwnerEdit, make, HasGenericParams}, + ted::{self, Position}, + AstNode, TextRange, +}; + +use crate::{assist_context::AssistBuilder, AssistContext, AssistId, AssistKind, Assists}; + +static ASSIST_NAME: &str = "introduce_named_lifetime"; +static ASSIST_LABEL: &str = "Introduce named lifetime"; + +// Assist: introduce_named_lifetime +// +// Change an anonymous lifetime to a named lifetime. +// +// ``` +// impl Cursor<'_$0> { +// fn node(self) -> &SyntaxNode { +// match self { +// Cursor::Replace(node) | Cursor::Before(node) => node, +// } +// } +// } +// ``` +// -> +// ``` +// impl<'a> Cursor<'a> { +// fn node(self) -> &SyntaxNode { +// match self { +// Cursor::Replace(node) | Cursor::Before(node) => node, +// } +// } +// } +// ``` +pub(crate) fn introduce_named_lifetime(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + // FIXME: How can we handle renaming any one of multiple anonymous lifetimes? + // FIXME: should also add support for the case fun(f: &Foo) -> &$0Foo + let lifetime = + ctx.find_node_at_offset::().filter(|lifetime| lifetime.text() == "'_")?; + let lifetime_loc = lifetime.lifetime_ident_token()?.text_range(); + + if let Some(fn_def) = lifetime.syntax().ancestors().find_map(ast::Fn::cast) { + generate_fn_def_assist(acc, fn_def, lifetime_loc, lifetime) + } else if let Some(impl_def) = lifetime.syntax().ancestors().find_map(ast::Impl::cast) { + generate_impl_def_assist(acc, impl_def, lifetime_loc, lifetime) + } else { + None + } +} + +/// Generate the assist for the fn def case +fn generate_fn_def_assist( + acc: &mut Assists, + fn_def: ast::Fn, + lifetime_loc: TextRange, + lifetime: ast::Lifetime, +) -> Option<()> { + let param_list: ast::ParamList = fn_def.param_list()?; + let new_lifetime_param = generate_unique_lifetime_param_name(fn_def.generic_param_list())?; + let self_param = + // use the self if it's a reference and has no explicit lifetime + param_list.self_param().filter(|p| p.lifetime().is_none() && p.amp_token().is_some()); + // compute the location which implicitly has the same lifetime as the anonymous lifetime + let loc_needing_lifetime = if let Some(self_param) = self_param { + // if we have a self reference, use that + Some(NeedsLifetime::SelfParam(self_param)) + } else { + // otherwise, if there's a single reference parameter without a named liftime, use that + let fn_params_without_lifetime: Vec<_> = param_list + .params() + .filter_map(|param| match param.ty() { + Some(ast::Type::RefType(ascribed_type)) if ascribed_type.lifetime().is_none() => { + Some(NeedsLifetime::RefType(ascribed_type)) + } + _ => None, + }) + .collect(); + match fn_params_without_lifetime.len() { + 1 => Some(fn_params_without_lifetime.into_iter().next()?), + 0 => None, + // multiple unnnamed is invalid. assist is not applicable + _ => return None, + } + }; + acc.add(AssistId(ASSIST_NAME, AssistKind::Refactor), ASSIST_LABEL, lifetime_loc, |builder| { + let fn_def = builder.make_mut(fn_def); + let lifetime = builder.make_mut(lifetime); + let loc_needing_lifetime = + loc_needing_lifetime.and_then(|it| it.make_mut(builder).to_position()); + + fn_def.get_or_create_generic_param_list().add_generic_param( + make::lifetime_param(new_lifetime_param.clone()).clone_for_update().into(), + ); + ted::replace(lifetime.syntax(), new_lifetime_param.clone_for_update().syntax()); + if let Some(position) = loc_needing_lifetime { + ted::insert(position, new_lifetime_param.clone_for_update().syntax()); + } + }) +} + +/// Generate the assist for the impl def case +fn generate_impl_def_assist( + acc: &mut Assists, + impl_def: ast::Impl, + lifetime_loc: TextRange, + lifetime: ast::Lifetime, +) -> Option<()> { + let new_lifetime_param = generate_unique_lifetime_param_name(impl_def.generic_param_list())?; + acc.add(AssistId(ASSIST_NAME, AssistKind::Refactor), ASSIST_LABEL, lifetime_loc, |builder| { + let impl_def = builder.make_mut(impl_def); + let lifetime = builder.make_mut(lifetime); + + impl_def.get_or_create_generic_param_list().add_generic_param( + make::lifetime_param(new_lifetime_param.clone()).clone_for_update().into(), + ); + ted::replace(lifetime.syntax(), new_lifetime_param.clone_for_update().syntax()); + }) +} + +/// Given a type parameter list, generate a unique lifetime parameter name +/// which is not in the list +fn generate_unique_lifetime_param_name( + existing_type_param_list: Option, +) -> Option { + match existing_type_param_list { + Some(type_params) => { + let used_lifetime_params: FxHashSet<_> = + type_params.lifetime_params().map(|p| p.syntax().text().to_string()).collect(); + ('a'..='z').map(|it| format!("'{}", it)).find(|it| !used_lifetime_params.contains(it)) + } + None => Some("'a".to_string()), + } + .map(|it| make::lifetime(&it)) +} + +enum NeedsLifetime { + SelfParam(ast::SelfParam), + RefType(ast::RefType), +} + +impl NeedsLifetime { + fn make_mut(self, builder: &mut AssistBuilder) -> Self { + match self { + Self::SelfParam(it) => Self::SelfParam(builder.make_mut(it)), + Self::RefType(it) => Self::RefType(builder.make_mut(it)), + } + } + + fn to_position(self) -> Option { + match self { + Self::SelfParam(it) => Some(Position::after(it.amp_token()?)), + Self::RefType(it) => Some(Position::after(it.amp_token()?)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::{check_assist, check_assist_not_applicable}; + + #[test] + fn test_example_case() { + check_assist( + introduce_named_lifetime, + r#"impl Cursor<'_$0> { + fn node(self) -> &SyntaxNode { + match self { + Cursor::Replace(node) | Cursor::Before(node) => node, + } + } + }"#, + r#"impl<'a> Cursor<'a> { + fn node(self) -> &SyntaxNode { + match self { + Cursor::Replace(node) | Cursor::Before(node) => node, + } + } + }"#, + ); + } + + #[test] + fn test_example_case_simplified() { + check_assist( + introduce_named_lifetime, + r#"impl Cursor<'_$0> {"#, + r#"impl<'a> Cursor<'a> {"#, + ); + } + + #[test] + fn test_example_case_cursor_after_tick() { + check_assist( + introduce_named_lifetime, + r#"impl Cursor<'$0_> {"#, + r#"impl<'a> Cursor<'a> {"#, + ); + } + + #[test] + fn test_impl_with_other_type_param() { + check_assist( + introduce_named_lifetime, + "impl fmt::Display for SepByBuilder<'_$0, I> + where + I: Iterator, + I::Item: fmt::Display, + {", + "impl fmt::Display for SepByBuilder<'a, I> + where + I: Iterator, + I::Item: fmt::Display, + {", + ) + } + + #[test] + fn test_example_case_cursor_before_tick() { + check_assist( + introduce_named_lifetime, + r#"impl Cursor<$0'_> {"#, + r#"impl<'a> Cursor<'a> {"#, + ); + } + + #[test] + fn test_not_applicable_cursor_position() { + check_assist_not_applicable(introduce_named_lifetime, r#"impl Cursor<'_>$0 {"#); + check_assist_not_applicable(introduce_named_lifetime, r#"impl Cursor$0<'_> {"#); + } + + #[test] + fn test_not_applicable_lifetime_already_name() { + check_assist_not_applicable(introduce_named_lifetime, r#"impl Cursor<'a$0> {"#); + check_assist_not_applicable(introduce_named_lifetime, r#"fn my_fun<'a>() -> X<'a$0>"#); + } + + #[test] + fn test_with_type_parameter() { + check_assist( + introduce_named_lifetime, + r#"impl Cursor"#, + r#"impl Cursor"#, + ); + } + + #[test] + fn test_with_existing_lifetime_name_conflict() { + check_assist( + introduce_named_lifetime, + r#"impl<'a, 'b> Cursor<'a, 'b, '_$0>"#, + r#"impl<'a, 'b, 'c> Cursor<'a, 'b, 'c>"#, + ); + } + + #[test] + fn test_function_return_value_anon_lifetime_param() { + check_assist( + introduce_named_lifetime, + r#"fn my_fun() -> X<'_$0>"#, + r#"fn my_fun<'a>() -> X<'a>"#, + ); + } + + #[test] + fn test_function_return_value_anon_reference_lifetime() { + check_assist( + introduce_named_lifetime, + r#"fn my_fun() -> &'_$0 X"#, + r#"fn my_fun<'a>() -> &'a X"#, + ); + } + + #[test] + fn test_function_param_anon_lifetime() { + check_assist( + introduce_named_lifetime, + r#"fn my_fun(x: X<'_$0>)"#, + r#"fn my_fun<'a>(x: X<'a>)"#, + ); + } + + #[test] + fn test_function_add_lifetime_to_params() { + check_assist( + introduce_named_lifetime, + r#"fn my_fun(f: &Foo) -> X<'_$0>"#, + r#"fn my_fun<'a>(f: &'a Foo) -> X<'a>"#, + ); + } + + #[test] + fn test_function_add_lifetime_to_params_in_presence_of_other_lifetime() { + check_assist( + introduce_named_lifetime, + r#"fn my_fun<'other>(f: &Foo, b: &'other Bar) -> X<'_$0>"#, + r#"fn my_fun<'other, 'a>(f: &'a Foo, b: &'other Bar) -> X<'a>"#, + ); + } + + #[test] + fn test_function_not_applicable_without_self_and_multiple_unnamed_param_lifetimes() { + // this is not permitted under lifetime elision rules + check_assist_not_applicable( + introduce_named_lifetime, + r#"fn my_fun(f: &Foo, b: &Bar) -> X<'_$0>"#, + ); + } + + #[test] + fn test_function_add_lifetime_to_self_ref_param() { + check_assist( + introduce_named_lifetime, + r#"fn my_fun<'other>(&self, f: &Foo, b: &'other Bar) -> X<'_$0>"#, + r#"fn my_fun<'other, 'a>(&'a self, f: &Foo, b: &'other Bar) -> X<'a>"#, + ); + } + + #[test] + fn test_function_add_lifetime_to_param_with_non_ref_self() { + check_assist( + introduce_named_lifetime, + r#"fn my_fun<'other>(self, f: &Foo, b: &'other Bar) -> X<'_$0>"#, + r#"fn my_fun<'other, 'a>(self, f: &'a Foo, b: &'other Bar) -> X<'a>"#, + ); + } + + #[test] + fn test_function_add_lifetime_to_self_ref_mut() { + check_assist( + introduce_named_lifetime, + r#"fn foo(&mut self) -> &'_$0 ()"#, + r#"fn foo<'a>(&'a mut self) -> &'a ()"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/invert_if.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/invert_if.rs new file mode 100644 index 000000000..547158e29 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/invert_if.rs @@ -0,0 +1,144 @@ +use ide_db::syntax_helpers::node_ext::is_pattern_cond; +use syntax::{ + ast::{self, AstNode}, + T, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + utils::invert_boolean_expression, + AssistId, AssistKind, +}; + +// Assist: invert_if +// +// This transforms if expressions of the form `if !x {A} else {B}` into `if x {B} else {A}` +// This also works with `!=`. This assist can only be applied with the cursor on `if`. +// +// ``` +// fn main() { +// if$0 !y { A } else { B } +// } +// ``` +// -> +// ``` +// fn main() { +// if y { B } else { A } +// } +// ``` +pub(crate) fn invert_if(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let if_keyword = ctx.find_token_syntax_at_offset(T![if])?; + let expr = ast::IfExpr::cast(if_keyword.parent()?)?; + let if_range = if_keyword.text_range(); + let cursor_in_range = if_range.contains_range(ctx.selection_trimmed()); + if !cursor_in_range { + return None; + } + + let cond = expr.condition()?; + // This assist should not apply for if-let. + if is_pattern_cond(cond.clone()) { + return None; + } + + let then_node = expr.then_branch()?.syntax().clone(); + let else_block = match expr.else_branch()? { + ast::ElseBranch::Block(it) => it, + ast::ElseBranch::IfExpr(_) => return None, + }; + + acc.add(AssistId("invert_if", AssistKind::RefactorRewrite), "Invert if", if_range, |edit| { + let flip_cond = invert_boolean_expression(cond.clone()); + edit.replace_ast(cond, flip_cond); + + let else_node = else_block.syntax(); + let else_range = else_node.text_range(); + let then_range = then_node.text_range(); + + edit.replace(else_range, then_node.text()); + edit.replace(then_range, else_node.text()); + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable}; + + #[test] + fn invert_if_composite_condition() { + check_assist( + invert_if, + "fn f() { i$0f x == 3 || x == 4 || x == 5 { 1 } else { 3 * 2 } }", + "fn f() { if !(x == 3 || x == 4 || x == 5) { 3 * 2 } else { 1 } }", + ) + } + + #[test] + fn invert_if_remove_not_parentheses() { + check_assist( + invert_if, + "fn f() { i$0f !(x == 3 || x == 4 || x == 5) { 3 * 2 } else { 1 } }", + "fn f() { if x == 3 || x == 4 || x == 5 { 1 } else { 3 * 2 } }", + ) + } + + #[test] + fn invert_if_remove_inequality() { + check_assist( + invert_if, + "fn f() { i$0f x != 3 { 1 } else { 3 + 2 } }", + "fn f() { if x == 3 { 3 + 2 } else { 1 } }", + ) + } + + #[test] + fn invert_if_remove_not() { + check_assist( + invert_if, + "fn f() { $0if !cond { 3 * 2 } else { 1 } }", + "fn f() { if cond { 1 } else { 3 * 2 } }", + ) + } + + #[test] + fn invert_if_general_case() { + check_assist( + invert_if, + "fn f() { i$0f cond { 3 * 2 } else { 1 } }", + "fn f() { if !cond { 1 } else { 3 * 2 } }", + ) + } + + #[test] + fn invert_if_doesnt_apply_with_cursor_not_on_if() { + check_assist_not_applicable(invert_if, "fn f() { if !$0cond { 3 * 2 } else { 1 } }") + } + + #[test] + fn invert_if_doesnt_apply_with_if_let() { + check_assist_not_applicable( + invert_if, + "fn f() { i$0f let Some(_) = Some(1) { 1 } else { 0 } }", + ) + } + + #[test] + fn invert_if_option_case() { + check_assist( + invert_if, + "fn f() { if$0 doc_style.is_some() { Class::DocComment } else { Class::Comment } }", + "fn f() { if doc_style.is_none() { Class::Comment } else { Class::DocComment } }", + ) + } + + #[test] + fn invert_if_result_case() { + check_assist( + invert_if, + "fn f() { i$0f doc_style.is_err() { Class::Err } else { Class::Ok } }", + "fn f() { if doc_style.is_ok() { Class::Ok } else { Class::Err } }", + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/merge_imports.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/merge_imports.rs new file mode 100644 index 000000000..7e102ceba --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/merge_imports.rs @@ -0,0 +1,570 @@ +use either::Either; +use ide_db::imports::merge_imports::{try_merge_imports, try_merge_trees, MergeBehavior}; +use syntax::{algo::neighbor, ast, match_ast, ted, AstNode, SyntaxElement, SyntaxNode}; + +use crate::{ + assist_context::{AssistContext, Assists}, + utils::next_prev, + AssistId, AssistKind, +}; + +use Edit::*; + +// Assist: merge_imports +// +// Merges two imports with a common prefix. +// +// ``` +// use std::$0fmt::Formatter; +// use std::io; +// ``` +// -> +// ``` +// use std::{fmt::Formatter, io}; +// ``` +pub(crate) fn merge_imports(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let (target, edits) = if ctx.has_empty_selection() { + // Merge a neighbor + let tree: ast::UseTree = ctx.find_node_at_offset()?; + let target = tree.syntax().text_range(); + + let edits = if let Some(use_item) = tree.syntax().parent().and_then(ast::Use::cast) { + let mut neighbor = next_prev().find_map(|dir| neighbor(&use_item, dir)).into_iter(); + use_item.try_merge_from(&mut neighbor) + } else { + let mut neighbor = next_prev().find_map(|dir| neighbor(&tree, dir)).into_iter(); + tree.try_merge_from(&mut neighbor) + }; + (target, edits?) + } else { + // Merge selected + let selection_range = ctx.selection_trimmed(); + let parent_node = match ctx.covering_element() { + SyntaxElement::Node(n) => n, + SyntaxElement::Token(t) => t.parent()?, + }; + let mut selected_nodes = + parent_node.children().filter(|it| selection_range.contains_range(it.text_range())); + + let first_selected = selected_nodes.next()?; + let edits = match_ast! { + match first_selected { + ast::Use(use_item) => { + use_item.try_merge_from(&mut selected_nodes.filter_map(ast::Use::cast)) + }, + ast::UseTree(use_tree) => { + use_tree.try_merge_from(&mut selected_nodes.filter_map(ast::UseTree::cast)) + }, + _ => return None, + } + }; + (selection_range, edits?) + }; + + acc.add( + AssistId("merge_imports", AssistKind::RefactorRewrite), + "Merge imports", + target, + |builder| { + let edits_mut: Vec = edits + .into_iter() + .map(|it| match it { + Remove(Either::Left(it)) => Remove(Either::Left(builder.make_mut(it))), + Remove(Either::Right(it)) => Remove(Either::Right(builder.make_mut(it))), + Replace(old, new) => Replace(builder.make_syntax_mut(old), new), + }) + .collect(); + for edit in edits_mut { + match edit { + Remove(it) => it.as_ref().either(ast::Use::remove, ast::UseTree::remove), + Replace(old, new) => ted::replace(old, new), + } + } + }, + ) +} + +trait Merge: AstNode + Clone { + fn try_merge_from(self, items: &mut dyn Iterator) -> Option> { + let mut edits = Vec::new(); + let mut merged = self.clone(); + while let Some(item) = items.next() { + merged = merged.try_merge(&item)?; + edits.push(Edit::Remove(item.into_either())); + } + if !edits.is_empty() { + edits.push(Edit::replace(self, merged)); + Some(edits) + } else { + None + } + } + fn try_merge(&self, other: &Self) -> Option; + fn into_either(self) -> Either; +} + +impl Merge for ast::Use { + fn try_merge(&self, other: &Self) -> Option { + try_merge_imports(self, other, MergeBehavior::Crate) + } + fn into_either(self) -> Either { + Either::Left(self) + } +} + +impl Merge for ast::UseTree { + fn try_merge(&self, other: &Self) -> Option { + try_merge_trees(self, other, MergeBehavior::Crate) + } + fn into_either(self) -> Either { + Either::Right(self) + } +} + +enum Edit { + Remove(Either), + Replace(SyntaxNode, SyntaxNode), +} + +impl Edit { + fn replace(old: impl AstNode, new: impl AstNode) -> Self { + Edit::Replace(old.syntax().clone(), new.syntax().clone()) + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_merge_equal() { + check_assist( + merge_imports, + r" +use std::fmt$0::{Display, Debug}; +use std::fmt::{Display, Debug}; +", + r" +use std::fmt::{Display, Debug}; +", + ) + } + + #[test] + fn test_merge_first() { + check_assist( + merge_imports, + r" +use std::fmt$0::Debug; +use std::fmt::Display; +", + r" +use std::fmt::{Debug, Display}; +", + ) + } + + #[test] + fn test_merge_second() { + check_assist( + merge_imports, + r" +use std::fmt::Debug; +use std::fmt$0::Display; +", + r" +use std::fmt::{Display, Debug}; +", + ); + } + + #[test] + fn merge_self1() { + check_assist( + merge_imports, + r" +use std::fmt$0; +use std::fmt::Display; +", + r" +use std::fmt::{self, Display}; +", + ); + } + + #[test] + fn merge_self2() { + check_assist( + merge_imports, + r" +use std::{fmt, $0fmt::Display}; +", + r" +use std::{fmt::{Display, self}}; +", + ); + } + + #[test] + fn skip_pub1() { + check_assist_not_applicable( + merge_imports, + r" +pub use std::fmt$0::Debug; +use std::fmt::Display; +", + ); + } + + #[test] + fn skip_pub_last() { + check_assist_not_applicable( + merge_imports, + r" +use std::fmt$0::Debug; +pub use std::fmt::Display; +", + ); + } + + #[test] + fn skip_pub_crate_pub() { + check_assist_not_applicable( + merge_imports, + r" +pub(crate) use std::fmt$0::Debug; +pub use std::fmt::Display; +", + ); + } + + #[test] + fn skip_pub_pub_crate() { + check_assist_not_applicable( + merge_imports, + r" +pub use std::fmt$0::Debug; +pub(crate) use std::fmt::Display; +", + ); + } + + #[test] + fn merge_pub() { + check_assist( + merge_imports, + r" +pub use std::fmt$0::Debug; +pub use std::fmt::Display; +", + r" +pub use std::fmt::{Debug, Display}; +", + ) + } + + #[test] + fn merge_pub_crate() { + check_assist( + merge_imports, + r" +pub(crate) use std::fmt$0::Debug; +pub(crate) use std::fmt::Display; +", + r" +pub(crate) use std::fmt::{Debug, Display}; +", + ) + } + + #[test] + fn merge_pub_in_path_crate() { + check_assist( + merge_imports, + r" +pub(in this::path) use std::fmt$0::Debug; +pub(in this::path) use std::fmt::Display; +", + r" +pub(in this::path) use std::fmt::{Debug, Display}; +", + ) + } + + #[test] + fn test_merge_nested() { + check_assist( + merge_imports, + r" +use std::{fmt$0::Debug, fmt::Display}; +", + r" +use std::{fmt::{Debug, Display}}; +", + ); + } + + #[test] + fn test_merge_nested2() { + check_assist( + merge_imports, + r" +use std::{fmt::Debug, fmt$0::Display}; +", + r" +use std::{fmt::{Display, Debug}}; +", + ); + } + + #[test] + fn test_merge_with_nested_self_item() { + check_assist( + merge_imports, + r" +use std$0::{fmt::{Write, Display}}; +use std::{fmt::{self, Debug}}; +", + r" +use std::{fmt::{Write, Display, self, Debug}}; +", + ); + } + + #[test] + fn test_merge_with_nested_self_item2() { + check_assist( + merge_imports, + r" +use std$0::{fmt::{self, Debug}}; +use std::{fmt::{Write, Display}}; +", + r" +use std::{fmt::{self, Debug, Write, Display}}; +", + ); + } + + #[test] + fn test_merge_self_with_nested_self_item() { + check_assist( + merge_imports, + r" +use std::{fmt$0::{self, Debug}, fmt::{Write, Display}}; +", + r" +use std::{fmt::{self, Debug, Write, Display}}; +", + ); + } + + #[test] + fn test_merge_nested_self_and_empty() { + check_assist( + merge_imports, + r" +use foo::$0{bar::{self}}; +use foo::{bar}; +", + r" +use foo::{bar::{self}}; +", + ) + } + + #[test] + fn test_merge_nested_empty_and_self() { + check_assist( + merge_imports, + r" +use foo::$0{bar}; +use foo::{bar::{self}}; +", + r" +use foo::{bar::{self}}; +", + ) + } + + #[test] + fn test_merge_nested_list_self_and_glob() { + check_assist( + merge_imports, + r" +use std$0::{fmt::*}; +use std::{fmt::{self, Display}}; +", + r" +use std::{fmt::{*, self, Display}}; +", + ) + } + + #[test] + fn test_merge_single_wildcard_diff_prefixes() { + check_assist( + merge_imports, + r" +use std$0::cell::*; +use std::str; +", + r" +use std::{cell::*, str}; +", + ) + } + + #[test] + fn test_merge_both_wildcard_diff_prefixes() { + check_assist( + merge_imports, + r" +use std$0::cell::*; +use std::str::*; +", + r" +use std::{cell::*, str::*}; +", + ) + } + + #[test] + fn removes_just_enough_whitespace() { + check_assist( + merge_imports, + r" +use foo$0::bar; +use foo::baz; + +/// Doc comment +", + r" +use foo::{bar, baz}; + +/// Doc comment +", + ); + } + + #[test] + fn works_with_trailing_comma() { + check_assist( + merge_imports, + r" +use { + foo$0::bar, + foo::baz, +}; +", + r" +use { + foo::{bar, baz}, +}; +", + ); + check_assist( + merge_imports, + r" +use { + foo::baz, + foo$0::bar, +}; +", + r" +use { + foo::{bar, baz}, +}; +", + ); + } + + #[test] + fn test_double_comma() { + check_assist( + merge_imports, + r" +use foo::bar::baz; +use foo::$0{ + FooBar, +}; +", + r" +use foo::{ + FooBar, bar::baz, +}; +", + ) + } + + #[test] + fn test_empty_use() { + check_assist_not_applicable( + merge_imports, + r" +use std::$0 +fn main() {}", + ); + } + + #[test] + fn split_glob() { + check_assist( + merge_imports, + r" +use foo::$0*; +use foo::bar::Baz; +", + r" +use foo::{*, bar::Baz}; +", + ); + } + + #[test] + fn merge_selection_uses() { + check_assist( + merge_imports, + r" +use std::fmt::Error; +$0use std::fmt::Display; +use std::fmt::Debug; +use std::fmt::Write; +$0use std::fmt::Result; +", + r" +use std::fmt::Error; +use std::fmt::{Display, Debug, Write}; +use std::fmt::Result; +", + ); + } + + #[test] + fn merge_selection_use_trees() { + check_assist( + merge_imports, + r" +use std::{ + fmt::Error, + $0fmt::Display, + fmt::Debug, + fmt::Write,$0 + fmt::Result, +};", + r" +use std::{ + fmt::Error, + fmt::{Display, Debug, Write}, + fmt::Result, +};", + ); + // FIXME: Remove redundant braces. See also unnecessary-braces diagnostic. + check_assist( + merge_imports, + r"use std::$0{fmt::Display, fmt::Debug}$0;", + r"use std::{fmt::{Display, Debug}};", + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/merge_match_arms.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/merge_match_arms.rs new file mode 100644 index 000000000..c24015b1c --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/merge_match_arms.rs @@ -0,0 +1,822 @@ +use hir::TypeInfo; +use std::{collections::HashMap, iter::successors}; +use syntax::{ + algo::neighbor, + ast::{self, AstNode, HasName}, + Direction, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists, TextRange}; + +// Assist: merge_match_arms +// +// Merges the current match arm with the following if their bodies are identical. +// +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// $0Action::Move(..) => foo(), +// Action::Stop => foo(), +// } +// } +// ``` +// -> +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// Action::Move(..) | Action::Stop => foo(), +// } +// } +// ``` +pub(crate) fn merge_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let current_arm = ctx.find_node_at_offset::()?; + // Don't try to handle arms with guards for now - can add support for this later + if current_arm.guard().is_some() { + return None; + } + let current_expr = current_arm.expr()?; + let current_text_range = current_arm.syntax().text_range(); + let current_arm_types = get_arm_types(ctx, ¤t_arm); + + // We check if the following match arms match this one. We could, but don't, + // compare to the previous match arm as well. + let arms_to_merge = successors(Some(current_arm), |it| neighbor(it, Direction::Next)) + .take_while(|arm| match arm.expr() { + Some(expr) if arm.guard().is_none() => { + let same_text = expr.syntax().text() == current_expr.syntax().text(); + if !same_text { + return false; + } + + are_same_types(¤t_arm_types, arm, ctx) + } + _ => false, + }) + .collect::>(); + + if arms_to_merge.len() <= 1 { + return None; + } + + acc.add( + AssistId("merge_match_arms", AssistKind::RefactorRewrite), + "Merge match arms", + current_text_range, + |edit| { + let pats = if arms_to_merge.iter().any(contains_placeholder) { + "_".into() + } else { + arms_to_merge + .iter() + .filter_map(ast::MatchArm::pat) + .map(|x| x.syntax().to_string()) + .collect::>() + .join(" | ") + }; + + let arm = format!("{} => {},", pats, current_expr.syntax().text()); + + if let [first, .., last] = &*arms_to_merge { + let start = first.syntax().text_range().start(); + let end = last.syntax().text_range().end(); + + edit.replace(TextRange::new(start, end), arm); + } + }, + ) +} + +fn contains_placeholder(a: &ast::MatchArm) -> bool { + matches!(a.pat(), Some(ast::Pat::WildcardPat(..))) +} + +fn are_same_types( + current_arm_types: &HashMap>, + arm: &ast::MatchArm, + ctx: &AssistContext<'_>, +) -> bool { + let arm_types = get_arm_types(ctx, arm); + for (other_arm_type_name, other_arm_type) in arm_types { + match (current_arm_types.get(&other_arm_type_name), other_arm_type) { + (Some(Some(current_arm_type)), Some(other_arm_type)) + if other_arm_type.original == current_arm_type.original => {} + _ => return false, + } + } + + true +} + +fn get_arm_types( + context: &AssistContext<'_>, + arm: &ast::MatchArm, +) -> HashMap> { + let mut mapping: HashMap> = HashMap::new(); + + fn recurse( + map: &mut HashMap>, + ctx: &AssistContext<'_>, + pat: &Option, + ) { + if let Some(local_pat) = pat { + match pat { + Some(ast::Pat::TupleStructPat(tuple)) => { + for field in tuple.fields() { + recurse(map, ctx, &Some(field)); + } + } + Some(ast::Pat::TuplePat(tuple)) => { + for field in tuple.fields() { + recurse(map, ctx, &Some(field)); + } + } + Some(ast::Pat::RecordPat(record)) => { + if let Some(field_list) = record.record_pat_field_list() { + for field in field_list.fields() { + recurse(map, ctx, &field.pat()); + } + } + } + Some(ast::Pat::ParenPat(parentheses)) => { + recurse(map, ctx, &parentheses.pat()); + } + Some(ast::Pat::SlicePat(slice)) => { + for slice_pat in slice.pats() { + recurse(map, ctx, &Some(slice_pat)); + } + } + Some(ast::Pat::IdentPat(ident_pat)) => { + if let Some(name) = ident_pat.name() { + let pat_type = ctx.sema.type_of_pat(local_pat); + map.insert(name.text().to_string(), pat_type); + } + } + _ => (), + } + } + } + + recurse(&mut mapping, context, &arm.pat()); + mapping +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn merge_match_arms_single_patterns() { + check_assist( + merge_match_arms, + r#" +#[derive(Debug)] +enum X { A, B, C } + +fn main() { + let x = X::A; + let y = match x { + X::A => { 1i32$0 } + X::B => { 1i32 } + X::C => { 2i32 } + } +} +"#, + r#" +#[derive(Debug)] +enum X { A, B, C } + +fn main() { + let x = X::A; + let y = match x { + X::A | X::B => { 1i32 }, + X::C => { 2i32 } + } +} +"#, + ); + } + + #[test] + fn merge_match_arms_multiple_patterns() { + check_assist( + merge_match_arms, + r#" +#[derive(Debug)] +enum X { A, B, C, D, E } + +fn main() { + let x = X::A; + let y = match x { + X::A | X::B => {$0 1i32 }, + X::C | X::D => { 1i32 }, + X::E => { 2i32 }, + } +} +"#, + r#" +#[derive(Debug)] +enum X { A, B, C, D, E } + +fn main() { + let x = X::A; + let y = match x { + X::A | X::B | X::C | X::D => { 1i32 }, + X::E => { 2i32 }, + } +} +"#, + ); + } + + #[test] + fn merge_match_arms_placeholder_pattern() { + check_assist( + merge_match_arms, + r#" +#[derive(Debug)] +enum X { A, B, C, D, E } + +fn main() { + let x = X::A; + let y = match x { + X::A => { 1i32 }, + X::B => { 2i$032 }, + _ => { 2i32 } + } +} +"#, + r#" +#[derive(Debug)] +enum X { A, B, C, D, E } + +fn main() { + let x = X::A; + let y = match x { + X::A => { 1i32 }, + _ => { 2i32 }, + } +} +"#, + ); + } + + #[test] + fn merges_all_subsequent_arms() { + check_assist( + merge_match_arms, + r#" +enum X { A, B, C, D, E } + +fn main() { + match X::A { + X::A$0 => 92, + X::B => 92, + X::C => 92, + X::D => 62, + _ => panic!(), + } +} +"#, + r#" +enum X { A, B, C, D, E } + +fn main() { + match X::A { + X::A | X::B | X::C => 92, + X::D => 62, + _ => panic!(), + } +} +"#, + ) + } + + #[test] + fn merge_match_arms_rejects_guards() { + check_assist_not_applicable( + merge_match_arms, + r#" +#[derive(Debug)] +enum X { + A(i32), + B, + C +} + +fn main() { + let x = X::A; + let y = match x { + X::A(a) if a > 5 => { $01i32 }, + X::B => { 1i32 }, + X::C => { 2i32 } + } +} +"#, + ); + } + + #[test] + fn merge_match_arms_different_type() { + check_assist_not_applicable( + merge_match_arms, + r#" +//- minicore: result +fn func() { + match Result::::Ok(0f64) { + Ok(x) => $0x.classify(), + Err(x) => x.classify() + }; +} +"#, + ); + } + + #[test] + fn merge_match_arms_different_type_multiple_fields() { + check_assist_not_applicable( + merge_match_arms, + r#" +//- minicore: result +fn func() { + match Result::<(f64, f64), (f32, f32)>::Ok((0f64, 0f64)) { + Ok(x) => $0x.1.classify(), + Err(x) => x.1.classify() + }; +} +"#, + ); + } + + #[test] + fn merge_match_arms_same_type_multiple_fields() { + check_assist( + merge_match_arms, + r#" +//- minicore: result +fn func() { + match Result::<(f64, f64), (f64, f64)>::Ok((0f64, 0f64)) { + Ok(x) => $0x.1.classify(), + Err(x) => x.1.classify() + }; +} +"#, + r#" +fn func() { + match Result::<(f64, f64), (f64, f64)>::Ok((0f64, 0f64)) { + Ok(x) | Err(x) => x.1.classify(), + }; +} +"#, + ); + } + + #[test] + fn merge_match_arms_same_type_subsequent_arm_with_different_type_in_other() { + check_assist( + merge_match_arms, + r#" +enum MyEnum { + OptionA(f32), + OptionB(f32), + OptionC(f64) +} + +fn func(e: MyEnum) { + match e { + MyEnum::OptionA(x) => $0x.classify(), + MyEnum::OptionB(x) => x.classify(), + MyEnum::OptionC(x) => x.classify(), + }; +} +"#, + r#" +enum MyEnum { + OptionA(f32), + OptionB(f32), + OptionC(f64) +} + +fn func(e: MyEnum) { + match e { + MyEnum::OptionA(x) | MyEnum::OptionB(x) => x.classify(), + MyEnum::OptionC(x) => x.classify(), + }; +} +"#, + ); + } + + #[test] + fn merge_match_arms_same_type_skip_arm_with_different_type_in_between() { + check_assist_not_applicable( + merge_match_arms, + r#" +enum MyEnum { + OptionA(f32), + OptionB(f64), + OptionC(f32) +} + +fn func(e: MyEnum) { + match e { + MyEnum::OptionA(x) => $0x.classify(), + MyEnum::OptionB(x) => x.classify(), + MyEnum::OptionC(x) => x.classify(), + }; +} +"#, + ); + } + + #[test] + fn merge_match_arms_same_type_different_number_of_fields() { + check_assist_not_applicable( + merge_match_arms, + r#" +//- minicore: result +fn func() { + match Result::<(f64, f64, f64), (f64, f64)>::Ok((0f64, 0f64, 0f64)) { + Ok(x) => $0x.1.classify(), + Err(x) => x.1.classify() + }; +} +"#, + ); + } + + #[test] + fn merge_match_same_destructuring_different_types() { + check_assist_not_applicable( + merge_match_arms, + r#" +struct Point { + x: i32, + y: i32, +} + +fn func() { + let p = Point { x: 0, y: 7 }; + + match p { + Point { x, y: 0 } => $0"", + Point { x: 0, y } => "", + Point { x, y } => "", + }; +} +"#, + ); + } + + #[test] + fn merge_match_arms_range() { + check_assist( + merge_match_arms, + r#" +fn func() { + let x = 'c'; + + match x { + 'a'..='j' => $0"", + 'c'..='z' => "", + _ => "other", + }; +} +"#, + r#" +fn func() { + let x = 'c'; + + match x { + 'a'..='j' | 'c'..='z' => "", + _ => "other", + }; +} +"#, + ); + } + + #[test] + fn merge_match_arms_enum_without_field() { + check_assist_not_applicable( + merge_match_arms, + r#" +enum MyEnum { + NoField, + AField(u8) +} + +fn func(x: MyEnum) { + match x { + MyEnum::NoField => $0"", + MyEnum::AField(x) => "" + }; +} + "#, + ) + } + + #[test] + fn merge_match_arms_enum_destructuring_different_types() { + check_assist_not_applicable( + merge_match_arms, + r#" +enum MyEnum { + Move { x: i32, y: i32 }, + Write(String), +} + +fn func(x: MyEnum) { + match x { + MyEnum::Move { x, y } => $0"", + MyEnum::Write(text) => "", + }; +} + "#, + ) + } + + #[test] + fn merge_match_arms_enum_destructuring_same_types() { + check_assist( + merge_match_arms, + r#" +enum MyEnum { + Move { x: i32, y: i32 }, + Crawl { x: i32, y: i32 } +} + +fn func(x: MyEnum) { + match x { + MyEnum::Move { x, y } => $0"", + MyEnum::Crawl { x, y } => "", + }; +} + "#, + r#" +enum MyEnum { + Move { x: i32, y: i32 }, + Crawl { x: i32, y: i32 } +} + +fn func(x: MyEnum) { + match x { + MyEnum::Move { x, y } | MyEnum::Crawl { x, y } => "", + }; +} + "#, + ) + } + + #[test] + fn merge_match_arms_enum_destructuring_same_types_different_name() { + check_assist_not_applicable( + merge_match_arms, + r#" +enum MyEnum { + Move { x: i32, y: i32 }, + Crawl { a: i32, b: i32 } +} + +fn func(x: MyEnum) { + match x { + MyEnum::Move { x, y } => $0"", + MyEnum::Crawl { a, b } => "", + }; +} + "#, + ) + } + + #[test] + fn merge_match_arms_enum_nested_pattern_different_names() { + check_assist_not_applicable( + merge_match_arms, + r#" +enum Color { + Rgb(i32, i32, i32), + Hsv(i32, i32, i32), +} + +enum Message { + Quit, + Move { x: i32, y: i32 }, + Write(String), + ChangeColor(Color), +} + +fn main(msg: Message) { + match msg { + Message::ChangeColor(Color::Rgb(r, g, b)) => $0"", + Message::ChangeColor(Color::Hsv(h, s, v)) => "", + _ => "other" + }; +} + "#, + ) + } + + #[test] + fn merge_match_arms_enum_nested_pattern_same_names() { + check_assist( + merge_match_arms, + r#" +enum Color { + Rgb(i32, i32, i32), + Hsv(i32, i32, i32), +} + +enum Message { + Quit, + Move { x: i32, y: i32 }, + Write(String), + ChangeColor(Color), +} + +fn main(msg: Message) { + match msg { + Message::ChangeColor(Color::Rgb(a, b, c)) => $0"", + Message::ChangeColor(Color::Hsv(a, b, c)) => "", + _ => "other" + }; +} + "#, + r#" +enum Color { + Rgb(i32, i32, i32), + Hsv(i32, i32, i32), +} + +enum Message { + Quit, + Move { x: i32, y: i32 }, + Write(String), + ChangeColor(Color), +} + +fn main(msg: Message) { + match msg { + Message::ChangeColor(Color::Rgb(a, b, c)) | Message::ChangeColor(Color::Hsv(a, b, c)) => "", + _ => "other" + }; +} + "#, + ) + } + + #[test] + fn merge_match_arms_enum_destructuring_with_ignore() { + check_assist( + merge_match_arms, + r#" +enum MyEnum { + Move { x: i32, a: i32 }, + Crawl { x: i32, b: i32 } +} + +fn func(x: MyEnum) { + match x { + MyEnum::Move { x, .. } => $0"", + MyEnum::Crawl { x, .. } => "", + }; +} + "#, + r#" +enum MyEnum { + Move { x: i32, a: i32 }, + Crawl { x: i32, b: i32 } +} + +fn func(x: MyEnum) { + match x { + MyEnum::Move { x, .. } | MyEnum::Crawl { x, .. } => "", + }; +} + "#, + ) + } + + #[test] + fn merge_match_arms_nested_with_conflicting_identifier() { + check_assist_not_applicable( + merge_match_arms, + r#" +enum Color { + Rgb(i32, i32, i32), + Hsv(i32, i32, i32), +} + +enum Message { + Move { x: i32, y: i32 }, + ChangeColor(u8, Color), +} + +fn main(msg: Message) { + match msg { + Message::ChangeColor(x, Color::Rgb(y, b, c)) => $0"", + Message::ChangeColor(y, Color::Hsv(x, b, c)) => "", + _ => "other" + }; +} + "#, + ) + } + + #[test] + fn merge_match_arms_tuple() { + check_assist_not_applicable( + merge_match_arms, + r#" +fn func() { + match (0, "boo") { + (x, y) => $0"", + (y, x) => "", + }; +} + "#, + ) + } + + #[test] + fn merge_match_arms_parentheses() { + check_assist_not_applicable( + merge_match_arms, + r#" +fn func(x: i32) { + let variable = 2; + match x { + 1 => $0"", + ((((variable)))) => "", + _ => "other" + }; +} + "#, + ) + } + + #[test] + fn merge_match_arms_refpat() { + check_assist_not_applicable( + merge_match_arms, + r#" +fn func() { + let name = Some(String::from("")); + let n = String::from(""); + match name { + Some(ref n) => $0"", + Some(n) => "", + _ => "other", + }; +} + "#, + ) + } + + #[test] + fn merge_match_arms_slice() { + check_assist_not_applicable( + merge_match_arms, + r#" +fn func(binary: &[u8]) { + let space = b' '; + match binary { + [0x7f, b'E', b'L', b'F', ..] => $0"", + [space] => "", + _ => "other", + }; +} + "#, + ) + } + + #[test] + fn merge_match_arms_slice_identical() { + check_assist( + merge_match_arms, + r#" +fn func(binary: &[u8]) { + let space = b' '; + match binary { + [space, 5u8] => $0"", + [space] => "", + _ => "other", + }; +} + "#, + r#" +fn func(binary: &[u8]) { + let space = b' '; + match binary { + [space, 5u8] | [space] => "", + _ => "other", + }; +} + "#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_bounds.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_bounds.rs new file mode 100644 index 000000000..176a3bf58 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_bounds.rs @@ -0,0 +1,122 @@ +use syntax::{ + ast::{self, edit_in_place::GenericParamsOwnerEdit, make, AstNode, HasName, HasTypeBounds}, + match_ast, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: move_bounds_to_where_clause +// +// Moves inline type bounds to a where clause. +// +// ``` +// fn apply U>(f: F, x: T) -> U { +// f(x) +// } +// ``` +// -> +// ``` +// fn apply(f: F, x: T) -> U where F: FnOnce(T) -> U { +// f(x) +// } +// ``` +pub(crate) fn move_bounds_to_where_clause( + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { + let type_param_list = ctx.find_node_at_offset::()?; + + let mut type_params = type_param_list.type_or_const_params(); + if type_params.all(|p| match p { + ast::TypeOrConstParam::Type(t) => t.type_bound_list().is_none(), + ast::TypeOrConstParam::Const(_) => true, + }) { + return None; + } + + let parent = type_param_list.syntax().parent()?; + + let target = type_param_list.syntax().text_range(); + acc.add( + AssistId("move_bounds_to_where_clause", AssistKind::RefactorRewrite), + "Move to where clause", + target, + |edit| { + let type_param_list = edit.make_mut(type_param_list); + let parent = edit.make_syntax_mut(parent); + + let where_clause: ast::WhereClause = match_ast! { + match parent { + ast::Fn(it) => it.get_or_create_where_clause(), + ast::Trait(it) => it.get_or_create_where_clause(), + ast::Impl(it) => it.get_or_create_where_clause(), + ast::Enum(it) => it.get_or_create_where_clause(), + ast::Struct(it) => it.get_or_create_where_clause(), + _ => return, + } + }; + + for toc_param in type_param_list.type_or_const_params() { + let type_param = match toc_param { + ast::TypeOrConstParam::Type(x) => x, + ast::TypeOrConstParam::Const(_) => continue, + }; + if let Some(tbl) = type_param.type_bound_list() { + if let Some(predicate) = build_predicate(type_param) { + where_clause.add_predicate(predicate) + } + tbl.remove() + } + } + }, + ) +} + +fn build_predicate(param: ast::TypeParam) -> Option { + let path = make::ext::ident_path(¶m.name()?.syntax().to_string()); + let predicate = make::where_pred(path, param.type_bound_list()?.bounds()); + Some(predicate.clone_for_update()) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::check_assist; + + #[test] + fn move_bounds_to_where_clause_fn() { + check_assist( + move_bounds_to_where_clause, + r#"fn foo T>() {}"#, + r#"fn foo() where T: u32, F: FnOnce(T) -> T {}"#, + ); + } + + #[test] + fn move_bounds_to_where_clause_impl() { + check_assist( + move_bounds_to_where_clause, + r#"impl A {}"#, + r#"impl A where U: u32 {}"#, + ); + } + + #[test] + fn move_bounds_to_where_clause_struct() { + check_assist( + move_bounds_to_where_clause, + r#"struct A<$0T: Iterator> {}"#, + r#"struct A where T: Iterator {}"#, + ); + } + + #[test] + fn move_bounds_to_where_clause_tuple_struct() { + check_assist( + move_bounds_to_where_clause, + r#"struct Pair<$0T: u32>(T, T);"#, + r#"struct Pair(T, T) where T: u32;"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_from_mod_rs.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_from_mod_rs.rs new file mode 100644 index 000000000..a6c85a2b1 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_from_mod_rs.rs @@ -0,0 +1,130 @@ +use ide_db::{ + assists::{AssistId, AssistKind}, + base_db::AnchoredPathBuf, +}; +use syntax::{ast, AstNode}; + +use crate::{ + assist_context::{AssistContext, Assists}, + utils::trimmed_text_range, +}; + +// Assist: move_from_mod_rs +// +// Moves xxx/mod.rs to xxx.rs. +// +// ``` +// //- /main.rs +// mod a; +// //- /a/mod.rs +// $0fn t() {}$0 +// ``` +// -> +// ``` +// fn t() {} +// ``` +pub(crate) fn move_from_mod_rs(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let source_file = ctx.find_node_at_offset::()?; + let module = ctx.sema.to_module_def(ctx.file_id())?; + // Enable this assist if the user select all "meaningful" content in the source file + let trimmed_selected_range = trimmed_text_range(&source_file, ctx.selection_trimmed()); + let trimmed_file_range = trimmed_text_range(&source_file, source_file.syntax().text_range()); + if !module.is_mod_rs(ctx.db()) { + cov_mark::hit!(not_mod_rs); + return None; + } + if trimmed_selected_range != trimmed_file_range { + cov_mark::hit!(not_all_selected); + return None; + } + + let target = source_file.syntax().text_range(); + let module_name = module.name(ctx.db())?.to_string(); + let path = format!("../{}.rs", module_name); + let dst = AnchoredPathBuf { anchor: ctx.file_id(), path }; + acc.add( + AssistId("move_from_mod_rs", AssistKind::Refactor), + format!("Convert {}/mod.rs to {}.rs", module_name, module_name), + target, + |builder| { + builder.move_file(ctx.file_id(), dst); + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn trivial() { + check_assist( + move_from_mod_rs, + r#" +//- /main.rs +mod a; +//- /a/mod.rs +$0fn t() {} +$0"#, + r#" +//- /a.rs +fn t() {} +"#, + ); + } + + #[test] + fn must_select_all_file() { + cov_mark::check!(not_all_selected); + check_assist_not_applicable( + move_from_mod_rs, + r#" +//- /main.rs +mod a; +//- /a/mod.rs +fn t() {}$0 +"#, + ); + cov_mark::check!(not_all_selected); + check_assist_not_applicable( + move_from_mod_rs, + r#" +//- /main.rs +mod a; +//- /a/mod.rs +$0fn$0 t() {} +"#, + ); + } + + #[test] + fn cannot_move_not_mod_rs() { + cov_mark::check!(not_mod_rs); + check_assist_not_applicable( + move_from_mod_rs, + r#"//- /main.rs +mod a; +//- /a.rs +$0fn t() {}$0 +"#, + ); + } + + #[test] + fn cannot_downgrade_main_and_lib_rs() { + check_assist_not_applicable( + move_from_mod_rs, + r#"//- /main.rs +$0fn t() {}$0 +"#, + ); + check_assist_not_applicable( + move_from_mod_rs, + r#"//- /lib.rs +$0fn t() {}$0 +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_guard.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_guard.rs new file mode 100644 index 000000000..b8f1b36de --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_guard.rs @@ -0,0 +1,997 @@ +use syntax::{ + ast::{edit::AstNodeEdit, make, AstNode, BlockExpr, ElseBranch, Expr, IfExpr, MatchArm, Pat}, + SyntaxKind::WHITESPACE, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: move_guard_to_arm_body +// +// Moves match guard into match arm body. +// +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// Action::Move { distance } $0if distance > 10 => foo(), +// _ => (), +// } +// } +// ``` +// -> +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// Action::Move { distance } => if distance > 10 { +// foo() +// }, +// _ => (), +// } +// } +// ``` +pub(crate) fn move_guard_to_arm_body(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let match_arm = ctx.find_node_at_offset::()?; + let guard = match_arm.guard()?; + if ctx.offset() > guard.syntax().text_range().end() { + cov_mark::hit!(move_guard_unapplicable_in_arm_body); + return None; + } + let space_before_guard = guard.syntax().prev_sibling_or_token(); + + let guard_condition = guard.condition()?; + let arm_expr = match_arm.expr()?; + let if_expr = + make::expr_if(guard_condition, make::block_expr(None, Some(arm_expr.clone())), None) + .indent(arm_expr.indent_level()); + + let target = guard.syntax().text_range(); + acc.add( + AssistId("move_guard_to_arm_body", AssistKind::RefactorRewrite), + "Move guard to arm body", + target, + |edit| { + match space_before_guard { + Some(element) if element.kind() == WHITESPACE => { + edit.delete(element.text_range()); + } + _ => (), + }; + + edit.delete(guard.syntax().text_range()); + edit.replace_ast(arm_expr, if_expr); + }, + ) +} + +// Assist: move_arm_cond_to_match_guard +// +// Moves if expression from match arm body into a guard. +// +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// Action::Move { distance } => $0if distance > 10 { foo() }, +// _ => (), +// } +// } +// ``` +// -> +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// Action::Move { distance } if distance > 10 => foo(), +// _ => (), +// } +// } +// ``` +pub(crate) fn move_arm_cond_to_match_guard( + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { + let match_arm: MatchArm = ctx.find_node_at_offset::()?; + let match_pat = match_arm.pat()?; + let arm_body = match_arm.expr()?; + + let mut replace_node = None; + let if_expr: IfExpr = IfExpr::cast(arm_body.syntax().clone()).or_else(|| { + let block_expr = BlockExpr::cast(arm_body.syntax().clone())?; + if let Expr::IfExpr(e) = block_expr.tail_expr()? { + replace_node = Some(block_expr.syntax().clone()); + Some(e) + } else { + None + } + })?; + if ctx.offset() > if_expr.then_branch()?.syntax().text_range().start() { + return None; + } + + let replace_node = replace_node.unwrap_or_else(|| if_expr.syntax().clone()); + let needs_dedent = replace_node != *if_expr.syntax(); + let (conds_blocks, tail) = parse_if_chain(if_expr)?; + + acc.add( + AssistId("move_arm_cond_to_match_guard", AssistKind::RefactorRewrite), + "Move condition to match guard", + replace_node.text_range(), + |edit| { + edit.delete(match_arm.syntax().text_range()); + // Dedent if if_expr is in a BlockExpr + let dedent = if needs_dedent { + cov_mark::hit!(move_guard_ifelse_in_block); + 1 + } else { + cov_mark::hit!(move_guard_ifelse_else_block); + 0 + }; + let then_arm_end = match_arm.syntax().text_range().end(); + let indent_level = match_arm.indent_level(); + let spaces = " ".repeat(indent_level.0 as _); + + let mut first = true; + for (cond, block) in conds_blocks { + if !first { + edit.insert(then_arm_end, format!("\n{}", spaces)); + } else { + first = false; + } + let guard = format!("{} if {} => ", match_pat, cond.syntax().text()); + edit.insert(then_arm_end, guard); + let only_expr = block.statements().next().is_none(); + match &block.tail_expr() { + Some(then_expr) if only_expr => { + edit.insert(then_arm_end, then_expr.syntax().text()); + edit.insert(then_arm_end, ","); + } + _ => { + let to_insert = block.dedent(dedent.into()).syntax().text(); + edit.insert(then_arm_end, to_insert) + } + } + } + if let Some(e) = tail { + cov_mark::hit!(move_guard_ifelse_else_tail); + let guard = format!("\n{}{} => ", spaces, match_pat); + edit.insert(then_arm_end, guard); + let only_expr = e.statements().next().is_none(); + match &e.tail_expr() { + Some(expr) if only_expr => { + cov_mark::hit!(move_guard_ifelse_expr_only); + edit.insert(then_arm_end, expr.syntax().text()); + edit.insert(then_arm_end, ","); + } + _ => { + let to_insert = e.dedent(dedent.into()).syntax().text(); + edit.insert(then_arm_end, to_insert) + } + } + } else { + // There's no else branch. Add a pattern without guard, unless the following match + // arm is `_ => ...` + cov_mark::hit!(move_guard_ifelse_notail); + match match_arm.syntax().next_sibling().and_then(MatchArm::cast) { + Some(next_arm) + if matches!(next_arm.pat(), Some(Pat::WildcardPat(_))) + && next_arm.guard().is_none() => + { + cov_mark::hit!(move_guard_ifelse_has_wildcard); + } + _ => edit.insert(then_arm_end, format!("\n{}{} => {{}}", spaces, match_pat)), + } + } + }, + ) +} + +// Parses an if-else-if chain to get the conditions and the then branches until we encounter an else +// branch or the end. +fn parse_if_chain(if_expr: IfExpr) -> Option<(Vec<(Expr, BlockExpr)>, Option)> { + let mut conds_blocks = Vec::new(); + let mut curr_if = if_expr; + let tail = loop { + let cond = curr_if.condition()?; + conds_blocks.push((cond, curr_if.then_branch()?)); + match curr_if.else_branch() { + Some(ElseBranch::IfExpr(e)) => { + curr_if = e; + } + Some(ElseBranch::Block(b)) => { + break Some(b); + } + None => break None, + } + }; + Some((conds_blocks, tail)) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + #[test] + fn move_guard_to_arm_body_range() { + cov_mark::check!(move_guard_unapplicable_in_arm_body); + check_assist_not_applicable( + move_guard_to_arm_body, + r#" +fn main() { + match 92 { + x if x > 10 => $0false, + _ => true + } +} +"#, + ); + } + #[test] + fn move_guard_to_arm_body_target() { + check_assist_target( + move_guard_to_arm_body, + r#" +fn main() { + match 92 { + x $0if x > 10 => false, + _ => true + } +} +"#, + r#"if x > 10"#, + ); + } + + #[test] + fn move_guard_to_arm_body_works() { + check_assist( + move_guard_to_arm_body, + r#" +fn main() { + match 92 { + x $0if x > 10 => false, + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x => if x > 10 { + false + }, + _ => true + } +} +"#, + ); + } + + #[test] + fn move_let_guard_to_arm_body_works() { + check_assist( + move_guard_to_arm_body, + r#" +fn main() { + match 92 { + x $0if (let 1 = x) => false, + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x => if (let 1 = x) { + false + }, + _ => true + } +} +"#, + ); + } + + #[test] + fn move_guard_to_arm_body_works_complex_match() { + check_assist( + move_guard_to_arm_body, + r#" +fn main() { + match 92 { + $0x @ 4 | x @ 5 if x > 5 => true, + _ => false + } +} +"#, + r#" +fn main() { + match 92 { + x @ 4 | x @ 5 => if x > 5 { + true + }, + _ => false + } +} +"#, + ); + } + + #[test] + fn move_arm_cond_to_match_guard_works() { + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => if x > 10$0 { false }, + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => false, + _ => true + } +} +"#, + ); + } + + #[test] + fn move_arm_cond_in_block_to_match_guard_works() { + cov_mark::check!(move_guard_ifelse_has_wildcard); + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => { + $0if x > 10 { + false + } + }, + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => false, + _ => true + } +} +"#, + ); + } + + #[test] + fn move_arm_cond_in_block_to_match_guard_no_wildcard_works() { + cov_mark::check_count!(move_guard_ifelse_has_wildcard, 0); + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => { + $0if x > 10 { + false + } + } + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => false, + x => {} + } +} +"#, + ); + } + + #[test] + fn move_arm_cond_in_block_to_match_guard_wildcard_guard_works() { + cov_mark::check_count!(move_guard_ifelse_has_wildcard, 0); + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => { + $0if x > 10 { + false + } + } + _ if x > 10 => true, + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => false, + x => {} + _ if x > 10 => true, + } +} +"#, + ); + } + + #[test] + fn move_arm_cond_in_block_to_match_guard_add_comma_works() { + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => { + $0if x > 10 { + false + } + } + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => false, + _ => true + } +} +"#, + ); + } + + #[test] + fn move_arm_cond_to_match_guard_if_let_works() { + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => if let 62 = x $0&& true { false }, + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x if let 62 = x && true => false, + _ => true + } +} +"#, + ); + } + + #[test] + fn move_arm_cond_to_match_guard_if_empty_body_works() { + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => if x $0> 10 { }, + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => { } + _ => true + } +} +"#, + ); + } + + #[test] + fn move_arm_cond_to_match_guard_if_multiline_body_works() { + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => if$0 x > 10 { + 92; + false + }, + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => { + 92; + false + } + _ => true + } +} +"#, + ); + } + + #[test] + fn move_arm_cond_in_block_to_match_guard_if_multiline_body_works() { + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => { + if x > $010 { + 92; + false + } + } + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => { + 92; + false + } + _ => true + } +} +"#, + ) + } + + #[test] + fn move_arm_cond_to_match_guard_with_else_works() { + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => if x > $010 { + false + } else { + true + } + _ => true, + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => false, + x => true, + _ => true, + } +} +"#, + ) + } + + #[test] + fn move_arm_cond_to_match_guard_with_else_block_works() { + cov_mark::check!(move_guard_ifelse_expr_only); + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => { + if x $0> 10 { + false + } else { + true + } + } + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => false, + x => true, + _ => true + } +} +"#, + ) + } + + #[test] + fn move_arm_cond_to_match_guard_else_if_empty_body_works() { + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => if x > $010 { } else { }, + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => { } + x => { } + _ => true + } +} +"#, + ); + } + + #[test] + fn move_arm_cond_to_match_guard_with_else_multiline_works() { + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => if$0 x > 10 { + 92; + false + } else { + true + } + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => { + 92; + false + } + x => true, + _ => true + } +} +"#, + ) + } + + #[test] + fn move_arm_cond_to_match_guard_with_else_multiline_else_works() { + cov_mark::check!(move_guard_ifelse_else_block); + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => if x $0> 10 { + false + } else { + 42; + true + } + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => false, + x => { + 42; + true + } + _ => true + } +} +"#, + ) + } + + #[test] + fn move_arm_cond_to_match_guard_with_else_multiline_else_block_works() { + cov_mark::check!(move_guard_ifelse_in_block); + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => { + if x > $010 { + false + } else { + 42; + true + } + } + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => false, + x => { + 42; + true + } + _ => true + } +} +"#, + ) + } + + #[test] + fn move_arm_cond_to_match_guard_with_else_last_arm_works() { + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + 3 => true, + x => { + if x > $010 { + false + } else { + 92; + true + } + } + } +} +"#, + r#" +fn main() { + match 92 { + 3 => true, + x if x > 10 => false, + x => { + 92; + true + } + } +} +"#, + ) + } + + #[test] + fn move_arm_cond_to_match_guard_with_else_comma_works() { + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + 3 => true, + x => if x > $010 { + false + } else { + 92; + true + }, + } +} +"#, + r#" +fn main() { + match 92 { + 3 => true, + x if x > 10 => false, + x => { + 92; + true + } + } +} +"#, + ) + } + + #[test] + fn move_arm_cond_to_match_guard_elseif() { + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + 3 => true, + x => if x $0> 10 { + false + } else if x > 5 { + true + } else if x > 4 { + false + } else { + true + }, + } +} +"#, + r#" +fn main() { + match 92 { + 3 => true, + x if x > 10 => false, + x if x > 5 => true, + x if x > 4 => false, + x => true, + } +} +"#, + ) + } + + #[test] + fn move_arm_cond_to_match_guard_elseif_in_block() { + cov_mark::check!(move_guard_ifelse_in_block); + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + 3 => true, + x => { + if x > $010 { + false + } else if x > 5 { + true + } else if x > 4 { + false + } else { + true + } + } + } +} +"#, + r#" +fn main() { + match 92 { + 3 => true, + x if x > 10 => false, + x if x > 5 => true, + x if x > 4 => false, + x => true, + } +} +"#, + ) + } + + #[test] + fn move_arm_cond_to_match_guard_elseif_chain() { + cov_mark::check!(move_guard_ifelse_else_tail); + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + 3 => 0, + x => if x $0> 10 { + 1 + } else if x > 5 { + 2 + } else if x > 3 { + 42; + 3 + } else { + 4 + }, + } +} +"#, + r#" +fn main() { + match 92 { + 3 => 0, + x if x > 10 => 1, + x if x > 5 => 2, + x if x > 3 => { + 42; + 3 + } + x => 4, + } +} +"#, + ) + } + + #[test] + fn move_arm_cond_to_match_guard_elseif_iflet() { + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + 3 => 0, + x => if x $0> 10 { + 1 + } else if x > 5 { + 2 + } else if let 4 = 4 { + 42; + 3 + } else { + 4 + }, + } +}"#, + r#" +fn main() { + match 92 { + 3 => 0, + x if x > 10 => 1, + x if x > 5 => 2, + x if let 4 = 4 => { + 42; + 3 + } + x => 4, + } +}"#, + ); + } + + #[test] + fn move_arm_cond_to_match_guard_elseif_notail() { + cov_mark::check!(move_guard_ifelse_notail); + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + 3 => 0, + x => if x > $010 { + 1 + } else if x > 5 { + 2 + } else if x > 4 { + 42; + 3 + }, + } +} +"#, + r#" +fn main() { + match 92 { + 3 => 0, + x if x > 10 => 1, + x if x > 5 => 2, + x if x > 4 => { + 42; + 3 + } + x => {} + } +} +"#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_module_to_file.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_module_to_file.rs new file mode 100644 index 000000000..7468318a5 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_module_to_file.rs @@ -0,0 +1,337 @@ +use std::iter; + +use ast::edit::IndentLevel; +use ide_db::base_db::AnchoredPathBuf; +use itertools::Itertools; +use stdx::format_to; +use syntax::{ + ast::{self, edit::AstNodeEdit, HasName}, + AstNode, SmolStr, TextRange, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: move_module_to_file +// +// Moves inline module's contents to a separate file. +// +// ``` +// mod $0foo { +// fn t() {} +// } +// ``` +// -> +// ``` +// mod foo; +// ``` +pub(crate) fn move_module_to_file(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let module_ast = ctx.find_node_at_offset::()?; + let module_items = module_ast.item_list()?; + + let l_curly_offset = module_items.syntax().text_range().start(); + if l_curly_offset <= ctx.offset() { + cov_mark::hit!(available_before_curly); + return None; + } + let target = TextRange::new(module_ast.syntax().text_range().start(), l_curly_offset); + + let module_name = module_ast.name()?; + + // get to the outermost module syntax so we can grab the module of file we are in + let outermost_mod_decl = + iter::successors(Some(module_ast.clone()), |module| module.parent()).last()?; + let module_def = ctx.sema.to_def(&outermost_mod_decl)?; + let parent_module = module_def.parent(ctx.db())?; + + acc.add( + AssistId("move_module_to_file", AssistKind::RefactorExtract), + "Extract module to file", + target, + |builder| { + let path = { + let mut buf = String::from("./"); + match parent_module.name(ctx.db()) { + Some(name) if !parent_module.is_mod_rs(ctx.db()) => { + format_to!(buf, "{}/", name) + } + _ => (), + } + let segments = iter::successors(Some(module_ast.clone()), |module| module.parent()) + .filter_map(|it| it.name()) + .map(|name| SmolStr::from(name.text().trim_start_matches("r#"))) + .collect::>(); + + format_to!(buf, "{}", segments.into_iter().rev().format("/")); + + // We need to special case mod named `r#mod` and place the file in a + // subdirectory as "mod.rs" would be of its parent module otherwise. + if module_name.text() == "r#mod" { + format_to!(buf, "/mod.rs"); + } else { + format_to!(buf, ".rs"); + } + buf + }; + let contents = { + let items = module_items.dedent(IndentLevel(1)).to_string(); + let mut items = + items.trim_start_matches('{').trim_end_matches('}').trim().to_string(); + if !items.is_empty() { + items.push('\n'); + } + items + }; + + let buf = format!("mod {};", module_name); + + let replacement_start = match module_ast.mod_token() { + Some(mod_token) => mod_token.text_range(), + None => module_ast.syntax().text_range(), + } + .start(); + + builder.replace( + TextRange::new(replacement_start, module_ast.syntax().text_range().end()), + buf, + ); + + let dst = AnchoredPathBuf { anchor: ctx.file_id(), path }; + builder.create_file(dst, contents); + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn extract_from_root() { + check_assist( + move_module_to_file, + r#" +mod $0tests { + #[test] fn t() {} +} +"#, + r#" +//- /main.rs +mod tests; +//- /tests.rs +#[test] fn t() {} +"#, + ); + } + + #[test] + fn extract_from_submodule() { + check_assist( + move_module_to_file, + r#" +//- /main.rs +mod submod; +//- /submod.rs +$0mod inner { + fn f() {} +} +fn g() {} +"#, + r#" +//- /submod.rs +mod inner; +fn g() {} +//- /submod/inner.rs +fn f() {} +"#, + ); + } + + #[test] + fn extract_from_mod_rs() { + check_assist( + move_module_to_file, + r#" +//- /main.rs +mod submodule; +//- /submodule/mod.rs +mod inner$0 { + fn f() {} +} +fn g() {} +"#, + r#" +//- /submodule/mod.rs +mod inner; +fn g() {} +//- /submodule/inner.rs +fn f() {} +"#, + ); + } + + #[test] + fn extract_public() { + check_assist( + move_module_to_file, + r#" +pub mod $0tests { + #[test] fn t() {} +} +"#, + r#" +//- /main.rs +pub mod tests; +//- /tests.rs +#[test] fn t() {} +"#, + ); + } + + #[test] + fn extract_public_crate() { + check_assist( + move_module_to_file, + r#" +pub(crate) mod $0tests { + #[test] fn t() {} +} +"#, + r#" +//- /main.rs +pub(crate) mod tests; +//- /tests.rs +#[test] fn t() {} +"#, + ); + } + + #[test] + fn available_before_curly() { + cov_mark::check!(available_before_curly); + check_assist_not_applicable(move_module_to_file, r#"mod m { $0 }"#); + } + + #[test] + fn keep_outer_comments_and_attributes() { + check_assist( + move_module_to_file, + r#" +/// doc comment +#[attribute] +mod $0tests { + #[test] fn t() {} +} +"#, + r#" +//- /main.rs +/// doc comment +#[attribute] +mod tests; +//- /tests.rs +#[test] fn t() {} +"#, + ); + } + + #[test] + fn extract_nested() { + check_assist( + move_module_to_file, + r#" +//- /lib.rs +mod foo; +//- /foo.rs +mod bar { + mod baz { + mod qux$0 {} + } +} +"#, + r#" +//- /foo.rs +mod bar { + mod baz { + mod qux; + } +} +//- /foo/bar/baz/qux.rs +"#, + ); + } + + #[test] + fn extract_mod_with_raw_ident() { + check_assist( + move_module_to_file, + r#" +//- /main.rs +mod $0r#static {} +"#, + r#" +//- /main.rs +mod r#static; +//- /static.rs +"#, + ) + } + + #[test] + fn extract_r_mod() { + check_assist( + move_module_to_file, + r#" +//- /main.rs +mod $0r#mod {} +"#, + r#" +//- /main.rs +mod r#mod; +//- /mod/mod.rs +"#, + ) + } + + #[test] + fn extract_r_mod_from_mod_rs() { + check_assist( + move_module_to_file, + r#" +//- /main.rs +mod foo; +//- /foo/mod.rs +mod $0r#mod {} +"#, + r#" +//- /foo/mod.rs +mod r#mod; +//- /foo/mod/mod.rs +"#, + ) + } + + #[test] + fn extract_nested_r_mod() { + check_assist( + move_module_to_file, + r#" +//- /main.rs +mod r#mod { + mod foo { + mod $0r#mod {} + } +} +"#, + r#" +//- /main.rs +mod r#mod { + mod foo { + mod r#mod; + } +} +//- /mod/foo/mod/mod.rs +"#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_to_mod_rs.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_to_mod_rs.rs new file mode 100644 index 000000000..a909ce8b2 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_to_mod_rs.rs @@ -0,0 +1,151 @@ +use ide_db::{ + assists::{AssistId, AssistKind}, + base_db::AnchoredPathBuf, +}; +use syntax::{ast, AstNode}; + +use crate::{ + assist_context::{AssistContext, Assists}, + utils::trimmed_text_range, +}; + +// Assist: move_to_mod_rs +// +// Moves xxx.rs to xxx/mod.rs. +// +// ``` +// //- /main.rs +// mod a; +// //- /a.rs +// $0fn t() {}$0 +// ``` +// -> +// ``` +// fn t() {} +// ``` +pub(crate) fn move_to_mod_rs(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let source_file = ctx.find_node_at_offset::()?; + let module = ctx.sema.to_module_def(ctx.file_id())?; + // Enable this assist if the user select all "meaningful" content in the source file + let trimmed_selected_range = trimmed_text_range(&source_file, ctx.selection_trimmed()); + let trimmed_file_range = trimmed_text_range(&source_file, source_file.syntax().text_range()); + if module.is_mod_rs(ctx.db()) { + cov_mark::hit!(already_mod_rs); + return None; + } + if trimmed_selected_range != trimmed_file_range { + cov_mark::hit!(not_all_selected); + return None; + } + + let target = source_file.syntax().text_range(); + let module_name = module.name(ctx.db())?.to_string(); + let path = format!("./{}/mod.rs", module_name); + let dst = AnchoredPathBuf { anchor: ctx.file_id(), path }; + acc.add( + AssistId("move_to_mod_rs", AssistKind::Refactor), + format!("Convert {}.rs to {}/mod.rs", module_name, module_name), + target, + |builder| { + builder.move_file(ctx.file_id(), dst); + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn trivial() { + check_assist( + move_to_mod_rs, + r#" +//- /main.rs +mod a; +//- /a.rs +$0fn t() {} +$0"#, + r#" +//- /a/mod.rs +fn t() {} +"#, + ); + } + + #[test] + fn must_select_all_file() { + cov_mark::check!(not_all_selected); + check_assist_not_applicable( + move_to_mod_rs, + r#" +//- /main.rs +mod a; +//- /a.rs +fn t() {}$0 +"#, + ); + cov_mark::check!(not_all_selected); + check_assist_not_applicable( + move_to_mod_rs, + r#" +//- /main.rs +mod a; +//- /a.rs +$0fn$0 t() {} +"#, + ); + } + + #[test] + fn cannot_promote_mod_rs() { + cov_mark::check!(already_mod_rs); + check_assist_not_applicable( + move_to_mod_rs, + r#"//- /main.rs +mod a; +//- /a/mod.rs +$0fn t() {}$0 +"#, + ); + } + + #[test] + fn cannot_promote_main_and_lib_rs() { + check_assist_not_applicable( + move_to_mod_rs, + r#"//- /main.rs +$0fn t() {}$0 +"#, + ); + check_assist_not_applicable( + move_to_mod_rs, + r#"//- /lib.rs +$0fn t() {}$0 +"#, + ); + } + + #[test] + fn works_in_mod() { + // note: /a/b.rs remains untouched + check_assist( + move_to_mod_rs, + r#"//- /main.rs +mod a; +//- /a.rs +$0mod b; +fn t() {}$0 +//- /a/b.rs +fn t1() {} +"#, + r#" +//- /a/mod.rs +mod b; +fn t() {} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/number_representation.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/number_representation.rs new file mode 100644 index 000000000..424db7437 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/number_representation.rs @@ -0,0 +1,183 @@ +use syntax::{ast, ast::Radix, AstToken}; + +use crate::{AssistContext, AssistId, AssistKind, Assists, GroupLabel}; + +const MIN_NUMBER_OF_DIGITS_TO_FORMAT: usize = 5; + +// Assist: reformat_number_literal +// +// Adds or removes separators from integer literal. +// +// ``` +// const _: i32 = 1012345$0; +// ``` +// -> +// ``` +// const _: i32 = 1_012_345; +// ``` +pub(crate) fn reformat_number_literal(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let literal = ctx.find_node_at_offset::()?; + let literal = match literal.kind() { + ast::LiteralKind::IntNumber(it) => it, + _ => return None, + }; + + let text = literal.text(); + if text.contains('_') { + return remove_separators(acc, literal); + } + + let (prefix, value, suffix) = literal.split_into_parts(); + if value.len() < MIN_NUMBER_OF_DIGITS_TO_FORMAT { + return None; + } + + let radix = literal.radix(); + let mut converted = prefix.to_string(); + converted.push_str(&add_group_separators(value, group_size(radix))); + converted.push_str(suffix); + + let group_id = GroupLabel("Reformat number literal".into()); + let label = format!("Convert {} to {}", literal, converted); + let range = literal.syntax().text_range(); + acc.add_group( + &group_id, + AssistId("reformat_number_literal", AssistKind::RefactorInline), + label, + range, + |builder| builder.replace(range, converted), + ) +} + +fn remove_separators(acc: &mut Assists, literal: ast::IntNumber) -> Option<()> { + let group_id = GroupLabel("Reformat number literal".into()); + let range = literal.syntax().text_range(); + acc.add_group( + &group_id, + AssistId("reformat_number_literal", AssistKind::RefactorInline), + "Remove digit separators", + range, + |builder| builder.replace(range, literal.text().replace('_', "")), + ) +} + +const fn group_size(r: Radix) -> usize { + match r { + Radix::Binary => 4, + Radix::Octal => 3, + Radix::Decimal => 3, + Radix::Hexadecimal => 4, + } +} + +fn add_group_separators(s: &str, group_size: usize) -> String { + let mut chars = Vec::new(); + for (i, ch) in s.chars().filter(|&ch| ch != '_').rev().enumerate() { + if i > 0 && i % group_size == 0 { + chars.push('_'); + } + chars.push(ch); + } + + chars.into_iter().rev().collect() +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist_by_label, check_assist_not_applicable, check_assist_target}; + + use super::*; + + #[test] + fn group_separators() { + let cases = vec![ + ("", 4, ""), + ("1", 4, "1"), + ("12", 4, "12"), + ("123", 4, "123"), + ("1234", 4, "1234"), + ("12345", 4, "1_2345"), + ("123456", 4, "12_3456"), + ("1234567", 4, "123_4567"), + ("12345678", 4, "1234_5678"), + ("123456789", 4, "1_2345_6789"), + ("1234567890", 4, "12_3456_7890"), + ("1_2_3_4_5_6_7_8_9_0_", 4, "12_3456_7890"), + ("1234567890", 3, "1_234_567_890"), + ("1234567890", 2, "12_34_56_78_90"), + ("1234567890", 1, "1_2_3_4_5_6_7_8_9_0"), + ]; + + for case in cases { + let (input, group_size, expected) = case; + assert_eq!(add_group_separators(input, group_size), expected) + } + } + + #[test] + fn good_targets() { + let cases = vec![ + ("const _: i32 = 0b11111$0", "0b11111"), + ("const _: i32 = 0o77777$0;", "0o77777"), + ("const _: i32 = 10000$0;", "10000"), + ("const _: i32 = 0xFFFFF$0;", "0xFFFFF"), + ("const _: i32 = 10000i32$0;", "10000i32"), + ("const _: i32 = 0b_10_0i32$0;", "0b_10_0i32"), + ]; + + for case in cases { + check_assist_target(reformat_number_literal, case.0, case.1); + } + } + + #[test] + fn bad_targets() { + let cases = vec![ + "const _: i32 = 0b111$0", + "const _: i32 = 0b1111$0", + "const _: i32 = 0o77$0;", + "const _: i32 = 0o777$0;", + "const _: i32 = 10$0;", + "const _: i32 = 999$0;", + "const _: i32 = 0xFF$0;", + "const _: i32 = 0xFFFF$0;", + ]; + + for case in cases { + check_assist_not_applicable(reformat_number_literal, case); + } + } + + #[test] + fn labels() { + let cases = vec![ + ("const _: i32 = 10000$0", "const _: i32 = 10_000", "Convert 10000 to 10_000"), + ( + "const _: i32 = 0xFF0000$0;", + "const _: i32 = 0xFF_0000;", + "Convert 0xFF0000 to 0xFF_0000", + ), + ( + "const _: i32 = 0b11111111$0;", + "const _: i32 = 0b1111_1111;", + "Convert 0b11111111 to 0b1111_1111", + ), + ( + "const _: i32 = 0o377211$0;", + "const _: i32 = 0o377_211;", + "Convert 0o377211 to 0o377_211", + ), + ( + "const _: i32 = 10000i32$0;", + "const _: i32 = 10_000i32;", + "Convert 10000i32 to 10_000i32", + ), + ("const _: i32 = 1_0_0_0_i32$0;", "const _: i32 = 1000i32;", "Remove digit separators"), + ]; + + for case in cases { + let (before, after, label) = case; + check_assist_by_label(reformat_number_literal, before, after, label); + } + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/promote_local_to_const.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/promote_local_to_const.rs new file mode 100644 index 000000000..cbbea6c1e --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/promote_local_to_const.rs @@ -0,0 +1,221 @@ +use hir::{HirDisplay, ModuleDef, PathResolution, Semantics}; +use ide_db::{ + assists::{AssistId, AssistKind}, + defs::Definition, + syntax_helpers::node_ext::preorder_expr, + RootDatabase, +}; +use stdx::to_upper_snake_case; +use syntax::{ + ast::{self, make, HasName}, + AstNode, WalkEvent, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + utils::{render_snippet, Cursor}, +}; + +// Assist: promote_local_to_const +// +// Promotes a local variable to a const item changing its name to a `SCREAMING_SNAKE_CASE` variant +// if the local uses no non-const expressions. +// +// ``` +// fn main() { +// let foo$0 = true; +// +// if foo { +// println!("It's true"); +// } else { +// println!("It's false"); +// } +// } +// ``` +// -> +// ``` +// fn main() { +// const $0FOO: bool = true; +// +// if FOO { +// println!("It's true"); +// } else { +// println!("It's false"); +// } +// } +// ``` +pub(crate) fn promote_local_to_const(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let pat = ctx.find_node_at_offset::()?; + let name = pat.name()?; + if !pat.is_simple_ident() { + cov_mark::hit!(promote_local_non_simple_ident); + return None; + } + let let_stmt = pat.syntax().parent().and_then(ast::LetStmt::cast)?; + + let module = ctx.sema.scope(pat.syntax())?.module(); + let local = ctx.sema.to_def(&pat)?; + let ty = ctx.sema.type_of_pat(&pat.into())?.original; + + if ty.contains_unknown() || ty.is_closure() { + cov_mark::hit!(promote_lcoal_not_applicable_if_ty_not_inferred); + return None; + } + let ty = ty.display_source_code(ctx.db(), module.into()).ok()?; + + let initializer = let_stmt.initializer()?; + if !is_body_const(&ctx.sema, &initializer) { + cov_mark::hit!(promote_local_non_const); + return None; + } + let target = let_stmt.syntax().text_range(); + acc.add( + AssistId("promote_local_to_const", AssistKind::Refactor), + "Promote local to constant", + target, + |builder| { + let name = to_upper_snake_case(&name.to_string()); + let usages = Definition::Local(local).usages(&ctx.sema).all(); + if let Some(usages) = usages.references.get(&ctx.file_id()) { + for usage in usages { + builder.replace(usage.range, &name); + } + } + + let item = make::item_const(None, make::name(&name), make::ty(&ty), initializer); + match ctx.config.snippet_cap.zip(item.name()) { + Some((cap, name)) => builder.replace_snippet( + cap, + target, + render_snippet(cap, item.syntax(), Cursor::Before(name.syntax())), + ), + None => builder.replace(target, item.to_string()), + } + }, + ) +} + +fn is_body_const(sema: &Semantics<'_, RootDatabase>, expr: &ast::Expr) -> bool { + let mut is_const = true; + preorder_expr(expr, &mut |ev| { + let expr = match ev { + WalkEvent::Enter(_) if !is_const => return true, + WalkEvent::Enter(expr) => expr, + WalkEvent::Leave(_) => return false, + }; + match expr { + ast::Expr::CallExpr(call) => { + if let Some(ast::Expr::PathExpr(path_expr)) = call.expr() { + if let Some(PathResolution::Def(ModuleDef::Function(func))) = + path_expr.path().and_then(|path| sema.resolve_path(&path)) + { + is_const &= func.is_const(sema.db); + } + } + } + ast::Expr::MethodCallExpr(call) => { + is_const &= + sema.resolve_method_call(&call).map(|it| it.is_const(sema.db)).unwrap_or(true) + } + ast::Expr::BoxExpr(_) + | ast::Expr::ForExpr(_) + | ast::Expr::ReturnExpr(_) + | ast::Expr::TryExpr(_) + | ast::Expr::YieldExpr(_) + | ast::Expr::AwaitExpr(_) => is_const = false, + _ => (), + } + !is_const + }); + is_const +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn simple() { + check_assist( + promote_local_to_const, + r" +fn foo() { + let x$0 = 0; + let y = x; +} +", + r" +fn foo() { + const $0X: i32 = 0; + let y = X; +} +", + ); + } + + #[test] + fn not_applicable_non_const_meth_call() { + cov_mark::check!(promote_local_non_const); + check_assist_not_applicable( + promote_local_to_const, + r" +struct Foo; +impl Foo { + fn foo(self) {} +} +fn foo() { + let x$0 = Foo.foo(); +} +", + ); + } + + #[test] + fn not_applicable_non_const_call() { + check_assist_not_applicable( + promote_local_to_const, + r" +fn bar(self) {} +fn foo() { + let x$0 = bar(); +} +", + ); + } + + #[test] + fn not_applicable_unknown_ty() { + cov_mark::check!(promote_lcoal_not_applicable_if_ty_not_inferred); + check_assist_not_applicable( + promote_local_to_const, + r" +fn foo() { + let x$0 = bar(); +} +", + ); + } + + #[test] + fn not_applicable_non_simple_ident() { + cov_mark::check!(promote_local_non_simple_ident); + check_assist_not_applicable( + promote_local_to_const, + r" +fn foo() { + let ref x$0 = (); +} +", + ); + check_assist_not_applicable( + promote_local_to_const, + r" +fn foo() { + let mut x$0 = (); +} +", + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/pull_assignment_up.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/pull_assignment_up.rs new file mode 100644 index 000000000..4cfe6c99b --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/pull_assignment_up.rs @@ -0,0 +1,507 @@ +use syntax::{ + ast::{self, make}, + ted, AstNode, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, AssistKind, +}; + +// Assist: pull_assignment_up +// +// Extracts variable assignment to outside an if or match statement. +// +// ``` +// fn main() { +// let mut foo = 6; +// +// if true { +// $0foo = 5; +// } else { +// foo = 4; +// } +// } +// ``` +// -> +// ``` +// fn main() { +// let mut foo = 6; +// +// foo = if true { +// 5 +// } else { +// 4 +// }; +// } +// ``` +pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let assign_expr = ctx.find_node_at_offset::()?; + + let op_kind = assign_expr.op_kind()?; + if op_kind != (ast::BinaryOp::Assignment { op: None }) { + cov_mark::hit!(test_cant_pull_non_assignments); + return None; + } + + let mut collector = AssignmentsCollector { + sema: &ctx.sema, + common_lhs: assign_expr.lhs()?, + assignments: Vec::new(), + }; + + let tgt: ast::Expr = if let Some(if_expr) = ctx.find_node_at_offset::() { + collector.collect_if(&if_expr)?; + if_expr.into() + } else if let Some(match_expr) = ctx.find_node_at_offset::() { + collector.collect_match(&match_expr)?; + match_expr.into() + } else { + return None; + }; + + if let Some(parent) = tgt.syntax().parent() { + if matches!(parent.kind(), syntax::SyntaxKind::BIN_EXPR | syntax::SyntaxKind::LET_STMT) { + return None; + } + } + + acc.add( + AssistId("pull_assignment_up", AssistKind::RefactorExtract), + "Pull assignment up", + tgt.syntax().text_range(), + move |edit| { + let assignments: Vec<_> = collector + .assignments + .into_iter() + .map(|(stmt, rhs)| (edit.make_mut(stmt), rhs.clone_for_update())) + .collect(); + + let tgt = edit.make_mut(tgt); + + for (stmt, rhs) in assignments { + let mut stmt = stmt.syntax().clone(); + if let Some(parent) = stmt.parent() { + if ast::ExprStmt::cast(parent.clone()).is_some() { + stmt = parent.clone(); + } + } + ted::replace(stmt, rhs.syntax()); + } + let assign_expr = make::expr_assignment(collector.common_lhs, tgt.clone()); + let assign_stmt = make::expr_stmt(assign_expr); + + ted::replace(tgt.syntax(), assign_stmt.syntax().clone_for_update()); + }, + ) +} + +struct AssignmentsCollector<'a> { + sema: &'a hir::Semantics<'a, ide_db::RootDatabase>, + common_lhs: ast::Expr, + assignments: Vec<(ast::BinExpr, ast::Expr)>, +} + +impl<'a> AssignmentsCollector<'a> { + fn collect_match(&mut self, match_expr: &ast::MatchExpr) -> Option<()> { + for arm in match_expr.match_arm_list()?.arms() { + match arm.expr()? { + ast::Expr::BlockExpr(block) => self.collect_block(&block)?, + ast::Expr::BinExpr(expr) => self.collect_expr(&expr)?, + _ => return None, + } + } + + Some(()) + } + fn collect_if(&mut self, if_expr: &ast::IfExpr) -> Option<()> { + let then_branch = if_expr.then_branch()?; + self.collect_block(&then_branch)?; + + match if_expr.else_branch()? { + ast::ElseBranch::Block(block) => self.collect_block(&block), + ast::ElseBranch::IfExpr(expr) => { + cov_mark::hit!(test_pull_assignment_up_chained_if); + self.collect_if(&expr) + } + } + } + fn collect_block(&mut self, block: &ast::BlockExpr) -> Option<()> { + let last_expr = block.tail_expr().or_else(|| match block.statements().last()? { + ast::Stmt::ExprStmt(stmt) => stmt.expr(), + ast::Stmt::Item(_) | ast::Stmt::LetStmt(_) => None, + })?; + + if let ast::Expr::BinExpr(expr) = last_expr { + return self.collect_expr(&expr); + } + + None + } + + fn collect_expr(&mut self, expr: &ast::BinExpr) -> Option<()> { + if expr.op_kind()? == (ast::BinaryOp::Assignment { op: None }) + && is_equivalent(self.sema, &expr.lhs()?, &self.common_lhs) + { + self.assignments.push((expr.clone(), expr.rhs()?)); + return Some(()); + } + None + } +} + +fn is_equivalent( + sema: &hir::Semantics<'_, ide_db::RootDatabase>, + expr0: &ast::Expr, + expr1: &ast::Expr, +) -> bool { + match (expr0, expr1) { + (ast::Expr::FieldExpr(field_expr0), ast::Expr::FieldExpr(field_expr1)) => { + cov_mark::hit!(test_pull_assignment_up_field_assignment); + sema.resolve_field(field_expr0) == sema.resolve_field(field_expr1) + } + (ast::Expr::PathExpr(path0), ast::Expr::PathExpr(path1)) => { + let path0 = path0.path(); + let path1 = path1.path(); + if let (Some(path0), Some(path1)) = (path0, path1) { + sema.resolve_path(&path0) == sema.resolve_path(&path1) + } else { + false + } + } + (ast::Expr::PrefixExpr(prefix0), ast::Expr::PrefixExpr(prefix1)) + if prefix0.op_kind() == Some(ast::UnaryOp::Deref) + && prefix1.op_kind() == Some(ast::UnaryOp::Deref) => + { + cov_mark::hit!(test_pull_assignment_up_deref); + if let (Some(prefix0), Some(prefix1)) = (prefix0.expr(), prefix1.expr()) { + is_equivalent(sema, &prefix0, &prefix1) + } else { + false + } + } + _ => false, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable}; + + #[test] + fn test_pull_assignment_up_if() { + check_assist( + pull_assignment_up, + r#" +fn foo() { + let mut a = 1; + + if true { + $0a = 2; + } else { + a = 3; + } +}"#, + r#" +fn foo() { + let mut a = 1; + + a = if true { + 2 + } else { + 3 + }; +}"#, + ); + } + + #[test] + fn test_pull_assignment_up_match() { + check_assist( + pull_assignment_up, + r#" +fn foo() { + let mut a = 1; + + match 1 { + 1 => { + $0a = 2; + }, + 2 => { + a = 3; + }, + 3 => { + a = 4; + } + } +}"#, + r#" +fn foo() { + let mut a = 1; + + a = match 1 { + 1 => { + 2 + }, + 2 => { + 3 + }, + 3 => { + 4 + } + }; +}"#, + ); + } + + #[test] + fn test_pull_assignment_up_assignment_expressions() { + check_assist( + pull_assignment_up, + r#" +fn foo() { + let mut a = 1; + + match 1 { + 1 => { $0a = 2; }, + 2 => a = 3, + 3 => { + a = 4 + } + } +}"#, + r#" +fn foo() { + let mut a = 1; + + a = match 1 { + 1 => { 2 }, + 2 => 3, + 3 => { + 4 + } + }; +}"#, + ); + } + + #[test] + fn test_pull_assignment_up_not_last_not_applicable() { + check_assist_not_applicable( + pull_assignment_up, + r#" +fn foo() { + let mut a = 1; + + if true { + $0a = 2; + b = a; + } else { + a = 3; + } +}"#, + ) + } + + #[test] + fn test_pull_assignment_up_chained_if() { + cov_mark::check!(test_pull_assignment_up_chained_if); + check_assist( + pull_assignment_up, + r#" +fn foo() { + let mut a = 1; + + if true { + $0a = 2; + } else if false { + a = 3; + } else { + a = 4; + } +}"#, + r#" +fn foo() { + let mut a = 1; + + a = if true { + 2 + } else if false { + 3 + } else { + 4 + }; +}"#, + ); + } + + #[test] + fn test_pull_assignment_up_retains_stmts() { + check_assist( + pull_assignment_up, + r#" +fn foo() { + let mut a = 1; + + if true { + let b = 2; + $0a = 2; + } else { + let b = 3; + a = 3; + } +}"#, + r#" +fn foo() { + let mut a = 1; + + a = if true { + let b = 2; + 2 + } else { + let b = 3; + 3 + }; +}"#, + ) + } + + #[test] + fn pull_assignment_up_let_stmt_not_applicable() { + check_assist_not_applicable( + pull_assignment_up, + r#" +fn foo() { + let mut a = 1; + + let b = if true { + $0a = 2 + } else { + a = 3 + }; +}"#, + ) + } + + #[test] + fn pull_assignment_up_if_missing_assigment_not_applicable() { + check_assist_not_applicable( + pull_assignment_up, + r#" +fn foo() { + let mut a = 1; + + if true { + $0a = 2; + } else {} +}"#, + ) + } + + #[test] + fn pull_assignment_up_match_missing_assigment_not_applicable() { + check_assist_not_applicable( + pull_assignment_up, + r#" +fn foo() { + let mut a = 1; + + match 1 { + 1 => { + $0a = 2; + }, + 2 => { + a = 3; + }, + 3 => {}, + } +}"#, + ) + } + + #[test] + fn test_pull_assignment_up_field_assignment() { + cov_mark::check!(test_pull_assignment_up_field_assignment); + check_assist( + pull_assignment_up, + r#" +struct A(usize); + +fn foo() { + let mut a = A(1); + + if true { + $0a.0 = 2; + } else { + a.0 = 3; + } +}"#, + r#" +struct A(usize); + +fn foo() { + let mut a = A(1); + + a.0 = if true { + 2 + } else { + 3 + }; +}"#, + ) + } + + #[test] + fn test_pull_assignment_up_deref() { + cov_mark::check!(test_pull_assignment_up_deref); + check_assist( + pull_assignment_up, + r#" +fn foo() { + let mut a = 1; + let b = &mut a; + + if true { + $0*b = 2; + } else { + *b = 3; + } +} +"#, + r#" +fn foo() { + let mut a = 1; + let b = &mut a; + + *b = if true { + 2 + } else { + 3 + }; +} +"#, + ) + } + + #[test] + fn test_cant_pull_non_assignments() { + cov_mark::check!(test_cant_pull_non_assignments); + check_assist_not_applicable( + pull_assignment_up, + r#" +fn foo() { + let mut a = 1; + let b = &mut a; + + if true { + $0*b + 2; + } else { + *b + 3; + } +} +"#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/qualify_method_call.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/qualify_method_call.rs new file mode 100644 index 000000000..121f8b4a1 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/qualify_method_call.rs @@ -0,0 +1,548 @@ +use hir::{db::HirDatabase, AsAssocItem, AssocItem, AssocItemContainer, ItemInNs, ModuleDef}; +use ide_db::assists::{AssistId, AssistKind}; +use syntax::{ast, AstNode}; + +use crate::{ + assist_context::{AssistContext, Assists}, + handlers::qualify_path::QualifyCandidate, +}; + +// Assist: qualify_method_call +// +// Replaces the method call with a qualified function call. +// +// ``` +// struct Foo; +// impl Foo { +// fn foo(&self) {} +// } +// fn main() { +// let foo = Foo; +// foo.fo$0o(); +// } +// ``` +// -> +// ``` +// struct Foo; +// impl Foo { +// fn foo(&self) {} +// } +// fn main() { +// let foo = Foo; +// Foo::foo(&foo); +// } +// ``` +pub(crate) fn qualify_method_call(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let name: ast::NameRef = ctx.find_node_at_offset()?; + let call = name.syntax().parent().and_then(ast::MethodCallExpr::cast)?; + + let ident = name.ident_token()?; + + let range = call.syntax().text_range(); + let resolved_call = ctx.sema.resolve_method_call(&call)?; + + let current_module = ctx.sema.scope(call.syntax())?.module(); + let target_module_def = ModuleDef::from(resolved_call); + let item_in_ns = ItemInNs::from(target_module_def); + let receiver_path = current_module + .find_use_path(ctx.sema.db, item_for_path_search(ctx.sema.db, item_in_ns)?)?; + + let qualify_candidate = QualifyCandidate::ImplMethod(ctx.sema.db, call, resolved_call); + + acc.add( + AssistId("qualify_method_call", AssistKind::RefactorInline), + format!("Qualify `{}` method call", ident.text()), + range, + |builder| { + qualify_candidate.qualify( + |replace_with: String| builder.replace(range, replace_with), + &receiver_path, + item_in_ns, + ) + }, + ); + Some(()) +} + +fn item_for_path_search(db: &dyn HirDatabase, item: ItemInNs) -> Option { + Some(match item { + ItemInNs::Types(_) | ItemInNs::Values(_) => match item_as_assoc(db, item) { + Some(assoc_item) => match assoc_item.container(db) { + AssocItemContainer::Trait(trait_) => ItemInNs::from(ModuleDef::from(trait_)), + AssocItemContainer::Impl(impl_) => match impl_.trait_(db) { + None => ItemInNs::from(ModuleDef::from(impl_.self_ty(db).as_adt()?)), + Some(trait_) => ItemInNs::from(ModuleDef::from(trait_)), + }, + }, + None => item, + }, + ItemInNs::Macros(_) => item, + }) +} + +fn item_as_assoc(db: &dyn HirDatabase, item: ItemInNs) -> Option { + item.as_module_def().and_then(|module_def| module_def.as_assoc_item(db)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::{check_assist, check_assist_not_applicable}; + + #[test] + fn struct_method() { + check_assist( + qualify_method_call, + r#" +struct Foo; +impl Foo { + fn foo(&self) {} +} + +fn main() { + let foo = Foo {}; + foo.fo$0o() +} +"#, + r#" +struct Foo; +impl Foo { + fn foo(&self) {} +} + +fn main() { + let foo = Foo {}; + Foo::foo(&foo) +} +"#, + ); + } + + #[test] + fn struct_method_multi_params() { + check_assist( + qualify_method_call, + r#" +struct Foo; +impl Foo { + fn foo(&self, p1: i32, p2: u32) {} +} + +fn main() { + let foo = Foo {}; + foo.fo$0o(9, 9u) +} +"#, + r#" +struct Foo; +impl Foo { + fn foo(&self, p1: i32, p2: u32) {} +} + +fn main() { + let foo = Foo {}; + Foo::foo(&foo, 9, 9u) +} +"#, + ); + } + + #[test] + fn struct_method_consume() { + check_assist( + qualify_method_call, + r#" +struct Foo; +impl Foo { + fn foo(self, p1: i32, p2: u32) {} +} + +fn main() { + let foo = Foo {}; + foo.fo$0o(9, 9u) +} +"#, + r#" +struct Foo; +impl Foo { + fn foo(self, p1: i32, p2: u32) {} +} + +fn main() { + let foo = Foo {}; + Foo::foo(foo, 9, 9u) +} +"#, + ); + } + + #[test] + fn struct_method_exclusive() { + check_assist( + qualify_method_call, + r#" +struct Foo; +impl Foo { + fn foo(&mut self, p1: i32, p2: u32) {} +} + +fn main() { + let foo = Foo {}; + foo.fo$0o(9, 9u) +} +"#, + r#" +struct Foo; +impl Foo { + fn foo(&mut self, p1: i32, p2: u32) {} +} + +fn main() { + let foo = Foo {}; + Foo::foo(&mut foo, 9, 9u) +} +"#, + ); + } + + #[test] + fn struct_method_cross_crate() { + check_assist( + qualify_method_call, + r#" +//- /main.rs crate:main deps:dep +fn main() { + let foo = dep::test_mod::Foo {}; + foo.fo$0o(9, 9u) +} +//- /dep.rs crate:dep +pub mod test_mod { + pub struct Foo; + impl Foo { + pub fn foo(&mut self, p1: i32, p2: u32) {} + } +} +"#, + r#" +fn main() { + let foo = dep::test_mod::Foo {}; + dep::test_mod::Foo::foo(&mut foo, 9, 9u) +} +"#, + ); + } + + #[test] + fn struct_method_generic() { + check_assist( + qualify_method_call, + r#" +struct Foo; +impl Foo { + fn foo(&self) {} +} + +fn main() { + let foo = Foo {}; + foo.fo$0o::<()>() +} +"#, + r#" +struct Foo; +impl Foo { + fn foo(&self) {} +} + +fn main() { + let foo = Foo {}; + Foo::foo::<()>(&foo) +} +"#, + ); + } + + #[test] + fn trait_method() { + check_assist( + qualify_method_call, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } +} + +use test_mod::*; + +fn main() { + let test_struct = test_mod::TestStruct {}; + test_struct.test_meth$0od() +} +"#, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } +} + +use test_mod::*; + +fn main() { + let test_struct = test_mod::TestStruct {}; + TestTrait::test_method(&test_struct) +} +"#, + ); + } + + #[test] + fn trait_method_multi_params() { + check_assist( + qualify_method_call, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(&self, p1: i32, p2: u32); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self, p1: i32, p2: u32) {} + } +} + +use test_mod::*; + +fn main() { + let test_struct = test_mod::TestStruct {}; + test_struct.test_meth$0od(12, 32u) +} +"#, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(&self, p1: i32, p2: u32); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self, p1: i32, p2: u32) {} + } +} + +use test_mod::*; + +fn main() { + let test_struct = test_mod::TestStruct {}; + TestTrait::test_method(&test_struct, 12, 32u) +} +"#, + ); + } + + #[test] + fn trait_method_consume() { + check_assist( + qualify_method_call, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(self, p1: i32, p2: u32); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(self, p1: i32, p2: u32) {} + } +} + +use test_mod::*; + +fn main() { + let test_struct = test_mod::TestStruct {}; + test_struct.test_meth$0od(12, 32u) +} +"#, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(self, p1: i32, p2: u32); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(self, p1: i32, p2: u32) {} + } +} + +use test_mod::*; + +fn main() { + let test_struct = test_mod::TestStruct {}; + TestTrait::test_method(test_struct, 12, 32u) +} +"#, + ); + } + + #[test] + fn trait_method_exclusive() { + check_assist( + qualify_method_call, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(&mut self, p1: i32, p2: u32); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&mut self, p1: i32, p2: u32); + } +} + +use test_mod::*; + +fn main() { + let test_struct = test_mod::TestStruct {}; + test_struct.test_meth$0od(12, 32u) +} +"#, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(&mut self, p1: i32, p2: u32); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&mut self, p1: i32, p2: u32); + } +} + +use test_mod::*; + +fn main() { + let test_struct = test_mod::TestStruct {}; + TestTrait::test_method(&mut test_struct, 12, 32u) +} +"#, + ); + } + + #[test] + fn trait_method_cross_crate() { + check_assist( + qualify_method_call, + r#" +//- /main.rs crate:main deps:dep +fn main() { + let foo = dep::test_mod::Foo {}; + foo.fo$0o(9, 9u) +} +//- /dep.rs crate:dep +pub mod test_mod { + pub struct Foo; + impl Foo { + pub fn foo(&mut self, p1: i32, p2: u32) {} + } +} +"#, + r#" +fn main() { + let foo = dep::test_mod::Foo {}; + dep::test_mod::Foo::foo(&mut foo, 9, 9u) +} +"#, + ); + } + + #[test] + fn trait_method_generic() { + check_assist( + qualify_method_call, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } +} + +use test_mod::*; + +fn main() { + let test_struct = TestStruct {}; + test_struct.test_meth$0od::<()>() +} +"#, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } +} + +use test_mod::*; + +fn main() { + let test_struct = TestStruct {}; + TestTrait::test_method::<()>(&test_struct) +} +"#, + ); + } + + #[test] + fn struct_method_over_stuct_instance() { + check_assist_not_applicable( + qualify_method_call, + r#" +struct Foo; +impl Foo { + fn foo(&self) {} +} + +fn main() { + let foo = Foo {}; + f$0oo.foo() +} +"#, + ); + } + + #[test] + fn trait_method_over_stuct_instance() { + check_assist_not_applicable( + qualify_method_call, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } +} + +use test_mod::*; + +fn main() { + let test_struct = test_mod::TestStruct {}; + tes$0t_struct.test_method() +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/qualify_path.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/qualify_path.rs new file mode 100644 index 000000000..0c2e9da38 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/qualify_path.rs @@ -0,0 +1,1297 @@ +use std::iter; + +use hir::AsAssocItem; +use ide_db::RootDatabase; +use ide_db::{ + helpers::mod_path_to_ast, + imports::import_assets::{ImportCandidate, LocatedImport}, +}; +use syntax::{ + ast, + ast::{make, HasArgList}, + AstNode, NodeOrToken, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + handlers::auto_import::find_importable_node, + AssistId, AssistKind, GroupLabel, +}; + +// Assist: qualify_path +// +// If the name is unresolved, provides all possible qualified paths for it. +// +// ``` +// fn main() { +// let map = HashMap$0::new(); +// } +// # pub mod std { pub mod collections { pub struct HashMap { } } } +// ``` +// -> +// ``` +// fn main() { +// let map = std::collections::HashMap::new(); +// } +// # pub mod std { pub mod collections { pub struct HashMap { } } } +// ``` +pub(crate) fn qualify_path(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let (import_assets, syntax_under_caret) = find_importable_node(ctx)?; + let mut proposed_imports = import_assets.search_for_relative_paths(&ctx.sema); + if proposed_imports.is_empty() { + return None; + } + + let range = match &syntax_under_caret { + NodeOrToken::Node(node) => ctx.sema.original_range(node).range, + NodeOrToken::Token(token) => token.text_range(), + }; + let candidate = import_assets.import_candidate(); + let qualify_candidate = match syntax_under_caret { + NodeOrToken::Node(syntax_under_caret) => match candidate { + ImportCandidate::Path(candidate) if candidate.qualifier.is_some() => { + cov_mark::hit!(qualify_path_qualifier_start); + let path = ast::Path::cast(syntax_under_caret)?; + let (prev_segment, segment) = (path.qualifier()?.segment()?, path.segment()?); + QualifyCandidate::QualifierStart(segment, prev_segment.generic_arg_list()) + } + ImportCandidate::Path(_) => { + cov_mark::hit!(qualify_path_unqualified_name); + let path = ast::Path::cast(syntax_under_caret)?; + let generics = path.segment()?.generic_arg_list(); + QualifyCandidate::UnqualifiedName(generics) + } + ImportCandidate::TraitAssocItem(_) => { + cov_mark::hit!(qualify_path_trait_assoc_item); + let path = ast::Path::cast(syntax_under_caret)?; + let (qualifier, segment) = (path.qualifier()?, path.segment()?); + QualifyCandidate::TraitAssocItem(qualifier, segment) + } + ImportCandidate::TraitMethod(_) => { + cov_mark::hit!(qualify_path_trait_method); + let mcall_expr = ast::MethodCallExpr::cast(syntax_under_caret)?; + QualifyCandidate::TraitMethod(ctx.sema.db, mcall_expr) + } + }, + // derive attribute path + NodeOrToken::Token(_) => QualifyCandidate::UnqualifiedName(None), + }; + + // we aren't interested in different namespaces + proposed_imports.dedup_by(|a, b| a.import_path == b.import_path); + + let group_label = group_label(candidate); + for import in proposed_imports { + acc.add_group( + &group_label, + AssistId("qualify_path", AssistKind::QuickFix), + label(candidate, &import), + range, + |builder| { + qualify_candidate.qualify( + |replace_with: String| builder.replace(range, replace_with), + &import.import_path, + import.item_to_import, + ) + }, + ); + } + Some(()) +} +pub(crate) enum QualifyCandidate<'db> { + QualifierStart(ast::PathSegment, Option), + UnqualifiedName(Option), + TraitAssocItem(ast::Path, ast::PathSegment), + TraitMethod(&'db RootDatabase, ast::MethodCallExpr), + ImplMethod(&'db RootDatabase, ast::MethodCallExpr, hir::Function), +} + +impl QualifyCandidate<'_> { + pub(crate) fn qualify( + &self, + mut replacer: impl FnMut(String), + import: &hir::ModPath, + item: hir::ItemInNs, + ) { + let import = mod_path_to_ast(import); + match self { + QualifyCandidate::QualifierStart(segment, generics) => { + let generics = generics.as_ref().map_or_else(String::new, ToString::to_string); + replacer(format!("{}{}::{}", import, generics, segment)); + } + QualifyCandidate::UnqualifiedName(generics) => { + let generics = generics.as_ref().map_or_else(String::new, ToString::to_string); + replacer(format!("{}{}", import, generics)); + } + QualifyCandidate::TraitAssocItem(qualifier, segment) => { + replacer(format!("<{} as {}>::{}", qualifier, import, segment)); + } + QualifyCandidate::TraitMethod(db, mcall_expr) => { + Self::qualify_trait_method(db, mcall_expr, replacer, import, item); + } + QualifyCandidate::ImplMethod(db, mcall_expr, hir_fn) => { + Self::qualify_fn_call(db, mcall_expr, replacer, import, hir_fn); + } + } + } + + fn qualify_fn_call( + db: &RootDatabase, + mcall_expr: &ast::MethodCallExpr, + mut replacer: impl FnMut(String), + import: ast::Path, + hir_fn: &hir::Function, + ) -> Option<()> { + let receiver = mcall_expr.receiver()?; + let method_name = mcall_expr.name_ref()?; + let generics = + mcall_expr.generic_arg_list().as_ref().map_or_else(String::new, ToString::to_string); + let arg_list = mcall_expr.arg_list().map(|arg_list| arg_list.args()); + + if let Some(self_access) = hir_fn.self_param(db).map(|sp| sp.access(db)) { + let receiver = match self_access { + hir::Access::Shared => make::expr_ref(receiver, false), + hir::Access::Exclusive => make::expr_ref(receiver, true), + hir::Access::Owned => receiver, + }; + replacer(format!( + "{}::{}{}{}", + import, + method_name, + generics, + match arg_list { + Some(args) => make::arg_list(iter::once(receiver).chain(args)), + None => make::arg_list(iter::once(receiver)), + } + )); + } + Some(()) + } + + fn qualify_trait_method( + db: &RootDatabase, + mcall_expr: &ast::MethodCallExpr, + replacer: impl FnMut(String), + import: ast::Path, + item: hir::ItemInNs, + ) -> Option<()> { + let trait_method_name = mcall_expr.name_ref()?; + let trait_ = item_as_trait(db, item)?; + let method = find_trait_method(db, trait_, &trait_method_name)?; + Self::qualify_fn_call(db, mcall_expr, replacer, import, &method) + } +} + +fn find_trait_method( + db: &RootDatabase, + trait_: hir::Trait, + trait_method_name: &ast::NameRef, +) -> Option { + if let Some(hir::AssocItem::Function(method)) = + trait_.items(db).into_iter().find(|item: &hir::AssocItem| { + item.name(db) + .map(|name| name.to_string() == trait_method_name.to_string()) + .unwrap_or(false) + }) + { + Some(method) + } else { + None + } +} + +fn item_as_trait(db: &RootDatabase, item: hir::ItemInNs) -> Option { + let item_module_def = item.as_module_def()?; + + match item_module_def { + hir::ModuleDef::Trait(trait_) => Some(trait_), + _ => item_module_def.as_assoc_item(db)?.containing_trait(db), + } +} + +fn group_label(candidate: &ImportCandidate) -> GroupLabel { + let name = match candidate { + ImportCandidate::Path(it) => &it.name, + ImportCandidate::TraitAssocItem(it) | ImportCandidate::TraitMethod(it) => { + &it.assoc_item_name + } + } + .text(); + GroupLabel(format!("Qualify {}", name)) +} + +fn label(candidate: &ImportCandidate, import: &LocatedImport) -> String { + match candidate { + ImportCandidate::Path(candidate) if candidate.qualifier.is_none() => { + format!("Qualify as `{}`", import.import_path) + } + _ => format!("Qualify with `{}`", import.import_path), + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + use super::*; + + #[test] + fn applicable_when_found_an_import_partial() { + cov_mark::check!(qualify_path_unqualified_name); + check_assist( + qualify_path, + r#" +mod std { + pub mod fmt { + pub struct Formatter; + } +} + +use std::fmt; + +$0Formatter +"#, + r#" +mod std { + pub mod fmt { + pub struct Formatter; + } +} + +use std::fmt; + +fmt::Formatter +"#, + ); + } + + #[test] + fn applicable_when_found_an_import() { + check_assist( + qualify_path, + r#" +$0PubStruct + +pub mod PubMod { + pub struct PubStruct; +} +"#, + r#" +PubMod::PubStruct + +pub mod PubMod { + pub struct PubStruct; +} +"#, + ); + } + + #[test] + fn applicable_in_macros() { + check_assist( + qualify_path, + r#" +macro_rules! foo { + ($i:ident) => { fn foo(a: $i) {} } +} +foo!(Pub$0Struct); + +pub mod PubMod { + pub struct PubStruct; +} +"#, + r#" +macro_rules! foo { + ($i:ident) => { fn foo(a: $i) {} } +} +foo!(PubMod::PubStruct); + +pub mod PubMod { + pub struct PubStruct; +} +"#, + ); + } + + #[test] + fn applicable_when_found_multiple_imports() { + check_assist( + qualify_path, + r#" +PubSt$0ruct + +pub mod PubMod1 { + pub struct PubStruct; +} +pub mod PubMod2 { + pub struct PubStruct; +} +pub mod PubMod3 { + pub struct PubStruct; +} +"#, + r#" +PubMod3::PubStruct + +pub mod PubMod1 { + pub struct PubStruct; +} +pub mod PubMod2 { + pub struct PubStruct; +} +pub mod PubMod3 { + pub struct PubStruct; +} +"#, + ); + } + + #[test] + fn not_applicable_for_already_imported_types() { + check_assist_not_applicable( + qualify_path, + r#" +use PubMod::PubStruct; + +PubStruct$0 + +pub mod PubMod { + pub struct PubStruct; +} +"#, + ); + } + + #[test] + fn not_applicable_for_types_with_private_paths() { + check_assist_not_applicable( + qualify_path, + r#" +PrivateStruct$0 + +pub mod PubMod { + struct PrivateStruct; +} +"#, + ); + } + + #[test] + fn not_applicable_when_no_imports_found() { + check_assist_not_applicable(qualify_path, r#"PubStruct$0"#); + } + + #[test] + fn qualify_function() { + check_assist( + qualify_path, + r#" +test_function$0 + +pub mod PubMod { + pub fn test_function() {}; +} +"#, + r#" +PubMod::test_function + +pub mod PubMod { + pub fn test_function() {}; +} +"#, + ); + } + + #[test] + fn qualify_macro() { + check_assist( + qualify_path, + r#" +//- /lib.rs crate:crate_with_macro +#[macro_export] +macro_rules! foo { + () => () +} + +//- /main.rs crate:main deps:crate_with_macro +fn main() { + foo$0 +} +"#, + r#" +fn main() { + crate_with_macro::foo +} +"#, + ); + } + + #[test] + fn qualify_path_target() { + check_assist_target( + qualify_path, + r#" +struct AssistInfo { + group_label: Option<$0GroupLabel>, +} + +mod m { pub struct GroupLabel; } +"#, + "GroupLabel", + ) + } + + #[test] + fn not_applicable_when_path_start_is_imported() { + check_assist_not_applicable( + qualify_path, + r#" +pub mod mod1 { + pub mod mod2 { + pub mod mod3 { + pub struct TestStruct; + } + } +} + +use mod1::mod2; +fn main() { + mod2::mod3::TestStruct$0 +} +"#, + ); + } + + #[test] + fn not_applicable_for_imported_function() { + check_assist_not_applicable( + qualify_path, + r#" +pub mod test_mod { + pub fn test_function() {} +} + +use test_mod::test_function; +fn main() { + test_function$0 +} +"#, + ); + } + + #[test] + fn associated_struct_function() { + check_assist( + qualify_path, + r#" +mod test_mod { + pub struct TestStruct {} + impl TestStruct { + pub fn test_function() {} + } +} + +fn main() { + TestStruct::test_function$0 +} +"#, + r#" +mod test_mod { + pub struct TestStruct {} + impl TestStruct { + pub fn test_function() {} + } +} + +fn main() { + test_mod::TestStruct::test_function +} +"#, + ); + } + + #[test] + fn associated_struct_const() { + cov_mark::check!(qualify_path_qualifier_start); + check_assist( + qualify_path, + r#" +mod test_mod { + pub struct TestStruct {} + impl TestStruct { + const TEST_CONST: u8 = 42; + } +} + +fn main() { + TestStruct::TEST_CONST$0 +} +"#, + r#" +mod test_mod { + pub struct TestStruct {} + impl TestStruct { + const TEST_CONST: u8 = 42; + } +} + +fn main() { + test_mod::TestStruct::TEST_CONST +} +"#, + ); + } + + #[test] + fn associated_struct_const_unqualified() { + // FIXME: non-trait assoc items completion is unsupported yet, see FIXME in the import_assets.rs for more details + check_assist_not_applicable( + qualify_path, + r#" +mod test_mod { + pub struct TestStruct {} + impl TestStruct { + const TEST_CONST: u8 = 42; + } +} + +fn main() { + TEST_CONST$0 +} +"#, + ); + } + + #[test] + fn associated_trait_function() { + check_assist( + qualify_path, + r#" +mod test_mod { + pub trait TestTrait { + fn test_function(); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_function() {} + } +} + +fn main() { + test_mod::TestStruct::test_function$0 +} +"#, + r#" +mod test_mod { + pub trait TestTrait { + fn test_function(); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_function() {} + } +} + +fn main() { + ::test_function +} +"#, + ); + } + + #[test] + fn not_applicable_for_imported_trait_for_function() { + check_assist_not_applicable( + qualify_path, + r#" +mod test_mod { + pub trait TestTrait { + fn test_function(); + } + pub trait TestTrait2 { + fn test_function(); + } + pub enum TestEnum { + One, + Two, + } + impl TestTrait2 for TestEnum { + fn test_function() {} + } + impl TestTrait for TestEnum { + fn test_function() {} + } +} + +use test_mod::TestTrait2; +fn main() { + test_mod::TestEnum::test_function$0; +} +"#, + ) + } + + #[test] + fn associated_trait_const() { + cov_mark::check!(qualify_path_trait_assoc_item); + check_assist( + qualify_path, + r#" +mod test_mod { + pub trait TestTrait { + const TEST_CONST: u8; + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + const TEST_CONST: u8 = 42; + } +} + +fn main() { + test_mod::TestStruct::TEST_CONST$0 +} +"#, + r#" +mod test_mod { + pub trait TestTrait { + const TEST_CONST: u8; + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + const TEST_CONST: u8 = 42; + } +} + +fn main() { + ::TEST_CONST +} +"#, + ); + } + + #[test] + fn not_applicable_for_imported_trait_for_const() { + check_assist_not_applicable( + qualify_path, + r#" +mod test_mod { + pub trait TestTrait { + const TEST_CONST: u8; + } + pub trait TestTrait2 { + const TEST_CONST: f64; + } + pub enum TestEnum { + One, + Two, + } + impl TestTrait2 for TestEnum { + const TEST_CONST: f64 = 42.0; + } + impl TestTrait for TestEnum { + const TEST_CONST: u8 = 42; + } +} + +use test_mod::TestTrait2; +fn main() { + test_mod::TestEnum::TEST_CONST$0; +} +"#, + ) + } + + #[test] + fn trait_method() { + cov_mark::check!(qualify_path_trait_method); + check_assist( + qualify_path, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } +} + +fn main() { + let test_struct = test_mod::TestStruct {}; + test_struct.test_meth$0od() +} +"#, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } +} + +fn main() { + let test_struct = test_mod::TestStruct {}; + test_mod::TestTrait::test_method(&test_struct) +} +"#, + ); + } + + #[test] + fn trait_method_multi_params() { + check_assist( + qualify_path, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(&self, test: i32); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self, test: i32) {} + } +} + +fn main() { + let test_struct = test_mod::TestStruct {}; + test_struct.test_meth$0od(42) +} +"#, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(&self, test: i32); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self, test: i32) {} + } +} + +fn main() { + let test_struct = test_mod::TestStruct {}; + test_mod::TestTrait::test_method(&test_struct, 42) +} +"#, + ); + } + + #[test] + fn trait_method_consume() { + check_assist( + qualify_path, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(self) {} + } +} + +fn main() { + let test_struct = test_mod::TestStruct {}; + test_struct.test_meth$0od() +} +"#, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(self) {} + } +} + +fn main() { + let test_struct = test_mod::TestStruct {}; + test_mod::TestTrait::test_method(test_struct) +} +"#, + ); + } + + #[test] + fn trait_method_cross_crate() { + check_assist( + qualify_path, + r#" +//- /main.rs crate:main deps:dep +fn main() { + let test_struct = dep::test_mod::TestStruct {}; + test_struct.test_meth$0od() +} +//- /dep.rs crate:dep +pub mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } +} +"#, + r#" +fn main() { + let test_struct = dep::test_mod::TestStruct {}; + dep::test_mod::TestTrait::test_method(&test_struct) +} +"#, + ); + } + + #[test] + fn assoc_fn_cross_crate() { + check_assist( + qualify_path, + r#" +//- /main.rs crate:main deps:dep +fn main() { + dep::test_mod::TestStruct::test_func$0tion +} +//- /dep.rs crate:dep +pub mod test_mod { + pub trait TestTrait { + fn test_function(); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_function() {} + } +} +"#, + r#" +fn main() { + ::test_function +} +"#, + ); + } + + #[test] + fn assoc_const_cross_crate() { + check_assist( + qualify_path, + r#" +//- /main.rs crate:main deps:dep +fn main() { + dep::test_mod::TestStruct::CONST$0 +} +//- /dep.rs crate:dep +pub mod test_mod { + pub trait TestTrait { + const CONST: bool; + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + const CONST: bool = true; + } +} +"#, + r#" +fn main() { + ::CONST +} +"#, + ); + } + + #[test] + fn assoc_fn_as_method_cross_crate() { + check_assist_not_applicable( + qualify_path, + r#" +//- /main.rs crate:main deps:dep +fn main() { + let test_struct = dep::test_mod::TestStruct {}; + test_struct.test_func$0tion() +} +//- /dep.rs crate:dep +pub mod test_mod { + pub trait TestTrait { + fn test_function(); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_function() {} + } +} +"#, + ); + } + + #[test] + fn private_trait_cross_crate() { + check_assist_not_applicable( + qualify_path, + r#" +//- /main.rs crate:main deps:dep +fn main() { + let test_struct = dep::test_mod::TestStruct {}; + test_struct.test_meth$0od() +} +//- /dep.rs crate:dep +pub mod test_mod { + trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } +} +"#, + ); + } + + #[test] + fn not_applicable_for_imported_trait_for_method() { + check_assist_not_applicable( + qualify_path, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub trait TestTrait2 { + fn test_method(&self); + } + pub enum TestEnum { + One, + Two, + } + impl TestTrait2 for TestEnum { + fn test_method(&self) {} + } + impl TestTrait for TestEnum { + fn test_method(&self) {} + } +} + +use test_mod::TestTrait2; +fn main() { + let one = test_mod::TestEnum::One; + one.test$0_method(); +} +"#, + ) + } + + #[test] + fn dep_import() { + check_assist( + qualify_path, + r" +//- /lib.rs crate:dep +pub struct Struct; + +//- /main.rs crate:main deps:dep +fn main() { + Struct$0 +} +", + r" +fn main() { + dep::Struct +} +", + ); + } + + #[test] + fn whole_segment() { + // Tests that only imports whose last segment matches the identifier get suggested. + check_assist( + qualify_path, + r" +//- /lib.rs crate:dep +pub mod fmt { + pub trait Display {} +} + +pub fn panic_fmt() {} + +//- /main.rs crate:main deps:dep +struct S; + +impl f$0mt::Display for S {} +", + r" +struct S; + +impl dep::fmt::Display for S {} +", + ); + } + + #[test] + fn macro_generated() { + // Tests that macro-generated items are suggested from external crates. + check_assist( + qualify_path, + r" +//- /lib.rs crate:dep +macro_rules! mac { + () => { + pub struct Cheese; + }; +} + +mac!(); + +//- /main.rs crate:main deps:dep +fn main() { + Cheese$0; +} +", + r" +fn main() { + dep::Cheese; +} +", + ); + } + + #[test] + fn casing() { + // Tests that differently cased names don't interfere and we only suggest the matching one. + check_assist( + qualify_path, + r" +//- /lib.rs crate:dep +pub struct FMT; +pub struct fmt; + +//- /main.rs crate:main deps:dep +fn main() { + FMT$0; +} +", + r" +fn main() { + dep::FMT; +} +", + ); + } + + #[test] + fn keep_generic_annotations() { + check_assist( + qualify_path, + r" +//- /lib.rs crate:dep +pub mod generic { pub struct Thing<'a, T>(&'a T); } + +//- /main.rs crate:main deps:dep +fn foo() -> Thin$0g<'static, ()> {} + +fn main() {} +", + r" +fn foo() -> dep::generic::Thing<'static, ()> {} + +fn main() {} +", + ); + } + + #[test] + fn keep_generic_annotations_leading_colon() { + check_assist( + qualify_path, + r#" +//- /lib.rs crate:dep +pub mod generic { pub struct Thing<'a, T>(&'a T); } + +//- /main.rs crate:main deps:dep +fn foo() -> Thin$0g::<'static, ()> {} + +fn main() {} +"#, + r" +fn foo() -> dep::generic::Thing::<'static, ()> {} + +fn main() {} +", + ); + } + + #[test] + fn associated_struct_const_generic() { + check_assist( + qualify_path, + r#" +mod test_mod { + pub struct TestStruct {} + impl TestStruct { + const TEST_CONST: u8 = 42; + } +} + +fn main() { + TestStruct::<()>::TEST_CONST$0 +} +"#, + r#" +mod test_mod { + pub struct TestStruct {} + impl TestStruct { + const TEST_CONST: u8 = 42; + } +} + +fn main() { + test_mod::TestStruct::<()>::TEST_CONST +} +"#, + ); + } + + #[test] + fn associated_trait_const_generic() { + check_assist( + qualify_path, + r#" +mod test_mod { + pub trait TestTrait { + const TEST_CONST: u8; + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + const TEST_CONST: u8 = 42; + } +} + +fn main() { + test_mod::TestStruct::<()>::TEST_CONST$0 +} +"#, + r#" +mod test_mod { + pub trait TestTrait { + const TEST_CONST: u8; + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + const TEST_CONST: u8 = 42; + } +} + +fn main() { + as test_mod::TestTrait>::TEST_CONST +} +"#, + ); + } + + #[test] + fn trait_method_generic() { + check_assist( + qualify_path, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } +} + +fn main() { + let test_struct = test_mod::TestStruct {}; + test_struct.test_meth$0od::<()>() +} +"#, + r#" +mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } +} + +fn main() { + let test_struct = test_mod::TestStruct {}; + test_mod::TestTrait::test_method::<()>(&test_struct) +} +"#, + ); + } + + #[test] + fn works_in_derives() { + check_assist( + qualify_path, + r#" +//- minicore:derive +mod foo { + #[rustc_builtin_macro] + pub macro Copy {} +} +#[derive(Copy$0)] +struct Foo; +"#, + r#" +mod foo { + #[rustc_builtin_macro] + pub macro Copy {} +} +#[derive(foo::Copy)] +struct Foo; +"#, + ); + } + + #[test] + fn works_in_use_start() { + check_assist( + qualify_path, + r#" +mod bar { + pub mod foo { + pub struct Foo; + } +} +use foo$0::Foo; +"#, + r#" +mod bar { + pub mod foo { + pub struct Foo; + } +} +use bar::foo::Foo; +"#, + ); + } + + #[test] + fn not_applicable_in_non_start_use() { + check_assist_not_applicable( + qualify_path, + r" +mod bar { + pub mod foo { + pub struct Foo; + } +} +use foo::Foo$0; +", + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/raw_string.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/raw_string.rs new file mode 100644 index 000000000..dbe8cb7bf --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/raw_string.rs @@ -0,0 +1,509 @@ +use std::borrow::Cow; + +use syntax::{ast, ast::IsString, AstToken, TextRange, TextSize}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: make_raw_string +// +// Adds `r#` to a plain string literal. +// +// ``` +// fn main() { +// "Hello,$0 World!"; +// } +// ``` +// -> +// ``` +// fn main() { +// r#"Hello, World!"#; +// } +// ``` +pub(crate) fn make_raw_string(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let token = ctx.find_token_at_offset::()?; + if token.is_raw() { + return None; + } + let value = token.value()?; + let target = token.syntax().text_range(); + acc.add( + AssistId("make_raw_string", AssistKind::RefactorRewrite), + "Rewrite as raw string", + target, + |edit| { + let hashes = "#".repeat(required_hashes(&value).max(1)); + if matches!(value, Cow::Borrowed(_)) { + // Avoid replacing the whole string to better position the cursor. + edit.insert(token.syntax().text_range().start(), format!("r{}", hashes)); + edit.insert(token.syntax().text_range().end(), hashes); + } else { + edit.replace( + token.syntax().text_range(), + format!("r{}\"{}\"{}", hashes, value, hashes), + ); + } + }, + ) +} + +// Assist: make_usual_string +// +// Turns a raw string into a plain string. +// +// ``` +// fn main() { +// r#"Hello,$0 "World!""#; +// } +// ``` +// -> +// ``` +// fn main() { +// "Hello, \"World!\""; +// } +// ``` +pub(crate) fn make_usual_string(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let token = ctx.find_token_at_offset::()?; + if !token.is_raw() { + return None; + } + let value = token.value()?; + let target = token.syntax().text_range(); + acc.add( + AssistId("make_usual_string", AssistKind::RefactorRewrite), + "Rewrite as regular string", + target, + |edit| { + // parse inside string to escape `"` + let escaped = value.escape_default().to_string(); + if let Some(offsets) = token.quote_offsets() { + if token.text()[offsets.contents - token.syntax().text_range().start()] == escaped { + edit.replace(offsets.quotes.0, "\""); + edit.replace(offsets.quotes.1, "\""); + return; + } + } + + edit.replace(token.syntax().text_range(), format!("\"{}\"", escaped)); + }, + ) +} + +// Assist: add_hash +// +// Adds a hash to a raw string literal. +// +// ``` +// fn main() { +// r#"Hello,$0 World!"#; +// } +// ``` +// -> +// ``` +// fn main() { +// r##"Hello, World!"##; +// } +// ``` +pub(crate) fn add_hash(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let token = ctx.find_token_at_offset::()?; + if !token.is_raw() { + return None; + } + let text_range = token.syntax().text_range(); + let target = text_range; + acc.add(AssistId("add_hash", AssistKind::Refactor), "Add #", target, |edit| { + edit.insert(text_range.start() + TextSize::of('r'), "#"); + edit.insert(text_range.end(), "#"); + }) +} + +// Assist: remove_hash +// +// Removes a hash from a raw string literal. +// +// ``` +// fn main() { +// r#"Hello,$0 World!"#; +// } +// ``` +// -> +// ``` +// fn main() { +// r"Hello, World!"; +// } +// ``` +pub(crate) fn remove_hash(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let token = ctx.find_token_at_offset::()?; + if !token.is_raw() { + return None; + } + + let text = token.text(); + if !text.starts_with("r#") && text.ends_with('#') { + return None; + } + + let existing_hashes = text.chars().skip(1).take_while(|&it| it == '#').count(); + + let text_range = token.syntax().text_range(); + let internal_text = &text[token.text_range_between_quotes()? - text_range.start()]; + + if existing_hashes == required_hashes(internal_text) { + cov_mark::hit!(cant_remove_required_hash); + return None; + } + + acc.add(AssistId("remove_hash", AssistKind::RefactorRewrite), "Remove #", text_range, |edit| { + edit.delete(TextRange::at(text_range.start() + TextSize::of('r'), TextSize::of('#'))); + edit.delete(TextRange::new(text_range.end() - TextSize::of('#'), text_range.end())); + }) +} + +fn required_hashes(s: &str) -> usize { + let mut res = 0usize; + for idx in s.match_indices('"').map(|(i, _)| i) { + let (_, sub) = s.split_at(idx + 1); + let n_hashes = sub.chars().take_while(|c| *c == '#').count(); + res = res.max(n_hashes + 1) + } + res +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + use super::*; + + #[test] + fn test_required_hashes() { + assert_eq!(0, required_hashes("abc")); + assert_eq!(0, required_hashes("###")); + assert_eq!(1, required_hashes("\"")); + assert_eq!(2, required_hashes("\"#abc")); + assert_eq!(0, required_hashes("#abc")); + assert_eq!(3, required_hashes("#ab\"##c")); + assert_eq!(5, required_hashes("#ab\"##\"####c")); + } + + #[test] + fn make_raw_string_target() { + check_assist_target( + make_raw_string, + r#" + fn f() { + let s = $0"random\nstring"; + } + "#, + r#""random\nstring""#, + ); + } + + #[test] + fn make_raw_string_works() { + check_assist( + make_raw_string, + r#" +fn f() { + let s = $0"random\nstring"; +} +"#, + r##" +fn f() { + let s = r#"random +string"#; +} +"##, + ) + } + + #[test] + fn make_raw_string_works_inside_macros() { + check_assist( + make_raw_string, + r#" + fn f() { + format!($0"x = {}", 92) + } + "#, + r##" + fn f() { + format!(r#"x = {}"#, 92) + } + "##, + ) + } + + #[test] + fn make_raw_string_hashes_inside_works() { + check_assist( + make_raw_string, + r###" +fn f() { + let s = $0"#random##\nstring"; +} +"###, + r####" +fn f() { + let s = r#"#random## +string"#; +} +"####, + ) + } + + #[test] + fn make_raw_string_closing_hashes_inside_works() { + check_assist( + make_raw_string, + r###" +fn f() { + let s = $0"#random\"##\nstring"; +} +"###, + r####" +fn f() { + let s = r###"#random"## +string"###; +} +"####, + ) + } + + #[test] + fn make_raw_string_nothing_to_unescape_works() { + check_assist( + make_raw_string, + r#" + fn f() { + let s = $0"random string"; + } + "#, + r##" + fn f() { + let s = r#"random string"#; + } + "##, + ) + } + + #[test] + fn make_raw_string_not_works_on_partial_string() { + check_assist_not_applicable( + make_raw_string, + r#" + fn f() { + let s = "foo$0 + } + "#, + ) + } + + #[test] + fn make_usual_string_not_works_on_partial_string() { + check_assist_not_applicable( + make_usual_string, + r#" + fn main() { + let s = r#"bar$0 + } + "#, + ) + } + + #[test] + fn add_hash_target() { + check_assist_target( + add_hash, + r#" + fn f() { + let s = $0r"random string"; + } + "#, + r#"r"random string""#, + ); + } + + #[test] + fn add_hash_works() { + check_assist( + add_hash, + r#" + fn f() { + let s = $0r"random string"; + } + "#, + r##" + fn f() { + let s = r#"random string"#; + } + "##, + ) + } + + #[test] + fn add_more_hash_works() { + check_assist( + add_hash, + r##" + fn f() { + let s = $0r#"random"string"#; + } + "##, + r###" + fn f() { + let s = r##"random"string"##; + } + "###, + ) + } + + #[test] + fn add_hash_not_works() { + check_assist_not_applicable( + add_hash, + r#" + fn f() { + let s = $0"random string"; + } + "#, + ); + } + + #[test] + fn remove_hash_target() { + check_assist_target( + remove_hash, + r##" + fn f() { + let s = $0r#"random string"#; + } + "##, + r##"r#"random string"#"##, + ); + } + + #[test] + fn remove_hash_works() { + check_assist( + remove_hash, + r##"fn f() { let s = $0r#"random string"#; }"##, + r#"fn f() { let s = r"random string"; }"#, + ) + } + + #[test] + fn cant_remove_required_hash() { + cov_mark::check!(cant_remove_required_hash); + check_assist_not_applicable( + remove_hash, + r##" + fn f() { + let s = $0r#"random"str"ing"#; + } + "##, + ) + } + + #[test] + fn remove_more_hash_works() { + check_assist( + remove_hash, + r###" + fn f() { + let s = $0r##"random string"##; + } + "###, + r##" + fn f() { + let s = r#"random string"#; + } + "##, + ) + } + + #[test] + fn remove_hash_doesnt_work() { + check_assist_not_applicable(remove_hash, r#"fn f() { let s = $0"random string"; }"#); + } + + #[test] + fn remove_hash_no_hash_doesnt_work() { + check_assist_not_applicable(remove_hash, r#"fn f() { let s = $0r"random string"; }"#); + } + + #[test] + fn make_usual_string_target() { + check_assist_target( + make_usual_string, + r##" + fn f() { + let s = $0r#"random string"#; + } + "##, + r##"r#"random string"#"##, + ); + } + + #[test] + fn make_usual_string_works() { + check_assist( + make_usual_string, + r##" + fn f() { + let s = $0r#"random string"#; + } + "##, + r#" + fn f() { + let s = "random string"; + } + "#, + ) + } + + #[test] + fn make_usual_string_with_quote_works() { + check_assist( + make_usual_string, + r##" + fn f() { + let s = $0r#"random"str"ing"#; + } + "##, + r#" + fn f() { + let s = "random\"str\"ing"; + } + "#, + ) + } + + #[test] + fn make_usual_string_more_hash_works() { + check_assist( + make_usual_string, + r###" + fn f() { + let s = $0r##"random string"##; + } + "###, + r##" + fn f() { + let s = "random string"; + } + "##, + ) + } + + #[test] + fn make_usual_string_not_works() { + check_assist_not_applicable( + make_usual_string, + r#" + fn f() { + let s = $0"random string"; + } + "#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_dbg.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_dbg.rs new file mode 100644 index 000000000..afaa7c933 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_dbg.rs @@ -0,0 +1,241 @@ +use itertools::Itertools; +use syntax::{ + ast::{self, AstNode, AstToken}, + match_ast, NodeOrToken, SyntaxElement, TextSize, T, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: remove_dbg +// +// Removes `dbg!()` macro call. +// +// ``` +// fn main() { +// $0dbg!(92); +// } +// ``` +// -> +// ``` +// fn main() { +// 92; +// } +// ``` +pub(crate) fn remove_dbg(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let macro_call = ctx.find_node_at_offset::()?; + let tt = macro_call.token_tree()?; + let r_delim = NodeOrToken::Token(tt.right_delimiter_token()?); + if macro_call.path()?.segment()?.name_ref()?.text() != "dbg" + || macro_call.excl_token().is_none() + { + return None; + } + + let mac_input = tt.syntax().children_with_tokens().skip(1).take_while(|it| *it != r_delim); + let input_expressions = mac_input.group_by(|tok| tok.kind() == T![,]); + let input_expressions = input_expressions + .into_iter() + .filter_map(|(is_sep, group)| (!is_sep).then(|| group)) + .map(|mut tokens| syntax::hacks::parse_expr_from_str(&tokens.join(""))) + .collect::>>()?; + + let macro_expr = ast::MacroExpr::cast(macro_call.syntax().parent()?)?; + let parent = macro_expr.syntax().parent()?; + let (range, text) = match &*input_expressions { + // dbg!() + [] => { + match_ast! { + match parent { + ast::StmtList(__) => { + let range = macro_expr.syntax().text_range(); + let range = match whitespace_start(macro_expr.syntax().prev_sibling_or_token()) { + Some(start) => range.cover_offset(start), + None => range, + }; + (range, String::new()) + }, + ast::ExprStmt(it) => { + let range = it.syntax().text_range(); + let range = match whitespace_start(it.syntax().prev_sibling_or_token()) { + Some(start) => range.cover_offset(start), + None => range, + }; + (range, String::new()) + }, + _ => (macro_call.syntax().text_range(), "()".to_owned()) + } + } + } + // dbg!(expr0) + [expr] => { + let wrap = match ast::Expr::cast(parent) { + Some(parent) => match (expr, parent) { + (ast::Expr::CastExpr(_), ast::Expr::CastExpr(_)) => false, + ( + ast::Expr::BoxExpr(_) | ast::Expr::PrefixExpr(_) | ast::Expr::RefExpr(_), + ast::Expr::AwaitExpr(_) + | ast::Expr::CallExpr(_) + | ast::Expr::CastExpr(_) + | ast::Expr::FieldExpr(_) + | ast::Expr::IndexExpr(_) + | ast::Expr::MethodCallExpr(_) + | ast::Expr::RangeExpr(_) + | ast::Expr::TryExpr(_), + ) => true, + ( + ast::Expr::BinExpr(_) | ast::Expr::CastExpr(_) | ast::Expr::RangeExpr(_), + ast::Expr::AwaitExpr(_) + | ast::Expr::BinExpr(_) + | ast::Expr::CallExpr(_) + | ast::Expr::CastExpr(_) + | ast::Expr::FieldExpr(_) + | ast::Expr::IndexExpr(_) + | ast::Expr::MethodCallExpr(_) + | ast::Expr::PrefixExpr(_) + | ast::Expr::RangeExpr(_) + | ast::Expr::RefExpr(_) + | ast::Expr::TryExpr(_), + ) => true, + _ => false, + }, + None => false, + }; + ( + macro_call.syntax().text_range(), + if wrap { format!("({})", expr) } else { expr.to_string() }, + ) + } + // dbg!(expr0, expr1, ...) + exprs => (macro_call.syntax().text_range(), format!("({})", exprs.iter().format(", "))), + }; + + acc.add(AssistId("remove_dbg", AssistKind::Refactor), "Remove dbg!()", range, |builder| { + builder.replace(range, text); + }) +} + +fn whitespace_start(it: Option) -> Option { + Some(it?.into_token().and_then(ast::Whitespace::cast)?.syntax().text_range().start()) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + fn check(ra_fixture_before: &str, ra_fixture_after: &str) { + check_assist( + remove_dbg, + &format!("fn main() {{\n{}\n}}", ra_fixture_before), + &format!("fn main() {{\n{}\n}}", ra_fixture_after), + ); + } + + #[test] + fn test_remove_dbg() { + check("$0dbg!(1 + 1)", "1 + 1"); + check("dbg!$0(1 + 1)", "1 + 1"); + check("dbg!(1 $0+ 1)", "1 + 1"); + check("dbg![$01 + 1]", "1 + 1"); + check("dbg!{$01 + 1}", "1 + 1"); + } + + #[test] + fn test_remove_dbg_not_applicable() { + check_assist_not_applicable(remove_dbg, "fn main() {$0vec![1, 2, 3]}"); + check_assist_not_applicable(remove_dbg, "fn main() {$0dbg(5, 6, 7)}"); + check_assist_not_applicable(remove_dbg, "fn main() {$0dbg!(5, 6, 7}"); + } + + #[test] + fn test_remove_dbg_keep_semicolon_in_let() { + // https://github.com/rust-lang/rust-analyzer/issues/5129#issuecomment-651399779 + check( + r#"let res = $0dbg!(1 * 20); // needless comment"#, + r#"let res = 1 * 20; // needless comment"#, + ); + check(r#"let res = $0dbg!(); // needless comment"#, r#"let res = (); // needless comment"#); + check( + r#"let res = $0dbg!(1, 2); // needless comment"#, + r#"let res = (1, 2); // needless comment"#, + ); + } + + #[test] + fn test_remove_dbg_cast_cast() { + check(r#"let res = $0dbg!(x as u32) as u32;"#, r#"let res = x as u32 as u32;"#); + } + + #[test] + fn test_remove_dbg_prefix() { + check(r#"let res = $0dbg!(&result).foo();"#, r#"let res = (&result).foo();"#); + check(r#"let res = &$0dbg!(&result);"#, r#"let res = &&result;"#); + check(r#"let res = $0dbg!(!result) && true;"#, r#"let res = !result && true;"#); + } + + #[test] + fn test_remove_dbg_post_expr() { + check(r#"let res = $0dbg!(fut.await).foo();"#, r#"let res = fut.await.foo();"#); + check(r#"let res = $0dbg!(result?).foo();"#, r#"let res = result?.foo();"#); + check(r#"let res = $0dbg!(foo as u32).foo();"#, r#"let res = (foo as u32).foo();"#); + check(r#"let res = $0dbg!(array[3]).foo();"#, r#"let res = array[3].foo();"#); + check(r#"let res = $0dbg!(tuple.3).foo();"#, r#"let res = tuple.3.foo();"#); + } + + #[test] + fn test_remove_dbg_range_expr() { + check(r#"let res = $0dbg!(foo..bar).foo();"#, r#"let res = (foo..bar).foo();"#); + check(r#"let res = $0dbg!(foo..=bar).foo();"#, r#"let res = (foo..=bar).foo();"#); + } + + #[test] + fn test_remove_empty_dbg() { + check_assist(remove_dbg, r#"fn foo() { $0dbg!(); }"#, r#"fn foo() { }"#); + check_assist( + remove_dbg, + r#" +fn foo() { + $0dbg!(); +} +"#, + r#" +fn foo() { +} +"#, + ); + check_assist( + remove_dbg, + r#" +fn foo() { + let test = $0dbg!(); +}"#, + r#" +fn foo() { + let test = (); +}"#, + ); + check_assist( + remove_dbg, + r#" +fn foo() { + let t = { + println!("Hello, world"); + $0dbg!() + }; +}"#, + r#" +fn foo() { + let t = { + println!("Hello, world"); + }; +}"#, + ); + } + + #[test] + fn test_remove_multi_dbg() { + check(r#"$0dbg!(0, 1)"#, r#"(0, 1)"#); + check(r#"$0dbg!(0, (1, 2))"#, r#"(0, (1, 2))"#); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_mut.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_mut.rs new file mode 100644 index 000000000..0b299e834 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_mut.rs @@ -0,0 +1,37 @@ +use syntax::{SyntaxKind, TextRange, T}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: remove_mut +// +// Removes the `mut` keyword. +// +// ``` +// impl Walrus { +// fn feed(&mut$0 self, amount: u32) {} +// } +// ``` +// -> +// ``` +// impl Walrus { +// fn feed(&self, amount: u32) {} +// } +// ``` +pub(crate) fn remove_mut(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let mut_token = ctx.find_token_syntax_at_offset(T![mut])?; + let delete_from = mut_token.text_range().start(); + let delete_to = match mut_token.next_token() { + Some(it) if it.kind() == SyntaxKind::WHITESPACE => it.text_range().end(), + _ => mut_token.text_range().end(), + }; + + let target = mut_token.text_range(); + acc.add( + AssistId("remove_mut", AssistKind::Refactor), + "Remove `mut` keyword", + target, + |builder| { + builder.delete(TextRange::new(delete_from, delete_to)); + }, + ) +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_unused_param.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_unused_param.rs new file mode 100644 index 000000000..59ea94ea1 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_unused_param.rs @@ -0,0 +1,409 @@ +use ide_db::{base_db::FileId, defs::Definition, search::FileReference}; +use syntax::{ + algo::find_node_at_range, + ast::{self, HasArgList}, + AstNode, SourceFile, SyntaxKind, SyntaxNode, TextRange, T, +}; + +use SyntaxKind::WHITESPACE; + +use crate::{ + assist_context::AssistBuilder, utils::next_prev, AssistContext, AssistId, AssistKind, Assists, +}; + +// Assist: remove_unused_param +// +// Removes unused function parameter. +// +// ``` +// fn frobnicate(x: i32$0) {} +// +// fn main() { +// frobnicate(92); +// } +// ``` +// -> +// ``` +// fn frobnicate() {} +// +// fn main() { +// frobnicate(); +// } +// ``` +pub(crate) fn remove_unused_param(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let param: ast::Param = ctx.find_node_at_offset()?; + let ident_pat = match param.pat()? { + ast::Pat::IdentPat(it) => it, + _ => return None, + }; + let func = param.syntax().ancestors().find_map(ast::Fn::cast)?; + let is_self_present = + param.syntax().parent()?.children().find_map(ast::SelfParam::cast).is_some(); + + // check if fn is in impl Trait for .. + if func + .syntax() + .parent() // AssocItemList + .and_then(|x| x.parent()) + .and_then(ast::Impl::cast) + .map_or(false, |imp| imp.trait_().is_some()) + { + cov_mark::hit!(trait_impl); + return None; + } + + let mut param_position = func.param_list()?.params().position(|it| it == param)?; + // param_list() does not take the self param into consideration, hence this additional check + // is required. For associated functions, param_position is incremented here. For inherent + // calls we revet the increment below, in process_usage, as those calls will not have an + // explicit self parameter. + if is_self_present { + param_position += 1; + } + let fn_def = { + let func = ctx.sema.to_def(&func)?; + Definition::Function(func) + }; + + let param_def = { + let local = ctx.sema.to_def(&ident_pat)?; + Definition::Local(local) + }; + if param_def.usages(&ctx.sema).at_least_one() { + cov_mark::hit!(keep_used); + return None; + } + acc.add( + AssistId("remove_unused_param", AssistKind::Refactor), + "Remove unused parameter", + param.syntax().text_range(), + |builder| { + builder.delete(range_to_remove(param.syntax())); + for (file_id, references) in fn_def.usages(&ctx.sema).all() { + process_usages(ctx, builder, file_id, references, param_position, is_self_present); + } + }, + ) +} + +fn process_usages( + ctx: &AssistContext<'_>, + builder: &mut AssistBuilder, + file_id: FileId, + references: Vec, + arg_to_remove: usize, + is_self_present: bool, +) { + let source_file = ctx.sema.parse(file_id); + builder.edit_file(file_id); + let possible_ranges = references + .into_iter() + .filter_map(|usage| process_usage(&source_file, usage, arg_to_remove, is_self_present)); + + let mut ranges_to_delete: Vec = vec![]; + for range in possible_ranges { + if !ranges_to_delete.iter().any(|it| it.contains_range(range)) { + ranges_to_delete.push(range) + } + } + + for range in ranges_to_delete { + builder.delete(range) + } +} + +fn process_usage( + source_file: &SourceFile, + FileReference { range, .. }: FileReference, + mut arg_to_remove: usize, + is_self_present: bool, +) -> Option { + let call_expr_opt: Option = find_node_at_range(source_file.syntax(), range); + if let Some(call_expr) = call_expr_opt { + let call_expr_range = call_expr.expr()?.syntax().text_range(); + if !call_expr_range.contains_range(range) { + return None; + } + + let arg = call_expr.arg_list()?.args().nth(arg_to_remove)?; + return Some(range_to_remove(arg.syntax())); + } + + let method_call_expr_opt: Option = + find_node_at_range(source_file.syntax(), range); + if let Some(method_call_expr) = method_call_expr_opt { + let method_call_expr_range = method_call_expr.name_ref()?.syntax().text_range(); + if !method_call_expr_range.contains_range(range) { + return None; + } + + if is_self_present { + arg_to_remove -= 1; + } + + let arg = method_call_expr.arg_list()?.args().nth(arg_to_remove)?; + return Some(range_to_remove(arg.syntax())); + } + + None +} + +pub(crate) fn range_to_remove(node: &SyntaxNode) -> TextRange { + let up_to_comma = next_prev().find_map(|dir| { + node.siblings_with_tokens(dir) + .filter_map(|it| it.into_token()) + .find(|it| it.kind() == T![,]) + .map(|it| (dir, it)) + }); + if let Some((dir, token)) = up_to_comma { + if node.next_sibling().is_some() { + let up_to_space = token + .siblings_with_tokens(dir) + .skip(1) + .take_while(|it| it.kind() == WHITESPACE) + .last() + .and_then(|it| it.into_token()); + return node + .text_range() + .cover(up_to_space.map_or(token.text_range(), |it| it.text_range())); + } + node.text_range().cover(token.text_range()) + } else { + node.text_range() + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn remove_unused() { + check_assist( + remove_unused_param, + r#" +fn a() { foo(9, 2) } +fn foo(x: i32, $0y: i32) { x; } +fn b() { foo(9, 2,) } +"#, + r#" +fn a() { foo(9) } +fn foo(x: i32) { x; } +fn b() { foo(9, ) } +"#, + ); + } + + #[test] + fn remove_unused_first_param() { + check_assist( + remove_unused_param, + r#" +fn foo($0x: i32, y: i32) { y; } +fn a() { foo(1, 2) } +fn b() { foo(1, 2,) } +"#, + r#" +fn foo(y: i32) { y; } +fn a() { foo(2) } +fn b() { foo(2,) } +"#, + ); + } + + #[test] + fn remove_unused_single_param() { + check_assist( + remove_unused_param, + r#" +fn foo($0x: i32) { 0; } +fn a() { foo(1) } +fn b() { foo(1, ) } +"#, + r#" +fn foo() { 0; } +fn a() { foo() } +fn b() { foo( ) } +"#, + ); + } + + #[test] + fn remove_unused_surrounded_by_parms() { + check_assist( + remove_unused_param, + r#" +fn foo(x: i32, $0y: i32, z: i32) { x; } +fn a() { foo(1, 2, 3) } +fn b() { foo(1, 2, 3,) } +"#, + r#" +fn foo(x: i32, z: i32) { x; } +fn a() { foo(1, 3) } +fn b() { foo(1, 3,) } +"#, + ); + } + + #[test] + fn remove_unused_qualified_call() { + check_assist( + remove_unused_param, + r#" +mod bar { pub fn foo(x: i32, $0y: i32) { x; } } +fn b() { bar::foo(9, 2) } +"#, + r#" +mod bar { pub fn foo(x: i32) { x; } } +fn b() { bar::foo(9) } +"#, + ); + } + + #[test] + fn remove_unused_turbofished_func() { + check_assist( + remove_unused_param, + r#" +pub fn foo(x: T, $0y: i32) { x; } +fn b() { foo::(9, 2) } +"#, + r#" +pub fn foo(x: T) { x; } +fn b() { foo::(9) } +"#, + ); + } + + #[test] + fn remove_unused_generic_unused_param_func() { + check_assist( + remove_unused_param, + r#" +pub fn foo(x: i32, $0y: T) { x; } +fn b() { foo::(9, 2) } +fn b2() { foo(9, 2) } +"#, + r#" +pub fn foo(x: i32) { x; } +fn b() { foo::(9) } +fn b2() { foo(9) } +"#, + ); + } + + #[test] + fn keep_used() { + cov_mark::check!(keep_used); + check_assist_not_applicable( + remove_unused_param, + r#" +fn foo(x: i32, $0y: i32) { y; } +fn main() { foo(9, 2) } +"#, + ); + } + + #[test] + fn trait_impl() { + cov_mark::check!(trait_impl); + check_assist_not_applicable( + remove_unused_param, + r#" +trait Trait { + fn foo(x: i32); +} +impl Trait for () { + fn foo($0x: i32) {} +} +"#, + ); + } + + #[test] + fn remove_across_files() { + check_assist( + remove_unused_param, + r#" +//- /main.rs +fn foo(x: i32, $0y: i32) { x; } + +mod foo; + +//- /foo.rs +use super::foo; + +fn bar() { + let _ = foo(1, 2); +} +"#, + r#" +//- /main.rs +fn foo(x: i32) { x; } + +mod foo; + +//- /foo.rs +use super::foo; + +fn bar() { + let _ = foo(1); +} +"#, + ) + } + + #[test] + fn test_remove_method_param() { + check_assist( + remove_unused_param, + r#" +struct S; +impl S { fn f(&self, $0_unused: i32) {} } +fn main() { + S.f(92); + S.f(); + S.f(93, 92); + S::f(&S, 92); +} +"#, + r#" +struct S; +impl S { fn f(&self) {} } +fn main() { + S.f(); + S.f(); + S.f(92); + S::f(&S); +} +"#, + ) + } + + #[test] + fn nested_call() { + check_assist( + remove_unused_param, + r#" +fn foo(x: i32, $0y: i32) -> i32 { + x +} + +fn bar() { + foo(1, foo(2, 3)); +} +"#, + r#" +fn foo(x: i32) -> i32 { + x +} + +fn bar() { + foo(1); +} +"#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_fields.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_fields.rs new file mode 100644 index 000000000..a899c7a64 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_fields.rs @@ -0,0 +1,212 @@ +use either::Either; +use ide_db::FxHashMap; +use itertools::Itertools; +use syntax::{ast, ted, AstNode}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: reorder_fields +// +// Reorder the fields of record literals and record patterns in the same order as in +// the definition. +// +// ``` +// struct Foo {foo: i32, bar: i32}; +// const test: Foo = $0Foo {bar: 0, foo: 1} +// ``` +// -> +// ``` +// struct Foo {foo: i32, bar: i32}; +// const test: Foo = Foo {foo: 1, bar: 0} +// ``` +pub(crate) fn reorder_fields(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let record = ctx + .find_node_at_offset::() + .map(Either::Left) + .or_else(|| ctx.find_node_at_offset::().map(Either::Right))?; + + let path = record.as_ref().either(|it| it.path(), |it| it.path())?; + let ranks = compute_fields_ranks(&path, ctx)?; + let get_rank_of_field = + |of: Option<_>| *ranks.get(&of.unwrap_or_default()).unwrap_or(&usize::MAX); + + let field_list = match &record { + Either::Left(it) => Either::Left(it.record_expr_field_list()?), + Either::Right(it) => Either::Right(it.record_pat_field_list()?), + }; + let fields = match field_list { + Either::Left(it) => Either::Left(( + it.fields() + .sorted_unstable_by_key(|field| { + get_rank_of_field(field.field_name().map(|it| it.to_string())) + }) + .collect::>(), + it, + )), + Either::Right(it) => Either::Right(( + it.fields() + .sorted_unstable_by_key(|field| { + get_rank_of_field(field.field_name().map(|it| it.to_string())) + }) + .collect::>(), + it, + )), + }; + + let is_sorted = fields.as_ref().either( + |(sorted, field_list)| field_list.fields().zip(sorted).all(|(a, b)| a == *b), + |(sorted, field_list)| field_list.fields().zip(sorted).all(|(a, b)| a == *b), + ); + if is_sorted { + cov_mark::hit!(reorder_sorted_fields); + return None; + } + let target = record.as_ref().either(AstNode::syntax, AstNode::syntax).text_range(); + acc.add( + AssistId("reorder_fields", AssistKind::RefactorRewrite), + "Reorder record fields", + target, + |builder| match fields { + Either::Left((sorted, field_list)) => { + replace(builder.make_mut(field_list).fields(), sorted) + } + Either::Right((sorted, field_list)) => { + replace(builder.make_mut(field_list).fields(), sorted) + } + }, + ) +} + +fn replace( + fields: impl Iterator, + sorted_fields: impl IntoIterator, +) { + fields.zip(sorted_fields).for_each(|(field, sorted_field)| { + ted::replace(field.syntax(), sorted_field.syntax().clone_for_update()) + }); +} + +fn compute_fields_ranks( + path: &ast::Path, + ctx: &AssistContext<'_>, +) -> Option> { + let strukt = match ctx.sema.resolve_path(path) { + Some(hir::PathResolution::Def(hir::ModuleDef::Adt(hir::Adt::Struct(it)))) => it, + _ => return None, + }; + + let res = strukt + .fields(ctx.db()) + .into_iter() + .enumerate() + .map(|(idx, field)| (field.name(ctx.db()).to_string(), idx)) + .collect(); + + Some(res) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn reorder_sorted_fields() { + cov_mark::check!(reorder_sorted_fields); + check_assist_not_applicable( + reorder_fields, + r#" +struct Foo { foo: i32, bar: i32 } +const test: Foo = $0Foo { foo: 0, bar: 0 }; +"#, + ) + } + + #[test] + fn trivial_empty_fields() { + check_assist_not_applicable( + reorder_fields, + r#" +struct Foo {} +const test: Foo = $0Foo {}; +"#, + ) + } + + #[test] + fn reorder_struct_fields() { + check_assist( + reorder_fields, + r#" +struct Foo { foo: i32, bar: i32 } +const test: Foo = $0Foo { bar: 0, foo: 1 }; +"#, + r#" +struct Foo { foo: i32, bar: i32 } +const test: Foo = Foo { foo: 1, bar: 0 }; +"#, + ) + } + #[test] + fn reorder_struct_pattern() { + check_assist( + reorder_fields, + r#" +struct Foo { foo: i64, bar: i64, baz: i64 } + +fn f(f: Foo) -> { + match f { + $0Foo { baz: 0, ref mut bar, .. } => (), + _ => () + } +} +"#, + r#" +struct Foo { foo: i64, bar: i64, baz: i64 } + +fn f(f: Foo) -> { + match f { + Foo { ref mut bar, baz: 0, .. } => (), + _ => () + } +} +"#, + ) + } + + #[test] + fn reorder_with_extra_field() { + check_assist( + reorder_fields, + r#" +struct Foo { foo: String, bar: String } + +impl Foo { + fn new() -> Foo { + let foo = String::new(); + $0Foo { + bar: foo.clone(), + extra: "Extra field", + foo, + } + } +} +"#, + r#" +struct Foo { foo: String, bar: String } + +impl Foo { + fn new() -> Foo { + let foo = String::new(); + Foo { + foo, + bar: foo.clone(), + extra: "Extra field", + } + } +} +"#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_impl_items.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_impl_items.rs new file mode 100644 index 000000000..208c3e109 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_impl_items.rs @@ -0,0 +1,284 @@ +use hir::{PathResolution, Semantics}; +use ide_db::{FxHashMap, RootDatabase}; +use itertools::Itertools; +use syntax::{ + ast::{self, HasName}, + ted, AstNode, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: reorder_impl_items +// +// Reorder the items of an `impl Trait`. The items will be ordered +// in the same order as in the trait definition. +// +// ``` +// trait Foo { +// type A; +// const B: u8; +// fn c(); +// } +// +// struct Bar; +// $0impl Foo for Bar { +// const B: u8 = 17; +// fn c() {} +// type A = String; +// } +// ``` +// -> +// ``` +// trait Foo { +// type A; +// const B: u8; +// fn c(); +// } +// +// struct Bar; +// impl Foo for Bar { +// type A = String; +// const B: u8 = 17; +// fn c() {} +// } +// ``` +pub(crate) fn reorder_impl_items(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let impl_ast = ctx.find_node_at_offset::()?; + let items = impl_ast.assoc_item_list()?; + let assoc_items = items.assoc_items().collect::>(); + + let path = impl_ast + .trait_() + .and_then(|t| match t { + ast::Type::PathType(path) => Some(path), + _ => None, + })? + .path()?; + + let ranks = compute_item_ranks(&path, ctx)?; + let sorted: Vec<_> = assoc_items + .iter() + .cloned() + .sorted_by_key(|i| { + let name = match i { + ast::AssocItem::Const(c) => c.name(), + ast::AssocItem::Fn(f) => f.name(), + ast::AssocItem::TypeAlias(t) => t.name(), + ast::AssocItem::MacroCall(_) => None, + }; + + name.and_then(|n| ranks.get(&n.to_string()).copied()).unwrap_or(usize::max_value()) + }) + .collect(); + + // Don't edit already sorted methods: + if assoc_items == sorted { + cov_mark::hit!(not_applicable_if_sorted); + return None; + } + + let target = items.syntax().text_range(); + acc.add( + AssistId("reorder_impl_items", AssistKind::RefactorRewrite), + "Sort items by trait definition", + target, + |builder| { + let assoc_items = + assoc_items.into_iter().map(|item| builder.make_mut(item)).collect::>(); + assoc_items + .into_iter() + .zip(sorted) + .for_each(|(old, new)| ted::replace(old.syntax(), new.clone_for_update().syntax())); + }, + ) +} + +fn compute_item_ranks( + path: &ast::Path, + ctx: &AssistContext<'_>, +) -> Option> { + let td = trait_definition(path, &ctx.sema)?; + + Some( + td.items(ctx.db()) + .iter() + .flat_map(|i| i.name(ctx.db())) + .enumerate() + .map(|(idx, name)| (name.to_string(), idx)) + .collect(), + ) +} + +fn trait_definition(path: &ast::Path, sema: &Semantics<'_, RootDatabase>) -> Option { + match sema.resolve_path(path)? { + PathResolution::Def(hir::ModuleDef::Trait(trait_)) => Some(trait_), + _ => None, + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn not_applicable_if_sorted() { + cov_mark::check!(not_applicable_if_sorted); + check_assist_not_applicable( + reorder_impl_items, + r#" +trait Bar { + type T; + const C: (); + fn a() {} + fn z() {} + fn b() {} +} +struct Foo; +$0impl Bar for Foo { + type T = (); + const C: () = (); + fn a() {} + fn z() {} + fn b() {} +} + "#, + ) + } + + #[test] + fn reorder_impl_trait_functions() { + check_assist( + reorder_impl_items, + r#" +trait Bar { + fn a() {} + fn c() {} + fn b() {} + fn d() {} +} + +struct Foo; +$0impl Bar for Foo { + fn d() {} + fn b() {} + fn c() {} + fn a() {} +} +"#, + r#" +trait Bar { + fn a() {} + fn c() {} + fn b() {} + fn d() {} +} + +struct Foo; +impl Bar for Foo { + fn a() {} + fn c() {} + fn b() {} + fn d() {} +} +"#, + ) + } + + #[test] + fn not_applicable_if_empty() { + check_assist_not_applicable( + reorder_impl_items, + r#" +trait Bar {}; +struct Foo; +$0impl Bar for Foo {} + "#, + ) + } + + #[test] + fn reorder_impl_trait_items() { + check_assist( + reorder_impl_items, + r#" +trait Bar { + fn a() {} + type T0; + fn c() {} + const C1: (); + fn b() {} + type T1; + fn d() {} + const C0: (); +} + +struct Foo; +$0impl Bar for Foo { + type T1 = (); + fn d() {} + fn b() {} + fn c() {} + const C1: () = (); + fn a() {} + type T0 = (); + const C0: () = (); +} + "#, + r#" +trait Bar { + fn a() {} + type T0; + fn c() {} + const C1: (); + fn b() {} + type T1; + fn d() {} + const C0: (); +} + +struct Foo; +impl Bar for Foo { + fn a() {} + type T0 = (); + fn c() {} + const C1: () = (); + fn b() {} + type T1 = (); + fn d() {} + const C0: () = (); +} + "#, + ) + } + + #[test] + fn reorder_impl_trait_items_uneven_ident_lengths() { + check_assist( + reorder_impl_items, + r#" +trait Bar { + type Foo; + type Fooo; +} + +struct Foo; +impl Bar for Foo { + type Fooo = (); + type Foo = ();$0 +}"#, + r#" +trait Bar { + type Foo; + type Fooo; +} + +struct Foo; +impl Bar for Foo { + type Foo = (); + type Fooo = (); +}"#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs new file mode 100644 index 000000000..bd50208da --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs @@ -0,0 +1,1250 @@ +use hir::{InFile, ModuleDef}; +use ide_db::{ + helpers::mod_path_to_ast, imports::import_assets::NameToImport, items_locator, + syntax_helpers::insert_whitespace_into_node::insert_ws_into, +}; +use itertools::Itertools; +use syntax::{ + ast::{self, AstNode, HasName}, + SyntaxKind::WHITESPACE, +}; + +use crate::{ + assist_context::{AssistBuilder, AssistContext, Assists}, + utils::{ + add_trait_assoc_items_to_impl, filter_assoc_items, gen_trait_fn_body, + generate_trait_impl_text, render_snippet, Cursor, DefaultMethods, + }, + AssistId, AssistKind, +}; + +// Assist: replace_derive_with_manual_impl +// +// Converts a `derive` impl into a manual one. +// +// ``` +// # //- minicore: derive +// # trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; } +// #[derive(Deb$0ug, Display)] +// struct S; +// ``` +// -> +// ``` +// # trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; } +// #[derive(Display)] +// struct S; +// +// impl Debug for S { +// $0fn fmt(&self, f: &mut Formatter) -> Result<()> { +// f.debug_struct("S").finish() +// } +// } +// ``` +pub(crate) fn replace_derive_with_manual_impl( + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { + let attr = ctx.find_node_at_offset_with_descend::()?; + let path = attr.path()?; + let hir_file = ctx.sema.hir_file_for(attr.syntax()); + if !hir_file.is_derive_attr_pseudo_expansion(ctx.db()) { + return None; + } + + let InFile { file_id, value } = hir_file.call_node(ctx.db())?; + if file_id.is_macro() { + // FIXME: make this work in macro files + return None; + } + // collect the derive paths from the #[derive] expansion + let current_derives = ctx + .sema + .parse_or_expand(hir_file)? + .descendants() + .filter_map(ast::Attr::cast) + .filter_map(|attr| attr.path()) + .collect::>(); + + let adt = value.parent().and_then(ast::Adt::cast)?; + let attr = ast::Attr::cast(value)?; + let args = attr.token_tree()?; + + let current_module = ctx.sema.scope(adt.syntax())?.module(); + let current_crate = current_module.krate(); + + let found_traits = items_locator::items_with_name( + &ctx.sema, + current_crate, + NameToImport::exact_case_sensitive(path.segments().last()?.to_string()), + items_locator::AssocItemSearch::Exclude, + Some(items_locator::DEFAULT_QUERY_SEARCH_LIMIT.inner()), + ) + .filter_map(|item| match item.as_module_def()? { + ModuleDef::Trait(trait_) => Some(trait_), + _ => None, + }) + .flat_map(|trait_| { + current_module + .find_use_path(ctx.sema.db, hir::ModuleDef::Trait(trait_)) + .as_ref() + .map(mod_path_to_ast) + .zip(Some(trait_)) + }); + + let mut no_traits_found = true; + for (replace_trait_path, trait_) in found_traits.inspect(|_| no_traits_found = false) { + add_assist( + acc, + ctx, + &attr, + ¤t_derives, + &args, + &path, + &replace_trait_path, + Some(trait_), + &adt, + )?; + } + if no_traits_found { + add_assist(acc, ctx, &attr, ¤t_derives, &args, &path, &path, None, &adt)?; + } + Some(()) +} + +fn add_assist( + acc: &mut Assists, + ctx: &AssistContext<'_>, + attr: &ast::Attr, + old_derives: &[ast::Path], + old_tree: &ast::TokenTree, + old_trait_path: &ast::Path, + replace_trait_path: &ast::Path, + trait_: Option, + adt: &ast::Adt, +) -> Option<()> { + let target = attr.syntax().text_range(); + let annotated_name = adt.name()?; + let label = format!("Convert to manual `impl {} for {}`", replace_trait_path, annotated_name); + + acc.add( + AssistId("replace_derive_with_manual_impl", AssistKind::Refactor), + label, + target, + |builder| { + let insert_pos = adt.syntax().text_range().end(); + let impl_def_with_items = + impl_def_from_trait(&ctx.sema, adt, &annotated_name, trait_, replace_trait_path); + update_attribute(builder, old_derives, old_tree, old_trait_path, attr); + let trait_path = replace_trait_path.to_string(); + match (ctx.config.snippet_cap, impl_def_with_items) { + (None, _) => { + builder.insert(insert_pos, generate_trait_impl_text(adt, &trait_path, "")) + } + (Some(cap), None) => builder.insert_snippet( + cap, + insert_pos, + generate_trait_impl_text(adt, &trait_path, " $0"), + ), + (Some(cap), Some((impl_def, first_assoc_item))) => { + let mut cursor = Cursor::Before(first_assoc_item.syntax()); + let placeholder; + if let ast::AssocItem::Fn(ref func) = first_assoc_item { + if let Some(m) = func.syntax().descendants().find_map(ast::MacroCall::cast) + { + if m.syntax().text() == "todo!()" { + placeholder = m; + cursor = Cursor::Replace(placeholder.syntax()); + } + } + } + + builder.insert_snippet( + cap, + insert_pos, + format!("\n\n{}", render_snippet(cap, impl_def.syntax(), cursor)), + ) + } + }; + }, + ) +} + +fn impl_def_from_trait( + sema: &hir::Semantics<'_, ide_db::RootDatabase>, + adt: &ast::Adt, + annotated_name: &ast::Name, + trait_: Option, + trait_path: &ast::Path, +) -> Option<(ast::Impl, ast::AssocItem)> { + let trait_ = trait_?; + let target_scope = sema.scope(annotated_name.syntax())?; + let trait_items = filter_assoc_items(sema, &trait_.items(sema.db), DefaultMethods::No); + if trait_items.is_empty() { + return None; + } + let impl_def = { + use syntax::ast::Impl; + let text = generate_trait_impl_text(adt, trait_path.to_string().as_str(), ""); + let parse = syntax::SourceFile::parse(&text); + let node = match parse.tree().syntax().descendants().find_map(Impl::cast) { + Some(it) => it, + None => { + panic!( + "Failed to make ast node `{}` from text {}", + std::any::type_name::(), + text + ) + } + }; + let node = node.clone_subtree(); + assert_eq!(node.syntax().text_range().start(), 0.into()); + node + }; + + let trait_items = trait_items + .into_iter() + .map(|it| { + if sema.hir_file_for(it.syntax()).is_macro() { + if let Some(it) = ast::AssocItem::cast(insert_ws_into(it.syntax().clone())) { + return it; + } + } + it.clone_for_update() + }) + .collect(); + let (impl_def, first_assoc_item) = + add_trait_assoc_items_to_impl(sema, trait_items, trait_, impl_def, target_scope); + + // Generate a default `impl` function body for the derived trait. + if let ast::AssocItem::Fn(ref func) = first_assoc_item { + let _ = gen_trait_fn_body(func, trait_path, adt); + }; + + Some((impl_def, first_assoc_item)) +} + +fn update_attribute( + builder: &mut AssistBuilder, + old_derives: &[ast::Path], + old_tree: &ast::TokenTree, + old_trait_path: &ast::Path, + attr: &ast::Attr, +) { + let new_derives = old_derives + .iter() + .filter(|t| t.to_string() != old_trait_path.to_string()) + .collect::>(); + let has_more_derives = !new_derives.is_empty(); + + if has_more_derives { + let new_derives = format!("({})", new_derives.iter().format(", ")); + builder.replace(old_tree.syntax().text_range(), new_derives); + } else { + let attr_range = attr.syntax().text_range(); + builder.delete(attr_range); + + if let Some(line_break_range) = attr + .syntax() + .next_sibling_or_token() + .filter(|t| t.kind() == WHITESPACE) + .map(|t| t.text_range()) + { + builder.delete(line_break_range); + } + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn add_custom_impl_debug_record_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: fmt, derive +#[derive(Debu$0g)] +struct Foo { + bar: String, +} +"#, + r#" +struct Foo { + bar: String, +} + +impl core::fmt::Debug for Foo { + $0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Foo").field("bar", &self.bar).finish() + } +} +"#, + ) + } + #[test] + fn add_custom_impl_debug_tuple_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: fmt, derive +#[derive(Debu$0g)] +struct Foo(String, usize); +"#, + r#"struct Foo(String, usize); + +impl core::fmt::Debug for Foo { + $0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_tuple("Foo").field(&self.0).field(&self.1).finish() + } +} +"#, + ) + } + #[test] + fn add_custom_impl_debug_empty_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: fmt, derive +#[derive(Debu$0g)] +struct Foo; +"#, + r#" +struct Foo; + +impl core::fmt::Debug for Foo { + $0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Foo").finish() + } +} +"#, + ) + } + #[test] + fn add_custom_impl_debug_enum() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: fmt, derive +#[derive(Debu$0g)] +enum Foo { + Bar, + Baz, +} +"#, + r#" +enum Foo { + Bar, + Baz, +} + +impl core::fmt::Debug for Foo { + $0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::Bar => write!(f, "Bar"), + Self::Baz => write!(f, "Baz"), + } + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_debug_tuple_enum() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: fmt, derive +#[derive(Debu$0g)] +enum Foo { + Bar(usize, usize), + Baz, +} +"#, + r#" +enum Foo { + Bar(usize, usize), + Baz, +} + +impl core::fmt::Debug for Foo { + $0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::Bar(arg0, arg1) => f.debug_tuple("Bar").field(arg0).field(arg1).finish(), + Self::Baz => write!(f, "Baz"), + } + } +} +"#, + ) + } + #[test] + fn add_custom_impl_debug_record_enum() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: fmt, derive +#[derive(Debu$0g)] +enum Foo { + Bar { + baz: usize, + qux: usize, + }, + Baz, +} +"#, + r#" +enum Foo { + Bar { + baz: usize, + qux: usize, + }, + Baz, +} + +impl core::fmt::Debug for Foo { + $0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::Bar { baz, qux } => f.debug_struct("Bar").field("baz", baz).field("qux", qux).finish(), + Self::Baz => write!(f, "Baz"), + } + } +} +"#, + ) + } + #[test] + fn add_custom_impl_default_record_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: default, derive +#[derive(Defau$0lt)] +struct Foo { + foo: usize, +} +"#, + r#" +struct Foo { + foo: usize, +} + +impl Default for Foo { + $0fn default() -> Self { + Self { foo: Default::default() } + } +} +"#, + ) + } + #[test] + fn add_custom_impl_default_tuple_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: default, derive +#[derive(Defau$0lt)] +struct Foo(usize); +"#, + r#" +struct Foo(usize); + +impl Default for Foo { + $0fn default() -> Self { + Self(Default::default()) + } +} +"#, + ) + } + #[test] + fn add_custom_impl_default_empty_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: default, derive +#[derive(Defau$0lt)] +struct Foo; +"#, + r#" +struct Foo; + +impl Default for Foo { + $0fn default() -> Self { + Self { } + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_hash_record_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: hash, derive +#[derive(Has$0h)] +struct Foo { + bin: usize, + bar: usize, +} +"#, + r#" +struct Foo { + bin: usize, + bar: usize, +} + +impl core::hash::Hash for Foo { + $0fn hash(&self, state: &mut H) { + self.bin.hash(state); + self.bar.hash(state); + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_hash_tuple_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: hash, derive +#[derive(Has$0h)] +struct Foo(usize, usize); +"#, + r#" +struct Foo(usize, usize); + +impl core::hash::Hash for Foo { + $0fn hash(&self, state: &mut H) { + self.0.hash(state); + self.1.hash(state); + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_hash_enum() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: hash, derive +#[derive(Has$0h)] +enum Foo { + Bar, + Baz, +} +"#, + r#" +enum Foo { + Bar, + Baz, +} + +impl core::hash::Hash for Foo { + $0fn hash(&self, state: &mut H) { + core::mem::discriminant(self).hash(state); + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_clone_record_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: clone, derive +#[derive(Clo$0ne)] +struct Foo { + bin: usize, + bar: usize, +} +"#, + r#" +struct Foo { + bin: usize, + bar: usize, +} + +impl Clone for Foo { + $0fn clone(&self) -> Self { + Self { bin: self.bin.clone(), bar: self.bar.clone() } + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_clone_tuple_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: clone, derive +#[derive(Clo$0ne)] +struct Foo(usize, usize); +"#, + r#" +struct Foo(usize, usize); + +impl Clone for Foo { + $0fn clone(&self) -> Self { + Self(self.0.clone(), self.1.clone()) + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_clone_empty_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: clone, derive +#[derive(Clo$0ne)] +struct Foo; +"#, + r#" +struct Foo; + +impl Clone for Foo { + $0fn clone(&self) -> Self { + Self { } + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_clone_enum() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: clone, derive +#[derive(Clo$0ne)] +enum Foo { + Bar, + Baz, +} +"#, + r#" +enum Foo { + Bar, + Baz, +} + +impl Clone for Foo { + $0fn clone(&self) -> Self { + match self { + Self::Bar => Self::Bar, + Self::Baz => Self::Baz, + } + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_clone_tuple_enum() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: clone, derive +#[derive(Clo$0ne)] +enum Foo { + Bar(String), + Baz, +} +"#, + r#" +enum Foo { + Bar(String), + Baz, +} + +impl Clone for Foo { + $0fn clone(&self) -> Self { + match self { + Self::Bar(arg0) => Self::Bar(arg0.clone()), + Self::Baz => Self::Baz, + } + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_clone_record_enum() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: clone, derive +#[derive(Clo$0ne)] +enum Foo { + Bar { + bin: String, + }, + Baz, +} +"#, + r#" +enum Foo { + Bar { + bin: String, + }, + Baz, +} + +impl Clone for Foo { + $0fn clone(&self) -> Self { + match self { + Self::Bar { bin } => Self::Bar { bin: bin.clone() }, + Self::Baz => Self::Baz, + } + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_partial_ord_record_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: ord, derive +#[derive(Partial$0Ord)] +struct Foo { + bin: usize, +} +"#, + r#" +struct Foo { + bin: usize, +} + +impl PartialOrd for Foo { + $0fn partial_cmp(&self, other: &Self) -> Option { + self.bin.partial_cmp(&other.bin) + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_partial_ord_record_struct_multi_field() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: ord, derive +#[derive(Partial$0Ord)] +struct Foo { + bin: usize, + bar: usize, + baz: usize, +} +"#, + r#" +struct Foo { + bin: usize, + bar: usize, + baz: usize, +} + +impl PartialOrd for Foo { + $0fn partial_cmp(&self, other: &Self) -> Option { + match self.bin.partial_cmp(&other.bin) { + Some(core::cmp::Ordering::Equal) => {} + ord => return ord, + } + match self.bar.partial_cmp(&other.bar) { + Some(core::cmp::Ordering::Equal) => {} + ord => return ord, + } + self.baz.partial_cmp(&other.baz) + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_partial_ord_tuple_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: ord, derive +#[derive(Partial$0Ord)] +struct Foo(usize, usize, usize); +"#, + r#" +struct Foo(usize, usize, usize); + +impl PartialOrd for Foo { + $0fn partial_cmp(&self, other: &Self) -> Option { + match self.0.partial_cmp(&other.0) { + Some(core::cmp::Ordering::Equal) => {} + ord => return ord, + } + match self.1.partial_cmp(&other.1) { + Some(core::cmp::Ordering::Equal) => {} + ord => return ord, + } + self.2.partial_cmp(&other.2) + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_partial_eq_record_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: eq, derive +#[derive(Partial$0Eq)] +struct Foo { + bin: usize, + bar: usize, +} +"#, + r#" +struct Foo { + bin: usize, + bar: usize, +} + +impl PartialEq for Foo { + $0fn eq(&self, other: &Self) -> bool { + self.bin == other.bin && self.bar == other.bar + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_partial_eq_tuple_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: eq, derive +#[derive(Partial$0Eq)] +struct Foo(usize, usize); +"#, + r#" +struct Foo(usize, usize); + +impl PartialEq for Foo { + $0fn eq(&self, other: &Self) -> bool { + self.0 == other.0 && self.1 == other.1 + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_partial_eq_empty_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: eq, derive +#[derive(Partial$0Eq)] +struct Foo; +"#, + r#" +struct Foo; + +impl PartialEq for Foo { + $0fn eq(&self, other: &Self) -> bool { + true + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_partial_eq_enum() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: eq, derive +#[derive(Partial$0Eq)] +enum Foo { + Bar, + Baz, +} +"#, + r#" +enum Foo { + Bar, + Baz, +} + +impl PartialEq for Foo { + $0fn eq(&self, other: &Self) -> bool { + core::mem::discriminant(self) == core::mem::discriminant(other) + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_partial_eq_tuple_enum() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: eq, derive +#[derive(Partial$0Eq)] +enum Foo { + Bar(String), + Baz, +} +"#, + r#" +enum Foo { + Bar(String), + Baz, +} + +impl PartialEq for Foo { + $0fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Bar(l0), Self::Bar(r0)) => l0 == r0, + _ => core::mem::discriminant(self) == core::mem::discriminant(other), + } + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_partial_eq_record_enum() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: eq, derive +#[derive(Partial$0Eq)] +enum Foo { + Bar { + bin: String, + }, + Baz { + qux: String, + fez: String, + }, + Qux {}, + Bin, +} +"#, + r#" +enum Foo { + Bar { + bin: String, + }, + Baz { + qux: String, + fez: String, + }, + Qux {}, + Bin, +} + +impl PartialEq for Foo { + $0fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin, + (Self::Baz { qux: l_qux, fez: l_fez }, Self::Baz { qux: r_qux, fez: r_fez }) => l_qux == r_qux && l_fez == r_fez, + _ => core::mem::discriminant(self) == core::mem::discriminant(other), + } + } +} +"#, + ) + } + #[test] + fn add_custom_impl_all() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: derive +mod foo { + pub trait Bar { + type Qux; + const Baz: usize = 42; + const Fez: usize; + fn foo(); + fn bar() {} + } +} + +#[derive($0Bar)] +struct Foo { + bar: String, +} +"#, + r#" +mod foo { + pub trait Bar { + type Qux; + const Baz: usize = 42; + const Fez: usize; + fn foo(); + fn bar() {} + } +} + +struct Foo { + bar: String, +} + +impl foo::Bar for Foo { + $0type Qux; + + const Baz: usize = 42; + + const Fez: usize; + + fn foo() { + todo!() + } +} +"#, + ) + } + #[test] + fn add_custom_impl_for_unique_input_unknown() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: derive +#[derive(Debu$0g)] +struct Foo { + bar: String, +} + "#, + r#" +struct Foo { + bar: String, +} + +impl Debug for Foo { + $0 +} + "#, + ) + } + + #[test] + fn add_custom_impl_for_with_visibility_modifier() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: derive +#[derive(Debug$0)] +pub struct Foo { + bar: String, +} + "#, + r#" +pub struct Foo { + bar: String, +} + +impl Debug for Foo { + $0 +} + "#, + ) + } + + #[test] + fn add_custom_impl_when_multiple_inputs() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: derive +#[derive(Display, Debug$0, Serialize)] +struct Foo {} + "#, + r#" +#[derive(Display, Serialize)] +struct Foo {} + +impl Debug for Foo { + $0 +} + "#, + ) + } + + #[test] + fn add_custom_impl_default_generic_record_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: default, derive +#[derive(Defau$0lt)] +struct Foo { + foo: T, + bar: U, +} +"#, + r#" +struct Foo { + foo: T, + bar: U, +} + +impl Default for Foo { + $0fn default() -> Self { + Self { foo: Default::default(), bar: Default::default() } + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_clone_generic_tuple_struct_with_bounds() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: clone, derive +#[derive(Clo$0ne)] +struct Foo(T, usize); +"#, + r#" +struct Foo(T, usize); + +impl Clone for Foo { + $0fn clone(&self) -> Self { + Self(self.0.clone(), self.1.clone()) + } +} +"#, + ) + } + + #[test] + fn test_ignore_derive_macro_without_input() { + check_assist_not_applicable( + replace_derive_with_manual_impl, + r#" +//- minicore: derive +#[derive($0)] +struct Foo {} + "#, + ) + } + + #[test] + fn test_ignore_if_cursor_on_param() { + check_assist_not_applicable( + replace_derive_with_manual_impl, + r#" +//- minicore: derive, fmt +#[derive$0(Debug)] +struct Foo {} + "#, + ); + + check_assist_not_applicable( + replace_derive_with_manual_impl, + r#" +//- minicore: derive, fmt +#[derive(Debug)$0] +struct Foo {} + "#, + ) + } + + #[test] + fn test_ignore_if_not_derive() { + check_assist_not_applicable( + replace_derive_with_manual_impl, + r#" +//- minicore: derive +#[allow(non_camel_$0case_types)] +struct Foo {} + "#, + ) + } + + #[test] + fn works_at_start_of_file() { + check_assist_not_applicable( + replace_derive_with_manual_impl, + r#" +//- minicore: derive, fmt +$0#[derive(Debug)] +struct S; + "#, + ); + } + + #[test] + fn add_custom_impl_keep_path() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: clone, derive +#[derive(std::fmt::Debug, Clo$0ne)] +pub struct Foo; +"#, + r#" +#[derive(std::fmt::Debug)] +pub struct Foo; + +impl Clone for Foo { + $0fn clone(&self) -> Self { + Self { } + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_replace_path() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: fmt, derive +#[derive(core::fmt::Deb$0ug, Clone)] +pub struct Foo; +"#, + r#" +#[derive(Clone)] +pub struct Foo; + +impl core::fmt::Debug for Foo { + $0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Foo").finish() + } +} +"#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_if_let_with_match.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_if_let_with_match.rs new file mode 100644 index 000000000..484c27387 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_if_let_with_match.rs @@ -0,0 +1,999 @@ +use std::iter::{self, successors}; + +use either::Either; +use ide_db::{ + defs::NameClass, + syntax_helpers::node_ext::{is_pattern_cond, single_let}, + ty_filter::TryEnum, + RootDatabase, +}; +use syntax::{ + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + make, HasName, + }, + AstNode, TextRange, +}; + +use crate::{ + utils::{does_nested_pattern, does_pat_match_variant, unwrap_trivial_block}, + AssistContext, AssistId, AssistKind, Assists, +}; + +// Assist: replace_if_let_with_match +// +// Replaces a `if let` expression with a `match` expression. +// +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// $0if let Action::Move { distance } = action { +// foo(distance) +// } else { +// bar() +// } +// } +// ``` +// -> +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// Action::Move { distance } => foo(distance), +// _ => bar(), +// } +// } +// ``` +pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let if_expr: ast::IfExpr = ctx.find_node_at_offset()?; + let available_range = TextRange::new( + if_expr.syntax().text_range().start(), + if_expr.then_branch()?.syntax().text_range().start(), + ); + let cursor_in_range = available_range.contains_range(ctx.selection_trimmed()); + if !cursor_in_range { + return None; + } + let mut else_block = None; + let if_exprs = successors(Some(if_expr.clone()), |expr| match expr.else_branch()? { + ast::ElseBranch::IfExpr(expr) => Some(expr), + ast::ElseBranch::Block(block) => { + else_block = Some(block); + None + } + }); + let scrutinee_to_be_expr = if_expr.condition()?; + let scrutinee_to_be_expr = match single_let(scrutinee_to_be_expr.clone()) { + Some(cond) => cond.expr()?, + None => scrutinee_to_be_expr, + }; + + let mut pat_seen = false; + let mut cond_bodies = Vec::new(); + for if_expr in if_exprs { + let cond = if_expr.condition()?; + let cond = match single_let(cond.clone()) { + Some(let_) => { + let pat = let_.pat()?; + let expr = let_.expr()?; + // FIXME: If one `let` is wrapped in parentheses and the second is not, + // we'll exit here. + if scrutinee_to_be_expr.syntax().text() != expr.syntax().text() { + // Only if all condition expressions are equal we can merge them into a match + return None; + } + pat_seen = true; + Either::Left(pat) + } + // Multiple `let`, unsupported. + None if is_pattern_cond(cond.clone()) => return None, + None => Either::Right(cond), + }; + let body = if_expr.then_branch()?; + cond_bodies.push((cond, body)); + } + + if !pat_seen { + // Don't offer turning an if (chain) without patterns into a match + return None; + } + + acc.add( + AssistId("replace_if_let_with_match", AssistKind::RefactorRewrite), + "Replace if let with match", + available_range, + move |edit| { + let match_expr = { + let else_arm = make_else_arm(ctx, else_block, &cond_bodies); + let make_match_arm = |(pat, body): (_, ast::BlockExpr)| { + let body = body.reset_indent().indent(IndentLevel(1)); + match pat { + Either::Left(pat) => { + make::match_arm(iter::once(pat), None, unwrap_trivial_block(body)) + } + Either::Right(expr) => make::match_arm( + iter::once(make::wildcard_pat().into()), + Some(expr), + unwrap_trivial_block(body), + ), + } + }; + let arms = cond_bodies.into_iter().map(make_match_arm).chain(iter::once(else_arm)); + let match_expr = make::expr_match(scrutinee_to_be_expr, make::match_arm_list(arms)); + match_expr.indent(IndentLevel::from_node(if_expr.syntax())) + }; + + let has_preceding_if_expr = + if_expr.syntax().parent().map_or(false, |it| ast::IfExpr::can_cast(it.kind())); + let expr = if has_preceding_if_expr { + // make sure we replace the `else if let ...` with a block so we don't end up with `else expr` + make::block_expr(None, Some(match_expr)).into() + } else { + match_expr + }; + edit.replace_ast::(if_expr.into(), expr); + }, + ) +} + +fn make_else_arm( + ctx: &AssistContext<'_>, + else_block: Option, + conditionals: &[(Either, ast::BlockExpr)], +) -> ast::MatchArm { + if let Some(else_block) = else_block { + let pattern = if let [(Either::Left(pat), _)] = conditionals { + ctx.sema + .type_of_pat(pat) + .and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty.adjusted())) + .zip(Some(pat)) + } else { + None + }; + let pattern = match pattern { + Some((it, pat)) => { + if does_pat_match_variant(pat, &it.sad_pattern()) { + it.happy_pattern_wildcard() + } else if does_nested_pattern(pat) { + make::wildcard_pat().into() + } else { + it.sad_pattern() + } + } + None => make::wildcard_pat().into(), + }; + make::match_arm(iter::once(pattern), None, unwrap_trivial_block(else_block)) + } else { + make::match_arm(iter::once(make::wildcard_pat().into()), None, make::expr_unit()) + } +} + +// Assist: replace_match_with_if_let +// +// Replaces a binary `match` with a wildcard pattern and no guards with an `if let` expression. +// +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// $0match action { +// Action::Move { distance } => foo(distance), +// _ => bar(), +// } +// } +// ``` +// -> +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// if let Action::Move { distance } = action { +// foo(distance) +// } else { +// bar() +// } +// } +// ``` +pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let match_expr: ast::MatchExpr = ctx.find_node_at_offset()?; + + let mut arms = match_expr.match_arm_list()?.arms(); + let (first_arm, second_arm) = (arms.next()?, arms.next()?); + if arms.next().is_some() || first_arm.guard().is_some() || second_arm.guard().is_some() { + return None; + } + + let (if_let_pat, then_expr, else_expr) = pick_pattern_and_expr_order( + &ctx.sema, + first_arm.pat()?, + second_arm.pat()?, + first_arm.expr()?, + second_arm.expr()?, + )?; + let scrutinee = match_expr.expr()?; + + let target = match_expr.syntax().text_range(); + acc.add( + AssistId("replace_match_with_if_let", AssistKind::RefactorRewrite), + "Replace match with if let", + target, + move |edit| { + fn make_block_expr(expr: ast::Expr) -> ast::BlockExpr { + // Blocks with modifiers (unsafe, async, etc.) are parsed as BlockExpr, but are + // formatted without enclosing braces. If we encounter such block exprs, + // wrap them in another BlockExpr. + match expr { + ast::Expr::BlockExpr(block) if block.modifier().is_none() => block, + expr => make::block_expr(iter::empty(), Some(expr)), + } + } + + let condition = make::expr_let(if_let_pat, scrutinee); + let then_block = make_block_expr(then_expr.reset_indent()); + let else_expr = if is_empty_expr(&else_expr) { None } else { Some(else_expr) }; + let if_let_expr = make::expr_if( + condition.into(), + then_block, + else_expr.map(make_block_expr).map(ast::ElseBranch::Block), + ) + .indent(IndentLevel::from_node(match_expr.syntax())); + + edit.replace_ast::(match_expr.into(), if_let_expr); + }, + ) +} + +/// Pick the pattern for the if let condition and return the expressions for the `then` body and `else` body in that order. +fn pick_pattern_and_expr_order( + sema: &hir::Semantics<'_, RootDatabase>, + pat: ast::Pat, + pat2: ast::Pat, + expr: ast::Expr, + expr2: ast::Expr, +) -> Option<(ast::Pat, ast::Expr, ast::Expr)> { + let res = match (pat, pat2) { + (ast::Pat::WildcardPat(_), _) => return None, + (pat, ast::Pat::WildcardPat(_)) => (pat, expr, expr2), + (pat, _) if is_empty_expr(&expr2) => (pat, expr, expr2), + (_, pat) if is_empty_expr(&expr) => (pat, expr2, expr), + (pat, pat2) => match (binds_name(sema, &pat), binds_name(sema, &pat2)) { + (true, true) => return None, + (true, false) => (pat, expr, expr2), + (false, true) => (pat2, expr2, expr), + _ if is_sad_pat(sema, &pat) => (pat2, expr2, expr), + (false, false) => (pat, expr, expr2), + }, + }; + Some(res) +} + +fn is_empty_expr(expr: &ast::Expr) -> bool { + match expr { + ast::Expr::BlockExpr(expr) => match expr.stmt_list() { + Some(it) => it.statements().next().is_none() && it.tail_expr().is_none(), + None => true, + }, + ast::Expr::TupleExpr(expr) => expr.fields().next().is_none(), + _ => false, + } +} + +fn binds_name(sema: &hir::Semantics<'_, RootDatabase>, pat: &ast::Pat) -> bool { + let binds_name_v = |pat| binds_name(sema, &pat); + match pat { + ast::Pat::IdentPat(pat) => !matches!( + pat.name().and_then(|name| NameClass::classify(sema, &name)), + Some(NameClass::ConstReference(_)) + ), + ast::Pat::MacroPat(_) => true, + ast::Pat::OrPat(pat) => pat.pats().any(binds_name_v), + ast::Pat::SlicePat(pat) => pat.pats().any(binds_name_v), + ast::Pat::TuplePat(it) => it.fields().any(binds_name_v), + ast::Pat::TupleStructPat(it) => it.fields().any(binds_name_v), + ast::Pat::RecordPat(it) => it + .record_pat_field_list() + .map_or(false, |rpfl| rpfl.fields().flat_map(|rpf| rpf.pat()).any(binds_name_v)), + ast::Pat::RefPat(pat) => pat.pat().map_or(false, binds_name_v), + ast::Pat::BoxPat(pat) => pat.pat().map_or(false, binds_name_v), + ast::Pat::ParenPat(pat) => pat.pat().map_or(false, binds_name_v), + _ => false, + } +} + +fn is_sad_pat(sema: &hir::Semantics<'_, RootDatabase>, pat: &ast::Pat) -> bool { + sema.type_of_pat(pat) + .and_then(|ty| TryEnum::from_ty(sema, &ty.adjusted())) + .map_or(false, |it| does_pat_match_variant(pat, &it.sad_pattern())) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + #[test] + fn test_if_let_with_match_unapplicable_for_simple_ifs() { + check_assist_not_applicable( + replace_if_let_with_match, + r#" +fn main() { + if $0true {} else if false {} else {} +} +"#, + ) + } + + #[test] + fn test_if_let_with_match_no_else() { + check_assist( + replace_if_let_with_match, + r#" +impl VariantData { + pub fn foo(&self) { + if $0let VariantData::Struct(..) = *self { + self.foo(); + } + } +} +"#, + r#" +impl VariantData { + pub fn foo(&self) { + match *self { + VariantData::Struct(..) => { + self.foo(); + } + _ => (), + } + } +} +"#, + ) + } + + #[test] + fn test_if_let_with_match_available_range_left() { + check_assist_not_applicable( + replace_if_let_with_match, + r#" +impl VariantData { + pub fn foo(&self) { + $0 if let VariantData::Struct(..) = *self { + self.foo(); + } + } +} +"#, + ) + } + + #[test] + fn test_if_let_with_match_available_range_right() { + check_assist_not_applicable( + replace_if_let_with_match, + r#" +impl VariantData { + pub fn foo(&self) { + if let VariantData::Struct(..) = *self {$0 + self.foo(); + } + } +} +"#, + ) + } + + #[test] + fn test_if_let_with_match_let_chain() { + check_assist_not_applicable( + replace_if_let_with_match, + r#" +fn main() { + if $0let true = true && let Some(1) = None {} +} +"#, + ) + } + + #[test] + fn test_if_let_with_match_basic() { + check_assist( + replace_if_let_with_match, + r#" +impl VariantData { + pub fn is_struct(&self) -> bool { + if $0let VariantData::Struct(..) = *self { + true + } else if let VariantData::Tuple(..) = *self { + false + } else if cond() { + true + } else { + bar( + 123 + ) + } + } +} +"#, + r#" +impl VariantData { + pub fn is_struct(&self) -> bool { + match *self { + VariantData::Struct(..) => true, + VariantData::Tuple(..) => false, + _ if cond() => true, + _ => { + bar( + 123 + ) + } + } + } +} +"#, + ) + } + + #[test] + fn test_if_let_with_match_on_tail_if_let() { + check_assist( + replace_if_let_with_match, + r#" +impl VariantData { + pub fn is_struct(&self) -> bool { + if let VariantData::Struct(..) = *self { + true + } else if let$0 VariantData::Tuple(..) = *self { + false + } else { + false + } + } +} +"#, + r#" +impl VariantData { + pub fn is_struct(&self) -> bool { + if let VariantData::Struct(..) = *self { + true + } else { + match *self { + VariantData::Tuple(..) => false, + _ => false, + } +} + } +} +"#, + ) + } + + #[test] + fn special_case_option() { + check_assist( + replace_if_let_with_match, + r#" +//- minicore: option +fn foo(x: Option) { + $0if let Some(x) = x { + println!("{}", x) + } else { + println!("none") + } +} +"#, + r#" +fn foo(x: Option) { + match x { + Some(x) => println!("{}", x), + None => println!("none"), + } +} +"#, + ); + } + + #[test] + fn special_case_inverted_option() { + check_assist( + replace_if_let_with_match, + r#" +//- minicore: option +fn foo(x: Option) { + $0if let None = x { + println!("none") + } else { + println!("some") + } +} +"#, + r#" +fn foo(x: Option) { + match x { + None => println!("none"), + Some(_) => println!("some"), + } +} +"#, + ); + } + + #[test] + fn special_case_result() { + check_assist( + replace_if_let_with_match, + r#" +//- minicore: result +fn foo(x: Result) { + $0if let Ok(x) = x { + println!("{}", x) + } else { + println!("none") + } +} +"#, + r#" +fn foo(x: Result) { + match x { + Ok(x) => println!("{}", x), + Err(_) => println!("none"), + } +} +"#, + ); + } + + #[test] + fn special_case_inverted_result() { + check_assist( + replace_if_let_with_match, + r#" +//- minicore: result +fn foo(x: Result) { + $0if let Err(x) = x { + println!("{}", x) + } else { + println!("ok") + } +} +"#, + r#" +fn foo(x: Result) { + match x { + Err(x) => println!("{}", x), + Ok(_) => println!("ok"), + } +} +"#, + ); + } + + #[test] + fn nested_indent() { + check_assist( + replace_if_let_with_match, + r#" +fn main() { + if true { + $0if let Ok(rel_path) = path.strip_prefix(root_path) { + let rel_path = RelativePathBuf::from_path(rel_path).ok()?; + Some((*id, rel_path)) + } else { + None + } + } +} +"#, + r#" +fn main() { + if true { + match path.strip_prefix(root_path) { + Ok(rel_path) => { + let rel_path = RelativePathBuf::from_path(rel_path).ok()?; + Some((*id, rel_path)) + } + _ => None, + } + } +} +"#, + ) + } + + #[test] + fn nested_type() { + check_assist( + replace_if_let_with_match, + r#" +//- minicore: result +fn foo(x: Result) { + let bar: Result<_, ()> = Ok(Some(1)); + $0if let Ok(Some(_)) = bar { + () + } else { + () + } +} +"#, + r#" +fn foo(x: Result) { + let bar: Result<_, ()> = Ok(Some(1)); + match bar { + Ok(Some(_)) => (), + _ => (), + } +} +"#, + ); + } + + #[test] + fn test_replace_match_with_if_let_unwraps_simple_expressions() { + check_assist( + replace_match_with_if_let, + r#" +impl VariantData { + pub fn is_struct(&self) -> bool { + $0match *self { + VariantData::Struct(..) => true, + _ => false, + } + } +} "#, + r#" +impl VariantData { + pub fn is_struct(&self) -> bool { + if let VariantData::Struct(..) = *self { + true + } else { + false + } + } +} "#, + ) + } + + #[test] + fn test_replace_match_with_if_let_doesnt_unwrap_multiline_expressions() { + check_assist( + replace_match_with_if_let, + r#" +fn foo() { + $0match a { + VariantData::Struct(..) => { + bar( + 123 + ) + } + _ => false, + } +} "#, + r#" +fn foo() { + if let VariantData::Struct(..) = a { + bar( + 123 + ) + } else { + false + } +} "#, + ) + } + + #[test] + fn replace_match_with_if_let_target() { + check_assist_target( + replace_match_with_if_let, + r#" +impl VariantData { + pub fn is_struct(&self) -> bool { + $0match *self { + VariantData::Struct(..) => true, + _ => false, + } + } +} "#, + r#"match *self { + VariantData::Struct(..) => true, + _ => false, + }"#, + ); + } + + #[test] + fn special_case_option_match_to_if_let() { + check_assist( + replace_match_with_if_let, + r#" +//- minicore: option +fn foo(x: Option) { + $0match x { + Some(x) => println!("{}", x), + None => println!("none"), + } +} +"#, + r#" +fn foo(x: Option) { + if let Some(x) = x { + println!("{}", x) + } else { + println!("none") + } +} +"#, + ); + } + + #[test] + fn special_case_result_match_to_if_let() { + check_assist( + replace_match_with_if_let, + r#" +//- minicore: result +fn foo(x: Result) { + $0match x { + Ok(x) => println!("{}", x), + Err(_) => println!("none"), + } +} +"#, + r#" +fn foo(x: Result) { + if let Ok(x) = x { + println!("{}", x) + } else { + println!("none") + } +} +"#, + ); + } + + #[test] + fn nested_indent_match_to_if_let() { + check_assist( + replace_match_with_if_let, + r#" +fn main() { + if true { + $0match path.strip_prefix(root_path) { + Ok(rel_path) => { + let rel_path = RelativePathBuf::from_path(rel_path).ok()?; + Some((*id, rel_path)) + } + _ => None, + } + } +} +"#, + r#" +fn main() { + if true { + if let Ok(rel_path) = path.strip_prefix(root_path) { + let rel_path = RelativePathBuf::from_path(rel_path).ok()?; + Some((*id, rel_path)) + } else { + None + } + } +} +"#, + ) + } + + #[test] + fn replace_match_with_if_let_empty_wildcard_expr() { + check_assist( + replace_match_with_if_let, + r#" +fn main() { + $0match path.strip_prefix(root_path) { + Ok(rel_path) => println!("{}", rel_path), + _ => (), + } +} +"#, + r#" +fn main() { + if let Ok(rel_path) = path.strip_prefix(root_path) { + println!("{}", rel_path) + } +} +"#, + ) + } + + #[test] + fn replace_match_with_if_let_number_body() { + check_assist( + replace_match_with_if_let, + r#" +fn main() { + $0match Ok(()) { + Ok(()) => {}, + Err(_) => 0, + } +} +"#, + r#" +fn main() { + if let Err(_) = Ok(()) { + 0 + } +} +"#, + ) + } + + #[test] + fn replace_match_with_if_let_exhaustive() { + check_assist( + replace_match_with_if_let, + r#" +fn print_source(def_source: ModuleSource) { + match def_so$0urce { + ModuleSource::SourceFile(..) => { println!("source file"); } + ModuleSource::Module(..) => { println!("module"); } + } +} +"#, + r#" +fn print_source(def_source: ModuleSource) { + if let ModuleSource::SourceFile(..) = def_source { println!("source file"); } else { println!("module"); } +} +"#, + ) + } + + #[test] + fn replace_match_with_if_let_prefer_name_bind() { + check_assist( + replace_match_with_if_let, + r#" +fn foo() { + match $0Foo(0) { + Foo(_) => (), + Bar(bar) => println!("bar {}", bar), + } +} +"#, + r#" +fn foo() { + if let Bar(bar) = Foo(0) { + println!("bar {}", bar) + } +} +"#, + ); + check_assist( + replace_match_with_if_let, + r#" +fn foo() { + match $0Foo(0) { + Bar(bar) => println!("bar {}", bar), + Foo(_) => (), + } +} +"#, + r#" +fn foo() { + if let Bar(bar) = Foo(0) { + println!("bar {}", bar) + } +} +"#, + ); + } + + #[test] + fn replace_match_with_if_let_prefer_nonempty_body() { + check_assist( + replace_match_with_if_let, + r#" +fn foo() { + match $0Ok(0) { + Ok(value) => {}, + Err(err) => eprintln!("{}", err), + } +} +"#, + r#" +fn foo() { + if let Err(err) = Ok(0) { + eprintln!("{}", err) + } +} +"#, + ); + check_assist( + replace_match_with_if_let, + r#" +fn foo() { + match $0Ok(0) { + Err(err) => eprintln!("{}", err), + Ok(value) => {}, + } +} +"#, + r#" +fn foo() { + if let Err(err) = Ok(0) { + eprintln!("{}", err) + } +} +"#, + ); + } + + #[test] + fn replace_match_with_if_let_rejects_double_name_bindings() { + check_assist_not_applicable( + replace_match_with_if_let, + r#" +fn foo() { + match $0Foo(0) { + Foo(foo) => println!("bar {}", foo), + Bar(bar) => println!("bar {}", bar), + } +} +"#, + ); + } + + #[test] + fn test_replace_match_with_if_let_keeps_unsafe_block() { + check_assist( + replace_match_with_if_let, + r#" +impl VariantData { + pub fn is_struct(&self) -> bool { + $0match *self { + VariantData::Struct(..) => true, + _ => unsafe { unreachable_unchecked() }, + } + } +} "#, + r#" +impl VariantData { + pub fn is_struct(&self) -> bool { + if let VariantData::Struct(..) = *self { + true + } else { + unsafe { unreachable_unchecked() } + } + } +} "#, + ) + } + + #[test] + fn test_replace_match_with_if_let_forces_else() { + check_assist( + replace_match_with_if_let, + r#" +fn main() { + match$0 0 { + 0 => (), + _ => code(), + } +} +"#, + r#" +fn main() { + if let 0 = 0 { + () + } else { + code() + } +} +"#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_let_with_if_let.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_let_with_if_let.rs new file mode 100644 index 000000000..c2be4593b --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_let_with_if_let.rs @@ -0,0 +1,100 @@ +use std::iter::once; + +use ide_db::ty_filter::TryEnum; +use syntax::{ + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + make, + }, + AstNode, T, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: replace_let_with_if_let +// +// Replaces `let` with an `if let`. +// +// ``` +// # enum Option { Some(T), None } +// +// fn main(action: Action) { +// $0let x = compute(); +// } +// +// fn compute() -> Option { None } +// ``` +// -> +// ``` +// # enum Option { Some(T), None } +// +// fn main(action: Action) { +// if let Some(x) = compute() { +// } +// } +// +// fn compute() -> Option { None } +// ``` +pub(crate) fn replace_let_with_if_let(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let let_kw = ctx.find_token_syntax_at_offset(T![let])?; + let let_stmt = let_kw.parent().and_then(ast::LetStmt::cast)?; + let init = let_stmt.initializer()?; + let original_pat = let_stmt.pat()?; + + let target = let_kw.text_range(); + acc.add( + AssistId("replace_let_with_if_let", AssistKind::RefactorRewrite), + "Replace let with if let", + target, + |edit| { + let ty = ctx.sema.type_of_expr(&init); + let happy_variant = ty + .and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty.adjusted())) + .map(|it| it.happy_case()); + let pat = match happy_variant { + None => original_pat, + Some(var_name) => { + make::tuple_struct_pat(make::ext::ident_path(var_name), once(original_pat)) + .into() + } + }; + + let block = + make::ext::empty_block_expr().indent(IndentLevel::from_node(let_stmt.syntax())); + let if_ = make::expr_if(make::expr_let(pat, init).into(), block, None); + let stmt = make::expr_stmt(if_); + + edit.replace_ast(ast::Stmt::from(let_stmt), ast::Stmt::from(stmt)); + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::check_assist; + + use super::*; + + #[test] + fn replace_let_unknown_enum() { + check_assist( + replace_let_with_if_let, + r" +enum E { X(T), Y(T) } + +fn main() { + $0let x = E::X(92); +} + ", + r" +enum E { X(T), Y(T) } + +fn main() { + if let x = E::X(92) { + } +} + ", + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_qualified_name_with_use.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_qualified_name_with_use.rs new file mode 100644 index 000000000..2419fa11c --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_qualified_name_with_use.rs @@ -0,0 +1,438 @@ +use hir::AsAssocItem; +use ide_db::{ + helpers::mod_path_to_ast, + imports::insert_use::{insert_use, ImportScope}, +}; +use syntax::{ + ast::{self, make}, + match_ast, ted, AstNode, SyntaxNode, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: replace_qualified_name_with_use +// +// Adds a use statement for a given fully-qualified name. +// +// ``` +// # mod std { pub mod collections { pub struct HashMap(T, U); } } +// fn process(map: std::collections::$0HashMap) {} +// ``` +// -> +// ``` +// use std::collections::HashMap; +// +// # mod std { pub mod collections { pub struct HashMap(T, U); } } +// fn process(map: HashMap) {} +// ``` +pub(crate) fn replace_qualified_name_with_use( + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { + let path: ast::Path = ctx.find_node_at_offset()?; + // We don't want to mess with use statements + if path.syntax().ancestors().find_map(ast::UseTree::cast).is_some() { + cov_mark::hit!(not_applicable_in_use); + return None; + } + + if path.qualifier().is_none() { + cov_mark::hit!(dont_import_trivial_paths); + return None; + } + + // only offer replacement for non assoc items + match ctx.sema.resolve_path(&path)? { + hir::PathResolution::Def(def) if def.as_assoc_item(ctx.sema.db).is_none() => (), + _ => return None, + } + // then search for an import for the first path segment of what we want to replace + // that way it is less likely that we import the item from a different location due re-exports + let module = match ctx.sema.resolve_path(&path.first_qualifier_or_self())? { + hir::PathResolution::Def(module @ hir::ModuleDef::Module(_)) => module, + _ => return None, + }; + + let starts_with_name_ref = !matches!( + path.first_segment().and_then(|it| it.kind()), + Some( + ast::PathSegmentKind::CrateKw + | ast::PathSegmentKind::SuperKw + | ast::PathSegmentKind::SelfKw + ) + ); + let path_to_qualifier = starts_with_name_ref + .then(|| { + ctx.sema.scope(path.syntax())?.module().find_use_path_prefixed( + ctx.sema.db, + module, + ctx.config.insert_use.prefix_kind, + ) + }) + .flatten(); + + let scope = ImportScope::find_insert_use_container(path.syntax(), &ctx.sema)?; + let target = path.syntax().text_range(); + acc.add( + AssistId("replace_qualified_name_with_use", AssistKind::RefactorRewrite), + "Replace qualified path with use", + target, + |builder| { + // Now that we've brought the name into scope, re-qualify all paths that could be + // affected (that is, all paths inside the node we added the `use` to). + let scope = match scope { + ImportScope::File(it) => ImportScope::File(builder.make_mut(it)), + ImportScope::Module(it) => ImportScope::Module(builder.make_mut(it)), + ImportScope::Block(it) => ImportScope::Block(builder.make_mut(it)), + }; + shorten_paths(scope.as_syntax_node(), &path); + let path = drop_generic_args(&path); + // stick the found import in front of the to be replaced path + let path = match path_to_qualifier.and_then(|it| mod_path_to_ast(&it).qualifier()) { + Some(qualifier) => make::path_concat(qualifier, path), + None => path, + }; + insert_use(&scope, path, &ctx.config.insert_use); + }, + ) +} + +fn drop_generic_args(path: &ast::Path) -> ast::Path { + let path = path.clone_for_update(); + if let Some(segment) = path.segment() { + if let Some(generic_args) = segment.generic_arg_list() { + ted::remove(generic_args.syntax()); + } + } + path +} + +/// Mutates `node` to shorten `path` in all descendants of `node`. +fn shorten_paths(node: &SyntaxNode, path: &ast::Path) { + for child in node.children() { + match_ast! { + match child { + // Don't modify `use` items, as this can break the `use` item when injecting a new + // import into the use tree. + ast::Use(_) => continue, + // Don't descend into submodules, they don't have the same `use` items in scope. + // FIXME: This isn't true due to `super::*` imports? + ast::Module(_) => continue, + ast::Path(p) => if maybe_replace_path(p.clone(), path.clone()).is_none() { + shorten_paths(p.syntax(), path); + }, + _ => shorten_paths(&child, path), + } + } + } +} + +fn maybe_replace_path(path: ast::Path, target: ast::Path) -> Option<()> { + if !path_eq_no_generics(path.clone(), target) { + return None; + } + + // Shorten `path`, leaving only its last segment. + if let Some(parent) = path.qualifier() { + ted::remove(parent.syntax()); + } + if let Some(double_colon) = path.coloncolon_token() { + ted::remove(&double_colon); + } + + Some(()) +} + +fn path_eq_no_generics(lhs: ast::Path, rhs: ast::Path) -> bool { + let mut lhs_curr = lhs; + let mut rhs_curr = rhs; + loop { + match lhs_curr.segment().zip(rhs_curr.segment()) { + Some((lhs, rhs)) + if lhs.coloncolon_token().is_some() == rhs.coloncolon_token().is_some() + && lhs + .name_ref() + .zip(rhs.name_ref()) + .map_or(false, |(lhs, rhs)| lhs.text() == rhs.text()) => {} + _ => return false, + } + + match (lhs_curr.qualifier(), rhs_curr.qualifier()) { + (Some(lhs), Some(rhs)) => { + lhs_curr = lhs; + rhs_curr = rhs; + } + (None, None) => return true, + _ => return false, + } + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_replace_already_imported() { + check_assist( + replace_qualified_name_with_use, + r" +mod std { pub mod fs { pub struct Path; } } +use std::fs; + +fn main() { + std::f$0s::Path +}", + r" +mod std { pub mod fs { pub struct Path; } } +use std::fs; + +fn main() { + fs::Path +}", + ) + } + + #[test] + fn test_replace_add_use_no_anchor() { + check_assist( + replace_qualified_name_with_use, + r" +mod std { pub mod fs { pub struct Path; } } +std::fs::Path$0 + ", + r" +use std::fs::Path; + +mod std { pub mod fs { pub struct Path; } } +Path + ", + ); + } + + #[test] + fn test_replace_add_use_no_anchor_middle_segment() { + check_assist( + replace_qualified_name_with_use, + r" +mod std { pub mod fs { pub struct Path; } } +std::fs$0::Path + ", + r" +use std::fs; + +mod std { pub mod fs { pub struct Path; } } +fs::Path + ", + ); + } + + #[test] + fn dont_import_trivial_paths() { + cov_mark::check!(dont_import_trivial_paths); + check_assist_not_applicable(replace_qualified_name_with_use, r"impl foo$0 for () {}"); + } + + #[test] + fn test_replace_not_applicable_in_use() { + cov_mark::check!(not_applicable_in_use); + check_assist_not_applicable(replace_qualified_name_with_use, r"use std::fmt$0;"); + } + + #[test] + fn replaces_all_affected_paths() { + check_assist( + replace_qualified_name_with_use, + r" +mod std { pub mod fmt { pub trait Debug {} } } +fn main() { + std::fmt::Debug$0; + let x: std::fmt::Debug = std::fmt::Debug; +} + ", + r" +use std::fmt::Debug; + +mod std { pub mod fmt { pub trait Debug {} } } +fn main() { + Debug; + let x: Debug = Debug; +} + ", + ); + } + + #[test] + fn does_not_replace_in_submodules() { + check_assist( + replace_qualified_name_with_use, + r" +mod std { pub mod fmt { pub trait Debug {} } } +fn main() { + std::fmt::Debug$0; +} + +mod sub { + fn f() { + std::fmt::Debug; + } +} + ", + r" +use std::fmt::Debug; + +mod std { pub mod fmt { pub trait Debug {} } } +fn main() { + Debug; +} + +mod sub { + fn f() { + std::fmt::Debug; + } +} + ", + ); + } + + #[test] + fn does_not_replace_in_use() { + check_assist( + replace_qualified_name_with_use, + r" +mod std { pub mod fmt { pub trait Display {} } } +use std::fmt::Display; + +fn main() { + std::fmt$0; +} + ", + r" +mod std { pub mod fmt { pub trait Display {} } } +use std::fmt::{Display, self}; + +fn main() { + fmt; +} + ", + ); + } + + #[test] + fn does_not_replace_assoc_item_path() { + check_assist_not_applicable( + replace_qualified_name_with_use, + r" +pub struct Foo; +impl Foo { + pub fn foo() {} +} + +fn main() { + Foo::foo$0(); +} +", + ); + } + + #[test] + fn replace_reuses_path_qualifier() { + check_assist( + replace_qualified_name_with_use, + r" +pub mod foo { + pub struct Foo; +} + +mod bar { + pub use super::foo::Foo as Bar; +} + +fn main() { + foo::Foo$0; +} +", + r" +use foo::Foo; + +pub mod foo { + pub struct Foo; +} + +mod bar { + pub use super::foo::Foo as Bar; +} + +fn main() { + Foo; +} +", + ); + } + + #[test] + fn replace_does_not_always_try_to_replace_by_full_item_path() { + check_assist( + replace_qualified_name_with_use, + r" +use std::mem; + +mod std { + pub mod mem { + pub fn drop(_: T) {} + } +} + +fn main() { + mem::drop$0(0); +} +", + r" +use std::mem::{self, drop}; + +mod std { + pub mod mem { + pub fn drop(_: T) {} + } +} + +fn main() { + drop(0); +} +", + ); + } + + #[test] + fn replace_should_drop_generic_args_in_use() { + check_assist( + replace_qualified_name_with_use, + r" +mod std { + pub mod mem { + pub fn drop(_: T) {} + } +} + +fn main() { + std::mem::drop::$0(0); +} +", + r" +use std::mem::drop; + +mod std { + pub mod mem { + pub fn drop(_: T) {} + } +} + +fn main() { + drop::(0); +} +", + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_string_with_char.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_string_with_char.rs new file mode 100644 index 000000000..decb5fb62 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_string_with_char.rs @@ -0,0 +1,307 @@ +use syntax::{ + ast, + ast::IsString, + AstToken, + SyntaxKind::{CHAR, STRING}, + TextRange, TextSize, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: replace_string_with_char +// +// Replace string literal with char literal. +// +// ``` +// fn main() { +// find("{$0"); +// } +// ``` +// -> +// ``` +// fn main() { +// find('{'); +// } +// ``` +pub(crate) fn replace_string_with_char(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let token = ctx.find_token_syntax_at_offset(STRING).and_then(ast::String::cast)?; + let value = token.value()?; + let target = token.syntax().text_range(); + + if value.chars().take(2).count() != 1 { + return None; + } + let quote_offets = token.quote_offsets()?; + + acc.add( + AssistId("replace_string_with_char", AssistKind::RefactorRewrite), + "Replace string with char", + target, + |edit| { + let (left, right) = quote_offets.quotes; + edit.replace(left, '\''); + edit.replace(right, '\''); + if value == "'" { + edit.insert(left.end(), '\\'); + } + }, + ) +} + +// Assist: replace_char_with_string +// +// Replace a char literal with a string literal. +// +// ``` +// fn main() { +// find('{$0'); +// } +// ``` +// -> +// ``` +// fn main() { +// find("{"); +// } +// ``` +pub(crate) fn replace_char_with_string(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let token = ctx.find_token_syntax_at_offset(CHAR)?; + let target = token.text_range(); + + acc.add( + AssistId("replace_char_with_string", AssistKind::RefactorRewrite), + "Replace char with string", + target, + |edit| { + if token.text() == "'\"'" { + edit.replace(token.text_range(), r#""\"""#); + } else { + let len = TextSize::of('\''); + edit.replace(TextRange::at(target.start(), len), '"'); + edit.replace(TextRange::at(target.end() - len, len), '"'); + } + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn replace_string_with_char_assist() { + check_assist( + replace_string_with_char, + r#" +fn f() { + let s = "$0c"; +} +"#, + r##" +fn f() { + let s = 'c'; +} +"##, + ) + } + + #[test] + fn replace_string_with_char_assist_with_multi_byte_char() { + check_assist( + replace_string_with_char, + r#" +fn f() { + let s = "$0😀"; +} +"#, + r##" +fn f() { + let s = '😀'; +} +"##, + ) + } + + #[test] + fn replace_string_with_char_multiple_chars() { + check_assist_not_applicable( + replace_string_with_char, + r#" +fn f() { + let s = "$0test"; +} +"#, + ) + } + + #[test] + fn replace_string_with_char_works_inside_macros() { + check_assist( + replace_string_with_char, + r#" +fn f() { + format!($0"x", 92) +} +"#, + r##" +fn f() { + format!('x', 92) +} +"##, + ) + } + + #[test] + fn replace_string_with_char_newline() { + check_assist( + replace_string_with_char, + r#" +fn f() { + find($0"\n"); +} +"#, + r##" +fn f() { + find('\n'); +} +"##, + ) + } + + #[test] + fn replace_string_with_char_unicode_escape() { + check_assist( + replace_string_with_char, + r#" +fn f() { + find($0"\u{7FFF}"); +} +"#, + r##" +fn f() { + find('\u{7FFF}'); +} +"##, + ) + } + + #[test] + fn replace_raw_string_with_char() { + check_assist( + replace_string_with_char, + r##" +fn f() { + $0r#"X"# +} +"##, + r##" +fn f() { + 'X' +} +"##, + ) + } + + #[test] + fn replace_char_with_string_assist() { + check_assist( + replace_char_with_string, + r" +fn f() { + let s = '$0c'; +} +", + r#" +fn f() { + let s = "c"; +} +"#, + ) + } + + #[test] + fn replace_char_with_string_assist_with_multi_byte_char() { + check_assist( + replace_char_with_string, + r" +fn f() { + let s = '$0😀'; +} +", + r#" +fn f() { + let s = "😀"; +} +"#, + ) + } + + #[test] + fn replace_char_with_string_newline() { + check_assist( + replace_char_with_string, + r" +fn f() { + find($0'\n'); +} +", + r#" +fn f() { + find("\n"); +} +"#, + ) + } + + #[test] + fn replace_char_with_string_unicode_escape() { + check_assist( + replace_char_with_string, + r" +fn f() { + find($0'\u{7FFF}'); +} +", + r#" +fn f() { + find("\u{7FFF}"); +} +"#, + ) + } + + #[test] + fn replace_char_with_string_quote() { + check_assist( + replace_char_with_string, + r#" +fn f() { + find($0'"'); +} +"#, + r#" +fn f() { + find("\""); +} +"#, + ) + } + + #[test] + fn replace_string_with_char_quote() { + check_assist( + replace_string_with_char, + r#" +fn f() { + find($0"'"); +} +"#, + r#" +fn f() { + find('\''); +} +"#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_try_expr_with_match.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_try_expr_with_match.rs new file mode 100644 index 000000000..38fccb338 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_try_expr_with_match.rs @@ -0,0 +1,150 @@ +use std::iter; + +use ide_db::{ + assists::{AssistId, AssistKind}, + ty_filter::TryEnum, +}; +use syntax::{ + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + make, + }, + AstNode, T, +}; + +use crate::assist_context::{AssistContext, Assists}; + +// Assist: replace_try_expr_with_match +// +// Replaces a `try` expression with a `match` expression. +// +// ``` +// # //- minicore:option +// fn handle() { +// let pat = Some(true)$0?; +// } +// ``` +// -> +// ``` +// fn handle() { +// let pat = match Some(true) { +// Some(it) => it, +// None => return None, +// }; +// } +// ``` +pub(crate) fn replace_try_expr_with_match( + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { + let qm_kw = ctx.find_token_syntax_at_offset(T![?])?; + let qm_kw_parent = qm_kw.parent().and_then(ast::TryExpr::cast)?; + + let expr = qm_kw_parent.expr()?; + let expr_type_info = ctx.sema.type_of_expr(&expr)?; + + let try_enum = TryEnum::from_ty(&ctx.sema, &expr_type_info.original)?; + + let target = qm_kw_parent.syntax().text_range(); + acc.add( + AssistId("replace_try_expr_with_match", AssistKind::RefactorRewrite), + "Replace try expression with match", + target, + |edit| { + let sad_pat = match try_enum { + TryEnum::Option => make::path_pat(make::ext::ident_path("None")), + TryEnum::Result => make::tuple_struct_pat( + make::ext::ident_path("Err"), + iter::once(make::path_pat(make::ext::ident_path("err"))), + ) + .into(), + }; + let sad_expr = match try_enum { + TryEnum::Option => { + make::expr_return(Some(make::expr_path(make::ext::ident_path("None")))) + } + TryEnum::Result => make::expr_return(Some(make::expr_call( + make::expr_path(make::ext::ident_path("Err")), + make::arg_list(iter::once(make::expr_path(make::ext::ident_path("err")))), + ))), + }; + + let happy_arm = make::match_arm( + iter::once( + try_enum.happy_pattern(make::ident_pat(false, false, make::name("it")).into()), + ), + None, + make::expr_path(make::ext::ident_path("it")), + ); + let sad_arm = make::match_arm(iter::once(sad_pat), None, sad_expr); + + let match_arm_list = make::match_arm_list([happy_arm, sad_arm]); + + let expr_match = make::expr_match(expr, match_arm_list) + .indent(IndentLevel::from_node(qm_kw_parent.syntax())); + edit.replace_ast::(qm_kw_parent.into(), expr_match); + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable}; + + #[test] + fn test_replace_try_expr_with_match_not_applicable() { + check_assist_not_applicable( + replace_try_expr_with_match, + r#" + fn test() { + let pat: u32 = 25$0; + } + "#, + ); + } + + #[test] + fn test_replace_try_expr_with_match_option() { + check_assist( + replace_try_expr_with_match, + r#" +//- minicore:option +fn test() { + let pat = Some(true)$0?; +} + "#, + r#" +fn test() { + let pat = match Some(true) { + Some(it) => it, + None => return None, + }; +} + "#, + ); + } + + #[test] + fn test_replace_try_expr_with_match_result() { + check_assist( + replace_try_expr_with_match, + r#" +//- minicore:result +fn test() { + let pat = Ok(true)$0?; +} + "#, + r#" +fn test() { + let pat = match Ok(true) { + Ok(it) => it, + Err(err) => return Err(err), + }; +} + "#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_turbofish_with_explicit_type.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_turbofish_with_explicit_type.rs new file mode 100644 index 000000000..6112e0945 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_turbofish_with_explicit_type.rs @@ -0,0 +1,243 @@ +use syntax::{ + ast::{Expr, GenericArg}, + ast::{LetStmt, Type::InferType}, + AstNode, TextRange, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, AssistKind, +}; + +// Assist: replace_turbofish_with_explicit_type +// +// Converts `::<_>` to an explicit type assignment. +// +// ``` +// fn make() -> T { ) } +// fn main() { +// let a = make$0::(); +// } +// ``` +// -> +// ``` +// fn make() -> T { ) } +// fn main() { +// let a: i32 = make(); +// } +// ``` +pub(crate) fn replace_turbofish_with_explicit_type( + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { + let let_stmt = ctx.find_node_at_offset::()?; + + let initializer = let_stmt.initializer()?; + + let generic_args = match &initializer { + Expr::MethodCallExpr(ce) => ce.generic_arg_list()?, + Expr::CallExpr(ce) => { + if let Expr::PathExpr(pe) = ce.expr()? { + pe.path()?.segment()?.generic_arg_list()? + } else { + cov_mark::hit!(not_applicable_if_non_path_function_call); + return None; + } + } + _ => { + cov_mark::hit!(not_applicable_if_non_function_call_initializer); + return None; + } + }; + + // Find range of ::<_> + let colon2 = generic_args.coloncolon_token()?; + let r_angle = generic_args.r_angle_token()?; + let turbofish_range = TextRange::new(colon2.text_range().start(), r_angle.text_range().end()); + + let turbofish_args: Vec = generic_args.generic_args().into_iter().collect(); + + // Find type of ::<_> + if turbofish_args.len() != 1 { + cov_mark::hit!(not_applicable_if_not_single_arg); + return None; + } + + // An improvement would be to check that this is correctly part of the return value of the + // function call, or sub in the actual return type. + let turbofish_type = &turbofish_args[0]; + + let initializer_start = initializer.syntax().text_range().start(); + if ctx.offset() > turbofish_range.end() || ctx.offset() < initializer_start { + cov_mark::hit!(not_applicable_outside_turbofish); + return None; + } + + if let None = let_stmt.colon_token() { + // If there's no colon in a let statement, then there is no explicit type. + // let x = fn::<...>(); + let ident_range = let_stmt.pat()?.syntax().text_range(); + + return acc.add( + AssistId("replace_turbofish_with_explicit_type", AssistKind::RefactorRewrite), + "Replace turbofish with explicit type", + TextRange::new(initializer_start, turbofish_range.end()), + |builder| { + builder.insert(ident_range.end(), format!(": {}", turbofish_type)); + builder.delete(turbofish_range); + }, + ); + } else if let Some(InferType(t)) = let_stmt.ty() { + // If there's a type inferrence underscore, we can offer to replace it with the type in + // the turbofish. + // let x: _ = fn::<...>(); + let underscore_range = t.syntax().text_range(); + + return acc.add( + AssistId("replace_turbofish_with_explicit_type", AssistKind::RefactorRewrite), + "Replace `_` with turbofish type", + turbofish_range, + |builder| { + builder.replace(underscore_range, turbofish_type.to_string()); + builder.delete(turbofish_range); + }, + ); + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + #[test] + fn replaces_turbofish_for_vec_string() { + check_assist( + replace_turbofish_with_explicit_type, + r#" +fn make() -> T {} +fn main() { + let a = make$0::>(); +} +"#, + r#" +fn make() -> T {} +fn main() { + let a: Vec = make(); +} +"#, + ); + } + + #[test] + fn replaces_method_calls() { + // foo.make() is a method call which uses a different expr in the let initializer + check_assist( + replace_turbofish_with_explicit_type, + r#" +fn make() -> T {} +fn main() { + let a = foo.make$0::>(); +} +"#, + r#" +fn make() -> T {} +fn main() { + let a: Vec = foo.make(); +} +"#, + ); + } + + #[test] + fn replace_turbofish_target() { + check_assist_target( + replace_turbofish_with_explicit_type, + r#" +fn make() -> T {} +fn main() { + let a = $0make::>(); +} +"#, + r#"make::>"#, + ); + } + + #[test] + fn not_applicable_outside_turbofish() { + cov_mark::check!(not_applicable_outside_turbofish); + check_assist_not_applicable( + replace_turbofish_with_explicit_type, + r#" +fn make() -> T {} +fn main() { + let $0a = make::>(); +} +"#, + ); + } + + #[test] + fn replace_inferred_type_placeholder() { + check_assist( + replace_turbofish_with_explicit_type, + r#" +fn make() -> T {} +fn main() { + let a: _ = make$0::>(); +} +"#, + r#" +fn make() -> T {} +fn main() { + let a: Vec = make(); +} +"#, + ); + } + + #[test] + fn not_applicable_constant_initializer() { + cov_mark::check!(not_applicable_if_non_function_call_initializer); + check_assist_not_applicable( + replace_turbofish_with_explicit_type, + r#" +fn make() -> T {} +fn main() { + let a = "foo"$0; +} +"#, + ); + } + + #[test] + fn not_applicable_non_path_function_call() { + cov_mark::check!(not_applicable_if_non_path_function_call); + check_assist_not_applicable( + replace_turbofish_with_explicit_type, + r#" +fn make() -> T {} +fn main() { + $0let a = (|| {})(); +} +"#, + ); + } + + #[test] + fn non_applicable_multiple_generic_args() { + cov_mark::check!(not_applicable_if_not_single_arg); + check_assist_not_applicable( + replace_turbofish_with_explicit_type, + r#" +fn make() -> T {} +fn main() { + let a = make$0::, i32>(); +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/sort_items.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/sort_items.rs new file mode 100644 index 000000000..a93704b39 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/sort_items.rs @@ -0,0 +1,588 @@ +use std::cmp::Ordering; + +use itertools::Itertools; + +use syntax::{ + ast::{self, HasName}, + ted, AstNode, TextRange, +}; + +use crate::{utils::get_methods, AssistContext, AssistId, AssistKind, Assists}; + +// Assist: sort_items +// +// Sorts item members alphabetically: fields, enum variants and methods. +// +// ``` +// struct $0Foo$0 { second: u32, first: String } +// ``` +// -> +// ``` +// struct Foo { first: String, second: u32 } +// ``` +// --- +// ``` +// trait $0Bar$0 { +// fn second(&self) -> u32; +// fn first(&self) -> String; +// } +// ``` +// -> +// ``` +// trait Bar { +// fn first(&self) -> String; +// fn second(&self) -> u32; +// } +// ``` +// --- +// ``` +// struct Baz; +// impl $0Baz$0 { +// fn second(&self) -> u32; +// fn first(&self) -> String; +// } +// ``` +// -> +// ``` +// struct Baz; +// impl Baz { +// fn first(&self) -> String; +// fn second(&self) -> u32; +// } +// ``` +// --- +// There is a difference between sorting enum variants: +// +// ``` +// enum $0Animal$0 { +// Dog(String, f64), +// Cat { weight: f64, name: String }, +// } +// ``` +// -> +// ``` +// enum Animal { +// Cat { weight: f64, name: String }, +// Dog(String, f64), +// } +// ``` +// and sorting a single enum struct variant: +// +// ``` +// enum Animal { +// Dog(String, f64), +// Cat $0{ weight: f64, name: String }$0, +// } +// ``` +// -> +// ``` +// enum Animal { +// Dog(String, f64), +// Cat { name: String, weight: f64 }, +// } +// ``` +pub(crate) fn sort_items(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + if ctx.has_empty_selection() { + cov_mark::hit!(not_applicable_if_no_selection); + return None; + } + + if let Some(trait_ast) = ctx.find_node_at_offset::() { + add_sort_methods_assist(acc, trait_ast.assoc_item_list()?) + } else if let Some(impl_ast) = ctx.find_node_at_offset::() { + add_sort_methods_assist(acc, impl_ast.assoc_item_list()?) + } else if let Some(struct_ast) = ctx.find_node_at_offset::() { + add_sort_field_list_assist(acc, struct_ast.field_list()) + } else if let Some(union_ast) = ctx.find_node_at_offset::() { + add_sort_fields_assist(acc, union_ast.record_field_list()?) + } else if let Some(variant_ast) = ctx.find_node_at_offset::() { + add_sort_field_list_assist(acc, variant_ast.field_list()) + } else if let Some(enum_struct_variant_ast) = ctx.find_node_at_offset::() + { + // should be above enum and below struct + add_sort_fields_assist(acc, enum_struct_variant_ast) + } else if let Some(enum_ast) = ctx.find_node_at_offset::() { + add_sort_variants_assist(acc, enum_ast.variant_list()?) + } else { + None + } +} + +trait AddRewrite { + fn add_rewrite( + &mut self, + label: &str, + old: Vec, + new: Vec, + target: TextRange, + ) -> Option<()>; +} + +impl AddRewrite for Assists { + fn add_rewrite( + &mut self, + label: &str, + old: Vec, + new: Vec, + target: TextRange, + ) -> Option<()> { + self.add(AssistId("sort_items", AssistKind::RefactorRewrite), label, target, |builder| { + let mutable: Vec = old.into_iter().map(|it| builder.make_mut(it)).collect(); + mutable + .into_iter() + .zip(new) + .for_each(|(old, new)| ted::replace(old.syntax(), new.clone_for_update().syntax())); + }) + } +} + +fn add_sort_field_list_assist(acc: &mut Assists, field_list: Option) -> Option<()> { + match field_list { + Some(ast::FieldList::RecordFieldList(it)) => add_sort_fields_assist(acc, it), + _ => { + cov_mark::hit!(not_applicable_if_sorted_or_empty_or_single); + None + } + } +} + +fn add_sort_methods_assist(acc: &mut Assists, item_list: ast::AssocItemList) -> Option<()> { + let methods = get_methods(&item_list); + let sorted = sort_by_name(&methods); + + if methods == sorted { + cov_mark::hit!(not_applicable_if_sorted_or_empty_or_single); + return None; + } + + acc.add_rewrite("Sort methods alphabetically", methods, sorted, item_list.syntax().text_range()) +} + +fn add_sort_fields_assist( + acc: &mut Assists, + record_field_list: ast::RecordFieldList, +) -> Option<()> { + let fields: Vec<_> = record_field_list.fields().collect(); + let sorted = sort_by_name(&fields); + + if fields == sorted { + cov_mark::hit!(not_applicable_if_sorted_or_empty_or_single); + return None; + } + + acc.add_rewrite( + "Sort fields alphabetically", + fields, + sorted, + record_field_list.syntax().text_range(), + ) +} + +fn add_sort_variants_assist(acc: &mut Assists, variant_list: ast::VariantList) -> Option<()> { + let variants: Vec<_> = variant_list.variants().collect(); + let sorted = sort_by_name(&variants); + + if variants == sorted { + cov_mark::hit!(not_applicable_if_sorted_or_empty_or_single); + return None; + } + + acc.add_rewrite( + "Sort variants alphabetically", + variants, + sorted, + variant_list.syntax().text_range(), + ) +} + +fn sort_by_name(initial: &[T]) -> Vec { + initial + .iter() + .cloned() + .sorted_by(|a, b| match (a.name(), b.name()) { + (Some(a), Some(b)) => Ord::cmp(&a.to_string(), &b.to_string()), + + // unexpected, but just in case + (None, None) => Ordering::Equal, + (None, Some(_)) => Ordering::Less, + (Some(_), None) => Ordering::Greater, + }) + .collect() +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn not_applicable_if_no_selection() { + cov_mark::check!(not_applicable_if_no_selection); + + check_assist_not_applicable( + sort_items, + r#" +t$0rait Bar { + fn b(); + fn a(); +} + "#, + ) + } + + #[test] + fn not_applicable_if_trait_empty() { + cov_mark::check!(not_applicable_if_sorted_or_empty_or_single); + + check_assist_not_applicable( + sort_items, + r#" +t$0rait Bar$0 { +} + "#, + ) + } + + #[test] + fn not_applicable_if_impl_empty() { + cov_mark::check!(not_applicable_if_sorted_or_empty_or_single); + + check_assist_not_applicable( + sort_items, + r#" +struct Bar; +$0impl Bar$0 { +} + "#, + ) + } + + #[test] + fn not_applicable_if_struct_empty() { + cov_mark::check!(not_applicable_if_sorted_or_empty_or_single); + + check_assist_not_applicable( + sort_items, + r#" +$0struct Bar$0 ; + "#, + ) + } + + #[test] + fn not_applicable_if_struct_empty2() { + cov_mark::check!(not_applicable_if_sorted_or_empty_or_single); + + check_assist_not_applicable( + sort_items, + r#" +$0struct Bar$0 { }; + "#, + ) + } + + #[test] + fn not_applicable_if_enum_empty() { + cov_mark::check!(not_applicable_if_sorted_or_empty_or_single); + + check_assist_not_applicable( + sort_items, + r#" +$0enum ZeroVariants$0 {}; + "#, + ) + } + + #[test] + fn not_applicable_if_trait_sorted() { + cov_mark::check!(not_applicable_if_sorted_or_empty_or_single); + + check_assist_not_applicable( + sort_items, + r#" +t$0rait Bar$0 { + fn a() {} + fn b() {} + fn c() {} +} + "#, + ) + } + + #[test] + fn not_applicable_if_impl_sorted() { + cov_mark::check!(not_applicable_if_sorted_or_empty_or_single); + + check_assist_not_applicable( + sort_items, + r#" +struct Bar; +$0impl Bar$0 { + fn a() {} + fn b() {} + fn c() {} +} + "#, + ) + } + + #[test] + fn not_applicable_if_struct_sorted() { + cov_mark::check!(not_applicable_if_sorted_or_empty_or_single); + + check_assist_not_applicable( + sort_items, + r#" +$0struct Bar$0 { + a: u32, + b: u8, + c: u64, +} + "#, + ) + } + + #[test] + fn not_applicable_if_union_sorted() { + cov_mark::check!(not_applicable_if_sorted_or_empty_or_single); + + check_assist_not_applicable( + sort_items, + r#" +$0union Bar$0 { + a: u32, + b: u8, + c: u64, +} + "#, + ) + } + + #[test] + fn not_applicable_if_enum_sorted() { + cov_mark::check!(not_applicable_if_sorted_or_empty_or_single); + + check_assist_not_applicable( + sort_items, + r#" +$0enum Bar$0 { + a, + b, + c, +} + "#, + ) + } + + #[test] + fn sort_trait() { + check_assist( + sort_items, + r#" +$0trait Bar$0 { + fn a() { + + } + + // comment for c + fn c() {} + fn z() {} + fn b() {} +} + "#, + r#" +trait Bar { + fn a() { + + } + + fn b() {} + // comment for c + fn c() {} + fn z() {} +} + "#, + ) + } + + #[test] + fn sort_impl() { + check_assist( + sort_items, + r#" +struct Bar; +$0impl Bar$0 { + fn c() {} + fn a() {} + /// long + /// doc + /// comment + fn z() {} + fn d() {} +} + "#, + r#" +struct Bar; +impl Bar { + fn a() {} + fn c() {} + fn d() {} + /// long + /// doc + /// comment + fn z() {} +} + "#, + ) + } + + #[test] + fn sort_struct() { + check_assist( + sort_items, + r#" +$0struct Bar$0 { + b: u8, + a: u32, + c: u64, +} + "#, + r#" +struct Bar { + a: u32, + b: u8, + c: u64, +} + "#, + ) + } + + #[test] + fn sort_generic_struct_with_lifetime() { + check_assist( + sort_items, + r#" +$0struct Bar<'a,$0 T> { + d: &'a str, + b: u8, + a: T, + c: u64, +} + "#, + r#" +struct Bar<'a, T> { + a: T, + b: u8, + c: u64, + d: &'a str, +} + "#, + ) + } + + #[test] + fn sort_struct_fields_diff_len() { + check_assist( + sort_items, + r#" +$0struct Bar $0{ + aaa: u8, + a: usize, + b: u8, +} + "#, + r#" +struct Bar { + a: usize, + aaa: u8, + b: u8, +} + "#, + ) + } + + #[test] + fn sort_union() { + check_assist( + sort_items, + r#" +$0union Bar$0 { + b: u8, + a: u32, + c: u64, +} + "#, + r#" +union Bar { + a: u32, + b: u8, + c: u64, +} + "#, + ) + } + + #[test] + fn sort_enum() { + check_assist( + sort_items, + r#" +$0enum Bar $0{ + d{ first: u32, second: usize}, + b = 14, + a, + c(u32, usize), +} + "#, + r#" +enum Bar { + a, + b = 14, + c(u32, usize), + d{ first: u32, second: usize}, +} + "#, + ) + } + + #[test] + fn sort_struct_enum_variant_fields() { + check_assist( + sort_items, + r#" +enum Bar { + d$0{ second: usize, first: u32 }$0, + b = 14, + a, + c(u32, usize), +} + "#, + r#" +enum Bar { + d{ first: u32, second: usize }, + b = 14, + a, + c(u32, usize), +} + "#, + ) + } + + #[test] + fn sort_struct_enum_variant() { + check_assist( + sort_items, + r#" +enum Bar { + $0d$0{ second: usize, first: u32 }, +} + "#, + r#" +enum Bar { + d{ first: u32, second: usize }, +} + "#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/split_import.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/split_import.rs new file mode 100644 index 000000000..775ededec --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/split_import.rs @@ -0,0 +1,82 @@ +use syntax::{ast, AstNode, T}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: split_import +// +// Wraps the tail of import into braces. +// +// ``` +// use std::$0collections::HashMap; +// ``` +// -> +// ``` +// use std::{collections::HashMap}; +// ``` +pub(crate) fn split_import(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let colon_colon = ctx.find_token_syntax_at_offset(T![::])?; + let path = ast::Path::cast(colon_colon.parent()?)?.qualifier()?; + + let use_tree = path.top_path().syntax().ancestors().find_map(ast::UseTree::cast)?; + + let has_errors = use_tree + .syntax() + .descendants_with_tokens() + .any(|it| it.kind() == syntax::SyntaxKind::ERROR); + let last_segment = use_tree.path().and_then(|it| it.segment()); + if has_errors || last_segment.is_none() { + return None; + } + + let target = colon_colon.text_range(); + acc.add(AssistId("split_import", AssistKind::RefactorRewrite), "Split import", target, |edit| { + let use_tree = edit.make_mut(use_tree.clone()); + let path = edit.make_mut(path); + use_tree.split_prefix(&path); + }) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + use super::*; + + #[test] + fn test_split_import() { + check_assist( + split_import, + "use crate::$0db::RootDatabase;", + "use crate::{db::RootDatabase};", + ) + } + + #[test] + fn split_import_works_with_trees() { + check_assist( + split_import, + "use crate:$0:db::{RootDatabase, FileSymbol}", + "use crate::{db::{RootDatabase, FileSymbol}}", + ) + } + + #[test] + fn split_import_target() { + check_assist_target(split_import, "use crate::$0db::{RootDatabase, FileSymbol}", "::"); + } + + #[test] + fn issue4044() { + check_assist_not_applicable(split_import, "use crate::$0:::self;") + } + + #[test] + fn test_empty_use() { + check_assist_not_applicable( + split_import, + r" +use std::$0 +fn main() {}", + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/toggle_ignore.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/toggle_ignore.rs new file mode 100644 index 000000000..b7d57f02b --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/toggle_ignore.rs @@ -0,0 +1,98 @@ +use syntax::{ + ast::{self, HasAttrs}, + AstNode, AstToken, +}; + +use crate::{utils::test_related_attribute, AssistContext, AssistId, AssistKind, Assists}; + +// Assist: toggle_ignore +// +// Adds `#[ignore]` attribute to the test. +// +// ``` +// $0#[test] +// fn arithmetics { +// assert_eq!(2 + 2, 5); +// } +// ``` +// -> +// ``` +// #[test] +// #[ignore] +// fn arithmetics { +// assert_eq!(2 + 2, 5); +// } +// ``` +pub(crate) fn toggle_ignore(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let attr: ast::Attr = ctx.find_node_at_offset()?; + let func = attr.syntax().parent().and_then(ast::Fn::cast)?; + let attr = test_related_attribute(&func)?; + + match has_ignore_attribute(&func) { + None => acc.add( + AssistId("toggle_ignore", AssistKind::None), + "Ignore this test", + attr.syntax().text_range(), + |builder| builder.insert(attr.syntax().text_range().end(), "\n#[ignore]"), + ), + Some(ignore_attr) => acc.add( + AssistId("toggle_ignore", AssistKind::None), + "Re-enable this test", + ignore_attr.syntax().text_range(), + |builder| { + builder.delete(ignore_attr.syntax().text_range()); + let whitespace = ignore_attr + .syntax() + .next_sibling_or_token() + .and_then(|x| x.into_token()) + .and_then(ast::Whitespace::cast); + if let Some(whitespace) = whitespace { + builder.delete(whitespace.syntax().text_range()); + } + }, + ), + } +} + +fn has_ignore_attribute(fn_def: &ast::Fn) -> Option { + fn_def.attrs().find(|attr| attr.path().map(|it| it.syntax().text() == "ignore") == Some(true)) +} + +#[cfg(test)] +mod tests { + use crate::tests::check_assist; + + use super::*; + + #[test] + fn test_base_case() { + check_assist( + toggle_ignore, + r#" + #[test$0] + fn test() {} + "#, + r#" + #[test] + #[ignore] + fn test() {} + "#, + ) + } + + #[test] + fn test_unignore() { + check_assist( + toggle_ignore, + r#" + #[test$0] + #[ignore] + fn test() {} + "#, + r#" + #[test] + fn test() {} + "#, + ) + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unmerge_use.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unmerge_use.rs new file mode 100644 index 000000000..3ce028e93 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unmerge_use.rs @@ -0,0 +1,237 @@ +use syntax::{ + ast::{self, make, HasVisibility}, + ted::{self, Position}, + AstNode, SyntaxKind, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, AssistKind, +}; + +// Assist: unmerge_use +// +// Extracts single use item from use list. +// +// ``` +// use std::fmt::{Debug, Display$0}; +// ``` +// -> +// ``` +// use std::fmt::{Debug}; +// use std::fmt::Display; +// ``` +pub(crate) fn unmerge_use(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let tree: ast::UseTree = ctx.find_node_at_offset::()?.clone_for_update(); + + let tree_list = tree.syntax().parent().and_then(ast::UseTreeList::cast)?; + if tree_list.use_trees().count() < 2 { + cov_mark::hit!(skip_single_use_item); + return None; + } + + let use_: ast::Use = tree_list.syntax().ancestors().find_map(ast::Use::cast)?; + let path = resolve_full_path(&tree)?; + + let old_parent_range = use_.syntax().parent()?.text_range(); + let new_parent = use_.syntax().parent()?; + + let target = tree.syntax().text_range(); + acc.add( + AssistId("unmerge_use", AssistKind::RefactorRewrite), + "Unmerge use", + target, + |builder| { + let new_use = make::use_( + use_.visibility(), + make::use_tree( + path, + tree.use_tree_list(), + tree.rename(), + tree.star_token().is_some(), + ), + ) + .clone_for_update(); + + tree.remove(); + ted::insert(Position::after(use_.syntax()), new_use.syntax()); + + builder.replace(old_parent_range, new_parent.to_string()); + }, + ) +} + +fn resolve_full_path(tree: &ast::UseTree) -> Option { + let paths = tree + .syntax() + .ancestors() + .take_while(|n| n.kind() != SyntaxKind::USE) + .filter_map(ast::UseTree::cast) + .filter_map(|t| t.path()); + + let final_path = paths.reduce(|prev, next| make::path_concat(next, prev))?; + if final_path.segment().map_or(false, |it| it.self_token().is_some()) { + final_path.qualifier() + } else { + Some(final_path) + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn skip_single_use_item() { + cov_mark::check!(skip_single_use_item); + check_assist_not_applicable( + unmerge_use, + r" +use std::fmt::Debug$0; +", + ); + check_assist_not_applicable( + unmerge_use, + r" +use std::fmt::{Debug$0}; +", + ); + check_assist_not_applicable( + unmerge_use, + r" +use std::fmt::Debug as Dbg$0; +", + ); + } + + #[test] + fn skip_single_glob_import() { + check_assist_not_applicable( + unmerge_use, + r" +use std::fmt::*$0; +", + ); + } + + #[test] + fn unmerge_use_item() { + check_assist( + unmerge_use, + r" +use std::fmt::{Debug, Display$0}; +", + r" +use std::fmt::{Debug}; +use std::fmt::Display; +", + ); + + check_assist( + unmerge_use, + r" +use std::fmt::{Debug, format$0, Display}; +", + r" +use std::fmt::{Debug, Display}; +use std::fmt::format; +", + ); + } + + #[test] + fn unmerge_glob_import() { + check_assist( + unmerge_use, + r" +use std::fmt::{*$0, Display}; +", + r" +use std::fmt::{Display}; +use std::fmt::*; +", + ); + } + + #[test] + fn unmerge_renamed_use_item() { + check_assist( + unmerge_use, + r" +use std::fmt::{Debug, Display as Disp$0}; +", + r" +use std::fmt::{Debug}; +use std::fmt::Display as Disp; +", + ); + } + + #[test] + fn unmerge_indented_use_item() { + check_assist( + unmerge_use, + r" +mod format { + use std::fmt::{Debug, Display$0 as Disp, format}; +} +", + r" +mod format { + use std::fmt::{Debug, format}; + use std::fmt::Display as Disp; +} +", + ); + } + + #[test] + fn unmerge_nested_use_item() { + check_assist( + unmerge_use, + r" +use foo::bar::{baz::{qux$0, foobar}, barbaz}; +", + r" +use foo::bar::{baz::{foobar}, barbaz}; +use foo::bar::baz::qux; +", + ); + check_assist( + unmerge_use, + r" +use foo::bar::{baz$0::{qux, foobar}, barbaz}; +", + r" +use foo::bar::{barbaz}; +use foo::bar::baz::{qux, foobar}; +", + ); + } + + #[test] + fn unmerge_use_item_with_visibility() { + check_assist( + unmerge_use, + r" +pub use std::fmt::{Debug, Display$0}; +", + r" +pub use std::fmt::{Debug}; +pub use std::fmt::Display; +", + ); + } + + #[test] + fn unmerge_use_item_on_self() { + check_assist( + unmerge_use, + r"use std::process::{Command, self$0};", + r"use std::process::{Command}; +use std::process;", + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unnecessary_async.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unnecessary_async.rs new file mode 100644 index 000000000..d5cd2d551 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unnecessary_async.rs @@ -0,0 +1,257 @@ +use ide_db::{ + assists::{AssistId, AssistKind}, + base_db::FileId, + defs::Definition, + search::FileReference, + syntax_helpers::node_ext::full_path_of_name_ref, +}; +use syntax::{ + ast::{self, NameLike, NameRef}, + AstNode, SyntaxKind, TextRange, +}; + +use crate::{AssistContext, Assists}; + +// Assist: unnecessary_async +// +// Removes the `async` mark from functions which have no `.await` in their body. +// Looks for calls to the functions and removes the `.await` on the call site. +// +// ``` +// pub async f$0n foo() {} +// pub async fn bar() { foo().await } +// ``` +// -> +// ``` +// pub fn foo() {} +// pub async fn bar() { foo() } +// ``` +pub(crate) fn unnecessary_async(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let function: ast::Fn = ctx.find_node_at_offset()?; + + // Do nothing if the cursor is not on the prototype. This is so that the check does not pollute + // when the user asks us for assists when in the middle of the function body. + // We consider the prototype to be anything that is before the body of the function. + let cursor_position = ctx.offset(); + if cursor_position >= function.body()?.syntax().text_range().start() { + return None; + } + // Do nothing if the function isn't async. + if let None = function.async_token() { + return None; + } + // Do nothing if the function has an `await` expression in its body. + if function.body()?.syntax().descendants().find_map(ast::AwaitExpr::cast).is_some() { + return None; + } + + // Remove the `async` keyword plus whitespace after it, if any. + let async_range = { + let async_token = function.async_token()?; + let next_token = async_token.next_token()?; + if matches!(next_token.kind(), SyntaxKind::WHITESPACE) { + TextRange::new(async_token.text_range().start(), next_token.text_range().end()) + } else { + async_token.text_range() + } + }; + + // Otherwise, we may remove the `async` keyword. + acc.add( + AssistId("unnecessary_async", AssistKind::QuickFix), + "Remove unnecessary async", + async_range, + |edit| { + // Remove async on the function definition. + edit.replace(async_range, ""); + + // Remove all `.await`s from calls to the function we remove `async` from. + if let Some(fn_def) = ctx.sema.to_def(&function) { + for await_expr in find_all_references(ctx, &Definition::Function(fn_def)) + // Keep only references that correspond NameRefs. + .filter_map(|(_, reference)| match reference.name { + NameLike::NameRef(nameref) => Some(nameref), + _ => None, + }) + // Keep only references that correspond to await expressions + .filter_map(|nameref| find_await_expression(ctx, &nameref)) + { + if let Some(await_token) = &await_expr.await_token() { + edit.replace(await_token.text_range(), ""); + } + if let Some(dot_token) = &await_expr.dot_token() { + edit.replace(dot_token.text_range(), ""); + } + } + } + }, + ) +} + +fn find_all_references( + ctx: &AssistContext<'_>, + def: &Definition, +) -> impl Iterator { + def.usages(&ctx.sema).all().into_iter().flat_map(|(file_id, references)| { + references.into_iter().map(move |reference| (file_id, reference)) + }) +} + +/// Finds the await expression for the given `NameRef`. +/// If no await expression is found, returns None. +fn find_await_expression(ctx: &AssistContext<'_>, nameref: &NameRef) -> Option { + // From the nameref, walk up the tree to the await expression. + let await_expr = if let Some(path) = full_path_of_name_ref(&nameref) { + // Function calls. + path.syntax() + .parent() + .and_then(ast::PathExpr::cast)? + .syntax() + .parent() + .and_then(ast::CallExpr::cast)? + .syntax() + .parent() + .and_then(ast::AwaitExpr::cast) + } else { + // Method calls. + nameref + .syntax() + .parent() + .and_then(ast::MethodCallExpr::cast)? + .syntax() + .parent() + .and_then(ast::AwaitExpr::cast) + }; + + ctx.sema.original_ast_node(await_expr?) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable}; + + #[test] + fn applies_on_empty_function() { + check_assist(unnecessary_async, "pub async f$0n f() {}", "pub fn f() {}") + } + + #[test] + fn applies_and_removes_whitespace() { + check_assist(unnecessary_async, "pub async f$0n f() {}", "pub fn f() {}") + } + + #[test] + fn does_not_apply_on_non_async_function() { + check_assist_not_applicable(unnecessary_async, "pub f$0n f() {}") + } + + #[test] + fn applies_on_function_with_a_non_await_expr() { + check_assist(unnecessary_async, "pub async f$0n f() { f2() }", "pub fn f() { f2() }") + } + + #[test] + fn does_not_apply_on_function_with_an_await_expr() { + check_assist_not_applicable(unnecessary_async, "pub async f$0n f() { f2().await }") + } + + #[test] + fn applies_and_removes_await_on_reference() { + check_assist( + unnecessary_async, + r#" +pub async fn f4() { } +pub async f$0n f2() { } +pub async fn f() { f2().await } +pub async fn f3() { f2().await }"#, + r#" +pub async fn f4() { } +pub fn f2() { } +pub async fn f() { f2() } +pub async fn f3() { f2() }"#, + ) + } + + #[test] + fn applies_and_removes_await_from_within_module() { + check_assist( + unnecessary_async, + r#" +pub async fn f4() { } +mod a { pub async f$0n f2() { } } +pub async fn f() { a::f2().await } +pub async fn f3() { a::f2().await }"#, + r#" +pub async fn f4() { } +mod a { pub fn f2() { } } +pub async fn f() { a::f2() } +pub async fn f3() { a::f2() }"#, + ) + } + + #[test] + fn applies_and_removes_await_on_inner_await() { + check_assist( + unnecessary_async, + // Ensure that it is the first await on the 3rd line that is removed + r#" +pub async fn f() { f2().await } +pub async f$0n f2() -> i32 { 1 } +pub async fn f3() { f4(f2().await).await } +pub async fn f4(i: i32) { }"#, + r#" +pub async fn f() { f2() } +pub fn f2() -> i32 { 1 } +pub async fn f3() { f4(f2()).await } +pub async fn f4(i: i32) { }"#, + ) + } + + #[test] + fn applies_and_removes_await_on_outer_await() { + check_assist( + unnecessary_async, + // Ensure that it is the second await on the 3rd line that is removed + r#" +pub async fn f() { f2().await } +pub async f$0n f2(i: i32) { } +pub async fn f3() { f2(f4().await).await } +pub async fn f4() -> i32 { 1 }"#, + r#" +pub async fn f() { f2() } +pub fn f2(i: i32) { } +pub async fn f3() { f2(f4().await) } +pub async fn f4() -> i32 { 1 }"#, + ) + } + + #[test] + fn applies_on_method_call() { + check_assist( + unnecessary_async, + r#" +pub struct S { } +impl S { pub async f$0n f2(&self) { } } +pub async fn f(s: &S) { s.f2().await }"#, + r#" +pub struct S { } +impl S { pub fn f2(&self) { } } +pub async fn f(s: &S) { s.f2() }"#, + ) + } + + #[test] + fn does_not_apply_on_function_with_a_nested_await_expr() { + check_assist_not_applicable( + unnecessary_async, + "async f$0n f() { if true { loop { f2().await } } }", + ) + } + + #[test] + fn does_not_apply_when_not_on_prototype() { + check_assist_not_applicable(unnecessary_async, "pub async fn f() { $0f2() }") + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_block.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_block.rs new file mode 100644 index 000000000..7969a4918 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_block.rs @@ -0,0 +1,719 @@ +use syntax::{ + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + }, + AstNode, SyntaxKind, TextRange, T, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: unwrap_block +// +// This assist removes if...else, for, while and loop control statements to just keep the body. +// +// ``` +// fn foo() { +// if true {$0 +// println!("foo"); +// } +// } +// ``` +// -> +// ``` +// fn foo() { +// println!("foo"); +// } +// ``` +pub(crate) fn unwrap_block(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let assist_id = AssistId("unwrap_block", AssistKind::RefactorRewrite); + let assist_label = "Unwrap block"; + + let l_curly_token = ctx.find_token_syntax_at_offset(T!['{'])?; + let mut block = ast::BlockExpr::cast(l_curly_token.parent_ancestors().nth(1)?)?; + let target = block.syntax().text_range(); + let mut parent = block.syntax().parent()?; + if ast::MatchArm::can_cast(parent.kind()) { + parent = parent.ancestors().find(|it| ast::MatchExpr::can_cast(it.kind()))? + } + + if matches!(parent.kind(), SyntaxKind::STMT_LIST | SyntaxKind::EXPR_STMT) { + return acc.add(assist_id, assist_label, target, |builder| { + builder.replace(block.syntax().text_range(), update_expr_string(block.to_string())); + }); + } + + let parent = ast::Expr::cast(parent)?; + + match parent.clone() { + ast::Expr::ForExpr(_) | ast::Expr::WhileExpr(_) | ast::Expr::LoopExpr(_) => (), + ast::Expr::MatchExpr(_) => block = block.dedent(IndentLevel(1)), + ast::Expr::IfExpr(if_expr) => { + let then_branch = if_expr.then_branch()?; + if then_branch == block { + if let Some(ancestor) = if_expr.syntax().parent().and_then(ast::IfExpr::cast) { + // For `else if` blocks + let ancestor_then_branch = ancestor.then_branch()?; + + return acc.add(assist_id, assist_label, target, |edit| { + let range_to_del_else_if = TextRange::new( + ancestor_then_branch.syntax().text_range().end(), + l_curly_token.text_range().start(), + ); + let range_to_del_rest = TextRange::new( + then_branch.syntax().text_range().end(), + if_expr.syntax().text_range().end(), + ); + + edit.delete(range_to_del_rest); + edit.delete(range_to_del_else_if); + edit.replace( + target, + update_expr_string_without_newline(then_branch.to_string()), + ); + }); + } + } else { + return acc.add(assist_id, assist_label, target, |edit| { + let range_to_del = TextRange::new( + then_branch.syntax().text_range().end(), + l_curly_token.text_range().start(), + ); + + edit.delete(range_to_del); + edit.replace(target, update_expr_string_without_newline(block.to_string())); + }); + } + } + _ => return None, + }; + + acc.add(assist_id, assist_label, target, |builder| { + builder.replace(parent.syntax().text_range(), update_expr_string(block.to_string())); + }) +} + +fn update_expr_string(expr_string: String) -> String { + update_expr_string_with_pat(expr_string, &[' ', '\n']) +} + +fn update_expr_string_without_newline(expr_string: String) -> String { + update_expr_string_with_pat(expr_string, &[' ']) +} + +fn update_expr_string_with_pat(expr_str: String, whitespace_pat: &[char]) -> String { + // Remove leading whitespace, index [1..] to remove the leading '{', + // then continue to remove leading whitespace. + let expr_str = + expr_str.trim_start_matches(whitespace_pat)[1..].trim_start_matches(whitespace_pat); + + // Remove trailing whitespace, index [..expr_str.len() - 1] to remove the trailing '}', + // then continue to remove trailing whitespace. + let expr_str = expr_str.trim_end_matches(whitespace_pat); + let expr_str = expr_str[..expr_str.len() - 1].trim_end_matches(whitespace_pat); + + expr_str + .lines() + .map(|line| line.replacen(" ", "", 1)) // Delete indentation + .collect::>() + .join("\n") +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn unwrap_tail_expr_block() { + check_assist( + unwrap_block, + r#" +fn main() { + $0{ + 92 + } +} +"#, + r#" +fn main() { + 92 +} +"#, + ) + } + + #[test] + fn unwrap_stmt_expr_block() { + check_assist( + unwrap_block, + r#" +fn main() { + $0{ + 92; + } + () +} +"#, + r#" +fn main() { + 92; + () +} +"#, + ); + // Pedantically, we should add an `;` here... + check_assist( + unwrap_block, + r#" +fn main() { + $0{ + 92 + } + () +} +"#, + r#" +fn main() { + 92 + () +} +"#, + ); + } + + #[test] + fn simple_if() { + check_assist( + unwrap_block, + r#" +fn main() { + bar(); + if true {$0 + foo(); + + // comment + bar(); + } else { + println!("bar"); + } +} +"#, + r#" +fn main() { + bar(); + foo(); + + // comment + bar(); +} +"#, + ); + } + + #[test] + fn simple_if_else() { + check_assist( + unwrap_block, + r#" +fn main() { + bar(); + if true { + foo(); + + // comment + bar(); + } else {$0 + println!("bar"); + } +} +"#, + r#" +fn main() { + bar(); + if true { + foo(); + + // comment + bar(); + } + println!("bar"); +} +"#, + ); + } + + #[test] + fn simple_if_else_if() { + check_assist( + unwrap_block, + r#" +fn main() { + // bar(); + if true { + println!("true"); + + // comment + // bar(); + } else if false {$0 + println!("bar"); + } else { + println!("foo"); + } +} +"#, + r#" +fn main() { + // bar(); + if true { + println!("true"); + + // comment + // bar(); + } + println!("bar"); +} +"#, + ); + } + + #[test] + fn simple_if_else_if_nested() { + check_assist( + unwrap_block, + r#" +fn main() { + // bar(); + if true { + println!("true"); + + // comment + // bar(); + } else if false { + println!("bar"); + } else if true {$0 + println!("foo"); + } +} +"#, + r#" +fn main() { + // bar(); + if true { + println!("true"); + + // comment + // bar(); + } else if false { + println!("bar"); + } + println!("foo"); +} +"#, + ); + } + + #[test] + fn simple_if_else_if_nested_else() { + check_assist( + unwrap_block, + r#" +fn main() { + // bar(); + if true { + println!("true"); + + // comment + // bar(); + } else if false { + println!("bar"); + } else if true { + println!("foo"); + } else {$0 + println!("else"); + } +} +"#, + r#" +fn main() { + // bar(); + if true { + println!("true"); + + // comment + // bar(); + } else if false { + println!("bar"); + } else if true { + println!("foo"); + } + println!("else"); +} +"#, + ); + } + + #[test] + fn simple_if_else_if_nested_middle() { + check_assist( + unwrap_block, + r#" +fn main() { + // bar(); + if true { + println!("true"); + + // comment + // bar(); + } else if false { + println!("bar"); + } else if true {$0 + println!("foo"); + } else { + println!("else"); + } +} +"#, + r#" +fn main() { + // bar(); + if true { + println!("true"); + + // comment + // bar(); + } else if false { + println!("bar"); + } + println!("foo"); +} +"#, + ); + } + + #[test] + fn simple_if_bad_cursor_position() { + check_assist_not_applicable( + unwrap_block, + r#" +fn main() { + bar();$0 + if true { + foo(); + + // comment + bar(); + } else { + println!("bar"); + } +} +"#, + ); + } + + #[test] + fn simple_for() { + check_assist( + unwrap_block, + r#" +fn main() { + for i in 0..5 {$0 + if true { + foo(); + + // comment + bar(); + } else { + println!("bar"); + } + } +} +"#, + r#" +fn main() { + if true { + foo(); + + // comment + bar(); + } else { + println!("bar"); + } +} +"#, + ); + } + + #[test] + fn simple_if_in_for() { + check_assist( + unwrap_block, + r#" +fn main() { + for i in 0..5 { + if true {$0 + foo(); + + // comment + bar(); + } else { + println!("bar"); + } + } +} +"#, + r#" +fn main() { + for i in 0..5 { + foo(); + + // comment + bar(); + } +} +"#, + ); + } + + #[test] + fn simple_loop() { + check_assist( + unwrap_block, + r#" +fn main() { + loop {$0 + if true { + foo(); + + // comment + bar(); + } else { + println!("bar"); + } + } +} +"#, + r#" +fn main() { + if true { + foo(); + + // comment + bar(); + } else { + println!("bar"); + } +} +"#, + ); + } + + #[test] + fn simple_while() { + check_assist( + unwrap_block, + r#" +fn main() { + while true {$0 + if true { + foo(); + + // comment + bar(); + } else { + println!("bar"); + } + } +} +"#, + r#" +fn main() { + if true { + foo(); + + // comment + bar(); + } else { + println!("bar"); + } +} +"#, + ); + } + + #[test] + fn unwrap_match_arm() { + check_assist( + unwrap_block, + r#" +fn main() { + match rel_path { + Ok(rel_path) => {$0 + let rel_path = RelativePathBuf::from_path(rel_path).ok()?; + Some((*id, rel_path)) + } + Err(_) => None, + } +} +"#, + r#" +fn main() { + let rel_path = RelativePathBuf::from_path(rel_path).ok()?; + Some((*id, rel_path)) +} +"#, + ); + } + + #[test] + fn simple_if_in_while_bad_cursor_position() { + check_assist_not_applicable( + unwrap_block, + r#" +fn main() { + while true { + if true { + foo();$0 + + // comment + bar(); + } else { + println!("bar"); + } + } +} +"#, + ); + } + + #[test] + fn simple_single_line() { + check_assist( + unwrap_block, + r#" +fn main() { + {$0 0 } +} +"#, + r#" +fn main() { + 0 +} +"#, + ); + } + + #[test] + fn simple_nested_block() { + check_assist( + unwrap_block, + r#" +fn main() { + $0{ + { + 3 + } + } +} +"#, + r#" +fn main() { + { + 3 + } +} +"#, + ); + } + + #[test] + fn nested_single_line() { + check_assist( + unwrap_block, + r#" +fn main() { + {$0 { println!("foo"); } } +} +"#, + r#" +fn main() { + { println!("foo"); } +} +"#, + ); + + check_assist( + unwrap_block, + r#" +fn main() { + {$0 { 0 } } +} +"#, + r#" +fn main() { + { 0 } +} +"#, + ); + } + + #[test] + fn simple_if_single_line() { + check_assist( + unwrap_block, + r#" +fn main() { + if true {$0 /* foo */ foo() } else { bar() /* bar */} +} +"#, + r#" +fn main() { + /* foo */ foo() +} +"#, + ); + } + + #[test] + fn if_single_statement() { + check_assist( + unwrap_block, + r#" +fn main() { + if true {$0 + return 3; + } +} +"#, + r#" +fn main() { + return 3; +} +"#, + ); + } + + #[test] + fn multiple_statements() { + check_assist( + unwrap_block, + r#" +fn main() -> i32 { + if 2 > 1 {$0 + let a = 5; + return 3; + } + 5 +} +"#, + r#" +fn main() -> i32 { + let a = 5; + return 3; + 5 +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_result_return_type.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_result_return_type.rs new file mode 100644 index 000000000..9ef4ae047 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_result_return_type.rs @@ -0,0 +1,1020 @@ +use ide_db::{ + famous_defs::FamousDefs, + syntax_helpers::node_ext::{for_each_tail_expr, walk_expr}, +}; +use itertools::Itertools; +use syntax::{ + ast::{self, Expr}, + match_ast, AstNode, TextRange, TextSize, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: unwrap_result_return_type +// +// Unwrap the function's return type. +// +// ``` +// # //- minicore: result +// fn foo() -> Result$0 { Ok(42i32) } +// ``` +// -> +// ``` +// fn foo() -> i32 { 42i32 } +// ``` +pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let ret_type = ctx.find_node_at_offset::()?; + let parent = ret_type.syntax().parent()?; + let body = match_ast! { + match parent { + ast::Fn(func) => func.body()?, + ast::ClosureExpr(closure) => match closure.body()? { + Expr::BlockExpr(block) => block, + // closures require a block when a return type is specified + _ => return None, + }, + _ => return None, + } + }; + + let type_ref = &ret_type.ty()?; + let ty = ctx.sema.resolve_type(type_ref)?.as_adt(); + let result_enum = + FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate()).core_result_Result()?; + + if !matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) { + return None; + } + + acc.add( + AssistId("unwrap_result_return_type", AssistKind::RefactorRewrite), + "Unwrap Result return type", + type_ref.syntax().text_range(), + |builder| { + let body = ast::Expr::BlockExpr(body); + + let mut exprs_to_unwrap = Vec::new(); + let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_unwrap, e); + walk_expr(&body, &mut |expr| { + if let Expr::ReturnExpr(ret_expr) = expr { + if let Some(ret_expr_arg) = &ret_expr.expr() { + for_each_tail_expr(ret_expr_arg, tail_cb); + } + } + }); + for_each_tail_expr(&body, tail_cb); + + let mut is_unit_type = false; + if let Some((_, inner_type)) = type_ref.to_string().split_once('<') { + let inner_type = match inner_type.split_once(',') { + Some((success_inner_type, _)) => success_inner_type, + None => inner_type, + }; + let new_ret_type = inner_type.strip_suffix('>').unwrap_or(inner_type); + if new_ret_type == "()" { + is_unit_type = true; + let text_range = TextRange::new( + ret_type.syntax().text_range().start(), + ret_type.syntax().text_range().end() + TextSize::from(1u32), + ); + builder.delete(text_range) + } else { + builder.replace( + type_ref.syntax().text_range(), + inner_type.strip_suffix('>').unwrap_or(inner_type), + ) + } + } + + for ret_expr_arg in exprs_to_unwrap { + let ret_expr_str = ret_expr_arg.to_string(); + if ret_expr_str.starts_with("Ok(") || ret_expr_str.starts_with("Err(") { + let arg_list = ret_expr_arg.syntax().children().find_map(ast::ArgList::cast); + if let Some(arg_list) = arg_list { + if is_unit_type { + match ret_expr_arg.syntax().prev_sibling_or_token() { + // Useful to delete the entire line without leaving trailing whitespaces + Some(whitespace) => { + let new_range = TextRange::new( + whitespace.text_range().start(), + ret_expr_arg.syntax().text_range().end(), + ); + builder.delete(new_range); + } + None => { + builder.delete(ret_expr_arg.syntax().text_range()); + } + } + } else { + builder.replace( + ret_expr_arg.syntax().text_range(), + arg_list.args().join(", "), + ); + } + } + } + } + }, + ) +} + +fn tail_cb_impl(acc: &mut Vec, e: &ast::Expr) { + match e { + Expr::BreakExpr(break_expr) => { + if let Some(break_expr_arg) = break_expr.expr() { + for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(acc, e)) + } + } + Expr::ReturnExpr(ret_expr) => { + if let Some(ret_expr_arg) = &ret_expr.expr() { + for_each_tail_expr(ret_expr_arg, &mut |e| tail_cb_impl(acc, e)); + } + } + e => acc.push(e.clone()), + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn unwrap_result_return_type_simple() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result { + let test = "test"; + return Ok(42i32); +} +"#, + r#" +fn foo() -> i32 { + let test = "test"; + return 42i32; +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_unit_type() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result<(), Box> { + Ok(()) +} +"#, + r#" +fn foo() { +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_ending_with_parent() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result> { + if true { + Ok(42) + } else { + foo() + } +} +"#, + r#" +fn foo() -> i32 { + if true { + 42 + } else { + foo() + } +} +"#, + ); + } + + #[test] + fn unwrap_return_type_break_split_tail() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result { + loop { + break if true { + Ok(1) + } else { + Ok(0) + }; + } +} +"#, + r#" +fn foo() -> i32 { + loop { + break if true { + 1 + } else { + 0 + }; + } +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_closure() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() { + || -> Result { + let test = "test"; + return Ok(42i32); + }; +} +"#, + r#" +fn foo() { + || -> i32 { + let test = "test"; + return 42i32; + }; +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_return_type_bad_cursor() { + check_assist_not_applicable( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> i32 { + let test = "test";$0 + return 42i32; +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_return_type_bad_cursor_closure() { + check_assist_not_applicable( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() { + || -> i32 { + let test = "test";$0 + return 42i32; + }; +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_closure_non_block() { + check_assist_not_applicable( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() { || -> i$032 3; } +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_return_type_already_not_result_std() { + check_assist_not_applicable( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> i32$0 { + let test = "test"; + return 42i32; +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_return_type_already_not_result_closure() { + check_assist_not_applicable( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() { + || -> i32$0 { + let test = "test"; + return 42i32; + }; +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_with_tail() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() ->$0 Result { + let test = "test"; + Ok(42i32) +} +"#, + r#" +fn foo() -> i32 { + let test = "test"; + 42i32 +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_with_tail_closure() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() { + || ->$0 Result { + let test = "test"; + Ok(42i32) + }; +} +"#, + r#" +fn foo() { + || -> i32 { + let test = "test"; + 42i32 + }; +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_with_tail_only() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result { Ok(42i32) } +"#, + r#" +fn foo() -> i32 { 42i32 } +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_with_tail_block_like() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result$0 { + if true { + Ok(42i32) + } else { + Ok(24i32) + } +} +"#, + r#" +fn foo() -> i32 { + if true { + 42i32 + } else { + 24i32 + } +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_without_block_closure() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() { + || -> Result$0 { + if true { + Ok(42i32) + } else { + Ok(24i32) + } + }; +} +"#, + r#" +fn foo() { + || -> i32 { + if true { + 42i32 + } else { + 24i32 + } + }; +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_with_nested_if() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result$0 { + if true { + if false { + Ok(1) + } else { + Ok(2) + } + } else { + Ok(24i32) + } +} +"#, + r#" +fn foo() -> i32 { + if true { + if false { + 1 + } else { + 2 + } + } else { + 24i32 + } +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_with_await() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +async fn foo() -> Result { + if true { + if false { + Ok(1.await) + } else { + Ok(2.await) + } + } else { + Ok(24i32.await) + } +} +"#, + r#" +async fn foo() -> i32 { + if true { + if false { + 1.await + } else { + 2.await + } + } else { + 24i32.await + } +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_with_array() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result<[i32; 3]$0> { Ok([1, 2, 3]) } +"#, + r#" +fn foo() -> [i32; 3] { [1, 2, 3] } +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_with_cast() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -$0> Result { + if true { + if false { + Ok(1 as i32) + } else { + Ok(2 as i32) + } + } else { + Ok(24 as i32) + } +} +"#, + r#" +fn foo() -> i32 { + if true { + if false { + 1 as i32 + } else { + 2 as i32 + } + } else { + 24 as i32 + } +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_with_tail_block_like_match() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result { + let my_var = 5; + match my_var { + 5 => Ok(42i32), + _ => Ok(24i32), + } +} +"#, + r#" +fn foo() -> i32 { + let my_var = 5; + match my_var { + 5 => 42i32, + _ => 24i32, + } +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_with_loop_with_tail() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result { + let my_var = 5; + loop { + println!("test"); + 5 + } + Ok(my_var) +} +"#, + r#" +fn foo() -> i32 { + let my_var = 5; + loop { + println!("test"); + 5 + } + my_var +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_with_loop_in_let_stmt() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result { + let my_var = let x = loop { + break 1; + }; + Ok(my_var) +} +"#, + r#" +fn foo() -> i32 { + let my_var = let x = loop { + break 1; + }; + my_var +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_with_tail_block_like_match_return_expr() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result$0 { + let my_var = 5; + let res = match my_var { + 5 => 42i32, + _ => return Ok(24i32), + }; + Ok(res) +} +"#, + r#" +fn foo() -> i32 { + let my_var = 5; + let res = match my_var { + 5 => 42i32, + _ => return 24i32, + }; + res +} +"#, + ); + + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result { + let my_var = 5; + let res = if my_var == 5 { + 42i32 + } else { + return Ok(24i32); + }; + Ok(res) +} +"#, + r#" +fn foo() -> i32 { + let my_var = 5; + let res = if my_var == 5 { + 42i32 + } else { + return 24i32; + }; + res +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_with_tail_block_like_match_deeper() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result { + let my_var = 5; + match my_var { + 5 => { + if true { + Ok(42i32) + } else { + Ok(25i32) + } + }, + _ => { + let test = "test"; + if test == "test" { + return Ok(bar()); + } + Ok(53i32) + }, + } +} +"#, + r#" +fn foo() -> i32 { + let my_var = 5; + match my_var { + 5 => { + if true { + 42i32 + } else { + 25i32 + } + }, + _ => { + let test = "test"; + if test == "test" { + return bar(); + } + 53i32 + }, + } +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_with_tail_block_like_early_return() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result { + let test = "test"; + if test == "test" { + return Ok(24i32); + } + Ok(53i32) +} +"#, + r#" +fn foo() -> i32 { + let test = "test"; + if test == "test" { + return 24i32; + } + 53i32 +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_with_closure() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo(the_field: u32) -> Result { + let true_closure = || { return true; }; + if the_field < 5 { + let mut i = 0; + if true_closure() { + return Ok(99); + } else { + return Ok(0); + } + } + Ok(the_field) +} +"#, + r#" +fn foo(the_field: u32) -> u32 { + let true_closure = || { return true; }; + if the_field < 5 { + let mut i = 0; + if true_closure() { + return 99; + } else { + return 0; + } + } + the_field +} +"#, + ); + + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo(the_field: u32) -> Result { + let true_closure = || { + return true; + }; + if the_field < 5 { + let mut i = 0; + + + if true_closure() { + return Ok(99); + } else { + return Ok(0); + } + } + let t = None; + + Ok(t.unwrap_or_else(|| the_field)) +} +"#, + r#" +fn foo(the_field: u32) -> u32 { + let true_closure = || { + return true; + }; + if the_field < 5 { + let mut i = 0; + + + if true_closure() { + return 99; + } else { + return 0; + } + } + let t = None; + + t.unwrap_or_else(|| the_field) +} +"#, + ); + } + + #[test] + fn unwrap_result_return_type_simple_with_weird_forms() { + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo() -> Result { + let test = "test"; + if test == "test" { + return Ok(24i32); + } + let mut i = 0; + loop { + if i == 1 { + break Ok(55); + } + i += 1; + } +} +"#, + r#" +fn foo() -> i32 { + let test = "test"; + if test == "test" { + return 24i32; + } + let mut i = 0; + loop { + if i == 1 { + break 55; + } + i += 1; + } +} +"#, + ); + + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo(the_field: u32) -> Result { + if the_field < 5 { + let mut i = 0; + loop { + if i > 5 { + return Ok(55u32); + } + i += 3; + } + match i { + 5 => return Ok(99), + _ => return Ok(0), + }; + } + Ok(the_field) +} +"#, + r#" +fn foo(the_field: u32) -> u32 { + if the_field < 5 { + let mut i = 0; + loop { + if i > 5 { + return 55u32; + } + i += 3; + } + match i { + 5 => return 99, + _ => return 0, + }; + } + the_field +} +"#, + ); + + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo(the_field: u32) -> Result { + if the_field < 5 { + let mut i = 0; + match i { + 5 => return Ok(99), + _ => return Ok(0), + } + } + Ok(the_field) +} +"#, + r#" +fn foo(the_field: u32) -> u32 { + if the_field < 5 { + let mut i = 0; + match i { + 5 => return 99, + _ => return 0, + } + } + the_field +} +"#, + ); + + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo(the_field: u32) -> Result { + if the_field < 5 { + let mut i = 0; + if i == 5 { + return Ok(99) + } else { + return Ok(0) + } + } + Ok(the_field) +} +"#, + r#" +fn foo(the_field: u32) -> u32 { + if the_field < 5 { + let mut i = 0; + if i == 5 { + return 99 + } else { + return 0 + } + } + the_field +} +"#, + ); + + check_assist( + unwrap_result_return_type, + r#" +//- minicore: result +fn foo(the_field: u32) -> Result { + if the_field < 5 { + let mut i = 0; + if i == 5 { + return Ok(99); + } else { + return Ok(0); + } + } + Ok(the_field) +} +"#, + r#" +fn foo(the_field: u32) -> u32 { + if the_field < 5 { + let mut i = 0; + if i == 5 { + return 99; + } else { + return 0; + } + } + the_field +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs new file mode 100644 index 000000000..83446387d --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs @@ -0,0 +1,980 @@ +use std::iter; + +use ide_db::{ + famous_defs::FamousDefs, + syntax_helpers::node_ext::{for_each_tail_expr, walk_expr}, +}; +use syntax::{ + ast::{self, make, Expr}, + match_ast, AstNode, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: wrap_return_type_in_result +// +// Wrap the function's return type into Result. +// +// ``` +// # //- minicore: result +// fn foo() -> i32$0 { 42i32 } +// ``` +// -> +// ``` +// fn foo() -> Result { Ok(42i32) } +// ``` +pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let ret_type = ctx.find_node_at_offset::()?; + let parent = ret_type.syntax().parent()?; + let body = match_ast! { + match parent { + ast::Fn(func) => func.body()?, + ast::ClosureExpr(closure) => match closure.body()? { + Expr::BlockExpr(block) => block, + // closures require a block when a return type is specified + _ => return None, + }, + _ => return None, + } + }; + + let type_ref = &ret_type.ty()?; + let ty = ctx.sema.resolve_type(type_ref)?.as_adt(); + let result_enum = + FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate()).core_result_Result()?; + + if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) { + cov_mark::hit!(wrap_return_type_in_result_simple_return_type_already_result); + return None; + } + + acc.add( + AssistId("wrap_return_type_in_result", AssistKind::RefactorRewrite), + "Wrap return type in Result", + type_ref.syntax().text_range(), + |builder| { + let body = ast::Expr::BlockExpr(body); + + let mut exprs_to_wrap = Vec::new(); + let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_wrap, e); + walk_expr(&body, &mut |expr| { + if let Expr::ReturnExpr(ret_expr) = expr { + if let Some(ret_expr_arg) = &ret_expr.expr() { + for_each_tail_expr(ret_expr_arg, tail_cb); + } + } + }); + for_each_tail_expr(&body, tail_cb); + + for ret_expr_arg in exprs_to_wrap { + let ok_wrapped = make::expr_call( + make::expr_path(make::ext::ident_path("Ok")), + make::arg_list(iter::once(ret_expr_arg.clone())), + ); + builder.replace_ast(ret_expr_arg, ok_wrapped); + } + + match ctx.config.snippet_cap { + Some(cap) => { + let snippet = format!("Result<{}, ${{0:_}}>", type_ref); + builder.replace_snippet(cap, type_ref.syntax().text_range(), snippet) + } + None => builder + .replace(type_ref.syntax().text_range(), format!("Result<{}, _>", type_ref)), + } + }, + ) +} + +fn tail_cb_impl(acc: &mut Vec, e: &ast::Expr) { + match e { + Expr::BreakExpr(break_expr) => { + if let Some(break_expr_arg) = break_expr.expr() { + for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(acc, e)) + } + } + Expr::ReturnExpr(ret_expr) => { + if let Some(ret_expr_arg) = &ret_expr.expr() { + for_each_tail_expr(ret_expr_arg, &mut |e| tail_cb_impl(acc, e)); + } + } + e => acc.push(e.clone()), + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn wrap_return_type_in_result_simple() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> i3$02 { + let test = "test"; + return 42i32; +} +"#, + r#" +fn foo() -> Result { + let test = "test"; + return Ok(42i32); +} +"#, + ); + } + + #[test] + fn wrap_return_type_break_split_tail() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> i3$02 { + loop { + break if true { + 1 + } else { + 0 + }; + } +} +"#, + r#" +fn foo() -> Result { + loop { + break if true { + Ok(1) + } else { + Ok(0) + }; + } +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_closure() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() { + || -> i32$0 { + let test = "test"; + return 42i32; + }; +} +"#, + r#" +fn foo() { + || -> Result { + let test = "test"; + return Ok(42i32); + }; +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_return_type_bad_cursor() { + check_assist_not_applicable( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> i32 { + let test = "test";$0 + return 42i32; +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_return_type_bad_cursor_closure() { + check_assist_not_applicable( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() { + || -> i32 { + let test = "test";$0 + return 42i32; + }; +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_closure_non_block() { + check_assist_not_applicable( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() { || -> i$032 3; } +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_return_type_already_result_std() { + check_assist_not_applicable( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> core::result::Result { + let test = "test"; + return 42i32; +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_return_type_already_result() { + cov_mark::check!(wrap_return_type_in_result_simple_return_type_already_result); + check_assist_not_applicable( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> Result { + let test = "test"; + return 42i32; +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_return_type_already_result_closure() { + check_assist_not_applicable( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() { + || -> Result { + let test = "test"; + return 42i32; + }; +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_cursor() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> $0i32 { + let test = "test"; + return 42i32; +} +"#, + r#" +fn foo() -> Result { + let test = "test"; + return Ok(42i32); +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_tail() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() ->$0 i32 { + let test = "test"; + 42i32 +} +"#, + r#" +fn foo() -> Result { + let test = "test"; + Ok(42i32) +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_tail_closure() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() { + || ->$0 i32 { + let test = "test"; + 42i32 + }; +} +"#, + r#" +fn foo() { + || -> Result { + let test = "test"; + Ok(42i32) + }; +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_tail_only() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> i32$0 { 42i32 } +"#, + r#" +fn foo() -> Result { Ok(42i32) } +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_tail_block_like() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> i32$0 { + if true { + 42i32 + } else { + 24i32 + } +} +"#, + r#" +fn foo() -> Result { + if true { + Ok(42i32) + } else { + Ok(24i32) + } +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_without_block_closure() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() { + || -> i32$0 { + if true { + 42i32 + } else { + 24i32 + } + }; +} +"#, + r#" +fn foo() { + || -> Result { + if true { + Ok(42i32) + } else { + Ok(24i32) + } + }; +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_nested_if() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> i32$0 { + if true { + if false { + 1 + } else { + 2 + } + } else { + 24i32 + } +} +"#, + r#" +fn foo() -> Result { + if true { + if false { + Ok(1) + } else { + Ok(2) + } + } else { + Ok(24i32) + } +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_await() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +async fn foo() -> i$032 { + if true { + if false { + 1.await + } else { + 2.await + } + } else { + 24i32.await + } +} +"#, + r#" +async fn foo() -> Result { + if true { + if false { + Ok(1.await) + } else { + Ok(2.await) + } + } else { + Ok(24i32.await) + } +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_array() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> [i32;$0 3] { [1, 2, 3] } +"#, + r#" +fn foo() -> Result<[i32; 3], ${0:_}> { Ok([1, 2, 3]) } +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_cast() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -$0> i32 { + if true { + if false { + 1 as i32 + } else { + 2 as i32 + } + } else { + 24 as i32 + } +} +"#, + r#" +fn foo() -> Result { + if true { + if false { + Ok(1 as i32) + } else { + Ok(2 as i32) + } + } else { + Ok(24 as i32) + } +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_tail_block_like_match() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> i32$0 { + let my_var = 5; + match my_var { + 5 => 42i32, + _ => 24i32, + } +} +"#, + r#" +fn foo() -> Result { + let my_var = 5; + match my_var { + 5 => Ok(42i32), + _ => Ok(24i32), + } +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_loop_with_tail() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> i32$0 { + let my_var = 5; + loop { + println!("test"); + 5 + } + my_var +} +"#, + r#" +fn foo() -> Result { + let my_var = 5; + loop { + println!("test"); + 5 + } + Ok(my_var) +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_loop_in_let_stmt() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> i32$0 { + let my_var = let x = loop { + break 1; + }; + my_var +} +"#, + r#" +fn foo() -> Result { + let my_var = let x = loop { + break 1; + }; + Ok(my_var) +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_tail_block_like_match_return_expr() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> i32$0 { + let my_var = 5; + let res = match my_var { + 5 => 42i32, + _ => return 24i32, + }; + res +} +"#, + r#" +fn foo() -> Result { + let my_var = 5; + let res = match my_var { + 5 => 42i32, + _ => return Ok(24i32), + }; + Ok(res) +} +"#, + ); + + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> i32$0 { + let my_var = 5; + let res = if my_var == 5 { + 42i32 + } else { + return 24i32; + }; + res +} +"#, + r#" +fn foo() -> Result { + let my_var = 5; + let res = if my_var == 5 { + 42i32 + } else { + return Ok(24i32); + }; + Ok(res) +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_tail_block_like_match_deeper() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> i32$0 { + let my_var = 5; + match my_var { + 5 => { + if true { + 42i32 + } else { + 25i32 + } + }, + _ => { + let test = "test"; + if test == "test" { + return bar(); + } + 53i32 + }, + } +} +"#, + r#" +fn foo() -> Result { + let my_var = 5; + match my_var { + 5 => { + if true { + Ok(42i32) + } else { + Ok(25i32) + } + }, + _ => { + let test = "test"; + if test == "test" { + return Ok(bar()); + } + Ok(53i32) + }, + } +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_tail_block_like_early_return() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> i$032 { + let test = "test"; + if test == "test" { + return 24i32; + } + 53i32 +} +"#, + r#" +fn foo() -> Result { + let test = "test"; + if test == "test" { + return Ok(24i32); + } + Ok(53i32) +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_closure() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo(the_field: u32) ->$0 u32 { + let true_closure = || { return true; }; + if the_field < 5 { + let mut i = 0; + if true_closure() { + return 99; + } else { + return 0; + } + } + the_field +} +"#, + r#" +fn foo(the_field: u32) -> Result { + let true_closure = || { return true; }; + if the_field < 5 { + let mut i = 0; + if true_closure() { + return Ok(99); + } else { + return Ok(0); + } + } + Ok(the_field) +} +"#, + ); + + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo(the_field: u32) -> u32$0 { + let true_closure = || { + return true; + }; + if the_field < 5 { + let mut i = 0; + + + if true_closure() { + return 99; + } else { + return 0; + } + } + let t = None; + + t.unwrap_or_else(|| the_field) +} +"#, + r#" +fn foo(the_field: u32) -> Result { + let true_closure = || { + return true; + }; + if the_field < 5 { + let mut i = 0; + + + if true_closure() { + return Ok(99); + } else { + return Ok(0); + } + } + let t = None; + + Ok(t.unwrap_or_else(|| the_field)) +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_result_simple_with_weird_forms() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> i32$0 { + let test = "test"; + if test == "test" { + return 24i32; + } + let mut i = 0; + loop { + if i == 1 { + break 55; + } + i += 1; + } +} +"#, + r#" +fn foo() -> Result { + let test = "test"; + if test == "test" { + return Ok(24i32); + } + let mut i = 0; + loop { + if i == 1 { + break Ok(55); + } + i += 1; + } +} +"#, + ); + + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo(the_field: u32) -> u32$0 { + if the_field < 5 { + let mut i = 0; + loop { + if i > 5 { + return 55u32; + } + i += 3; + } + match i { + 5 => return 99, + _ => return 0, + }; + } + the_field +} +"#, + r#" +fn foo(the_field: u32) -> Result { + if the_field < 5 { + let mut i = 0; + loop { + if i > 5 { + return Ok(55u32); + } + i += 3; + } + match i { + 5 => return Ok(99), + _ => return Ok(0), + }; + } + Ok(the_field) +} +"#, + ); + + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo(the_field: u32) -> u3$02 { + if the_field < 5 { + let mut i = 0; + match i { + 5 => return 99, + _ => return 0, + } + } + the_field +} +"#, + r#" +fn foo(the_field: u32) -> Result { + if the_field < 5 { + let mut i = 0; + match i { + 5 => return Ok(99), + _ => return Ok(0), + } + } + Ok(the_field) +} +"#, + ); + + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo(the_field: u32) -> u32$0 { + if the_field < 5 { + let mut i = 0; + if i == 5 { + return 99 + } else { + return 0 + } + } + the_field +} +"#, + r#" +fn foo(the_field: u32) -> Result { + if the_field < 5 { + let mut i = 0; + if i == 5 { + return Ok(99) + } else { + return Ok(0) + } + } + Ok(the_field) +} +"#, + ); + + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo(the_field: u32) -> $0u32 { + if the_field < 5 { + let mut i = 0; + if i == 5 { + return 99; + } else { + return 0; + } + } + the_field +} +"#, + r#" +fn foo(the_field: u32) -> Result { + if the_field < 5 { + let mut i = 0; + if i == 5 { + return Ok(99); + } else { + return Ok(0); + } + } + Ok(the_field) +} +"#, + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/lib.rs b/src/tools/rust-analyzer/crates/ide-assists/src/lib.rs new file mode 100644 index 000000000..fe87aa15f --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/lib.rs @@ -0,0 +1,309 @@ +//! `assists` crate provides a bunch of code assists, also known as code actions +//! (in LSP) or intentions (in IntelliJ). +//! +//! An assist is a micro-refactoring, which is automatically activated in +//! certain context. For example, if the cursor is over `,`, a "swap `,`" assist +//! becomes available. +//! +//! ## Assists Guidelines +//! +//! Assists are the main mechanism to deliver advanced IDE features to the user, +//! so we should pay extra attention to the UX. +//! +//! The power of assists comes from their context-awareness. The main problem +//! with IDE features is that there are a lot of them, and it's hard to teach +//! the user what's available. Assists solve this problem nicely: 💡 signifies +//! that *something* is possible, and clicking on it reveals a *short* list of +//! actions. Contrast it with Emacs `M-x`, which just spits an infinite list of +//! all the features. +//! +//! Here are some considerations when creating a new assist: +//! +//! * It's good to preserve semantics, and it's good to keep the code compiling, +//! but it isn't necessary. Example: "flip binary operation" might change +//! semantics. +//! * Assist shouldn't necessary make the code "better". A lot of assist come in +//! pairs: "if let <-> match". +//! * Assists should have as narrow scope as possible. Each new assists greatly +//! improves UX for cases where the user actually invokes it, but it makes UX +//! worse for every case where the user clicks 💡 to invoke some *other* +//! assist. So, a rarely useful assist which is always applicable can be a net +//! negative. +//! * Rarely useful actions are tricky. Sometimes there are features which are +//! clearly useful to some users, but are just noise most of the time. We +//! don't have a good solution here, our current approach is to make this +//! functionality available only if assist is applicable to the whole +//! selection. Example: `sort_items` sorts items alphabetically. Naively, it +//! should be available more or less everywhere, which isn't useful. So +//! instead we only show it if the user *selects* the items they want to sort. +//! * Consider grouping related assists together (see [`Assists::add_group`]). +//! * Make assists robust. If the assist depends on results of type-inference too +//! much, it might only fire in fully-correct code. This makes assist less +//! useful and (worse) less predictable. The user should have a clear +//! intuition when each particular assist is available. +//! * Make small assists, which compose. Example: rather than auto-importing +//! enums in `add_missing_match_arms`, we use fully-qualified names. There's a +//! separate assist to shorten a fully-qualified name. +//! * Distinguish between assists and fixits for diagnostics. Internally, fixits +//! and assists are equivalent. They have the same "show a list + invoke a +//! single element" workflow, and both use [`Assist`] data structure. The main +//! difference is in the UX: while 💡 looks only at the cursor position, +//! diagnostics squigglies and fixits are calculated for the whole file and +//! are presented to the user eagerly. So, diagnostics should be fixable +//! errors, while assists can be just suggestions for an alternative way to do +//! something. If something *could* be a diagnostic, it should be a +//! diagnostic. Conversely, it might be valuable to turn a diagnostic with a +//! lot of false errors into an assist. +//! +//! See also this post: +//! + +#![warn(rust_2018_idioms, unused_lifetimes, semicolon_in_expressions_from_macros)] + +#[allow(unused)] +macro_rules! eprintln { + ($($tt:tt)*) => { stdx::eprintln!($($tt)*) }; +} + +mod assist_config; +mod assist_context; +#[cfg(test)] +mod tests; +pub mod utils; + +use hir::Semantics; +use ide_db::{base_db::FileRange, RootDatabase}; +use syntax::TextRange; + +pub(crate) use crate::assist_context::{AssistContext, Assists}; + +pub use assist_config::AssistConfig; +pub use ide_db::assists::{ + Assist, AssistId, AssistKind, AssistResolveStrategy, GroupLabel, SingleResolve, +}; + +/// Return all the assists applicable at the given position. +/// +// NOTE: We don't have a `Feature: ` section for assists, they are special-cased +// in the manual. +pub fn assists( + db: &RootDatabase, + config: &AssistConfig, + resolve: AssistResolveStrategy, + range: FileRange, +) -> Vec { + let sema = Semantics::new(db); + let ctx = AssistContext::new(sema, config, range); + let mut acc = Assists::new(&ctx, resolve); + handlers::all().iter().for_each(|handler| { + handler(&mut acc, &ctx); + }); + acc.finish() +} + +mod handlers { + use crate::{AssistContext, Assists}; + + pub(crate) type Handler = fn(&mut Assists, &AssistContext<'_>) -> Option<()>; + + mod add_explicit_type; + mod add_label_to_loop; + mod add_lifetime_to_type; + mod add_missing_impl_members; + mod add_turbo_fish; + mod apply_demorgan; + mod auto_import; + mod change_visibility; + mod convert_bool_then; + mod convert_comment_block; + mod convert_integer_literal; + mod convert_into_to_from; + mod convert_iter_for_each_to_for; + mod convert_let_else_to_match; + mod convert_tuple_struct_to_named_struct; + mod convert_to_guarded_return; + mod convert_while_to_loop; + mod destructure_tuple_binding; + mod expand_glob_import; + mod extract_function; + mod extract_module; + mod extract_struct_from_enum_variant; + mod extract_type_alias; + mod extract_variable; + mod add_missing_match_arms; + mod fix_visibility; + mod flip_binexpr; + mod flip_comma; + mod flip_trait_bound; + mod generate_constant; + mod generate_default_from_enum_variant; + mod generate_default_from_new; + mod generate_deref; + mod generate_derive; + mod generate_documentation_template; + mod generate_enum_is_method; + mod generate_enum_projection_method; + mod generate_enum_variant; + mod generate_from_impl_for_enum; + mod generate_function; + mod generate_getter; + mod generate_impl; + mod generate_is_empty_from_len; + mod generate_new; + mod generate_setter; + mod generate_delegate_methods; + mod add_return_type; + mod inline_call; + mod inline_local_variable; + mod inline_type_alias; + mod introduce_named_lifetime; + mod invert_if; + mod merge_imports; + mod merge_match_arms; + mod move_bounds; + mod move_guard; + mod move_module_to_file; + mod move_to_mod_rs; + mod move_from_mod_rs; + mod number_representation; + mod promote_local_to_const; + mod pull_assignment_up; + mod qualify_path; + mod qualify_method_call; + mod raw_string; + mod remove_dbg; + mod remove_mut; + mod remove_unused_param; + mod reorder_fields; + mod reorder_impl_items; + mod replace_try_expr_with_match; + mod replace_derive_with_manual_impl; + mod replace_if_let_with_match; + mod introduce_named_generic; + mod replace_let_with_if_let; + mod replace_qualified_name_with_use; + mod replace_string_with_char; + mod replace_turbofish_with_explicit_type; + mod split_import; + mod sort_items; + mod toggle_ignore; + mod unmerge_use; + mod unnecessary_async; + mod unwrap_block; + mod unwrap_result_return_type; + mod wrap_return_type_in_result; + + pub(crate) fn all() -> &'static [Handler] { + &[ + // These are alphabetic for the foolish consistency + add_explicit_type::add_explicit_type, + add_label_to_loop::add_label_to_loop, + add_missing_match_arms::add_missing_match_arms, + add_lifetime_to_type::add_lifetime_to_type, + add_return_type::add_return_type, + add_turbo_fish::add_turbo_fish, + apply_demorgan::apply_demorgan, + auto_import::auto_import, + change_visibility::change_visibility, + convert_bool_then::convert_bool_then_to_if, + convert_bool_then::convert_if_to_bool_then, + convert_comment_block::convert_comment_block, + convert_integer_literal::convert_integer_literal, + convert_into_to_from::convert_into_to_from, + convert_iter_for_each_to_for::convert_iter_for_each_to_for, + convert_iter_for_each_to_for::convert_for_loop_with_for_each, + convert_let_else_to_match::convert_let_else_to_match, + convert_to_guarded_return::convert_to_guarded_return, + convert_tuple_struct_to_named_struct::convert_tuple_struct_to_named_struct, + convert_while_to_loop::convert_while_to_loop, + destructure_tuple_binding::destructure_tuple_binding, + expand_glob_import::expand_glob_import, + extract_struct_from_enum_variant::extract_struct_from_enum_variant, + extract_type_alias::extract_type_alias, + fix_visibility::fix_visibility, + flip_binexpr::flip_binexpr, + flip_comma::flip_comma, + flip_trait_bound::flip_trait_bound, + generate_constant::generate_constant, + generate_default_from_enum_variant::generate_default_from_enum_variant, + generate_default_from_new::generate_default_from_new, + generate_derive::generate_derive, + generate_documentation_template::generate_documentation_template, + generate_documentation_template::generate_doc_example, + generate_enum_is_method::generate_enum_is_method, + generate_enum_projection_method::generate_enum_as_method, + generate_enum_projection_method::generate_enum_try_into_method, + generate_enum_variant::generate_enum_variant, + generate_from_impl_for_enum::generate_from_impl_for_enum, + generate_function::generate_function, + generate_impl::generate_impl, + generate_is_empty_from_len::generate_is_empty_from_len, + generate_new::generate_new, + inline_call::inline_call, + inline_call::inline_into_callers, + inline_local_variable::inline_local_variable, + inline_type_alias::inline_type_alias, + introduce_named_generic::introduce_named_generic, + introduce_named_lifetime::introduce_named_lifetime, + invert_if::invert_if, + merge_imports::merge_imports, + merge_match_arms::merge_match_arms, + move_bounds::move_bounds_to_where_clause, + move_guard::move_arm_cond_to_match_guard, + move_guard::move_guard_to_arm_body, + move_module_to_file::move_module_to_file, + move_to_mod_rs::move_to_mod_rs, + move_from_mod_rs::move_from_mod_rs, + number_representation::reformat_number_literal, + pull_assignment_up::pull_assignment_up, + promote_local_to_const::promote_local_to_const, + qualify_path::qualify_path, + qualify_method_call::qualify_method_call, + raw_string::add_hash, + raw_string::make_usual_string, + raw_string::remove_hash, + remove_dbg::remove_dbg, + remove_mut::remove_mut, + remove_unused_param::remove_unused_param, + reorder_fields::reorder_fields, + reorder_impl_items::reorder_impl_items, + replace_try_expr_with_match::replace_try_expr_with_match, + replace_derive_with_manual_impl::replace_derive_with_manual_impl, + replace_if_let_with_match::replace_if_let_with_match, + replace_if_let_with_match::replace_match_with_if_let, + replace_let_with_if_let::replace_let_with_if_let, + replace_turbofish_with_explicit_type::replace_turbofish_with_explicit_type, + replace_qualified_name_with_use::replace_qualified_name_with_use, + sort_items::sort_items, + split_import::split_import, + toggle_ignore::toggle_ignore, + unmerge_use::unmerge_use, + unnecessary_async::unnecessary_async, + unwrap_block::unwrap_block, + unwrap_result_return_type::unwrap_result_return_type, + wrap_return_type_in_result::wrap_return_type_in_result, + // These are manually sorted for better priorities. By default, + // priority is determined by the size of the target range (smaller + // target wins). If the ranges are equal, position in this list is + // used as a tie-breaker. + add_missing_impl_members::add_missing_impl_members, + add_missing_impl_members::add_missing_default_members, + // + replace_string_with_char::replace_string_with_char, + replace_string_with_char::replace_char_with_string, + raw_string::make_raw_string, + // + extract_variable::extract_variable, + extract_function::extract_function, + extract_module::extract_module, + // + generate_getter::generate_getter, + generate_getter::generate_getter_mut, + generate_setter::generate_setter, + generate_delegate_methods::generate_delegate_methods, + generate_deref::generate_deref, + // Are you sure you want to add new assist here, and not to the + // sorted list above? + ] + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/tests.rs b/src/tools/rust-analyzer/crates/ide-assists/src/tests.rs new file mode 100644 index 000000000..9cd66c6b3 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/tests.rs @@ -0,0 +1,558 @@ +mod generated; +#[cfg(not(feature = "in-rust-tree"))] +mod sourcegen; + +use expect_test::expect; +use hir::{db::DefDatabase, Semantics}; +use ide_db::{ + base_db::{fixture::WithFixture, FileId, FileRange, SourceDatabaseExt}, + imports::insert_use::{ImportGranularity, InsertUseConfig}, + source_change::FileSystemEdit, + RootDatabase, SnippetCap, +}; +use stdx::{format_to, trim_indent}; +use syntax::TextRange; +use test_utils::{assert_eq_text, extract_offset}; + +use crate::{ + assists, handlers::Handler, Assist, AssistConfig, AssistContext, AssistKind, + AssistResolveStrategy, Assists, SingleResolve, +}; + +pub(crate) const TEST_CONFIG: AssistConfig = AssistConfig { + snippet_cap: SnippetCap::new(true), + allowed: None, + insert_use: InsertUseConfig { + granularity: ImportGranularity::Crate, + prefix_kind: hir::PrefixKind::Plain, + enforce_granularity: true, + group: true, + skip_glob_imports: true, + }, +}; + +pub(crate) fn with_single_file(text: &str) -> (RootDatabase, FileId) { + RootDatabase::with_single_file(text) +} + +#[track_caller] +pub(crate) fn check_assist(assist: Handler, ra_fixture_before: &str, ra_fixture_after: &str) { + let ra_fixture_after = trim_indent(ra_fixture_after); + check(assist, ra_fixture_before, ExpectedResult::After(&ra_fixture_after), None); +} + +// There is no way to choose what assist within a group you want to test against, +// so this is here to allow you choose. +pub(crate) fn check_assist_by_label( + assist: Handler, + ra_fixture_before: &str, + ra_fixture_after: &str, + label: &str, +) { + let ra_fixture_after = trim_indent(ra_fixture_after); + check(assist, ra_fixture_before, ExpectedResult::After(&ra_fixture_after), Some(label)); +} + +// FIXME: instead of having a separate function here, maybe use +// `extract_ranges` and mark the target as ` ` in the +// fixture? +#[track_caller] +pub(crate) fn check_assist_target(assist: Handler, ra_fixture: &str, target: &str) { + check(assist, ra_fixture, ExpectedResult::Target(target), None); +} + +#[track_caller] +pub(crate) fn check_assist_not_applicable(assist: Handler, ra_fixture: &str) { + check(assist, ra_fixture, ExpectedResult::NotApplicable, None); +} + +/// Check assist in unresolved state. Useful to check assists for lazy computation. +#[track_caller] +pub(crate) fn check_assist_unresolved(assist: Handler, ra_fixture: &str) { + check(assist, ra_fixture, ExpectedResult::Unresolved, None); +} + +#[track_caller] +fn check_doc_test(assist_id: &str, before: &str, after: &str) { + let after = trim_indent(after); + let (db, file_id, selection) = RootDatabase::with_range_or_offset(before); + let before = db.file_text(file_id).to_string(); + let frange = FileRange { file_id, range: selection.into() }; + + let assist = assists(&db, &TEST_CONFIG, AssistResolveStrategy::All, frange) + .into_iter() + .find(|assist| assist.id.0 == assist_id) + .unwrap_or_else(|| { + panic!( + "\n\nAssist is not applicable: {}\nAvailable assists: {}", + assist_id, + assists(&db, &TEST_CONFIG, AssistResolveStrategy::None, frange) + .into_iter() + .map(|assist| assist.id.0) + .collect::>() + .join(", ") + ) + }); + + let actual = { + let source_change = + assist.source_change.expect("Assist did not contain any source changes"); + let mut actual = before; + if let Some(source_file_edit) = source_change.get_source_edit(file_id) { + source_file_edit.apply(&mut actual); + } + actual + }; + assert_eq_text!(&after, &actual); +} + +enum ExpectedResult<'a> { + NotApplicable, + Unresolved, + After(&'a str), + Target(&'a str), +} + +#[track_caller] +fn check(handler: Handler, before: &str, expected: ExpectedResult<'_>, assist_label: Option<&str>) { + let (mut db, file_with_caret_id, range_or_offset) = RootDatabase::with_range_or_offset(before); + db.set_enable_proc_attr_macros(true); + let text_without_caret = db.file_text(file_with_caret_id).to_string(); + + let frange = FileRange { file_id: file_with_caret_id, range: range_or_offset.into() }; + + let sema = Semantics::new(&db); + let config = TEST_CONFIG; + let ctx = AssistContext::new(sema, &config, frange); + let resolve = match expected { + ExpectedResult::Unresolved => AssistResolveStrategy::None, + _ => AssistResolveStrategy::All, + }; + let mut acc = Assists::new(&ctx, resolve); + handler(&mut acc, &ctx); + let mut res = acc.finish(); + + let assist = match assist_label { + Some(label) => res.into_iter().find(|resolved| resolved.label == label), + None => res.pop(), + }; + + match (assist, expected) { + (Some(assist), ExpectedResult::After(after)) => { + let source_change = + assist.source_change.expect("Assist did not contain any source changes"); + let skip_header = source_change.source_file_edits.len() == 1 + && source_change.file_system_edits.len() == 0; + + let mut buf = String::new(); + for (file_id, edit) in source_change.source_file_edits { + let mut text = db.file_text(file_id).as_ref().to_owned(); + edit.apply(&mut text); + if !skip_header { + let sr = db.file_source_root(file_id); + let sr = db.source_root(sr); + let path = sr.path_for_file(&file_id).unwrap(); + format_to!(buf, "//- {}\n", path) + } + buf.push_str(&text); + } + + for file_system_edit in source_change.file_system_edits { + let (dst, contents) = match file_system_edit { + FileSystemEdit::CreateFile { dst, initial_contents } => (dst, initial_contents), + FileSystemEdit::MoveFile { src, dst } => { + (dst, db.file_text(src).as_ref().to_owned()) + } + FileSystemEdit::MoveDir { src, src_id, dst } => { + // temporary placeholder for MoveDir since we are not using MoveDir in ide assists yet. + (dst, format!("{:?}\n{:?}", src_id, src)) + } + }; + let sr = db.file_source_root(dst.anchor); + let sr = db.source_root(sr); + let mut base = sr.path_for_file(&dst.anchor).unwrap().clone(); + base.pop(); + let created_file_path = base.join(&dst.path).unwrap(); + format_to!(buf, "//- {}\n", created_file_path); + buf.push_str(&contents); + } + + assert_eq_text!(after, &buf); + } + (Some(assist), ExpectedResult::Target(target)) => { + let range = assist.target; + assert_eq_text!(&text_without_caret[range], target); + } + (Some(assist), ExpectedResult::Unresolved) => assert!( + assist.source_change.is_none(), + "unresolved assist should not contain source changes" + ), + (Some(_), ExpectedResult::NotApplicable) => panic!("assist should not be applicable!"), + ( + None, + ExpectedResult::After(_) | ExpectedResult::Target(_) | ExpectedResult::Unresolved, + ) => { + panic!("code action is not applicable") + } + (None, ExpectedResult::NotApplicable) => (), + }; +} + +fn labels(assists: &[Assist]) -> String { + let mut labels = assists + .iter() + .map(|assist| { + let mut label = match &assist.group { + Some(g) => g.0.clone(), + None => assist.label.to_string(), + }; + label.push('\n'); + label + }) + .collect::>(); + labels.dedup(); + labels.into_iter().collect::() +} + +#[test] +fn assist_order_field_struct() { + let before = "struct Foo { $0bar: u32 }"; + let (before_cursor_pos, before) = extract_offset(before); + let (db, file_id) = with_single_file(&before); + let frange = FileRange { file_id, range: TextRange::empty(before_cursor_pos) }; + let assists = assists(&db, &TEST_CONFIG, AssistResolveStrategy::None, frange); + let mut assists = assists.iter(); + + assert_eq!(assists.next().expect("expected assist").label, "Change visibility to pub(crate)"); + assert_eq!(assists.next().expect("expected assist").label, "Generate a getter method"); + assert_eq!(assists.next().expect("expected assist").label, "Generate a mut getter method"); + assert_eq!(assists.next().expect("expected assist").label, "Generate a setter method"); + assert_eq!(assists.next().expect("expected assist").label, "Add `#[derive]`"); +} + +#[test] +fn assist_order_if_expr() { + let (db, frange) = RootDatabase::with_range( + r#" +pub fn test_some_range(a: int) -> bool { + if let 2..6 = $05$0 { + true + } else { + false + } +} +"#, + ); + + let assists = assists(&db, &TEST_CONFIG, AssistResolveStrategy::None, frange); + let expected = labels(&assists); + + expect![[r#" + Convert integer base + Extract into variable + Extract into function + Replace if let with match + "#]] + .assert_eq(&expected); +} + +#[test] +fn assist_filter_works() { + let (db, frange) = RootDatabase::with_range( + r#" +pub fn test_some_range(a: int) -> bool { + if let 2..6 = $05$0 { + true + } else { + false + } +} +"#, + ); + { + let mut cfg = TEST_CONFIG; + cfg.allowed = Some(vec![AssistKind::Refactor]); + + let assists = assists(&db, &cfg, AssistResolveStrategy::None, frange); + let expected = labels(&assists); + + expect![[r#" + Convert integer base + Extract into variable + Extract into function + Replace if let with match + "#]] + .assert_eq(&expected); + } + + { + let mut cfg = TEST_CONFIG; + cfg.allowed = Some(vec![AssistKind::RefactorExtract]); + let assists = assists(&db, &cfg, AssistResolveStrategy::None, frange); + let expected = labels(&assists); + + expect![[r#" + Extract into variable + Extract into function + "#]] + .assert_eq(&expected); + } + + { + let mut cfg = TEST_CONFIG; + cfg.allowed = Some(vec![AssistKind::QuickFix]); + let assists = assists(&db, &cfg, AssistResolveStrategy::None, frange); + let expected = labels(&assists); + + expect![[r#""#]].assert_eq(&expected); + } +} + +#[test] +fn various_resolve_strategies() { + let (db, frange) = RootDatabase::with_range( + r#" +pub fn test_some_range(a: int) -> bool { + if let 2..6 = $05$0 { + true + } else { + false + } +} +"#, + ); + + let mut cfg = TEST_CONFIG; + cfg.allowed = Some(vec![AssistKind::RefactorExtract]); + + { + let assists = assists(&db, &cfg, AssistResolveStrategy::None, frange); + assert_eq!(2, assists.len()); + let mut assists = assists.into_iter(); + + let extract_into_variable_assist = assists.next().unwrap(); + expect![[r#" + Assist { + id: AssistId( + "extract_variable", + RefactorExtract, + ), + label: "Extract into variable", + group: None, + target: 59..60, + source_change: None, + trigger_signature_help: false, + } + "#]] + .assert_debug_eq(&extract_into_variable_assist); + + let extract_into_function_assist = assists.next().unwrap(); + expect![[r#" + Assist { + id: AssistId( + "extract_function", + RefactorExtract, + ), + label: "Extract into function", + group: None, + target: 59..60, + source_change: None, + trigger_signature_help: false, + } + "#]] + .assert_debug_eq(&extract_into_function_assist); + } + + { + let assists = assists( + &db, + &cfg, + AssistResolveStrategy::Single(SingleResolve { + assist_id: "SOMETHING_MISMATCHING".to_string(), + assist_kind: AssistKind::RefactorExtract, + }), + frange, + ); + assert_eq!(2, assists.len()); + let mut assists = assists.into_iter(); + + let extract_into_variable_assist = assists.next().unwrap(); + expect![[r#" + Assist { + id: AssistId( + "extract_variable", + RefactorExtract, + ), + label: "Extract into variable", + group: None, + target: 59..60, + source_change: None, + trigger_signature_help: false, + } + "#]] + .assert_debug_eq(&extract_into_variable_assist); + + let extract_into_function_assist = assists.next().unwrap(); + expect![[r#" + Assist { + id: AssistId( + "extract_function", + RefactorExtract, + ), + label: "Extract into function", + group: None, + target: 59..60, + source_change: None, + trigger_signature_help: false, + } + "#]] + .assert_debug_eq(&extract_into_function_assist); + } + + { + let assists = assists( + &db, + &cfg, + AssistResolveStrategy::Single(SingleResolve { + assist_id: "extract_variable".to_string(), + assist_kind: AssistKind::RefactorExtract, + }), + frange, + ); + assert_eq!(2, assists.len()); + let mut assists = assists.into_iter(); + + let extract_into_variable_assist = assists.next().unwrap(); + expect![[r#" + Assist { + id: AssistId( + "extract_variable", + RefactorExtract, + ), + label: "Extract into variable", + group: None, + target: 59..60, + source_change: Some( + SourceChange { + source_file_edits: { + FileId( + 0, + ): TextEdit { + indels: [ + Indel { + insert: "let $0var_name = 5;\n ", + delete: 45..45, + }, + Indel { + insert: "var_name", + delete: 59..60, + }, + ], + }, + }, + file_system_edits: [], + is_snippet: true, + }, + ), + trigger_signature_help: false, + } + "#]] + .assert_debug_eq(&extract_into_variable_assist); + + let extract_into_function_assist = assists.next().unwrap(); + expect![[r#" + Assist { + id: AssistId( + "extract_function", + RefactorExtract, + ), + label: "Extract into function", + group: None, + target: 59..60, + source_change: None, + trigger_signature_help: false, + } + "#]] + .assert_debug_eq(&extract_into_function_assist); + } + + { + let assists = assists(&db, &cfg, AssistResolveStrategy::All, frange); + assert_eq!(2, assists.len()); + let mut assists = assists.into_iter(); + + let extract_into_variable_assist = assists.next().unwrap(); + expect![[r#" + Assist { + id: AssistId( + "extract_variable", + RefactorExtract, + ), + label: "Extract into variable", + group: None, + target: 59..60, + source_change: Some( + SourceChange { + source_file_edits: { + FileId( + 0, + ): TextEdit { + indels: [ + Indel { + insert: "let $0var_name = 5;\n ", + delete: 45..45, + }, + Indel { + insert: "var_name", + delete: 59..60, + }, + ], + }, + }, + file_system_edits: [], + is_snippet: true, + }, + ), + trigger_signature_help: false, + } + "#]] + .assert_debug_eq(&extract_into_variable_assist); + + let extract_into_function_assist = assists.next().unwrap(); + expect![[r#" + Assist { + id: AssistId( + "extract_function", + RefactorExtract, + ), + label: "Extract into function", + group: None, + target: 59..60, + source_change: Some( + SourceChange { + source_file_edits: { + FileId( + 0, + ): TextEdit { + indels: [ + Indel { + insert: "fun_name()", + delete: 59..60, + }, + Indel { + insert: "\n\nfn $0fun_name() -> i32 {\n 5\n}", + delete: 110..110, + }, + ], + }, + }, + file_system_edits: [], + is_snippet: true, + }, + ), + trigger_signature_help: false, + } + "#]] + .assert_debug_eq(&extract_into_function_assist); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs b/src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs new file mode 100644 index 000000000..6eaab48a3 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs @@ -0,0 +1,2259 @@ +//! Generated by `sourcegen_assists_docs`, do not edit by hand. + +use super::check_doc_test; + +#[test] +fn doctest_add_explicit_type() { + check_doc_test( + "add_explicit_type", + r#####" +fn main() { + let x$0 = 92; +} +"#####, + r#####" +fn main() { + let x: i32 = 92; +} +"#####, + ) +} + +#[test] +fn doctest_add_hash() { + check_doc_test( + "add_hash", + r#####" +fn main() { + r#"Hello,$0 World!"#; +} +"#####, + r#####" +fn main() { + r##"Hello, World!"##; +} +"#####, + ) +} + +#[test] +fn doctest_add_impl_default_members() { + check_doc_test( + "add_impl_default_members", + r#####" +trait Trait { + type X; + fn foo(&self); + fn bar(&self) {} +} + +impl Trait for () { + type X = (); + fn foo(&self) {}$0 +} +"#####, + r#####" +trait Trait { + type X; + fn foo(&self); + fn bar(&self) {} +} + +impl Trait for () { + type X = (); + fn foo(&self) {} + + $0fn bar(&self) {} +} +"#####, + ) +} + +#[test] +fn doctest_add_impl_missing_members() { + check_doc_test( + "add_impl_missing_members", + r#####" +trait Trait { + type X; + fn foo(&self) -> T; + fn bar(&self) {} +} + +impl Trait for () {$0 + +} +"#####, + r#####" +trait Trait { + type X; + fn foo(&self) -> T; + fn bar(&self) {} +} + +impl Trait for () { + $0type X; + + fn foo(&self) -> u32 { + todo!() + } +} +"#####, + ) +} + +#[test] +fn doctest_add_label_to_loop() { + check_doc_test( + "add_label_to_loop", + r#####" +fn main() { + loop$0 { + break; + continue; + } +} +"#####, + r#####" +fn main() { + 'l: loop { + break 'l; + continue 'l; + } +} +"#####, + ) +} + +#[test] +fn doctest_add_lifetime_to_type() { + check_doc_test( + "add_lifetime_to_type", + r#####" +struct Point { + x: &$0u32, + y: u32, +} +"#####, + r#####" +struct Point<'a> { + x: &'a u32, + y: u32, +} +"#####, + ) +} + +#[test] +fn doctest_add_missing_match_arms() { + check_doc_test( + "add_missing_match_arms", + r#####" +enum Action { Move { distance: u32 }, Stop } + +fn handle(action: Action) { + match action { + $0 + } +} +"#####, + r#####" +enum Action { Move { distance: u32 }, Stop } + +fn handle(action: Action) { + match action { + $0Action::Move { distance } => todo!(), + Action::Stop => todo!(), + } +} +"#####, + ) +} + +#[test] +fn doctest_add_return_type() { + check_doc_test( + "add_return_type", + r#####" +fn foo() { 4$02i32 } +"#####, + r#####" +fn foo() -> i32 { 42i32 } +"#####, + ) +} + +#[test] +fn doctest_add_turbo_fish() { + check_doc_test( + "add_turbo_fish", + r#####" +fn make() -> T { todo!() } +fn main() { + let x = make$0(); +} +"#####, + r#####" +fn make() -> T { todo!() } +fn main() { + let x = make::<${0:_}>(); +} +"#####, + ) +} + +#[test] +fn doctest_apply_demorgan() { + check_doc_test( + "apply_demorgan", + r#####" +fn main() { + if x != 4 ||$0 y < 3.14 {} +} +"#####, + r#####" +fn main() { + if !(x == 4 && y >= 3.14) {} +} +"#####, + ) +} + +#[test] +fn doctest_auto_import() { + check_doc_test( + "auto_import", + r#####" +fn main() { + let map = HashMap$0::new(); +} +pub mod std { pub mod collections { pub struct HashMap { } } } +"#####, + r#####" +use std::collections::HashMap; + +fn main() { + let map = HashMap::new(); +} +pub mod std { pub mod collections { pub struct HashMap { } } } +"#####, + ) +} + +#[test] +fn doctest_change_visibility() { + check_doc_test( + "change_visibility", + r#####" +$0fn frobnicate() {} +"#####, + r#####" +pub(crate) fn frobnicate() {} +"#####, + ) +} + +#[test] +fn doctest_convert_bool_then_to_if() { + check_doc_test( + "convert_bool_then_to_if", + r#####" +//- minicore: bool_impl +fn main() { + (0 == 0).then$0(|| val) +} +"#####, + r#####" +fn main() { + if 0 == 0 { + Some(val) + } else { + None + } +} +"#####, + ) +} + +#[test] +fn doctest_convert_for_loop_with_for_each() { + check_doc_test( + "convert_for_loop_with_for_each", + r#####" +fn main() { + let x = vec![1, 2, 3]; + for$0 v in x { + let y = v * 2; + } +} +"#####, + r#####" +fn main() { + let x = vec![1, 2, 3]; + x.into_iter().for_each(|v| { + let y = v * 2; + }); +} +"#####, + ) +} + +#[test] +fn doctest_convert_if_to_bool_then() { + check_doc_test( + "convert_if_to_bool_then", + r#####" +//- minicore: option +fn main() { + if$0 cond { + Some(val) + } else { + None + } +} +"#####, + r#####" +fn main() { + cond.then(|| val) +} +"#####, + ) +} + +#[test] +fn doctest_convert_integer_literal() { + check_doc_test( + "convert_integer_literal", + r#####" +const _: i32 = 10$0; +"#####, + r#####" +const _: i32 = 0b1010; +"#####, + ) +} + +#[test] +fn doctest_convert_into_to_from() { + check_doc_test( + "convert_into_to_from", + r#####" +//- minicore: from +impl $0Into for usize { + fn into(self) -> Thing { + Thing { + b: self.to_string(), + a: self + } + } +} +"#####, + r#####" +impl From for Thing { + fn from(val: usize) -> Self { + Thing { + b: val.to_string(), + a: val + } + } +} +"#####, + ) +} + +#[test] +fn doctest_convert_iter_for_each_to_for() { + check_doc_test( + "convert_iter_for_each_to_for", + r#####" +//- minicore: iterators +use core::iter; +fn main() { + let iter = iter::repeat((9, 2)); + iter.for_each$0(|(x, y)| { + println!("x: {}, y: {}", x, y); + }); +} +"#####, + r#####" +use core::iter; +fn main() { + let iter = iter::repeat((9, 2)); + for (x, y) in iter { + println!("x: {}, y: {}", x, y); + } +} +"#####, + ) +} + +#[test] +fn doctest_convert_let_else_to_match() { + check_doc_test( + "convert_let_else_to_match", + r#####" +fn main() { + let Ok(mut x) = f() else$0 { return }; +} +"#####, + r#####" +fn main() { + let mut x = match f() { + Ok(x) => x, + _ => return, + }; +} +"#####, + ) +} + +#[test] +fn doctest_convert_to_guarded_return() { + check_doc_test( + "convert_to_guarded_return", + r#####" +fn main() { + $0if cond { + foo(); + bar(); + } +} +"#####, + r#####" +fn main() { + if !cond { + return; + } + foo(); + bar(); +} +"#####, + ) +} + +#[test] +fn doctest_convert_tuple_struct_to_named_struct() { + check_doc_test( + "convert_tuple_struct_to_named_struct", + r#####" +struct Point$0(f32, f32); + +impl Point { + pub fn new(x: f32, y: f32) -> Self { + Point(x, y) + } + + pub fn x(&self) -> f32 { + self.0 + } + + pub fn y(&self) -> f32 { + self.1 + } +} +"#####, + r#####" +struct Point { field1: f32, field2: f32 } + +impl Point { + pub fn new(x: f32, y: f32) -> Self { + Point { field1: x, field2: y } + } + + pub fn x(&self) -> f32 { + self.field1 + } + + pub fn y(&self) -> f32 { + self.field2 + } +} +"#####, + ) +} + +#[test] +fn doctest_convert_while_to_loop() { + check_doc_test( + "convert_while_to_loop", + r#####" +fn main() { + $0while cond { + foo(); + } +} +"#####, + r#####" +fn main() { + loop { + if !cond { + break; + } + foo(); + } +} +"#####, + ) +} + +#[test] +fn doctest_destructure_tuple_binding() { + check_doc_test( + "destructure_tuple_binding", + r#####" +fn main() { + let $0t = (1,2); + let v = t.0; +} +"#####, + r#####" +fn main() { + let ($0_0, _1) = (1,2); + let v = _0; +} +"#####, + ) +} + +#[test] +fn doctest_expand_glob_import() { + check_doc_test( + "expand_glob_import", + r#####" +mod foo { + pub struct Bar; + pub struct Baz; +} + +use foo::*$0; + +fn qux(bar: Bar, baz: Baz) {} +"#####, + r#####" +mod foo { + pub struct Bar; + pub struct Baz; +} + +use foo::{Bar, Baz}; + +fn qux(bar: Bar, baz: Baz) {} +"#####, + ) +} + +#[test] +fn doctest_extract_function() { + check_doc_test( + "extract_function", + r#####" +fn main() { + let n = 1; + $0let m = n + 2; + // calculate + let k = m + n;$0 + let g = 3; +} +"#####, + r#####" +fn main() { + let n = 1; + fun_name(n); + let g = 3; +} + +fn $0fun_name(n: i32) { + let m = n + 2; + // calculate + let k = m + n; +} +"#####, + ) +} + +#[test] +fn doctest_extract_module() { + check_doc_test( + "extract_module", + r#####" +$0fn foo(name: i32) -> i32 { + name + 1 +}$0 + +fn bar(name: i32) -> i32 { + name + 2 +} +"#####, + r#####" +mod modname { + pub(crate) fn foo(name: i32) -> i32 { + name + 1 + } +} + +fn bar(name: i32) -> i32 { + name + 2 +} +"#####, + ) +} + +#[test] +fn doctest_extract_struct_from_enum_variant() { + check_doc_test( + "extract_struct_from_enum_variant", + r#####" +enum A { $0One(u32, u32) } +"#####, + r#####" +struct One(u32, u32); + +enum A { One(One) } +"#####, + ) +} + +#[test] +fn doctest_extract_type_alias() { + check_doc_test( + "extract_type_alias", + r#####" +struct S { + field: $0(u8, u8, u8)$0, +} +"#####, + r#####" +type $0Type = (u8, u8, u8); + +struct S { + field: Type, +} +"#####, + ) +} + +#[test] +fn doctest_extract_variable() { + check_doc_test( + "extract_variable", + r#####" +fn main() { + $0(1 + 2)$0 * 4; +} +"#####, + r#####" +fn main() { + let $0var_name = (1 + 2); + var_name * 4; +} +"#####, + ) +} + +#[test] +fn doctest_fix_visibility() { + check_doc_test( + "fix_visibility", + r#####" +mod m { + fn frobnicate() {} +} +fn main() { + m::frobnicate$0() {} +} +"#####, + r#####" +mod m { + $0pub(crate) fn frobnicate() {} +} +fn main() { + m::frobnicate() {} +} +"#####, + ) +} + +#[test] +fn doctest_flip_binexpr() { + check_doc_test( + "flip_binexpr", + r#####" +fn main() { + let _ = 90 +$0 2; +} +"#####, + r#####" +fn main() { + let _ = 2 + 90; +} +"#####, + ) +} + +#[test] +fn doctest_flip_comma() { + check_doc_test( + "flip_comma", + r#####" +fn main() { + ((1, 2),$0 (3, 4)); +} +"#####, + r#####" +fn main() { + ((3, 4), (1, 2)); +} +"#####, + ) +} + +#[test] +fn doctest_flip_trait_bound() { + check_doc_test( + "flip_trait_bound", + r#####" +fn foo() { } +"#####, + r#####" +fn foo() { } +"#####, + ) +} + +#[test] +fn doctest_generate_constant() { + check_doc_test( + "generate_constant", + r#####" +struct S { i: usize } +impl S { pub fn new(n: usize) {} } +fn main() { + let v = S::new(CAPA$0CITY); +} +"#####, + r#####" +struct S { i: usize } +impl S { pub fn new(n: usize) {} } +fn main() { + const CAPACITY: usize = $0; + let v = S::new(CAPACITY); +} +"#####, + ) +} + +#[test] +fn doctest_generate_default_from_enum_variant() { + check_doc_test( + "generate_default_from_enum_variant", + r#####" +enum Version { + Undefined, + Minor$0, + Major, +} +"#####, + r#####" +enum Version { + Undefined, + Minor, + Major, +} + +impl Default for Version { + fn default() -> Self { + Self::Minor + } +} +"#####, + ) +} + +#[test] +fn doctest_generate_default_from_new() { + check_doc_test( + "generate_default_from_new", + r#####" +struct Example { _inner: () } + +impl Example { + pub fn n$0ew() -> Self { + Self { _inner: () } + } +} +"#####, + r#####" +struct Example { _inner: () } + +impl Example { + pub fn new() -> Self { + Self { _inner: () } + } +} + +impl Default for Example { + fn default() -> Self { + Self::new() + } +} +"#####, + ) +} + +#[test] +fn doctest_generate_delegate_methods() { + check_doc_test( + "generate_delegate_methods", + r#####" +struct Age(u8); +impl Age { + fn age(&self) -> u8 { + self.0 + } +} + +struct Person { + ag$0e: Age, +} +"#####, + r#####" +struct Age(u8); +impl Age { + fn age(&self) -> u8 { + self.0 + } +} + +struct Person { + age: Age, +} + +impl Person { + $0fn age(&self) -> u8 { + self.age.age() + } +} +"#####, + ) +} + +#[test] +fn doctest_generate_deref() { + check_doc_test( + "generate_deref", + r#####" +//- minicore: deref, deref_mut +struct A; +struct B { + $0a: A +} +"#####, + r#####" +struct A; +struct B { + a: A +} + +impl core::ops::Deref for B { + type Target = A; + + fn deref(&self) -> &Self::Target { + &self.a + } +} +"#####, + ) +} + +#[test] +fn doctest_generate_derive() { + check_doc_test( + "generate_derive", + r#####" +struct Point { + x: u32, + y: u32,$0 +} +"#####, + r#####" +#[derive($0)] +struct Point { + x: u32, + y: u32, +} +"#####, + ) +} + +#[test] +fn doctest_generate_doc_example() { + check_doc_test( + "generate_doc_example", + r#####" +/// Adds two numbers.$0 +pub fn add(a: i32, b: i32) -> i32 { a + b } +"#####, + r#####" +/// Adds two numbers. +/// +/// # Examples +/// +/// ``` +/// use test::add; +/// +/// assert_eq!(add(a, b), ); +/// ``` +pub fn add(a: i32, b: i32) -> i32 { a + b } +"#####, + ) +} + +#[test] +fn doctest_generate_documentation_template() { + check_doc_test( + "generate_documentation_template", + r#####" +pub struct S; +impl S { + pub unsafe fn set_len$0(&mut self, len: usize) -> Result<(), std::io::Error> { + /* ... */ + } +} +"#####, + r#####" +pub struct S; +impl S { + /// Sets the length of this [`S`]. + /// + /// # Errors + /// + /// This function will return an error if . + /// + /// # Safety + /// + /// . + pub unsafe fn set_len(&mut self, len: usize) -> Result<(), std::io::Error> { + /* ... */ + } +} +"#####, + ) +} + +#[test] +fn doctest_generate_enum_as_method() { + check_doc_test( + "generate_enum_as_method", + r#####" +enum Value { + Number(i32), + Text(String)$0, +} +"#####, + r#####" +enum Value { + Number(i32), + Text(String), +} + +impl Value { + fn as_text(&self) -> Option<&String> { + if let Self::Text(v) = self { + Some(v) + } else { + None + } + } +} +"#####, + ) +} + +#[test] +fn doctest_generate_enum_is_method() { + check_doc_test( + "generate_enum_is_method", + r#####" +enum Version { + Undefined, + Minor$0, + Major, +} +"#####, + r#####" +enum Version { + Undefined, + Minor, + Major, +} + +impl Version { + /// Returns `true` if the version is [`Minor`]. + /// + /// [`Minor`]: Version::Minor + #[must_use] + fn is_minor(&self) -> bool { + matches!(self, Self::Minor) + } +} +"#####, + ) +} + +#[test] +fn doctest_generate_enum_try_into_method() { + check_doc_test( + "generate_enum_try_into_method", + r#####" +enum Value { + Number(i32), + Text(String)$0, +} +"#####, + r#####" +enum Value { + Number(i32), + Text(String), +} + +impl Value { + fn try_into_text(self) -> Result { + if let Self::Text(v) = self { + Ok(v) + } else { + Err(self) + } + } +} +"#####, + ) +} + +#[test] +fn doctest_generate_enum_variant() { + check_doc_test( + "generate_enum_variant", + r#####" +enum Countries { + Ghana, +} + +fn main() { + let country = Countries::Lesotho$0; +} +"#####, + r#####" +enum Countries { + Ghana, + Lesotho, +} + +fn main() { + let country = Countries::Lesotho; +} +"#####, + ) +} + +#[test] +fn doctest_generate_from_impl_for_enum() { + check_doc_test( + "generate_from_impl_for_enum", + r#####" +enum A { $0One(u32) } +"#####, + r#####" +enum A { One(u32) } + +impl From for A { + fn from(v: u32) -> Self { + Self::One(v) + } +} +"#####, + ) +} + +#[test] +fn doctest_generate_function() { + check_doc_test( + "generate_function", + r#####" +struct Baz; +fn baz() -> Baz { Baz } +fn foo() { + bar$0("", baz()); +} + +"#####, + r#####" +struct Baz; +fn baz() -> Baz { Baz } +fn foo() { + bar("", baz()); +} + +fn bar(arg: &str, baz: Baz) ${0:-> _} { + todo!() +} + +"#####, + ) +} + +#[test] +fn doctest_generate_getter() { + check_doc_test( + "generate_getter", + r#####" +//- minicore: as_ref +pub struct String; +impl AsRef for String { + fn as_ref(&self) -> &str { + "" + } +} + +struct Person { + nam$0e: String, +} +"#####, + r#####" +pub struct String; +impl AsRef for String { + fn as_ref(&self) -> &str { + "" + } +} + +struct Person { + name: String, +} + +impl Person { + fn $0name(&self) -> &str { + self.name.as_ref() + } +} +"#####, + ) +} + +#[test] +fn doctest_generate_getter_mut() { + check_doc_test( + "generate_getter_mut", + r#####" +struct Person { + nam$0e: String, +} +"#####, + r#####" +struct Person { + name: String, +} + +impl Person { + fn $0name_mut(&mut self) -> &mut String { + &mut self.name + } +} +"#####, + ) +} + +#[test] +fn doctest_generate_impl() { + check_doc_test( + "generate_impl", + r#####" +struct Ctx { + data: T,$0 +} +"#####, + r#####" +struct Ctx { + data: T, +} + +impl Ctx { + $0 +} +"#####, + ) +} + +#[test] +fn doctest_generate_is_empty_from_len() { + check_doc_test( + "generate_is_empty_from_len", + r#####" +struct MyStruct { data: Vec } + +impl MyStruct { + #[must_use] + p$0ub fn len(&self) -> usize { + self.data.len() + } +} +"#####, + r#####" +struct MyStruct { data: Vec } + +impl MyStruct { + #[must_use] + pub fn len(&self) -> usize { + self.data.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} +"#####, + ) +} + +#[test] +fn doctest_generate_new() { + check_doc_test( + "generate_new", + r#####" +struct Ctx { + data: T,$0 +} +"#####, + r#####" +struct Ctx { + data: T, +} + +impl Ctx { + fn $0new(data: T) -> Self { Self { data } } +} +"#####, + ) +} + +#[test] +fn doctest_generate_setter() { + check_doc_test( + "generate_setter", + r#####" +struct Person { + nam$0e: String, +} +"#####, + r#####" +struct Person { + name: String, +} + +impl Person { + fn set_name(&mut self, name: String) { + self.name = name; + } +} +"#####, + ) +} + +#[test] +fn doctest_inline_call() { + check_doc_test( + "inline_call", + r#####" +//- minicore: option +fn foo(name: Option<&str>) { + let name = name.unwrap$0(); +} +"#####, + r#####" +fn foo(name: Option<&str>) { + let name = match name { + Some(val) => val, + None => panic!("called `Option::unwrap()` on a `None` value"), + }; +} +"#####, + ) +} + +#[test] +fn doctest_inline_into_callers() { + check_doc_test( + "inline_into_callers", + r#####" +fn print(_: &str) {} +fn foo$0(word: &str) { + if !word.is_empty() { + print(word); + } +} +fn bar() { + foo("안녕하세요"); + foo("여러분"); +} +"#####, + r#####" +fn print(_: &str) {} + +fn bar() { + { + let word = "안녕하세요"; + if !word.is_empty() { + print(word); + } + }; + { + let word = "여러분"; + if !word.is_empty() { + print(word); + } + }; +} +"#####, + ) +} + +#[test] +fn doctest_inline_local_variable() { + check_doc_test( + "inline_local_variable", + r#####" +fn main() { + let x$0 = 1 + 2; + x * 4; +} +"#####, + r#####" +fn main() { + (1 + 2) * 4; +} +"#####, + ) +} + +#[test] +fn doctest_inline_type_alias() { + check_doc_test( + "inline_type_alias", + r#####" +type A = Vec; + +fn main() { + let a: $0A; +} +"#####, + r#####" +type A = Vec; + +fn main() { + let a: Vec; +} +"#####, + ) +} + +#[test] +fn doctest_introduce_named_generic() { + check_doc_test( + "introduce_named_generic", + r#####" +fn foo(bar: $0impl Bar) {} +"#####, + r#####" +fn foo(bar: B) {} +"#####, + ) +} + +#[test] +fn doctest_introduce_named_lifetime() { + check_doc_test( + "introduce_named_lifetime", + r#####" +impl Cursor<'_$0> { + fn node(self) -> &SyntaxNode { + match self { + Cursor::Replace(node) | Cursor::Before(node) => node, + } + } +} +"#####, + r#####" +impl<'a> Cursor<'a> { + fn node(self) -> &SyntaxNode { + match self { + Cursor::Replace(node) | Cursor::Before(node) => node, + } + } +} +"#####, + ) +} + +#[test] +fn doctest_invert_if() { + check_doc_test( + "invert_if", + r#####" +fn main() { + if$0 !y { A } else { B } +} +"#####, + r#####" +fn main() { + if y { B } else { A } +} +"#####, + ) +} + +#[test] +fn doctest_line_to_block() { + check_doc_test( + "line_to_block", + r#####" + // Multi-line$0 + // comment +"#####, + r#####" + /* + Multi-line + comment + */ +"#####, + ) +} + +#[test] +fn doctest_make_raw_string() { + check_doc_test( + "make_raw_string", + r#####" +fn main() { + "Hello,$0 World!"; +} +"#####, + r#####" +fn main() { + r#"Hello, World!"#; +} +"#####, + ) +} + +#[test] +fn doctest_make_usual_string() { + check_doc_test( + "make_usual_string", + r#####" +fn main() { + r#"Hello,$0 "World!""#; +} +"#####, + r#####" +fn main() { + "Hello, \"World!\""; +} +"#####, + ) +} + +#[test] +fn doctest_merge_imports() { + check_doc_test( + "merge_imports", + r#####" +use std::$0fmt::Formatter; +use std::io; +"#####, + r#####" +use std::{fmt::Formatter, io}; +"#####, + ) +} + +#[test] +fn doctest_merge_match_arms() { + check_doc_test( + "merge_match_arms", + r#####" +enum Action { Move { distance: u32 }, Stop } + +fn handle(action: Action) { + match action { + $0Action::Move(..) => foo(), + Action::Stop => foo(), + } +} +"#####, + r#####" +enum Action { Move { distance: u32 }, Stop } + +fn handle(action: Action) { + match action { + Action::Move(..) | Action::Stop => foo(), + } +} +"#####, + ) +} + +#[test] +fn doctest_move_arm_cond_to_match_guard() { + check_doc_test( + "move_arm_cond_to_match_guard", + r#####" +enum Action { Move { distance: u32 }, Stop } + +fn handle(action: Action) { + match action { + Action::Move { distance } => $0if distance > 10 { foo() }, + _ => (), + } +} +"#####, + r#####" +enum Action { Move { distance: u32 }, Stop } + +fn handle(action: Action) { + match action { + Action::Move { distance } if distance > 10 => foo(), + _ => (), + } +} +"#####, + ) +} + +#[test] +fn doctest_move_bounds_to_where_clause() { + check_doc_test( + "move_bounds_to_where_clause", + r#####" +fn apply U>(f: F, x: T) -> U { + f(x) +} +"#####, + r#####" +fn apply(f: F, x: T) -> U where F: FnOnce(T) -> U { + f(x) +} +"#####, + ) +} + +#[test] +fn doctest_move_from_mod_rs() { + check_doc_test( + "move_from_mod_rs", + r#####" +//- /main.rs +mod a; +//- /a/mod.rs +$0fn t() {}$0 +"#####, + r#####" +fn t() {} +"#####, + ) +} + +#[test] +fn doctest_move_guard_to_arm_body() { + check_doc_test( + "move_guard_to_arm_body", + r#####" +enum Action { Move { distance: u32 }, Stop } + +fn handle(action: Action) { + match action { + Action::Move { distance } $0if distance > 10 => foo(), + _ => (), + } +} +"#####, + r#####" +enum Action { Move { distance: u32 }, Stop } + +fn handle(action: Action) { + match action { + Action::Move { distance } => if distance > 10 { + foo() + }, + _ => (), + } +} +"#####, + ) +} + +#[test] +fn doctest_move_module_to_file() { + check_doc_test( + "move_module_to_file", + r#####" +mod $0foo { + fn t() {} +} +"#####, + r#####" +mod foo; +"#####, + ) +} + +#[test] +fn doctest_move_to_mod_rs() { + check_doc_test( + "move_to_mod_rs", + r#####" +//- /main.rs +mod a; +//- /a.rs +$0fn t() {}$0 +"#####, + r#####" +fn t() {} +"#####, + ) +} + +#[test] +fn doctest_promote_local_to_const() { + check_doc_test( + "promote_local_to_const", + r#####" +fn main() { + let foo$0 = true; + + if foo { + println!("It's true"); + } else { + println!("It's false"); + } +} +"#####, + r#####" +fn main() { + const $0FOO: bool = true; + + if FOO { + println!("It's true"); + } else { + println!("It's false"); + } +} +"#####, + ) +} + +#[test] +fn doctest_pull_assignment_up() { + check_doc_test( + "pull_assignment_up", + r#####" +fn main() { + let mut foo = 6; + + if true { + $0foo = 5; + } else { + foo = 4; + } +} +"#####, + r#####" +fn main() { + let mut foo = 6; + + foo = if true { + 5 + } else { + 4 + }; +} +"#####, + ) +} + +#[test] +fn doctest_qualify_method_call() { + check_doc_test( + "qualify_method_call", + r#####" +struct Foo; +impl Foo { + fn foo(&self) {} +} +fn main() { + let foo = Foo; + foo.fo$0o(); +} +"#####, + r#####" +struct Foo; +impl Foo { + fn foo(&self) {} +} +fn main() { + let foo = Foo; + Foo::foo(&foo); +} +"#####, + ) +} + +#[test] +fn doctest_qualify_path() { + check_doc_test( + "qualify_path", + r#####" +fn main() { + let map = HashMap$0::new(); +} +pub mod std { pub mod collections { pub struct HashMap { } } } +"#####, + r#####" +fn main() { + let map = std::collections::HashMap::new(); +} +pub mod std { pub mod collections { pub struct HashMap { } } } +"#####, + ) +} + +#[test] +fn doctest_reformat_number_literal() { + check_doc_test( + "reformat_number_literal", + r#####" +const _: i32 = 1012345$0; +"#####, + r#####" +const _: i32 = 1_012_345; +"#####, + ) +} + +#[test] +fn doctest_remove_dbg() { + check_doc_test( + "remove_dbg", + r#####" +fn main() { + $0dbg!(92); +} +"#####, + r#####" +fn main() { + 92; +} +"#####, + ) +} + +#[test] +fn doctest_remove_hash() { + check_doc_test( + "remove_hash", + r#####" +fn main() { + r#"Hello,$0 World!"#; +} +"#####, + r#####" +fn main() { + r"Hello, World!"; +} +"#####, + ) +} + +#[test] +fn doctest_remove_mut() { + check_doc_test( + "remove_mut", + r#####" +impl Walrus { + fn feed(&mut$0 self, amount: u32) {} +} +"#####, + r#####" +impl Walrus { + fn feed(&self, amount: u32) {} +} +"#####, + ) +} + +#[test] +fn doctest_remove_unused_param() { + check_doc_test( + "remove_unused_param", + r#####" +fn frobnicate(x: i32$0) {} + +fn main() { + frobnicate(92); +} +"#####, + r#####" +fn frobnicate() {} + +fn main() { + frobnicate(); +} +"#####, + ) +} + +#[test] +fn doctest_reorder_fields() { + check_doc_test( + "reorder_fields", + r#####" +struct Foo {foo: i32, bar: i32}; +const test: Foo = $0Foo {bar: 0, foo: 1} +"#####, + r#####" +struct Foo {foo: i32, bar: i32}; +const test: Foo = Foo {foo: 1, bar: 0} +"#####, + ) +} + +#[test] +fn doctest_reorder_impl_items() { + check_doc_test( + "reorder_impl_items", + r#####" +trait Foo { + type A; + const B: u8; + fn c(); +} + +struct Bar; +$0impl Foo for Bar { + const B: u8 = 17; + fn c() {} + type A = String; +} +"#####, + r#####" +trait Foo { + type A; + const B: u8; + fn c(); +} + +struct Bar; +impl Foo for Bar { + type A = String; + const B: u8 = 17; + fn c() {} +} +"#####, + ) +} + +#[test] +fn doctest_replace_char_with_string() { + check_doc_test( + "replace_char_with_string", + r#####" +fn main() { + find('{$0'); +} +"#####, + r#####" +fn main() { + find("{"); +} +"#####, + ) +} + +#[test] +fn doctest_replace_derive_with_manual_impl() { + check_doc_test( + "replace_derive_with_manual_impl", + r#####" +//- minicore: derive +trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; } +#[derive(Deb$0ug, Display)] +struct S; +"#####, + r#####" +trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; } +#[derive(Display)] +struct S; + +impl Debug for S { + $0fn fmt(&self, f: &mut Formatter) -> Result<()> { + f.debug_struct("S").finish() + } +} +"#####, + ) +} + +#[test] +fn doctest_replace_if_let_with_match() { + check_doc_test( + "replace_if_let_with_match", + r#####" +enum Action { Move { distance: u32 }, Stop } + +fn handle(action: Action) { + $0if let Action::Move { distance } = action { + foo(distance) + } else { + bar() + } +} +"#####, + r#####" +enum Action { Move { distance: u32 }, Stop } + +fn handle(action: Action) { + match action { + Action::Move { distance } => foo(distance), + _ => bar(), + } +} +"#####, + ) +} + +#[test] +fn doctest_replace_let_with_if_let() { + check_doc_test( + "replace_let_with_if_let", + r#####" +enum Option { Some(T), None } + +fn main(action: Action) { + $0let x = compute(); +} + +fn compute() -> Option { None } +"#####, + r#####" +enum Option { Some(T), None } + +fn main(action: Action) { + if let Some(x) = compute() { + } +} + +fn compute() -> Option { None } +"#####, + ) +} + +#[test] +fn doctest_replace_match_with_if_let() { + check_doc_test( + "replace_match_with_if_let", + r#####" +enum Action { Move { distance: u32 }, Stop } + +fn handle(action: Action) { + $0match action { + Action::Move { distance } => foo(distance), + _ => bar(), + } +} +"#####, + r#####" +enum Action { Move { distance: u32 }, Stop } + +fn handle(action: Action) { + if let Action::Move { distance } = action { + foo(distance) + } else { + bar() + } +} +"#####, + ) +} + +#[test] +fn doctest_replace_qualified_name_with_use() { + check_doc_test( + "replace_qualified_name_with_use", + r#####" +mod std { pub mod collections { pub struct HashMap(T, U); } } +fn process(map: std::collections::$0HashMap) {} +"#####, + r#####" +use std::collections::HashMap; + +mod std { pub mod collections { pub struct HashMap(T, U); } } +fn process(map: HashMap) {} +"#####, + ) +} + +#[test] +fn doctest_replace_string_with_char() { + check_doc_test( + "replace_string_with_char", + r#####" +fn main() { + find("{$0"); +} +"#####, + r#####" +fn main() { + find('{'); +} +"#####, + ) +} + +#[test] +fn doctest_replace_try_expr_with_match() { + check_doc_test( + "replace_try_expr_with_match", + r#####" +//- minicore:option +fn handle() { + let pat = Some(true)$0?; +} +"#####, + r#####" +fn handle() { + let pat = match Some(true) { + Some(it) => it, + None => return None, + }; +} +"#####, + ) +} + +#[test] +fn doctest_replace_turbofish_with_explicit_type() { + check_doc_test( + "replace_turbofish_with_explicit_type", + r#####" +fn make() -> T { ) } +fn main() { + let a = make$0::(); +} +"#####, + r#####" +fn make() -> T { ) } +fn main() { + let a: i32 = make(); +} +"#####, + ) +} + +#[test] +fn doctest_sort_items() { + check_doc_test( + "sort_items", + r#####" +struct $0Foo$0 { second: u32, first: String } +"#####, + r#####" +struct Foo { first: String, second: u32 } +"#####, + ) +} + +#[test] +fn doctest_sort_items_1() { + check_doc_test( + "sort_items", + r#####" +trait $0Bar$0 { + fn second(&self) -> u32; + fn first(&self) -> String; +} +"#####, + r#####" +trait Bar { + fn first(&self) -> String; + fn second(&self) -> u32; +} +"#####, + ) +} + +#[test] +fn doctest_sort_items_2() { + check_doc_test( + "sort_items", + r#####" +struct Baz; +impl $0Baz$0 { + fn second(&self) -> u32; + fn first(&self) -> String; +} +"#####, + r#####" +struct Baz; +impl Baz { + fn first(&self) -> String; + fn second(&self) -> u32; +} +"#####, + ) +} + +#[test] +fn doctest_sort_items_3() { + check_doc_test( + "sort_items", + r#####" +enum $0Animal$0 { + Dog(String, f64), + Cat { weight: f64, name: String }, +} +"#####, + r#####" +enum Animal { + Cat { weight: f64, name: String }, + Dog(String, f64), +} +"#####, + ) +} + +#[test] +fn doctest_sort_items_4() { + check_doc_test( + "sort_items", + r#####" +enum Animal { + Dog(String, f64), + Cat $0{ weight: f64, name: String }$0, +} +"#####, + r#####" +enum Animal { + Dog(String, f64), + Cat { name: String, weight: f64 }, +} +"#####, + ) +} + +#[test] +fn doctest_split_import() { + check_doc_test( + "split_import", + r#####" +use std::$0collections::HashMap; +"#####, + r#####" +use std::{collections::HashMap}; +"#####, + ) +} + +#[test] +fn doctest_toggle_ignore() { + check_doc_test( + "toggle_ignore", + r#####" +$0#[test] +fn arithmetics { + assert_eq!(2 + 2, 5); +} +"#####, + r#####" +#[test] +#[ignore] +fn arithmetics { + assert_eq!(2 + 2, 5); +} +"#####, + ) +} + +#[test] +fn doctest_unmerge_use() { + check_doc_test( + "unmerge_use", + r#####" +use std::fmt::{Debug, Display$0}; +"#####, + r#####" +use std::fmt::{Debug}; +use std::fmt::Display; +"#####, + ) +} + +#[test] +fn doctest_unnecessary_async() { + check_doc_test( + "unnecessary_async", + r#####" +pub async f$0n foo() {} +pub async fn bar() { foo().await } +"#####, + r#####" +pub fn foo() {} +pub async fn bar() { foo() } +"#####, + ) +} + +#[test] +fn doctest_unwrap_block() { + check_doc_test( + "unwrap_block", + r#####" +fn foo() { + if true {$0 + println!("foo"); + } +} +"#####, + r#####" +fn foo() { + println!("foo"); +} +"#####, + ) +} + +#[test] +fn doctest_unwrap_result_return_type() { + check_doc_test( + "unwrap_result_return_type", + r#####" +//- minicore: result +fn foo() -> Result$0 { Ok(42i32) } +"#####, + r#####" +fn foo() -> i32 { 42i32 } +"#####, + ) +} + +#[test] +fn doctest_wrap_return_type_in_result() { + check_doc_test( + "wrap_return_type_in_result", + r#####" +//- minicore: result +fn foo() -> i32$0 { 42i32 } +"#####, + r#####" +fn foo() -> Result { Ok(42i32) } +"#####, + ) +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/tests/sourcegen.rs b/src/tools/rust-analyzer/crates/ide-assists/src/tests/sourcegen.rs new file mode 100644 index 000000000..070b83d3c --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/tests/sourcegen.rs @@ -0,0 +1,195 @@ +//! Generates `assists.md` documentation. + +use std::{fmt, fs, path::Path}; + +use test_utils::project_root; + +#[test] +fn sourcegen_assists_docs() { + let assists = Assist::collect(); + + { + // Generate doctests. + + let mut buf = " +use super::check_doc_test; +" + .to_string(); + for assist in assists.iter() { + for (idx, section) in assist.sections.iter().enumerate() { + let test_id = + if idx == 0 { assist.id.clone() } else { format!("{}_{}", &assist.id, idx) }; + let test = format!( + r######" +#[test] +fn doctest_{}() {{ + check_doc_test( + "{}", +r#####" +{}"#####, r#####" +{}"#####) +}} +"######, + &test_id, + &assist.id, + reveal_hash_comments(§ion.before), + reveal_hash_comments(§ion.after) + ); + + buf.push_str(&test) + } + } + let buf = sourcegen::add_preamble("sourcegen_assists_docs", sourcegen::reformat(buf)); + sourcegen::ensure_file_contents( + &project_root().join("crates/ide-assists/src/tests/generated.rs"), + &buf, + ); + } + + { + // Generate assists manual. Note that we do _not_ commit manual to the + // git repo. Instead, `cargo xtask release` runs this test before making + // a release. + + let contents = sourcegen::add_preamble( + "sourcegen_assists_docs", + assists.into_iter().map(|it| it.to_string()).collect::>().join("\n\n"), + ); + let dst = project_root().join("docs/user/generated_assists.adoc"); + fs::write(dst, contents).unwrap(); + } +} + +#[derive(Debug)] +struct Section { + doc: String, + before: String, + after: String, +} + +#[derive(Debug)] +struct Assist { + id: String, + location: sourcegen::Location, + sections: Vec
, +} + +impl Assist { + fn collect() -> Vec { + let handlers_dir = project_root().join("crates/ide-assists/src/handlers"); + + let mut res = Vec::new(); + for path in sourcegen::list_rust_files(&handlers_dir) { + collect_file(&mut res, path.as_path()); + } + res.sort_by(|lhs, rhs| lhs.id.cmp(&rhs.id)); + return res; + + fn collect_file(acc: &mut Vec, path: &Path) { + let text = fs::read_to_string(path).unwrap(); + let comment_blocks = sourcegen::CommentBlock::extract("Assist", &text); + + for block in comment_blocks { + // FIXME: doesn't support blank lines yet, need to tweak + // `extract_comment_blocks` for that. + let id = block.id; + assert!( + id.chars().all(|it| it.is_ascii_lowercase() || it == '_'), + "invalid assist id: {:?}", + id + ); + let mut lines = block.contents.iter().peekable(); + let location = sourcegen::Location { file: path.to_path_buf(), line: block.line }; + let mut assist = Assist { id, location, sections: Vec::new() }; + + while lines.peek().is_some() { + let doc = take_until(lines.by_ref(), "```").trim().to_string(); + assert!( + (doc.chars().next().unwrap().is_ascii_uppercase() && doc.ends_with('.')) + || assist.sections.len() > 0, + "\n\n{}: assist docs should be proper sentences, with capitalization and a full stop at the end.\n\n{}\n\n", + &assist.id, + doc, + ); + + let before = take_until(lines.by_ref(), "```"); + + assert_eq!(lines.next().unwrap().as_str(), "->"); + assert_eq!(lines.next().unwrap().as_str(), "```"); + let after = take_until(lines.by_ref(), "```"); + + assist.sections.push(Section { doc, before, after }); + } + + acc.push(assist) + } + } + + fn take_until<'a>(lines: impl Iterator, marker: &str) -> String { + let mut buf = Vec::new(); + for line in lines { + if line == marker { + break; + } + buf.push(line.clone()); + } + buf.join("\n") + } + } +} + +impl fmt::Display for Assist { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let _ = writeln!( + f, + "[discrete]\n=== `{}` +**Source:** {}", + self.id, self.location, + ); + + for section in &self.sections { + let before = section.before.replace("$0", "┃"); // Unicode pseudo-graphics bar + let after = section.after.replace("$0", "┃"); + let _ = writeln!( + f, + " +{} + +.Before +```rust +{}``` + +.After +```rust +{}```", + section.doc, + hide_hash_comments(&before), + hide_hash_comments(&after) + ); + } + + Ok(()) + } +} + +fn hide_hash_comments(text: &str) -> String { + text.split('\n') // want final newline + .filter(|&it| !(it.starts_with("# ") || it == "#")) + .map(|it| format!("{}\n", it)) + .collect() +} + +fn reveal_hash_comments(text: &str) -> String { + text.split('\n') // want final newline + .map(|it| { + if let Some(stripped) = it.strip_prefix("# ") { + stripped + } else if it == "#" { + "" + } else { + it + } + }) + .map(|it| format!("{}\n", it)) + .collect() +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/utils.rs b/src/tools/rust-analyzer/crates/ide-assists/src/utils.rs new file mode 100644 index 000000000..3e61d0741 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/utils.rs @@ -0,0 +1,703 @@ +//! Assorted functions shared by several assists. + +use std::ops; + +use itertools::Itertools; + +pub(crate) use gen_trait_fn_body::gen_trait_fn_body; +use hir::{db::HirDatabase, HirDisplay, Semantics}; +use ide_db::{famous_defs::FamousDefs, path_transform::PathTransform, RootDatabase, SnippetCap}; +use stdx::format_to; +use syntax::{ + ast::{ + self, + edit::{self, AstNodeEdit}, + edit_in_place::AttrsOwnerEdit, + make, HasArgList, HasAttrs, HasGenericParams, HasName, HasTypeBounds, Whitespace, + }, + ted, AstNode, AstToken, Direction, SmolStr, SourceFile, + SyntaxKind::*, + SyntaxNode, TextRange, TextSize, T, +}; + +use crate::assist_context::{AssistBuilder, AssistContext}; + +pub(crate) mod suggest_name; +mod gen_trait_fn_body; + +pub(crate) fn unwrap_trivial_block(block_expr: ast::BlockExpr) -> ast::Expr { + extract_trivial_expression(&block_expr) + .filter(|expr| !expr.syntax().text().contains_char('\n')) + .unwrap_or_else(|| block_expr.into()) +} + +pub fn extract_trivial_expression(block_expr: &ast::BlockExpr) -> Option { + if block_expr.modifier().is_some() { + return None; + } + let stmt_list = block_expr.stmt_list()?; + let has_anything_else = |thing: &SyntaxNode| -> bool { + let mut non_trivial_children = + stmt_list.syntax().children_with_tokens().filter(|it| match it.kind() { + WHITESPACE | T!['{'] | T!['}'] => false, + _ => it.as_node() != Some(thing), + }); + non_trivial_children.next().is_some() + }; + + if let Some(expr) = stmt_list.tail_expr() { + if has_anything_else(expr.syntax()) { + return None; + } + return Some(expr); + } + // Unwrap `{ continue; }` + let stmt = stmt_list.statements().next()?; + if let ast::Stmt::ExprStmt(expr_stmt) = stmt { + if has_anything_else(expr_stmt.syntax()) { + return None; + } + let expr = expr_stmt.expr()?; + if matches!(expr.syntax().kind(), CONTINUE_EXPR | BREAK_EXPR | RETURN_EXPR) { + return Some(expr); + } + } + None +} + +/// This is a method with a heuristics to support test methods annotated with custom test annotations, such as +/// `#[test_case(...)]`, `#[tokio::test]` and similar. +/// Also a regular `#[test]` annotation is supported. +/// +/// It may produce false positives, for example, `#[wasm_bindgen_test]` requires a different command to run the test, +/// but it's better than not to have the runnables for the tests at all. +pub fn test_related_attribute(fn_def: &ast::Fn) -> Option { + fn_def.attrs().find_map(|attr| { + let path = attr.path()?; + let text = path.syntax().text().to_string(); + if text.starts_with("test") || text.ends_with("test") { + Some(attr) + } else { + None + } + }) +} + +#[derive(Copy, Clone, PartialEq)] +pub enum DefaultMethods { + Only, + No, +} + +pub fn filter_assoc_items( + sema: &Semantics<'_, RootDatabase>, + items: &[hir::AssocItem], + default_methods: DefaultMethods, +) -> Vec { + fn has_def_name(item: &ast::AssocItem) -> bool { + match item { + ast::AssocItem::Fn(def) => def.name(), + ast::AssocItem::TypeAlias(def) => def.name(), + ast::AssocItem::Const(def) => def.name(), + ast::AssocItem::MacroCall(_) => None, + } + .is_some() + } + + items + .iter() + // Note: This throws away items with no source. + .filter_map(|&i| { + let item = match i { + hir::AssocItem::Function(i) => ast::AssocItem::Fn(sema.source(i)?.value), + hir::AssocItem::TypeAlias(i) => ast::AssocItem::TypeAlias(sema.source(i)?.value), + hir::AssocItem::Const(i) => ast::AssocItem::Const(sema.source(i)?.value), + }; + Some(item) + }) + .filter(has_def_name) + .filter(|it| match it { + ast::AssocItem::Fn(def) => matches!( + (default_methods, def.body()), + (DefaultMethods::Only, Some(_)) | (DefaultMethods::No, None) + ), + _ => default_methods == DefaultMethods::No, + }) + .collect::>() +} + +pub fn add_trait_assoc_items_to_impl( + sema: &Semantics<'_, RootDatabase>, + items: Vec, + trait_: hir::Trait, + impl_: ast::Impl, + target_scope: hir::SemanticsScope<'_>, +) -> (ast::Impl, ast::AssocItem) { + let source_scope = sema.scope_for_def(trait_); + + let transform = PathTransform::trait_impl(&target_scope, &source_scope, trait_, impl_.clone()); + + let items = items.into_iter().map(|assoc_item| { + transform.apply(assoc_item.syntax()); + assoc_item.remove_attrs_and_docs(); + assoc_item + }); + + let res = impl_.clone_for_update(); + + let assoc_item_list = res.get_or_create_assoc_item_list(); + let mut first_item = None; + for item in items { + first_item.get_or_insert_with(|| item.clone()); + match &item { + ast::AssocItem::Fn(fn_) if fn_.body().is_none() => { + let body = make::block_expr(None, Some(make::ext::expr_todo())) + .indent(edit::IndentLevel(1)); + ted::replace(fn_.get_or_create_body().syntax(), body.clone_for_update().syntax()) + } + ast::AssocItem::TypeAlias(type_alias) => { + if let Some(type_bound_list) = type_alias.type_bound_list() { + type_bound_list.remove() + } + } + _ => {} + } + + assoc_item_list.add_item(item) + } + + (res, first_item.unwrap()) +} + +#[derive(Clone, Copy, Debug)] +pub(crate) enum Cursor<'a> { + Replace(&'a SyntaxNode), + Before(&'a SyntaxNode), +} + +impl<'a> Cursor<'a> { + fn node(self) -> &'a SyntaxNode { + match self { + Cursor::Replace(node) | Cursor::Before(node) => node, + } + } +} + +pub(crate) fn render_snippet(_cap: SnippetCap, node: &SyntaxNode, cursor: Cursor<'_>) -> String { + assert!(cursor.node().ancestors().any(|it| it == *node)); + let range = cursor.node().text_range() - node.text_range().start(); + let range: ops::Range = range.into(); + + let mut placeholder = cursor.node().to_string(); + escape(&mut placeholder); + let tab_stop = match cursor { + Cursor::Replace(placeholder) => format!("${{0:{}}}", placeholder), + Cursor::Before(placeholder) => format!("$0{}", placeholder), + }; + + let mut buf = node.to_string(); + buf.replace_range(range, &tab_stop); + return buf; + + fn escape(buf: &mut String) { + stdx::replace(buf, '{', r"\{"); + stdx::replace(buf, '}', r"\}"); + stdx::replace(buf, '$', r"\$"); + } +} + +pub(crate) fn vis_offset(node: &SyntaxNode) -> TextSize { + node.children_with_tokens() + .find(|it| !matches!(it.kind(), WHITESPACE | COMMENT | ATTR)) + .map(|it| it.text_range().start()) + .unwrap_or_else(|| node.text_range().start()) +} + +pub(crate) fn invert_boolean_expression(expr: ast::Expr) -> ast::Expr { + invert_special_case(&expr).unwrap_or_else(|| make::expr_prefix(T![!], expr)) +} + +fn invert_special_case(expr: &ast::Expr) -> Option { + match expr { + ast::Expr::BinExpr(bin) => { + let bin = bin.clone_for_update(); + let op_token = bin.op_token()?; + let rev_token = match op_token.kind() { + T![==] => T![!=], + T![!=] => T![==], + T![<] => T![>=], + T![<=] => T![>], + T![>] => T![<=], + T![>=] => T![<], + // Parenthesize other expressions before prefixing `!` + _ => return Some(make::expr_prefix(T![!], make::expr_paren(expr.clone()))), + }; + ted::replace(op_token, make::token(rev_token)); + Some(bin.into()) + } + ast::Expr::MethodCallExpr(mce) => { + let receiver = mce.receiver()?; + let method = mce.name_ref()?; + let arg_list = mce.arg_list()?; + + let method = match method.text().as_str() { + "is_some" => "is_none", + "is_none" => "is_some", + "is_ok" => "is_err", + "is_err" => "is_ok", + _ => return None, + }; + Some(make::expr_method_call(receiver, make::name_ref(method), arg_list)) + } + ast::Expr::PrefixExpr(pe) if pe.op_kind()? == ast::UnaryOp::Not => match pe.expr()? { + ast::Expr::ParenExpr(parexpr) => parexpr.expr(), + _ => pe.expr(), + }, + ast::Expr::Literal(lit) => match lit.kind() { + ast::LiteralKind::Bool(b) => match b { + true => Some(ast::Expr::Literal(make::expr_literal("false"))), + false => Some(ast::Expr::Literal(make::expr_literal("true"))), + }, + _ => None, + }, + _ => None, + } +} + +pub(crate) fn next_prev() -> impl Iterator { + [Direction::Next, Direction::Prev].into_iter() +} + +pub(crate) fn does_pat_match_variant(pat: &ast::Pat, var: &ast::Pat) -> bool { + let first_node_text = |pat: &ast::Pat| pat.syntax().first_child().map(|node| node.text()); + + let pat_head = match pat { + ast::Pat::IdentPat(bind_pat) => match bind_pat.pat() { + Some(p) => first_node_text(&p), + None => return pat.syntax().text() == var.syntax().text(), + }, + pat => first_node_text(pat), + }; + + let var_head = first_node_text(var); + + pat_head == var_head +} + +pub(crate) fn does_nested_pattern(pat: &ast::Pat) -> bool { + let depth = calc_depth(pat, 0); + + if 1 < depth { + return true; + } + false +} + +fn calc_depth(pat: &ast::Pat, depth: usize) -> usize { + match pat { + ast::Pat::IdentPat(_) + | ast::Pat::BoxPat(_) + | ast::Pat::RestPat(_) + | ast::Pat::LiteralPat(_) + | ast::Pat::MacroPat(_) + | ast::Pat::OrPat(_) + | ast::Pat::ParenPat(_) + | ast::Pat::PathPat(_) + | ast::Pat::WildcardPat(_) + | ast::Pat::RangePat(_) + | ast::Pat::RecordPat(_) + | ast::Pat::RefPat(_) + | ast::Pat::SlicePat(_) + | ast::Pat::TuplePat(_) + | ast::Pat::ConstBlockPat(_) => depth, + + // FIXME: Other patterns may also be nested. Currently it simply supports only `TupleStructPat` + ast::Pat::TupleStructPat(pat) => { + let mut max_depth = depth; + for p in pat.fields() { + let d = calc_depth(&p, depth + 1); + if d > max_depth { + max_depth = d + } + } + max_depth + } + } +} + +// Uses a syntax-driven approach to find any impl blocks for the struct that +// exist within the module/file +// +// Returns `None` if we've found an existing fn +// +// FIXME: change the new fn checking to a more semantic approach when that's more +// viable (e.g. we process proc macros, etc) +// FIXME: this partially overlaps with `find_impl_block_*` +pub(crate) fn find_struct_impl( + ctx: &AssistContext<'_>, + adt: &ast::Adt, + name: &str, +) -> Option> { + let db = ctx.db(); + let module = adt.syntax().parent()?; + + let struct_def = ctx.sema.to_def(adt)?; + + let block = module.descendants().filter_map(ast::Impl::cast).find_map(|impl_blk| { + let blk = ctx.sema.to_def(&impl_blk)?; + + // FIXME: handle e.g. `struct S; impl S {}` + // (we currently use the wrong type parameter) + // also we wouldn't want to use e.g. `impl S` + + let same_ty = match blk.self_ty(db).as_adt() { + Some(def) => def == struct_def, + None => false, + }; + let not_trait_impl = blk.trait_(db).is_none(); + + if !(same_ty && not_trait_impl) { + None + } else { + Some(impl_blk) + } + }); + + if let Some(ref impl_blk) = block { + if has_fn(impl_blk, name) { + return None; + } + } + + Some(block) +} + +fn has_fn(imp: &ast::Impl, rhs_name: &str) -> bool { + if let Some(il) = imp.assoc_item_list() { + for item in il.assoc_items() { + if let ast::AssocItem::Fn(f) = item { + if let Some(name) = f.name() { + if name.text().eq_ignore_ascii_case(rhs_name) { + return true; + } + } + } + } + } + + false +} + +/// Find the start of the `impl` block for the given `ast::Impl`. +// +// FIXME: this partially overlaps with `find_struct_impl` +pub(crate) fn find_impl_block_start(impl_def: ast::Impl, buf: &mut String) -> Option { + buf.push('\n'); + let start = impl_def.assoc_item_list().and_then(|it| it.l_curly_token())?.text_range().end(); + Some(start) +} + +/// Find the end of the `impl` block for the given `ast::Impl`. +// +// FIXME: this partially overlaps with `find_struct_impl` +pub(crate) fn find_impl_block_end(impl_def: ast::Impl, buf: &mut String) -> Option { + buf.push('\n'); + let end = impl_def + .assoc_item_list() + .and_then(|it| it.r_curly_token())? + .prev_sibling_or_token()? + .text_range() + .end(); + Some(end) +} + +// Generates the surrounding `impl Type { }` including type and lifetime +// parameters +pub(crate) fn generate_impl_text(adt: &ast::Adt, code: &str) -> String { + generate_impl_text_inner(adt, None, code) +} + +// Generates the surrounding `impl for Type { }` including type +// and lifetime parameters +pub(crate) fn generate_trait_impl_text(adt: &ast::Adt, trait_text: &str, code: &str) -> String { + generate_impl_text_inner(adt, Some(trait_text), code) +} + +fn generate_impl_text_inner(adt: &ast::Adt, trait_text: Option<&str>, code: &str) -> String { + let generic_params = adt.generic_param_list(); + let mut buf = String::with_capacity(code.len()); + buf.push_str("\n\n"); + adt.attrs() + .filter(|attr| attr.as_simple_call().map(|(name, _arg)| name == "cfg").unwrap_or(false)) + .for_each(|attr| buf.push_str(format!("{}\n", attr).as_str())); + buf.push_str("impl"); + if let Some(generic_params) = &generic_params { + let lifetimes = generic_params.lifetime_params().map(|lt| format!("{}", lt.syntax())); + let toc_params = generic_params.type_or_const_params().map(|toc_param| { + let type_param = match toc_param { + ast::TypeOrConstParam::Type(x) => x, + ast::TypeOrConstParam::Const(x) => return x.syntax().to_string(), + }; + let mut buf = String::new(); + if let Some(it) = type_param.name() { + format_to!(buf, "{}", it.syntax()); + } + if let Some(it) = type_param.colon_token() { + format_to!(buf, "{} ", it); + } + if let Some(it) = type_param.type_bound_list() { + format_to!(buf, "{}", it.syntax()); + } + buf + }); + let generics = lifetimes.chain(toc_params).format(", "); + format_to!(buf, "<{}>", generics); + } + buf.push(' '); + if let Some(trait_text) = trait_text { + buf.push_str(trait_text); + buf.push_str(" for "); + } + buf.push_str(&adt.name().unwrap().text()); + if let Some(generic_params) = generic_params { + let lifetime_params = generic_params + .lifetime_params() + .filter_map(|it| it.lifetime()) + .map(|it| SmolStr::from(it.text())); + let toc_params = generic_params + .type_or_const_params() + .filter_map(|it| it.name()) + .map(|it| SmolStr::from(it.text())); + format_to!(buf, "<{}>", lifetime_params.chain(toc_params).format(", ")) + } + + match adt.where_clause() { + Some(where_clause) => { + format_to!(buf, "\n{}\n{{\n{}\n}}", where_clause, code); + } + None => { + format_to!(buf, " {{\n{}\n}}", code); + } + } + + buf +} + +pub(crate) fn add_method_to_adt( + builder: &mut AssistBuilder, + adt: &ast::Adt, + impl_def: Option, + method: &str, +) { + let mut buf = String::with_capacity(method.len() + 2); + if impl_def.is_some() { + buf.push('\n'); + } + buf.push_str(method); + + let start_offset = impl_def + .and_then(|impl_def| find_impl_block_end(impl_def, &mut buf)) + .unwrap_or_else(|| { + buf = generate_impl_text(adt, &buf); + adt.syntax().text_range().end() + }); + + builder.insert(start_offset, buf); +} + +#[derive(Debug)] +pub(crate) struct ReferenceConversion { + conversion: ReferenceConversionType, + ty: hir::Type, +} + +#[derive(Debug)] +enum ReferenceConversionType { + // reference can be stripped if the type is Copy + Copy, + // &String -> &str + AsRefStr, + // &Vec -> &[T] + AsRefSlice, + // &Box -> &T + Dereferenced, + // &Option -> Option<&T> + Option, + // &Result -> Result<&T, &E> + Result, +} + +impl ReferenceConversion { + pub(crate) fn convert_type(&self, db: &dyn HirDatabase) -> String { + match self.conversion { + ReferenceConversionType::Copy => self.ty.display(db).to_string(), + ReferenceConversionType::AsRefStr => "&str".to_string(), + ReferenceConversionType::AsRefSlice => { + let type_argument_name = + self.ty.type_arguments().next().unwrap().display(db).to_string(); + format!("&[{}]", type_argument_name) + } + ReferenceConversionType::Dereferenced => { + let type_argument_name = + self.ty.type_arguments().next().unwrap().display(db).to_string(); + format!("&{}", type_argument_name) + } + ReferenceConversionType::Option => { + let type_argument_name = + self.ty.type_arguments().next().unwrap().display(db).to_string(); + format!("Option<&{}>", type_argument_name) + } + ReferenceConversionType::Result => { + let mut type_arguments = self.ty.type_arguments(); + let first_type_argument_name = + type_arguments.next().unwrap().display(db).to_string(); + let second_type_argument_name = + type_arguments.next().unwrap().display(db).to_string(); + format!("Result<&{}, &{}>", first_type_argument_name, second_type_argument_name) + } + } + } + + pub(crate) fn getter(&self, field_name: String) -> String { + match self.conversion { + ReferenceConversionType::Copy => format!("self.{}", field_name), + ReferenceConversionType::AsRefStr + | ReferenceConversionType::AsRefSlice + | ReferenceConversionType::Dereferenced + | ReferenceConversionType::Option + | ReferenceConversionType::Result => format!("self.{}.as_ref()", field_name), + } + } +} + +// FIXME: It should return a new hir::Type, but currently constructing new types is too cumbersome +// and all users of this function operate on string type names, so they can do the conversion +// itself themselves. +pub(crate) fn convert_reference_type( + ty: hir::Type, + db: &RootDatabase, + famous_defs: &FamousDefs<'_, '_>, +) -> Option { + handle_copy(&ty, db) + .or_else(|| handle_as_ref_str(&ty, db, famous_defs)) + .or_else(|| handle_as_ref_slice(&ty, db, famous_defs)) + .or_else(|| handle_dereferenced(&ty, db, famous_defs)) + .or_else(|| handle_option_as_ref(&ty, db, famous_defs)) + .or_else(|| handle_result_as_ref(&ty, db, famous_defs)) + .map(|conversion| ReferenceConversion { ty, conversion }) +} + +fn handle_copy(ty: &hir::Type, db: &dyn HirDatabase) -> Option { + ty.is_copy(db).then(|| ReferenceConversionType::Copy) +} + +fn handle_as_ref_str( + ty: &hir::Type, + db: &dyn HirDatabase, + famous_defs: &FamousDefs<'_, '_>, +) -> Option { + let str_type = hir::BuiltinType::str().ty(db); + + ty.impls_trait(db, famous_defs.core_convert_AsRef()?, &[str_type]) + .then(|| ReferenceConversionType::AsRefStr) +} + +fn handle_as_ref_slice( + ty: &hir::Type, + db: &dyn HirDatabase, + famous_defs: &FamousDefs<'_, '_>, +) -> Option { + let type_argument = ty.type_arguments().next()?; + let slice_type = hir::Type::new_slice(type_argument); + + ty.impls_trait(db, famous_defs.core_convert_AsRef()?, &[slice_type]) + .then(|| ReferenceConversionType::AsRefSlice) +} + +fn handle_dereferenced( + ty: &hir::Type, + db: &dyn HirDatabase, + famous_defs: &FamousDefs<'_, '_>, +) -> Option { + let type_argument = ty.type_arguments().next()?; + + ty.impls_trait(db, famous_defs.core_convert_AsRef()?, &[type_argument]) + .then(|| ReferenceConversionType::Dereferenced) +} + +fn handle_option_as_ref( + ty: &hir::Type, + db: &dyn HirDatabase, + famous_defs: &FamousDefs<'_, '_>, +) -> Option { + if ty.as_adt() == famous_defs.core_option_Option()?.ty(db).as_adt() { + Some(ReferenceConversionType::Option) + } else { + None + } +} + +fn handle_result_as_ref( + ty: &hir::Type, + db: &dyn HirDatabase, + famous_defs: &FamousDefs<'_, '_>, +) -> Option { + if ty.as_adt() == famous_defs.core_result_Result()?.ty(db).as_adt() { + Some(ReferenceConversionType::Result) + } else { + None + } +} + +pub(crate) fn get_methods(items: &ast::AssocItemList) -> Vec { + items + .assoc_items() + .flat_map(|i| match i { + ast::AssocItem::Fn(f) => Some(f), + _ => None, + }) + .filter(|f| f.name().is_some()) + .collect() +} + +/// Trim(remove leading and trailing whitespace) `initial_range` in `source_file`, return the trimmed range. +pub(crate) fn trimmed_text_range(source_file: &SourceFile, initial_range: TextRange) -> TextRange { + let mut trimmed_range = initial_range; + while source_file + .syntax() + .token_at_offset(trimmed_range.start()) + .find_map(Whitespace::cast) + .is_some() + && trimmed_range.start() < trimmed_range.end() + { + let start = trimmed_range.start() + TextSize::from(1); + trimmed_range = TextRange::new(start, trimmed_range.end()); + } + while source_file + .syntax() + .token_at_offset(trimmed_range.end()) + .find_map(Whitespace::cast) + .is_some() + && trimmed_range.start() < trimmed_range.end() + { + let end = trimmed_range.end() - TextSize::from(1); + trimmed_range = TextRange::new(trimmed_range.start(), end); + } + trimmed_range +} + +/// Convert a list of function params to a list of arguments that can be passed +/// into a function call. +pub(crate) fn convert_param_list_to_arg_list(list: ast::ParamList) -> ast::ArgList { + let mut args = vec![]; + for param in list.params() { + if let Some(ast::Pat::IdentPat(pat)) = param.pat() { + if let Some(name) = pat.name() { + let name = name.to_string(); + let expr = make::expr_path(make::ext::ident_path(&name)); + args.push(expr); + } + } + } + make::arg_list(args) +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/utils/gen_trait_fn_body.rs b/src/tools/rust-analyzer/crates/ide-assists/src/utils/gen_trait_fn_body.rs new file mode 100644 index 000000000..7a0c91295 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/utils/gen_trait_fn_body.rs @@ -0,0 +1,661 @@ +//! This module contains functions to generate default trait impl function bodies where possible. + +use syntax::{ + ast::{self, edit::AstNodeEdit, make, AstNode, BinaryOp, CmpOp, HasName, LogicOp}, + ted, +}; + +/// Generate custom trait bodies without default implementation where possible. +/// +/// Returns `Option` so that we can use `?` rather than `if let Some`. Returning +/// `None` means that generating a custom trait body failed, and the body will remain +/// as `todo!` instead. +pub(crate) fn gen_trait_fn_body( + func: &ast::Fn, + trait_path: &ast::Path, + adt: &ast::Adt, +) -> Option<()> { + match trait_path.segment()?.name_ref()?.text().as_str() { + "Clone" => gen_clone_impl(adt, func), + "Debug" => gen_debug_impl(adt, func), + "Default" => gen_default_impl(adt, func), + "Hash" => gen_hash_impl(adt, func), + "PartialEq" => gen_partial_eq(adt, func), + "PartialOrd" => gen_partial_ord(adt, func), + _ => None, + } +} + +/// Generate a `Clone` impl based on the fields and members of the target type. +fn gen_clone_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> { + stdx::always!(func.name().map_or(false, |name| name.text() == "clone")); + fn gen_clone_call(target: ast::Expr) -> ast::Expr { + let method = make::name_ref("clone"); + make::expr_method_call(target, method, make::arg_list(None)) + } + let expr = match adt { + // `Clone` cannot be derived for unions, so no default impl can be provided. + ast::Adt::Union(_) => return None, + ast::Adt::Enum(enum_) => { + let list = enum_.variant_list()?; + let mut arms = vec![]; + for variant in list.variants() { + let name = variant.name()?; + let variant_name = make::ext::path_from_idents(["Self", &format!("{}", name)])?; + + match variant.field_list() { + // => match self { Self::Name { x } => Self::Name { x: x.clone() } } + Some(ast::FieldList::RecordFieldList(list)) => { + let mut pats = vec![]; + let mut fields = vec![]; + for field in list.fields() { + let field_name = field.name()?; + let pat = make::ident_pat(false, false, field_name.clone()); + pats.push(pat.into()); + + let path = make::ext::ident_path(&field_name.to_string()); + let method_call = gen_clone_call(make::expr_path(path)); + let name_ref = make::name_ref(&field_name.to_string()); + let field = make::record_expr_field(name_ref, Some(method_call)); + fields.push(field); + } + let pat = make::record_pat(variant_name.clone(), pats.into_iter()); + let fields = make::record_expr_field_list(fields); + let record_expr = make::record_expr(variant_name, fields).into(); + arms.push(make::match_arm(Some(pat.into()), None, record_expr)); + } + + // => match self { Self::Name(arg1) => Self::Name(arg1.clone()) } + Some(ast::FieldList::TupleFieldList(list)) => { + let mut pats = vec![]; + let mut fields = vec![]; + for (i, _) in list.fields().enumerate() { + let field_name = format!("arg{}", i); + let pat = make::ident_pat(false, false, make::name(&field_name)); + pats.push(pat.into()); + + let f_path = make::expr_path(make::ext::ident_path(&field_name)); + fields.push(gen_clone_call(f_path)); + } + let pat = make::tuple_struct_pat(variant_name.clone(), pats.into_iter()); + let struct_name = make::expr_path(variant_name); + let tuple_expr = make::expr_call(struct_name, make::arg_list(fields)); + arms.push(make::match_arm(Some(pat.into()), None, tuple_expr)); + } + + // => match self { Self::Name => Self::Name } + None => { + let pattern = make::path_pat(variant_name.clone()); + let variant_expr = make::expr_path(variant_name); + arms.push(make::match_arm(Some(pattern), None, variant_expr)); + } + } + } + + let match_target = make::expr_path(make::ext::ident_path("self")); + let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1)); + make::expr_match(match_target, list) + } + ast::Adt::Struct(strukt) => { + match strukt.field_list() { + // => Self { name: self.name.clone() } + Some(ast::FieldList::RecordFieldList(field_list)) => { + let mut fields = vec![]; + for field in field_list.fields() { + let base = make::expr_path(make::ext::ident_path("self")); + let target = make::expr_field(base, &field.name()?.to_string()); + let method_call = gen_clone_call(target); + let name_ref = make::name_ref(&field.name()?.to_string()); + let field = make::record_expr_field(name_ref, Some(method_call)); + fields.push(field); + } + let struct_name = make::ext::ident_path("Self"); + let fields = make::record_expr_field_list(fields); + make::record_expr(struct_name, fields).into() + } + // => Self(self.0.clone(), self.1.clone()) + Some(ast::FieldList::TupleFieldList(field_list)) => { + let mut fields = vec![]; + for (i, _) in field_list.fields().enumerate() { + let f_path = make::expr_path(make::ext::ident_path("self")); + let target = make::expr_field(f_path, &format!("{}", i)); + fields.push(gen_clone_call(target)); + } + let struct_name = make::expr_path(make::ext::ident_path("Self")); + make::expr_call(struct_name, make::arg_list(fields)) + } + // => Self { } + None => { + let struct_name = make::ext::ident_path("Self"); + let fields = make::record_expr_field_list(None); + make::record_expr(struct_name, fields).into() + } + } + } + }; + let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1)); + ted::replace(func.body()?.syntax(), body.clone_for_update().syntax()); + Some(()) +} + +/// Generate a `Debug` impl based on the fields and members of the target type. +fn gen_debug_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> { + let annotated_name = adt.name()?; + match adt { + // `Debug` cannot be derived for unions, so no default impl can be provided. + ast::Adt::Union(_) => None, + + // => match self { Self::Variant => write!(f, "Variant") } + ast::Adt::Enum(enum_) => { + let list = enum_.variant_list()?; + let mut arms = vec![]; + for variant in list.variants() { + let name = variant.name()?; + let variant_name = make::ext::path_from_idents(["Self", &format!("{}", name)])?; + let target = make::expr_path(make::ext::ident_path("f")); + + match variant.field_list() { + Some(ast::FieldList::RecordFieldList(list)) => { + // => f.debug_struct(name) + let target = make::expr_path(make::ext::ident_path("f")); + let method = make::name_ref("debug_struct"); + let struct_name = format!("\"{}\"", name); + let args = make::arg_list(Some(make::expr_literal(&struct_name).into())); + let mut expr = make::expr_method_call(target, method, args); + + let mut pats = vec![]; + for field in list.fields() { + let field_name = field.name()?; + + // create a field pattern for use in `MyStruct { fields.. }` + let pat = make::ident_pat(false, false, field_name.clone()); + pats.push(pat.into()); + + // => .field("field_name", field) + let method_name = make::name_ref("field"); + let name = make::expr_literal(&(format!("\"{}\"", field_name))).into(); + let path = &format!("{}", field_name); + let path = make::expr_path(make::ext::ident_path(path)); + let args = make::arg_list(vec![name, path]); + expr = make::expr_method_call(expr, method_name, args); + } + + // => .finish() + let method = make::name_ref("finish"); + let expr = make::expr_method_call(expr, method, make::arg_list(None)); + + // => MyStruct { fields.. } => f.debug_struct("MyStruct")...finish(), + let pat = make::record_pat(variant_name.clone(), pats.into_iter()); + arms.push(make::match_arm(Some(pat.into()), None, expr)); + } + Some(ast::FieldList::TupleFieldList(list)) => { + // => f.debug_tuple(name) + let target = make::expr_path(make::ext::ident_path("f")); + let method = make::name_ref("debug_tuple"); + let struct_name = format!("\"{}\"", name); + let args = make::arg_list(Some(make::expr_literal(&struct_name).into())); + let mut expr = make::expr_method_call(target, method, args); + + let mut pats = vec![]; + for (i, _) in list.fields().enumerate() { + let name = format!("arg{}", i); + + // create a field pattern for use in `MyStruct(fields..)` + let field_name = make::name(&name); + let pat = make::ident_pat(false, false, field_name.clone()); + pats.push(pat.into()); + + // => .field(field) + let method_name = make::name_ref("field"); + let field_path = &name.to_string(); + let field_path = make::expr_path(make::ext::ident_path(field_path)); + let args = make::arg_list(vec![field_path]); + expr = make::expr_method_call(expr, method_name, args); + } + + // => .finish() + let method = make::name_ref("finish"); + let expr = make::expr_method_call(expr, method, make::arg_list(None)); + + // => MyStruct (fields..) => f.debug_tuple("MyStruct")...finish(), + let pat = make::tuple_struct_pat(variant_name.clone(), pats.into_iter()); + arms.push(make::match_arm(Some(pat.into()), None, expr)); + } + None => { + let fmt_string = make::expr_literal(&(format!("\"{}\"", name))).into(); + let args = make::arg_list([target, fmt_string]); + let macro_name = make::expr_path(make::ext::ident_path("write")); + let macro_call = make::expr_macro_call(macro_name, args); + + let variant_name = make::path_pat(variant_name); + arms.push(make::match_arm(Some(variant_name), None, macro_call)); + } + } + } + + let match_target = make::expr_path(make::ext::ident_path("self")); + let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1)); + let match_expr = make::expr_match(match_target, list); + + let body = make::block_expr(None, Some(match_expr)); + let body = body.indent(ast::edit::IndentLevel(1)); + ted::replace(func.body()?.syntax(), body.clone_for_update().syntax()); + Some(()) + } + + ast::Adt::Struct(strukt) => { + let name = format!("\"{}\"", annotated_name); + let args = make::arg_list(Some(make::expr_literal(&name).into())); + let target = make::expr_path(make::ext::ident_path("f")); + + let expr = match strukt.field_list() { + // => f.debug_struct("Name").finish() + None => make::expr_method_call(target, make::name_ref("debug_struct"), args), + + // => f.debug_struct("Name").field("foo", &self.foo).finish() + Some(ast::FieldList::RecordFieldList(field_list)) => { + let method = make::name_ref("debug_struct"); + let mut expr = make::expr_method_call(target, method, args); + for field in field_list.fields() { + let name = field.name()?; + let f_name = make::expr_literal(&(format!("\"{}\"", name))).into(); + let f_path = make::expr_path(make::ext::ident_path("self")); + let f_path = make::expr_ref(f_path, false); + let f_path = make::expr_field(f_path, &format!("{}", name)); + let args = make::arg_list([f_name, f_path]); + expr = make::expr_method_call(expr, make::name_ref("field"), args); + } + expr + } + + // => f.debug_tuple("Name").field(self.0).finish() + Some(ast::FieldList::TupleFieldList(field_list)) => { + let method = make::name_ref("debug_tuple"); + let mut expr = make::expr_method_call(target, method, args); + for (i, _) in field_list.fields().enumerate() { + let f_path = make::expr_path(make::ext::ident_path("self")); + let f_path = make::expr_ref(f_path, false); + let f_path = make::expr_field(f_path, &format!("{}", i)); + let method = make::name_ref("field"); + expr = make::expr_method_call(expr, method, make::arg_list(Some(f_path))); + } + expr + } + }; + + let method = make::name_ref("finish"); + let expr = make::expr_method_call(expr, method, make::arg_list(None)); + let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1)); + ted::replace(func.body()?.syntax(), body.clone_for_update().syntax()); + Some(()) + } + } +} + +/// Generate a `Debug` impl based on the fields and members of the target type. +fn gen_default_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> { + fn gen_default_call() -> Option { + let fn_name = make::ext::path_from_idents(["Default", "default"])?; + Some(make::expr_call(make::expr_path(fn_name), make::arg_list(None))) + } + match adt { + // `Debug` cannot be derived for unions, so no default impl can be provided. + ast::Adt::Union(_) => None, + // Deriving `Debug` for enums is not stable yet. + ast::Adt::Enum(_) => None, + ast::Adt::Struct(strukt) => { + let expr = match strukt.field_list() { + Some(ast::FieldList::RecordFieldList(field_list)) => { + let mut fields = vec![]; + for field in field_list.fields() { + let method_call = gen_default_call()?; + let name_ref = make::name_ref(&field.name()?.to_string()); + let field = make::record_expr_field(name_ref, Some(method_call)); + fields.push(field); + } + let struct_name = make::ext::ident_path("Self"); + let fields = make::record_expr_field_list(fields); + make::record_expr(struct_name, fields).into() + } + Some(ast::FieldList::TupleFieldList(field_list)) => { + let struct_name = make::expr_path(make::ext::ident_path("Self")); + let fields = field_list + .fields() + .map(|_| gen_default_call()) + .collect::>>()?; + make::expr_call(struct_name, make::arg_list(fields)) + } + None => { + let struct_name = make::ext::ident_path("Self"); + let fields = make::record_expr_field_list(None); + make::record_expr(struct_name, fields).into() + } + }; + let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1)); + ted::replace(func.body()?.syntax(), body.clone_for_update().syntax()); + Some(()) + } + } +} + +/// Generate a `Hash` impl based on the fields and members of the target type. +fn gen_hash_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> { + stdx::always!(func.name().map_or(false, |name| name.text() == "hash")); + fn gen_hash_call(target: ast::Expr) -> ast::Stmt { + let method = make::name_ref("hash"); + let arg = make::expr_path(make::ext::ident_path("state")); + let expr = make::expr_method_call(target, method, make::arg_list(Some(arg))); + make::expr_stmt(expr).into() + } + + let body = match adt { + // `Hash` cannot be derived for unions, so no default impl can be provided. + ast::Adt::Union(_) => return None, + + // => std::mem::discriminant(self).hash(state); + ast::Adt::Enum(_) => { + let fn_name = make_discriminant()?; + + let arg = make::expr_path(make::ext::ident_path("self")); + let fn_call = make::expr_call(fn_name, make::arg_list(Some(arg))); + let stmt = gen_hash_call(fn_call); + + make::block_expr(Some(stmt), None).indent(ast::edit::IndentLevel(1)) + } + ast::Adt::Struct(strukt) => match strukt.field_list() { + // => self..hash(state); + Some(ast::FieldList::RecordFieldList(field_list)) => { + let mut stmts = vec![]; + for field in field_list.fields() { + let base = make::expr_path(make::ext::ident_path("self")); + let target = make::expr_field(base, &field.name()?.to_string()); + stmts.push(gen_hash_call(target)); + } + make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1)) + } + + // => self..hash(state); + Some(ast::FieldList::TupleFieldList(field_list)) => { + let mut stmts = vec![]; + for (i, _) in field_list.fields().enumerate() { + let base = make::expr_path(make::ext::ident_path("self")); + let target = make::expr_field(base, &format!("{}", i)); + stmts.push(gen_hash_call(target)); + } + make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1)) + } + + // No fields in the body means there's nothing to hash. + None => return None, + }, + }; + + ted::replace(func.body()?.syntax(), body.clone_for_update().syntax()); + Some(()) +} + +/// Generate a `PartialEq` impl based on the fields and members of the target type. +fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> { + stdx::always!(func.name().map_or(false, |name| name.text() == "eq")); + fn gen_eq_chain(expr: Option, cmp: ast::Expr) -> Option { + match expr { + Some(expr) => Some(make::expr_bin_op(expr, BinaryOp::LogicOp(LogicOp::And), cmp)), + None => Some(cmp), + } + } + + fn gen_record_pat_field(field_name: &str, pat_name: &str) -> ast::RecordPatField { + let pat = make::ext::simple_ident_pat(make::name(pat_name)); + let name_ref = make::name_ref(field_name); + make::record_pat_field(name_ref, pat.into()) + } + + fn gen_record_pat(record_name: ast::Path, fields: Vec) -> ast::RecordPat { + let list = make::record_pat_field_list(fields); + make::record_pat_with_fields(record_name, list) + } + + fn gen_variant_path(variant: &ast::Variant) -> Option { + make::ext::path_from_idents(["Self", &variant.name()?.to_string()]) + } + + fn gen_tuple_field(field_name: &String) -> ast::Pat { + ast::Pat::IdentPat(make::ident_pat(false, false, make::name(field_name))) + } + + // FIXME: return `None` if the trait carries a generic type; we can only + // generate this code `Self` for the time being. + + let body = match adt { + // `PartialEq` cannot be derived for unions, so no default impl can be provided. + ast::Adt::Union(_) => return None, + + ast::Adt::Enum(enum_) => { + // => std::mem::discriminant(self) == std::mem::discriminant(other) + let lhs_name = make::expr_path(make::ext::ident_path("self")); + let lhs = make::expr_call(make_discriminant()?, make::arg_list(Some(lhs_name.clone()))); + let rhs_name = make::expr_path(make::ext::ident_path("other")); + let rhs = make::expr_call(make_discriminant()?, make::arg_list(Some(rhs_name.clone()))); + let eq_check = + make::expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs); + + let mut n_cases = 0; + let mut arms = vec![]; + for variant in enum_.variant_list()?.variants() { + n_cases += 1; + match variant.field_list() { + // => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin, + Some(ast::FieldList::RecordFieldList(list)) => { + let mut expr = None; + let mut l_fields = vec![]; + let mut r_fields = vec![]; + + for field in list.fields() { + let field_name = field.name()?.to_string(); + + let l_name = &format!("l_{}", field_name); + l_fields.push(gen_record_pat_field(&field_name, l_name)); + + let r_name = &format!("r_{}", field_name); + r_fields.push(gen_record_pat_field(&field_name, r_name)); + + let lhs = make::expr_path(make::ext::ident_path(l_name)); + let rhs = make::expr_path(make::ext::ident_path(r_name)); + let cmp = make::expr_bin_op( + lhs, + BinaryOp::CmpOp(CmpOp::Eq { negated: false }), + rhs, + ); + expr = gen_eq_chain(expr, cmp); + } + + let left = gen_record_pat(gen_variant_path(&variant)?, l_fields); + let right = gen_record_pat(gen_variant_path(&variant)?, r_fields); + let tuple = make::tuple_pat(vec![left.into(), right.into()]); + + if let Some(expr) = expr { + arms.push(make::match_arm(Some(tuple.into()), None, expr)); + } + } + + Some(ast::FieldList::TupleFieldList(list)) => { + let mut expr = None; + let mut l_fields = vec![]; + let mut r_fields = vec![]; + + for (i, _) in list.fields().enumerate() { + let field_name = format!("{}", i); + + let l_name = format!("l{}", field_name); + l_fields.push(gen_tuple_field(&l_name)); + + let r_name = format!("r{}", field_name); + r_fields.push(gen_tuple_field(&r_name)); + + let lhs = make::expr_path(make::ext::ident_path(&l_name)); + let rhs = make::expr_path(make::ext::ident_path(&r_name)); + let cmp = make::expr_bin_op( + lhs, + BinaryOp::CmpOp(CmpOp::Eq { negated: false }), + rhs, + ); + expr = gen_eq_chain(expr, cmp); + } + + let left = make::tuple_struct_pat(gen_variant_path(&variant)?, l_fields); + let right = make::tuple_struct_pat(gen_variant_path(&variant)?, r_fields); + let tuple = make::tuple_pat(vec![left.into(), right.into()]); + + if let Some(expr) = expr { + arms.push(make::match_arm(Some(tuple.into()), None, expr)); + } + } + None => continue, + } + } + + let expr = match arms.len() { + 0 => eq_check, + _ => { + if n_cases > arms.len() { + let lhs = make::wildcard_pat().into(); + arms.push(make::match_arm(Some(lhs), None, eq_check)); + } + + let match_target = make::expr_tuple(vec![lhs_name, rhs_name]); + let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1)); + make::expr_match(match_target, list) + } + }; + + make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1)) + } + ast::Adt::Struct(strukt) => match strukt.field_list() { + Some(ast::FieldList::RecordFieldList(field_list)) => { + let mut expr = None; + for field in field_list.fields() { + let lhs = make::expr_path(make::ext::ident_path("self")); + let lhs = make::expr_field(lhs, &field.name()?.to_string()); + let rhs = make::expr_path(make::ext::ident_path("other")); + let rhs = make::expr_field(rhs, &field.name()?.to_string()); + let cmp = + make::expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs); + expr = gen_eq_chain(expr, cmp); + } + make::block_expr(None, expr).indent(ast::edit::IndentLevel(1)) + } + + Some(ast::FieldList::TupleFieldList(field_list)) => { + let mut expr = None; + for (i, _) in field_list.fields().enumerate() { + let idx = format!("{}", i); + let lhs = make::expr_path(make::ext::ident_path("self")); + let lhs = make::expr_field(lhs, &idx); + let rhs = make::expr_path(make::ext::ident_path("other")); + let rhs = make::expr_field(rhs, &idx); + let cmp = + make::expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs); + expr = gen_eq_chain(expr, cmp); + } + make::block_expr(None, expr).indent(ast::edit::IndentLevel(1)) + } + + // No fields in the body means there's nothing to hash. + None => { + let expr = make::expr_literal("true").into(); + make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1)) + } + }, + }; + + ted::replace(func.body()?.syntax(), body.clone_for_update().syntax()); + Some(()) +} + +fn gen_partial_ord(adt: &ast::Adt, func: &ast::Fn) -> Option<()> { + stdx::always!(func.name().map_or(false, |name| name.text() == "partial_cmp")); + fn gen_partial_eq_match(match_target: ast::Expr) -> Option { + let mut arms = vec![]; + + let variant_name = + make::path_pat(make::ext::path_from_idents(["core", "cmp", "Ordering", "Equal"])?); + let lhs = make::tuple_struct_pat(make::ext::path_from_idents(["Some"])?, [variant_name]); + arms.push(make::match_arm(Some(lhs.into()), None, make::expr_empty_block())); + + arms.push(make::match_arm( + [make::ident_pat(false, false, make::name("ord")).into()], + None, + make::expr_return(Some(make::expr_path(make::ext::ident_path("ord")))), + )); + let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1)); + Some(make::expr_stmt(make::expr_match(match_target, list)).into()) + } + + fn gen_partial_cmp_call(lhs: ast::Expr, rhs: ast::Expr) -> ast::Expr { + let rhs = make::expr_ref(rhs, false); + let method = make::name_ref("partial_cmp"); + make::expr_method_call(lhs, method, make::arg_list(Some(rhs))) + } + + // FIXME: return `None` if the trait carries a generic type; we can only + // generate this code `Self` for the time being. + + let body = match adt { + // `PartialOrd` cannot be derived for unions, so no default impl can be provided. + ast::Adt::Union(_) => return None, + // `core::mem::Discriminant` does not implement `PartialOrd` in stable Rust today. + ast::Adt::Enum(_) => return None, + ast::Adt::Struct(strukt) => match strukt.field_list() { + Some(ast::FieldList::RecordFieldList(field_list)) => { + let mut exprs = vec![]; + for field in field_list.fields() { + let lhs = make::expr_path(make::ext::ident_path("self")); + let lhs = make::expr_field(lhs, &field.name()?.to_string()); + let rhs = make::expr_path(make::ext::ident_path("other")); + let rhs = make::expr_field(rhs, &field.name()?.to_string()); + let ord = gen_partial_cmp_call(lhs, rhs); + exprs.push(ord); + } + + let tail = exprs.pop(); + let stmts = exprs + .into_iter() + .map(gen_partial_eq_match) + .collect::>>()?; + make::block_expr(stmts.into_iter(), tail).indent(ast::edit::IndentLevel(1)) + } + + Some(ast::FieldList::TupleFieldList(field_list)) => { + let mut exprs = vec![]; + for (i, _) in field_list.fields().enumerate() { + let idx = format!("{}", i); + let lhs = make::expr_path(make::ext::ident_path("self")); + let lhs = make::expr_field(lhs, &idx); + let rhs = make::expr_path(make::ext::ident_path("other")); + let rhs = make::expr_field(rhs, &idx); + let ord = gen_partial_cmp_call(lhs, rhs); + exprs.push(ord); + } + let tail = exprs.pop(); + let stmts = exprs + .into_iter() + .map(gen_partial_eq_match) + .collect::>>()?; + make::block_expr(stmts.into_iter(), tail).indent(ast::edit::IndentLevel(1)) + } + + // No fields in the body means there's nothing to compare. + None => { + let expr = make::expr_literal("true").into(); + make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1)) + } + }, + }; + + ted::replace(func.body()?.syntax(), body.clone_for_update().syntax()); + Some(()) +} + +fn make_discriminant() -> Option { + Some(make::expr_path(make::ext::path_from_idents(["core", "mem", "discriminant"])?)) +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/utils/suggest_name.rs b/src/tools/rust-analyzer/crates/ide-assists/src/utils/suggest_name.rs new file mode 100644 index 000000000..779cdbc93 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/utils/suggest_name.rs @@ -0,0 +1,775 @@ +//! This module contains functions to suggest names for expressions, functions and other items + +use hir::Semantics; +use ide_db::RootDatabase; +use itertools::Itertools; +use stdx::to_lower_snake_case; +use syntax::{ + ast::{self, HasName}, + match_ast, AstNode, SmolStr, +}; + +/// Trait names, that will be ignored when in `impl Trait` and `dyn Trait` +const USELESS_TRAITS: &[&str] = &["Send", "Sync", "Copy", "Clone", "Eq", "PartialEq"]; + +/// Identifier names that won't be suggested, ever +/// +/// **NOTE**: they all must be snake lower case +const USELESS_NAMES: &[&str] = + &["new", "default", "option", "some", "none", "ok", "err", "str", "string"]; + +/// Generic types replaced by their first argument +/// +/// # Examples +/// `Option` -> `Name` +/// `Result` -> `User` +const WRAPPER_TYPES: &[&str] = &["Box", "Option", "Result"]; + +/// Prefixes to strip from methods names +/// +/// # Examples +/// `vec.as_slice()` -> `slice` +/// `args.into_config()` -> `config` +/// `bytes.to_vec()` -> `vec` +const USELESS_METHOD_PREFIXES: &[&str] = &["into_", "as_", "to_"]; + +/// Useless methods that are stripped from expression +/// +/// # Examples +/// `var.name().to_string()` -> `var.name()` +const USELESS_METHODS: &[&str] = &[ + "to_string", + "as_str", + "to_owned", + "as_ref", + "clone", + "cloned", + "expect", + "expect_none", + "unwrap", + "unwrap_none", + "unwrap_or", + "unwrap_or_default", + "unwrap_or_else", + "unwrap_unchecked", + "iter", + "into_iter", + "iter_mut", +]; + +pub(crate) fn for_generic_parameter(ty: &ast::ImplTraitType) -> SmolStr { + let c = ty + .type_bound_list() + .and_then(|bounds| bounds.syntax().text().char_at(0.into())) + .unwrap_or('T'); + c.encode_utf8(&mut [0; 4]).into() +} + +/// Suggest name of variable for given expression +/// +/// **NOTE**: it is caller's responsibility to guarantee uniqueness of the name. +/// I.e. it doesn't look for names in scope. +/// +/// # Current implementation +/// +/// In current implementation, the function tries to get the name from +/// the following sources: +/// +/// * if expr is an argument to function/method, use paramter name +/// * if expr is a function/method call, use function name +/// * expression type name if it exists (E.g. `()`, `fn() -> ()` or `!` do not have names) +/// * fallback: `var_name` +/// +/// It also applies heuristics to filter out less informative names +/// +/// Currently it sticks to the first name found. +// FIXME: Microoptimize and return a `SmolStr` here. +pub(crate) fn for_variable(expr: &ast::Expr, sema: &Semantics<'_, RootDatabase>) -> String { + // `from_param` does not benifit from stripping + // it need the largest context possible + // so we check firstmost + if let Some(name) = from_param(expr, sema) { + return name; + } + + let mut next_expr = Some(expr.clone()); + while let Some(expr) = next_expr { + let name = + from_call(&expr).or_else(|| from_type(&expr, sema)).or_else(|| from_field_name(&expr)); + if let Some(name) = name { + return name; + } + + match expr { + ast::Expr::RefExpr(inner) => next_expr = inner.expr(), + ast::Expr::BoxExpr(inner) => next_expr = inner.expr(), + ast::Expr::AwaitExpr(inner) => next_expr = inner.expr(), + // ast::Expr::BlockExpr(block) => expr = block.tail_expr(), + ast::Expr::CastExpr(inner) => next_expr = inner.expr(), + ast::Expr::MethodCallExpr(method) if is_useless_method(&method) => { + next_expr = method.receiver(); + } + ast::Expr::ParenExpr(inner) => next_expr = inner.expr(), + ast::Expr::TryExpr(inner) => next_expr = inner.expr(), + ast::Expr::PrefixExpr(prefix) if prefix.op_kind() == Some(ast::UnaryOp::Deref) => { + next_expr = prefix.expr() + } + _ => break, + } + } + + "var_name".to_string() +} + +fn normalize(name: &str) -> Option { + let name = to_lower_snake_case(name); + + if USELESS_NAMES.contains(&name.as_str()) { + return None; + } + + if !is_valid_name(&name) { + return None; + } + + Some(name) +} + +fn is_valid_name(name: &str) -> bool { + match ide_db::syntax_helpers::LexedStr::single_token(name) { + Some((syntax::SyntaxKind::IDENT, _error)) => true, + _ => false, + } +} + +fn is_useless_method(method: &ast::MethodCallExpr) -> bool { + let ident = method.name_ref().and_then(|it| it.ident_token()); + + match ident { + Some(ident) => USELESS_METHODS.contains(&ident.text()), + None => false, + } +} + +fn from_call(expr: &ast::Expr) -> Option { + from_func_call(expr).or_else(|| from_method_call(expr)) +} + +fn from_func_call(expr: &ast::Expr) -> Option { + let call = match expr { + ast::Expr::CallExpr(call) => call, + _ => return None, + }; + let func = match call.expr()? { + ast::Expr::PathExpr(path) => path, + _ => return None, + }; + let ident = func.path()?.segment()?.name_ref()?.ident_token()?; + normalize(ident.text()) +} + +fn from_method_call(expr: &ast::Expr) -> Option { + let method = match expr { + ast::Expr::MethodCallExpr(call) => call, + _ => return None, + }; + let ident = method.name_ref()?.ident_token()?; + let mut name = ident.text(); + + if USELESS_METHODS.contains(&name) { + return None; + } + + for prefix in USELESS_METHOD_PREFIXES { + if let Some(suffix) = name.strip_prefix(prefix) { + name = suffix; + break; + } + } + + normalize(name) +} + +fn from_param(expr: &ast::Expr, sema: &Semantics<'_, RootDatabase>) -> Option { + let arg_list = expr.syntax().parent().and_then(ast::ArgList::cast)?; + let args_parent = arg_list.syntax().parent()?; + let func = match_ast! { + match args_parent { + ast::CallExpr(call) => { + let func = call.expr()?; + let func_ty = sema.type_of_expr(&func)?.adjusted(); + func_ty.as_callable(sema.db)? + }, + ast::MethodCallExpr(method) => sema.resolve_method_call_as_callable(&method)?, + _ => return None, + } + }; + + let (idx, _) = arg_list.args().find_position(|it| it == expr).unwrap(); + let (pat, _) = func.params(sema.db).into_iter().nth(idx)?; + let pat = match pat? { + either::Either::Right(pat) => pat, + _ => return None, + }; + let name = var_name_from_pat(&pat)?; + normalize(&name.to_string()) +} + +fn var_name_from_pat(pat: &ast::Pat) -> Option { + match pat { + ast::Pat::IdentPat(var) => var.name(), + ast::Pat::RefPat(ref_pat) => var_name_from_pat(&ref_pat.pat()?), + ast::Pat::BoxPat(box_pat) => var_name_from_pat(&box_pat.pat()?), + _ => None, + } +} + +fn from_type(expr: &ast::Expr, sema: &Semantics<'_, RootDatabase>) -> Option { + let ty = sema.type_of_expr(expr)?.adjusted(); + let ty = ty.remove_ref().unwrap_or(ty); + + name_of_type(&ty, sema.db) +} + +fn name_of_type(ty: &hir::Type, db: &RootDatabase) -> Option { + let name = if let Some(adt) = ty.as_adt() { + let name = adt.name(db).to_string(); + + if WRAPPER_TYPES.contains(&name.as_str()) { + let inner_ty = ty.type_arguments().next()?; + return name_of_type(&inner_ty, db); + } + + name + } else if let Some(trait_) = ty.as_dyn_trait() { + trait_name(&trait_, db)? + } else if let Some(traits) = ty.as_impl_traits(db) { + let mut iter = traits.filter_map(|t| trait_name(&t, db)); + let name = iter.next()?; + if iter.next().is_some() { + return None; + } + name + } else { + return None; + }; + normalize(&name) +} + +fn trait_name(trait_: &hir::Trait, db: &RootDatabase) -> Option { + let name = trait_.name(db).to_string(); + if USELESS_TRAITS.contains(&name.as_str()) { + return None; + } + Some(name) +} + +fn from_field_name(expr: &ast::Expr) -> Option { + let field = match expr { + ast::Expr::FieldExpr(field) => field, + _ => return None, + }; + let ident = field.name_ref()?.ident_token()?; + normalize(ident.text()) +} + +#[cfg(test)] +mod tests { + use ide_db::base_db::{fixture::WithFixture, FileRange}; + + use super::*; + + #[track_caller] + fn check(ra_fixture: &str, expected: &str) { + let (db, file_id, range_or_offset) = RootDatabase::with_range_or_offset(ra_fixture); + let frange = FileRange { file_id, range: range_or_offset.into() }; + + let sema = Semantics::new(&db); + let source_file = sema.parse(frange.file_id); + let element = source_file.syntax().covering_element(frange.range); + let expr = + element.ancestors().find_map(ast::Expr::cast).expect("selection is not an expression"); + assert_eq!( + expr.syntax().text_range(), + frange.range, + "selection is not an expression(yet contained in one)" + ); + let name = for_variable(&expr, &sema); + assert_eq!(&name, expected); + } + + #[test] + fn no_args() { + check(r#"fn foo() { $0bar()$0 }"#, "bar"); + check(r#"fn foo() { $0bar.frobnicate()$0 }"#, "frobnicate"); + } + + #[test] + fn single_arg() { + check(r#"fn foo() { $0bar(1)$0 }"#, "bar"); + } + + #[test] + fn many_args() { + check(r#"fn foo() { $0bar(1, 2, 3)$0 }"#, "bar"); + } + + #[test] + fn path() { + check(r#"fn foo() { $0i32::bar(1, 2, 3)$0 }"#, "bar"); + } + + #[test] + fn generic_params() { + check(r#"fn foo() { $0bar::(1, 2, 3)$0 }"#, "bar"); + check(r#"fn foo() { $0bar.frobnicate::()$0 }"#, "frobnicate"); + } + + #[test] + fn to_name() { + check( + r#" +struct Args; +struct Config; +impl Args { + fn to_config(&self) -> Config {} +} +fn foo() { + $0Args.to_config()$0; +} +"#, + "config", + ); + } + + #[test] + fn plain_func() { + check( + r#" +fn bar(n: i32, m: u32); +fn foo() { bar($01$0, 2) } +"#, + "n", + ); + } + + #[test] + fn mut_param() { + check( + r#" +fn bar(mut n: i32, m: u32); +fn foo() { bar($01$0, 2) } +"#, + "n", + ); + } + + #[test] + fn func_does_not_exist() { + check(r#"fn foo() { bar($01$0, 2) }"#, "var_name"); + } + + #[test] + fn unnamed_param() { + check( + r#" +fn bar(_: i32, m: u32); +fn foo() { bar($01$0, 2) } +"#, + "var_name", + ); + } + + #[test] + fn tuple_pat() { + check( + r#" +fn bar((n, k): (i32, i32), m: u32); +fn foo() { + bar($0(1, 2)$0, 3) +} +"#, + "var_name", + ); + } + + #[test] + fn ref_pat() { + check( + r#" +fn bar(&n: &i32, m: u32); +fn foo() { bar($0&1$0, 3) } +"#, + "n", + ); + } + + #[test] + fn box_pat() { + check( + r#" +fn bar(box n: &i32, m: u32); +fn foo() { bar($01$0, 3) } +"#, + "n", + ); + } + + #[test] + fn param_out_of_index() { + check( + r#" +fn bar(n: i32, m: u32); +fn foo() { bar(1, 2, $03$0) } +"#, + "var_name", + ); + } + + #[test] + fn generic_param_resolved() { + check( + r#" +fn bar(n: T, m: u32); +fn foo() { bar($01$0, 2) } +"#, + "n", + ); + } + + #[test] + fn generic_param_unresolved() { + check( + r#" +fn bar(n: T, m: u32); +fn foo(x: T) { bar($0x$0, 2) } +"#, + "n", + ); + } + + #[test] + fn method() { + check( + r#" +struct S; +impl S { fn bar(&self, n: i32, m: u32); } +fn foo() { S.bar($01$0, 2) } +"#, + "n", + ); + } + + #[test] + fn method_on_impl_trait() { + check( + r#" +struct S; +trait T { + fn bar(&self, n: i32, m: u32); +} +impl T for S { fn bar(&self, n: i32, m: u32); } +fn foo() { S.bar($01$0, 2) } +"#, + "n", + ); + } + + #[test] + fn method_ufcs() { + check( + r#" +struct S; +impl S { fn bar(&self, n: i32, m: u32); } +fn foo() { S::bar(&S, $01$0, 2) } +"#, + "n", + ); + } + + #[test] + fn method_self() { + check( + r#" +struct S; +impl S { fn bar(&self, n: i32, m: u32); } +fn foo() { S::bar($0&S$0, 1, 2) } +"#, + "s", + ); + } + + #[test] + fn method_self_named() { + check( + r#" +struct S; +impl S { fn bar(strukt: &Self, n: i32, m: u32); } +fn foo() { S::bar($0&S$0, 1, 2) } +"#, + "strukt", + ); + } + + #[test] + fn i32() { + check(r#"fn foo() { let _: i32 = $01$0; }"#, "var_name"); + } + + #[test] + fn u64() { + check(r#"fn foo() { let _: u64 = $01$0; }"#, "var_name"); + } + + #[test] + fn bool() { + check(r#"fn foo() { let _: bool = $0true$0; }"#, "var_name"); + } + + #[test] + fn struct_unit() { + check( + r#" +struct Seed; +fn foo() { let _ = $0Seed$0; } +"#, + "seed", + ); + } + + #[test] + fn struct_unit_to_snake() { + check( + r#" +struct SeedState; +fn foo() { let _ = $0SeedState$0; } +"#, + "seed_state", + ); + } + + #[test] + fn struct_single_arg() { + check( + r#" +struct Seed(u32); +fn foo() { let _ = $0Seed(0)$0; } +"#, + "seed", + ); + } + + #[test] + fn struct_with_fields() { + check( + r#" +struct Seed { value: u32 } +fn foo() { let _ = $0Seed { value: 0 }$0; } +"#, + "seed", + ); + } + + #[test] + fn enum_() { + check( + r#" +enum Kind { A, B } +fn foo() { let _ = $0Kind::A$0; } +"#, + "kind", + ); + } + + #[test] + fn enum_generic_resolved() { + check( + r#" +enum Kind { A { x: T }, B } +fn foo() { let _ = $0Kind::A { x:1 }$0; } +"#, + "kind", + ); + } + + #[test] + fn enum_generic_unresolved() { + check( + r#" +enum Kind { A { x: T }, B } +fn foo(x: T) { let _ = $0Kind::A { x }$0; } +"#, + "kind", + ); + } + + #[test] + fn dyn_trait() { + check( + r#" +trait DynHandler {} +fn bar() -> dyn DynHandler {} +fn foo() { $0(bar())$0; } +"#, + "dyn_handler", + ); + } + + #[test] + fn impl_trait() { + check( + r#" +trait StaticHandler {} +fn bar() -> impl StaticHandler {} +fn foo() { $0(bar())$0; } +"#, + "static_handler", + ); + } + + #[test] + fn impl_trait_plus_clone() { + check( + r#" +trait StaticHandler {} +trait Clone {} +fn bar() -> impl StaticHandler + Clone {} +fn foo() { $0(bar())$0; } +"#, + "static_handler", + ); + } + + #[test] + fn impl_trait_plus_lifetime() { + check( + r#" +trait StaticHandler {} +trait Clone {} +fn bar<'a>(&'a i32) -> impl StaticHandler + 'a {} +fn foo() { $0(bar(&1))$0; } +"#, + "static_handler", + ); + } + + #[test] + fn impl_trait_plus_trait() { + check( + r#" +trait Handler {} +trait StaticHandler {} +fn bar() -> impl StaticHandler + Handler {} +fn foo() { $0(bar())$0; } +"#, + "bar", + ); + } + + #[test] + fn ref_value() { + check( + r#" +struct Seed; +fn bar() -> &Seed {} +fn foo() { $0(bar())$0; } +"#, + "seed", + ); + } + + #[test] + fn box_value() { + check( + r#" +struct Box(*const T); +struct Seed; +fn bar() -> Box {} +fn foo() { $0(bar())$0; } +"#, + "seed", + ); + } + + #[test] + fn box_generic() { + check( + r#" +struct Box(*const T); +fn bar() -> Box {} +fn foo() { $0(bar::())$0; } +"#, + "bar", + ); + } + + #[test] + fn option_value() { + check( + r#" +enum Option { Some(T) } +struct Seed; +fn bar() -> Option {} +fn foo() { $0(bar())$0; } +"#, + "seed", + ); + } + + #[test] + fn result_value() { + check( + r#" +enum Result { Ok(T), Err(E) } +struct Seed; +struct Error; +fn bar() -> Result {} +fn foo() { $0(bar())$0; } +"#, + "seed", + ); + } + + #[test] + fn ref_call() { + check( + r#" +fn foo() { $0&bar(1, 3)$0 } +"#, + "bar", + ); + } + + #[test] + fn name_to_string() { + check( + r#" +fn foo() { $0function.name().to_string()$0 } +"#, + "name", + ); + } + + #[test] + fn nested_useless_method() { + check( + r#" +fn foo() { $0function.name().as_ref().unwrap().to_string()$0 } +"#, + "name", + ); + } + + #[test] + fn struct_field_name() { + check( + r#" +struct S { + some_field: T; +} +fn foo(some_struct: S) { $0some_struct.some_field$0 } +"#, + "some_field", + ); + } +} -- cgit v1.2.3