diff options
Diffstat (limited to 'src/tools/rustfmt/src/format-diff/main.rs')
-rw-r--r-- | src/tools/rustfmt/src/format-diff/main.rs | 281 |
1 files changed, 281 insertions, 0 deletions
diff --git a/src/tools/rustfmt/src/format-diff/main.rs b/src/tools/rustfmt/src/format-diff/main.rs new file mode 100644 index 000000000..f6b739e1c --- /dev/null +++ b/src/tools/rustfmt/src/format-diff/main.rs @@ -0,0 +1,281 @@ +// Inspired by Clang's clang-format-diff: +// +// https://github.com/llvm-mirror/clang/blob/master/tools/clang-format/clang-format-diff.py + +#![deny(warnings)] + +#[macro_use] +extern crate log; + +use serde::{Deserialize, Serialize}; +use serde_json as json; +use thiserror::Error; + +use std::collections::HashSet; +use std::env; +use std::ffi::OsStr; +use std::io::{self, BufRead}; +use std::process; + +use regex::Regex; + +use clap::{CommandFactory, Parser}; + +/// The default pattern of files to format. +/// +/// We only want to format rust files by default. +const DEFAULT_PATTERN: &str = r".*\.rs"; + +#[derive(Error, Debug)] +enum FormatDiffError { + #[error("{0}")] + IncorrectOptions(#[from] getopts::Fail), + #[error("{0}")] + IncorrectFilter(#[from] regex::Error), + #[error("{0}")] + IoError(#[from] io::Error), +} + +#[derive(Parser, Debug)] +#[clap( + name = "rustfmt-format-diff", + disable_version_flag = true, + next_line_help = true +)] +pub struct Opts { + /// Skip the smallest prefix containing NUMBER slashes + #[clap( + short = 'p', + long = "skip-prefix", + value_name = "NUMBER", + default_value = "0" + )] + skip_prefix: u32, + + /// Custom pattern selecting file paths to reformat + #[clap( + short = 'f', + long = "filter", + value_name = "PATTERN", + default_value = DEFAULT_PATTERN + )] + filter: String, +} + +fn main() { + env_logger::Builder::from_env("RUSTFMT_LOG").init(); + let opts = Opts::parse(); + if let Err(e) = run(opts) { + println!("{}", e); + Opts::command() + .print_help() + .expect("cannot write to stdout"); + process::exit(1); + } +} + +#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)] +struct Range { + file: String, + range: [u32; 2], +} + +fn run(opts: Opts) -> Result<(), FormatDiffError> { + let (files, ranges) = scan_diff(io::stdin(), opts.skip_prefix, &opts.filter)?; + run_rustfmt(&files, &ranges) +} + +fn run_rustfmt(files: &HashSet<String>, ranges: &[Range]) -> Result<(), FormatDiffError> { + if files.is_empty() || ranges.is_empty() { + debug!("No files to format found"); + return Ok(()); + } + + let ranges_as_json = json::to_string(ranges).unwrap(); + + debug!("Files: {:?}", files); + debug!("Ranges: {:?}", ranges); + + let rustfmt_var = env::var_os("RUSTFMT"); + let rustfmt = match &rustfmt_var { + Some(rustfmt) => rustfmt, + None => OsStr::new("rustfmt"), + }; + let exit_status = process::Command::new(rustfmt) + .args(files) + .arg("--file-lines") + .arg(ranges_as_json) + .status()?; + + if !exit_status.success() { + return Err(FormatDiffError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("rustfmt failed with {}", exit_status), + ))); + } + Ok(()) +} + +/// Scans a diff from `from`, and returns the set of files found, and the ranges +/// in those files. +fn scan_diff<R>( + from: R, + skip_prefix: u32, + file_filter: &str, +) -> Result<(HashSet<String>, Vec<Range>), FormatDiffError> +where + R: io::Read, +{ + let diff_pattern = format!(r"^\+\+\+\s(?:.*?/){{{}}}(\S*)", skip_prefix); + let diff_pattern = Regex::new(&diff_pattern).unwrap(); + + let lines_pattern = Regex::new(r"^@@.*\+(\d+)(,(\d+))?").unwrap(); + + let file_filter = Regex::new(&format!("^{}$", file_filter))?; + + let mut current_file = None; + + let mut files = HashSet::new(); + let mut ranges = vec![]; + for line in io::BufReader::new(from).lines() { + let line = line.unwrap(); + + if let Some(captures) = diff_pattern.captures(&line) { + current_file = Some(captures.get(1).unwrap().as_str().to_owned()); + } + + let file = match current_file { + Some(ref f) => &**f, + None => continue, + }; + + // FIXME(emilio): We could avoid this most of the time if needed, but + // it's not clear it's worth it. + if !file_filter.is_match(file) { + continue; + } + + let lines_captures = match lines_pattern.captures(&line) { + Some(captures) => captures, + None => continue, + }; + + let start_line = lines_captures + .get(1) + .unwrap() + .as_str() + .parse::<u32>() + .unwrap(); + let line_count = match lines_captures.get(3) { + Some(line_count) => line_count.as_str().parse::<u32>().unwrap(), + None => 1, + }; + + if line_count == 0 { + continue; + } + + let end_line = start_line + line_count - 1; + files.insert(file.to_owned()); + ranges.push(Range { + file: file.to_owned(), + range: [start_line, end_line], + }); + } + + Ok((files, ranges)) +} + +#[test] +fn scan_simple_git_diff() { + const DIFF: &str = include_str!("test/bindgen.diff"); + let (files, ranges) = scan_diff(DIFF.as_bytes(), 1, r".*\.rs").expect("scan_diff failed?"); + + assert!( + files.contains("src/ir/traversal.rs"), + "Should've matched the filter" + ); + + assert!( + !files.contains("tests/headers/anon_enum.hpp"), + "Shouldn't have matched the filter" + ); + + assert_eq!( + &ranges, + &[ + Range { + file: "src/ir/item.rs".to_owned(), + range: [148, 158], + }, + Range { + file: "src/ir/item.rs".to_owned(), + range: [160, 170], + }, + Range { + file: "src/ir/traversal.rs".to_owned(), + range: [9, 16], + }, + Range { + file: "src/ir/traversal.rs".to_owned(), + range: [35, 43], + }, + ] + ); +} + +#[cfg(test)] +mod cmd_line_tests { + use super::*; + + #[test] + fn default_options() { + let empty: Vec<String> = vec![]; + let o = Opts::parse_from(&empty); + assert_eq!(DEFAULT_PATTERN, o.filter); + assert_eq!(0, o.skip_prefix); + } + + #[test] + fn good_options() { + let o = Opts::parse_from(&["test", "-p", "10", "-f", r".*\.hs"]); + assert_eq!(r".*\.hs", o.filter); + assert_eq!(10, o.skip_prefix); + } + + #[test] + fn unexpected_option() { + assert!( + Opts::command() + .try_get_matches_from(&["test", "unexpected"]) + .is_err() + ); + } + + #[test] + fn unexpected_flag() { + assert!( + Opts::command() + .try_get_matches_from(&["test", "--flag"]) + .is_err() + ); + } + + #[test] + fn overridden_option() { + assert!( + Opts::command() + .try_get_matches_from(&["test", "-p", "10", "-p", "20"]) + .is_err() + ); + } + + #[test] + fn negative_filter() { + assert!( + Opts::command() + .try_get_matches_from(&["test", "-p", "-1"]) + .is_err() + ); + } +} |