From 26a029d407be480d791972afb5975cf62c9360a6 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 19 Apr 2024 02:47:55 +0200 Subject: Adding upstream version 124.0.1. Signed-off-by: Daniel Baumann --- third_party/aom/tools/aggregate_entropy_stats.py | 39 + third_party/aom/tools/aom_entropy_optimizer.c | 761 +++ .../aom/tools/auto_refactor/auto_refactor.py | 919 +++ .../aom/tools/auto_refactor/av1_preprocess.py | 113 + .../tools/auto_refactor/c_files/decl_status_code.c | 31 + .../aom/tools/auto_refactor/c_files/func_in_out.c | 208 + .../tools/auto_refactor/c_files/global_variable.c | 27 + .../aom/tools/auto_refactor/c_files/parse_lvalue.c | 46 + .../aom/tools/auto_refactor/c_files/simple_code.c | 64 + .../aom/tools/auto_refactor/c_files/struct_code.c | 49 + .../aom/tools/auto_refactor/test_auto_refactor.py | 675 +++ third_party/aom/tools/cpplint.py | 6244 ++++++++++++++++++++ third_party/aom/tools/diff.py | 132 + third_party/aom/tools/dump_obu.cc | 168 + .../aom/tools/frame_size_variation_analyzer.py | 74 + third_party/aom/tools/gen_authors.sh | 10 + third_party/aom/tools/gen_constrained_tokenset.py | 120 + third_party/aom/tools/gop_bitrate/analyze_data.py | 18 + .../aom/tools/gop_bitrate/encode_all_script.sh | 13 + .../tools/gop_bitrate/python/bitrate_accuracy.py | 185 + third_party/aom/tools/inspect-cli.js | 39 + third_party/aom/tools/inspect-post.js | 1 + third_party/aom/tools/intersect-diffs.py | 78 + third_party/aom/tools/lint-hunks.py | 150 + third_party/aom/tools/obu_parser.cc | 190 + third_party/aom/tools/obu_parser.h | 27 + .../ratectrl_log_analyzer/analyze_ratectrl_log.py | 154 + .../aom/tools/txfm_analyzer/txfm_gen_code.cc | 580 ++ third_party/aom/tools/txfm_analyzer/txfm_graph.cc | 943 +++ third_party/aom/tools/txfm_analyzer/txfm_graph.h | 160 + third_party/aom/tools/wrap-commit-msg.py | 72 + 31 files changed, 12290 insertions(+) create mode 100644 third_party/aom/tools/aggregate_entropy_stats.py create mode 100644 third_party/aom/tools/aom_entropy_optimizer.c create mode 100644 third_party/aom/tools/auto_refactor/auto_refactor.py create mode 100644 third_party/aom/tools/auto_refactor/av1_preprocess.py create mode 100644 third_party/aom/tools/auto_refactor/c_files/decl_status_code.c create mode 100644 third_party/aom/tools/auto_refactor/c_files/func_in_out.c create mode 100644 third_party/aom/tools/auto_refactor/c_files/global_variable.c create mode 100644 third_party/aom/tools/auto_refactor/c_files/parse_lvalue.c create mode 100644 third_party/aom/tools/auto_refactor/c_files/simple_code.c create mode 100644 third_party/aom/tools/auto_refactor/c_files/struct_code.c create mode 100644 third_party/aom/tools/auto_refactor/test_auto_refactor.py create mode 100755 third_party/aom/tools/cpplint.py create mode 100644 third_party/aom/tools/diff.py create mode 100644 third_party/aom/tools/dump_obu.cc create mode 100644 third_party/aom/tools/frame_size_variation_analyzer.py create mode 100755 third_party/aom/tools/gen_authors.sh create mode 100755 third_party/aom/tools/gen_constrained_tokenset.py create mode 100644 third_party/aom/tools/gop_bitrate/analyze_data.py create mode 100755 third_party/aom/tools/gop_bitrate/encode_all_script.sh create mode 100644 third_party/aom/tools/gop_bitrate/python/bitrate_accuracy.py create mode 100644 third_party/aom/tools/inspect-cli.js create mode 100644 third_party/aom/tools/inspect-post.js create mode 100755 third_party/aom/tools/intersect-diffs.py create mode 100755 third_party/aom/tools/lint-hunks.py create mode 100644 third_party/aom/tools/obu_parser.cc create mode 100644 third_party/aom/tools/obu_parser.h create mode 100644 third_party/aom/tools/ratectrl_log_analyzer/analyze_ratectrl_log.py create mode 100644 third_party/aom/tools/txfm_analyzer/txfm_gen_code.cc create mode 100644 third_party/aom/tools/txfm_analyzer/txfm_graph.cc create mode 100644 third_party/aom/tools/txfm_analyzer/txfm_graph.h create mode 100755 third_party/aom/tools/wrap-commit-msg.py (limited to 'third_party/aom/tools') diff --git a/third_party/aom/tools/aggregate_entropy_stats.py b/third_party/aom/tools/aggregate_entropy_stats.py new file mode 100644 index 0000000000..0311681f2d --- /dev/null +++ b/third_party/aom/tools/aggregate_entropy_stats.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +## Copyright (c) 2017, Alliance for Open Media. All rights reserved +## +## This source code is subject to the terms of the BSD 2 Clause License and +## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License +## was not distributed with this source code in the LICENSE file, you can +## obtain it at www.aomedia.org/license/software. If the Alliance for Open +## Media Patent License 1.0 was not distributed with this source code in the +## PATENTS file, you can obtain it at www.aomedia.org/license/patent. +## +"""Aggregate multiple entropy stats output which is written in 32-bit int. + +python ./aggregate_entropy_stats.py [dir of stats files] [keyword of filenames] + [filename of final stats] +""" + +__author__ = "yuec@google.com" + +import os +import sys +import numpy as np + +def main(): + dir = sys.argv[1] + sum = [] + for fn in os.listdir(dir): + if sys.argv[2] in fn: + stats = np.fromfile(dir + fn, dtype=np.int32) + if len(sum) == 0: + sum = stats + else: + sum = np.add(sum, stats) + if len(sum) == 0: + print("No stats file is found. Double-check directory and keyword?") + else: + sum.tofile(dir+sys.argv[3]) + +if __name__ == '__main__': + main() diff --git a/third_party/aom/tools/aom_entropy_optimizer.c b/third_party/aom/tools/aom_entropy_optimizer.c new file mode 100644 index 0000000000..fa7bf7ea9e --- /dev/null +++ b/third_party/aom/tools/aom_entropy_optimizer.c @@ -0,0 +1,761 @@ +/* + * Copyright (c) 2017, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +// This tool is a gadget for offline probability training. +// A binary executable aom_entropy_optimizer will be generated in tools/. It +// parses a binary file consisting of counts written in the format of +// FRAME_COUNTS in entropymode.h, and computes optimized probability tables +// and CDF tables, which will be written to a new c file optimized_probs.c +// according to format in the codebase. +// +// Command line: ./aom_entropy_optimizer [directory of the count file] +// +// The input file can either be generated by encoding a single clip by +// turning on entropy_stats experiment, or be collected at a larger scale at +// which a python script which will be provided soon can be used to aggregate +// multiple stats output. + +#include +#include + +#include "config/aom_config.h" + +#include "av1/encoder/encoder.h" + +#define SPACES_PER_TAB 2 +#define CDF_MAX_SIZE 16 + +typedef unsigned int aom_count_type; +// A log file recording parsed counts +static FILE *logfile; // TODO(yuec): make it a command line option + +static void counts_to_cdf(const aom_count_type *counts, aom_cdf_prob *cdf, + int modes) { + int64_t csum[CDF_MAX_SIZE]; + assert(modes <= CDF_MAX_SIZE); + + csum[0] = counts[0] + 1; + for (int i = 1; i < modes; ++i) csum[i] = counts[i] + 1 + csum[i - 1]; + + for (int i = 0; i < modes; ++i) fprintf(logfile, "%d ", counts[i]); + fprintf(logfile, "\n"); + + int64_t sum = csum[modes - 1]; + const int64_t round_shift = sum >> 1; + for (int i = 0; i < modes; ++i) { + cdf[i] = (csum[i] * CDF_PROB_TOP + round_shift) / sum; + cdf[i] = AOMMIN(cdf[i], CDF_PROB_TOP - (modes - 1 + i) * 4); + cdf[i] = (i == 0) ? AOMMAX(cdf[i], 4) : AOMMAX(cdf[i], cdf[i - 1] + 4); + } +} + +static int parse_counts_for_cdf_opt(aom_count_type **ct_ptr, + FILE *const probsfile, int tabs, + int dim_of_cts, int *cts_each_dim) { + if (dim_of_cts < 1) { + fprintf(stderr, "The dimension of a counts vector should be at least 1!\n"); + return 1; + } + const int total_modes = cts_each_dim[0]; + if (dim_of_cts == 1) { + assert(total_modes <= CDF_MAX_SIZE); + aom_cdf_prob cdfs[CDF_MAX_SIZE]; + aom_count_type *counts1d = *ct_ptr; + + counts_to_cdf(counts1d, cdfs, total_modes); + (*ct_ptr) += total_modes; + + if (tabs > 0) fprintf(probsfile, "%*c", tabs * SPACES_PER_TAB, ' '); + fprintf(probsfile, "AOM_CDF%d(", total_modes); + for (int k = 0; k < total_modes - 1; ++k) { + fprintf(probsfile, "%d", cdfs[k]); + if (k < total_modes - 2) fprintf(probsfile, ", "); + } + fprintf(probsfile, ")"); + } else { + for (int k = 0; k < total_modes; ++k) { + int tabs_next_level; + + if (dim_of_cts == 2) + fprintf(probsfile, "%*c{ ", tabs * SPACES_PER_TAB, ' '); + else + fprintf(probsfile, "%*c{\n", tabs * SPACES_PER_TAB, ' '); + tabs_next_level = dim_of_cts == 2 ? 0 : tabs + 1; + + if (parse_counts_for_cdf_opt(ct_ptr, probsfile, tabs_next_level, + dim_of_cts - 1, cts_each_dim + 1)) { + return 1; + } + + if (dim_of_cts == 2) { + if (k == total_modes - 1) + fprintf(probsfile, " }\n"); + else + fprintf(probsfile, " },\n"); + } else { + if (k == total_modes - 1) + fprintf(probsfile, "%*c}\n", tabs * SPACES_PER_TAB, ' '); + else + fprintf(probsfile, "%*c},\n", tabs * SPACES_PER_TAB, ' '); + } + } + } + return 0; +} + +static void optimize_cdf_table(aom_count_type *counts, FILE *const probsfile, + int dim_of_cts, int *cts_each_dim, + char *prefix) { + aom_count_type *ct_ptr = counts; + + fprintf(probsfile, "%s = {\n", prefix); + fprintf(logfile, "%s\n", prefix); + if (parse_counts_for_cdf_opt(&ct_ptr, probsfile, 1, dim_of_cts, + cts_each_dim)) { + fprintf(probsfile, "Optimizer failed!\n"); + } + fprintf(probsfile, "};\n\n"); + fprintf(logfile, "============================\n"); +} + +static void optimize_uv_mode(aom_count_type *counts, FILE *const probsfile, + int dim_of_cts, int *cts_each_dim, char *prefix) { + aom_count_type *ct_ptr = counts; + + fprintf(probsfile, "%s = {\n", prefix); + fprintf(probsfile, "%*c{\n", SPACES_PER_TAB, ' '); + fprintf(logfile, "%s\n", prefix); + cts_each_dim[2] = UV_INTRA_MODES - 1; + for (int k = 0; k < cts_each_dim[1]; ++k) { + fprintf(probsfile, "%*c{ ", 2 * SPACES_PER_TAB, ' '); + parse_counts_for_cdf_opt(&ct_ptr, probsfile, 0, dim_of_cts - 2, + cts_each_dim + 2); + if (k + 1 == cts_each_dim[1]) { + fprintf(probsfile, " }\n"); + } else { + fprintf(probsfile, " },\n"); + } + ++ct_ptr; + } + fprintf(probsfile, "%*c},\n", SPACES_PER_TAB, ' '); + fprintf(probsfile, "%*c{\n", SPACES_PER_TAB, ' '); + cts_each_dim[2] = UV_INTRA_MODES; + parse_counts_for_cdf_opt(&ct_ptr, probsfile, 2, dim_of_cts - 1, + cts_each_dim + 1); + fprintf(probsfile, "%*c}\n", SPACES_PER_TAB, ' '); + fprintf(probsfile, "};\n\n"); + fprintf(logfile, "============================\n"); +} + +static void optimize_cdf_table_var_modes_2d(aom_count_type *counts, + FILE *const probsfile, + int dim_of_cts, int *cts_each_dim, + int *modes_each_ctx, char *prefix) { + aom_count_type *ct_ptr = counts; + + assert(dim_of_cts == 2); + (void)dim_of_cts; + + fprintf(probsfile, "%s = {\n", prefix); + fprintf(logfile, "%s\n", prefix); + + for (int d0_idx = 0; d0_idx < cts_each_dim[0]; ++d0_idx) { + int num_of_modes = modes_each_ctx[d0_idx]; + + if (num_of_modes > 0) { + fprintf(probsfile, "%*c{ ", SPACES_PER_TAB, ' '); + parse_counts_for_cdf_opt(&ct_ptr, probsfile, 0, 1, &num_of_modes); + ct_ptr += cts_each_dim[1] - num_of_modes; + fprintf(probsfile, " },\n"); + } else { + fprintf(probsfile, "%*c{ 0 },\n", SPACES_PER_TAB, ' '); + fprintf(logfile, "dummy cdf, no need to optimize\n"); + ct_ptr += cts_each_dim[1]; + } + } + fprintf(probsfile, "};\n\n"); + fprintf(logfile, "============================\n"); +} + +static void optimize_cdf_table_var_modes_3d(aom_count_type *counts, + FILE *const probsfile, + int dim_of_cts, int *cts_each_dim, + int *modes_each_ctx, char *prefix) { + aom_count_type *ct_ptr = counts; + + assert(dim_of_cts == 3); + (void)dim_of_cts; + + fprintf(probsfile, "%s = {\n", prefix); + fprintf(logfile, "%s\n", prefix); + + for (int d0_idx = 0; d0_idx < cts_each_dim[0]; ++d0_idx) { + fprintf(probsfile, "%*c{\n", SPACES_PER_TAB, ' '); + for (int d1_idx = 0; d1_idx < cts_each_dim[1]; ++d1_idx) { + int num_of_modes = modes_each_ctx[d0_idx]; + + if (num_of_modes > 0) { + fprintf(probsfile, "%*c{ ", 2 * SPACES_PER_TAB, ' '); + parse_counts_for_cdf_opt(&ct_ptr, probsfile, 0, 1, &num_of_modes); + ct_ptr += cts_each_dim[2] - num_of_modes; + fprintf(probsfile, " },\n"); + } else { + fprintf(probsfile, "%*c{ 0 },\n", 2 * SPACES_PER_TAB, ' '); + fprintf(logfile, "dummy cdf, no need to optimize\n"); + ct_ptr += cts_each_dim[2]; + } + } + fprintf(probsfile, "%*c},\n", SPACES_PER_TAB, ' '); + } + fprintf(probsfile, "};\n\n"); + fprintf(logfile, "============================\n"); +} + +static void optimize_cdf_table_var_modes_4d(aom_count_type *counts, + FILE *const probsfile, + int dim_of_cts, int *cts_each_dim, + int *modes_each_ctx, char *prefix) { + aom_count_type *ct_ptr = counts; + + assert(dim_of_cts == 4); + (void)dim_of_cts; + + fprintf(probsfile, "%s = {\n", prefix); + fprintf(logfile, "%s\n", prefix); + + for (int d0_idx = 0; d0_idx < cts_each_dim[0]; ++d0_idx) { + fprintf(probsfile, "%*c{\n", SPACES_PER_TAB, ' '); + for (int d1_idx = 0; d1_idx < cts_each_dim[1]; ++d1_idx) { + fprintf(probsfile, "%*c{\n", 2 * SPACES_PER_TAB, ' '); + for (int d2_idx = 0; d2_idx < cts_each_dim[2]; ++d2_idx) { + int num_of_modes = modes_each_ctx[d0_idx]; + + if (num_of_modes > 0) { + fprintf(probsfile, "%*c{ ", 3 * SPACES_PER_TAB, ' '); + parse_counts_for_cdf_opt(&ct_ptr, probsfile, 0, 1, &num_of_modes); + ct_ptr += cts_each_dim[3] - num_of_modes; + fprintf(probsfile, " },\n"); + } else { + fprintf(probsfile, "%*c{ 0 },\n", 3 * SPACES_PER_TAB, ' '); + fprintf(logfile, "dummy cdf, no need to optimize\n"); + ct_ptr += cts_each_dim[3]; + } + } + fprintf(probsfile, "%*c},\n", 2 * SPACES_PER_TAB, ' '); + } + fprintf(probsfile, "%*c},\n", SPACES_PER_TAB, ' '); + } + fprintf(probsfile, "};\n\n"); + fprintf(logfile, "============================\n"); +} + +int main(int argc, const char **argv) { + if (argc < 2) { + fprintf(stderr, "Please specify the input stats file!\n"); + exit(EXIT_FAILURE); + } + + FILE *const statsfile = fopen(argv[1], "rb"); + if (statsfile == NULL) { + fprintf(stderr, "Failed to open input file!\n"); + exit(EXIT_FAILURE); + } + + FRAME_COUNTS fc; + const size_t bytes = fread(&fc, sizeof(FRAME_COUNTS), 1, statsfile); + if (!bytes) { + fclose(statsfile); + return 1; + } + + FILE *const probsfile = fopen("optimized_probs.c", "w"); + if (probsfile == NULL) { + fprintf(stderr, + "Failed to create output file for optimized entropy tables!\n"); + exit(EXIT_FAILURE); + } + + logfile = fopen("aom_entropy_optimizer_parsed_counts.log", "w"); + if (logfile == NULL) { + fprintf(stderr, "Failed to create log file for parsed counts!\n"); + exit(EXIT_FAILURE); + } + + int cts_each_dim[10]; + + /* Intra mode (keyframe luma) */ + cts_each_dim[0] = KF_MODE_CONTEXTS; + cts_each_dim[1] = KF_MODE_CONTEXTS; + cts_each_dim[2] = INTRA_MODES; + optimize_cdf_table(&fc.kf_y_mode[0][0][0], probsfile, 3, cts_each_dim, + "const aom_cdf_prob\n" + "default_kf_y_mode_cdf[KF_MODE_CONTEXTS][KF_MODE_CONTEXTS]" + "[CDF_SIZE(INTRA_MODES)]"); + + cts_each_dim[0] = DIRECTIONAL_MODES; + cts_each_dim[1] = 2 * MAX_ANGLE_DELTA + 1; + optimize_cdf_table(&fc.angle_delta[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob default_angle_delta_cdf" + "[DIRECTIONAL_MODES][CDF_SIZE(2 * MAX_ANGLE_DELTA + 1)]"); + + /* Intra mode (non-keyframe luma) */ + cts_each_dim[0] = BLOCK_SIZE_GROUPS; + cts_each_dim[1] = INTRA_MODES; + optimize_cdf_table( + &fc.y_mode[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob\n" + "default_if_y_mode_cdf[BLOCK_SIZE_GROUPS][CDF_SIZE(INTRA_MODES)]"); + + /* Intra mode (chroma) */ + cts_each_dim[0] = CFL_ALLOWED_TYPES; + cts_each_dim[1] = INTRA_MODES; + cts_each_dim[2] = UV_INTRA_MODES; + optimize_uv_mode(&fc.uv_mode[0][0][0], probsfile, 3, cts_each_dim, + "static const aom_cdf_prob\n" + "default_uv_mode_cdf[CFL_ALLOWED_TYPES][INTRA_MODES]" + "[CDF_SIZE(UV_INTRA_MODES)]"); + + /* block partition */ + cts_each_dim[0] = PARTITION_CONTEXTS; + cts_each_dim[1] = EXT_PARTITION_TYPES; + int part_types_each_ctx[PARTITION_CONTEXTS] = { 4, 4, 4, 4, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, + 10, 10, 8, 8, 8, 8 }; + optimize_cdf_table_var_modes_2d( + &fc.partition[0][0], probsfile, 2, cts_each_dim, part_types_each_ctx, + "static const aom_cdf_prob default_partition_cdf[PARTITION_CONTEXTS]" + "[CDF_SIZE(EXT_PARTITION_TYPES)]"); + + /* tx type */ + cts_each_dim[0] = EXT_TX_SETS_INTRA; + cts_each_dim[1] = EXT_TX_SIZES; + cts_each_dim[2] = INTRA_MODES; + cts_each_dim[3] = TX_TYPES; + int intra_ext_tx_types_each_ctx[EXT_TX_SETS_INTRA] = { 0, 7, 5 }; + optimize_cdf_table_var_modes_4d( + &fc.intra_ext_tx[0][0][0][0], probsfile, 4, cts_each_dim, + intra_ext_tx_types_each_ctx, + "static const aom_cdf_prob default_intra_ext_tx_cdf[EXT_TX_SETS_INTRA]" + "[EXT_TX_SIZES][INTRA_MODES][CDF_SIZE(TX_TYPES)]"); + + cts_each_dim[0] = EXT_TX_SETS_INTER; + cts_each_dim[1] = EXT_TX_SIZES; + cts_each_dim[2] = TX_TYPES; + int inter_ext_tx_types_each_ctx[EXT_TX_SETS_INTER] = { 0, 16, 12, 2 }; + optimize_cdf_table_var_modes_3d( + &fc.inter_ext_tx[0][0][0], probsfile, 3, cts_each_dim, + inter_ext_tx_types_each_ctx, + "static const aom_cdf_prob default_inter_ext_tx_cdf[EXT_TX_SETS_INTER]" + "[EXT_TX_SIZES][CDF_SIZE(TX_TYPES)]"); + + /* Chroma from Luma */ + cts_each_dim[0] = CFL_JOINT_SIGNS; + optimize_cdf_table(&fc.cfl_sign[0], probsfile, 1, cts_each_dim, + "static const aom_cdf_prob\n" + "default_cfl_sign_cdf[CDF_SIZE(CFL_JOINT_SIGNS)]"); + cts_each_dim[0] = CFL_ALPHA_CONTEXTS; + cts_each_dim[1] = CFL_ALPHABET_SIZE; + optimize_cdf_table(&fc.cfl_alpha[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob\n" + "default_cfl_alpha_cdf[CFL_ALPHA_CONTEXTS]" + "[CDF_SIZE(CFL_ALPHABET_SIZE)]"); + + /* Interpolation filter */ + cts_each_dim[0] = SWITCHABLE_FILTER_CONTEXTS; + cts_each_dim[1] = SWITCHABLE_FILTERS; + optimize_cdf_table(&fc.switchable_interp[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob\n" + "default_switchable_interp_cdf[SWITCHABLE_FILTER_CONTEXTS]" + "[CDF_SIZE(SWITCHABLE_FILTERS)]"); + + /* Motion vector referencing */ + cts_each_dim[0] = NEWMV_MODE_CONTEXTS; + cts_each_dim[1] = 2; + optimize_cdf_table(&fc.newmv_mode[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob " + "default_newmv_cdf[NEWMV_MODE_CONTEXTS][CDF_SIZE(2)]"); + + cts_each_dim[0] = GLOBALMV_MODE_CONTEXTS; + cts_each_dim[1] = 2; + optimize_cdf_table(&fc.zeromv_mode[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob " + "default_zeromv_cdf[GLOBALMV_MODE_CONTEXTS][CDF_SIZE(2)]"); + + cts_each_dim[0] = REFMV_MODE_CONTEXTS; + cts_each_dim[1] = 2; + optimize_cdf_table(&fc.refmv_mode[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob " + "default_refmv_cdf[REFMV_MODE_CONTEXTS][CDF_SIZE(2)]"); + + cts_each_dim[0] = DRL_MODE_CONTEXTS; + cts_each_dim[1] = 2; + optimize_cdf_table(&fc.drl_mode[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob " + "default_drl_cdf[DRL_MODE_CONTEXTS][CDF_SIZE(2)]"); + + /* ext_inter experiment */ + /* New compound mode */ + cts_each_dim[0] = INTER_MODE_CONTEXTS; + cts_each_dim[1] = INTER_COMPOUND_MODES; + optimize_cdf_table(&fc.inter_compound_mode[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob\n" + "default_inter_compound_mode_cdf[INTER_MODE_CONTEXTS][CDF_" + "SIZE(INTER_COMPOUND_MODES)]"); + + /* Interintra */ + cts_each_dim[0] = BLOCK_SIZE_GROUPS; + cts_each_dim[1] = 2; + optimize_cdf_table(&fc.interintra[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob " + "default_interintra_cdf[BLOCK_SIZE_GROUPS][CDF_SIZE(2)]"); + + cts_each_dim[0] = BLOCK_SIZE_GROUPS; + cts_each_dim[1] = INTERINTRA_MODES; + optimize_cdf_table(&fc.interintra_mode[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob\n" + "default_interintra_mode_cdf[BLOCK_SIZE_GROUPS][CDF_SIZE(" + "INTERINTRA_MODES)]"); + + cts_each_dim[0] = BLOCK_SIZES_ALL; + cts_each_dim[1] = 2; + optimize_cdf_table( + &fc.wedge_interintra[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob\n" + "default_wedge_interintra_cdf[BLOCK_SIZES_ALL][CDF_SIZE(2)]"); + + /* Compound type */ + cts_each_dim[0] = BLOCK_SIZES_ALL; + cts_each_dim[1] = COMPOUND_TYPES - 1; + optimize_cdf_table(&fc.compound_type[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob default_compound_type_cdf" + "[BLOCK_SIZES_ALL][CDF_SIZE(COMPOUND_TYPES - 1)]"); + + cts_each_dim[0] = BLOCK_SIZES_ALL; + cts_each_dim[1] = 16; + optimize_cdf_table(&fc.wedge_idx[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob " + "default_wedge_idx_cdf[BLOCK_SIZES_ALL][CDF_SIZE(16)]"); + + /* motion_var and warped_motion experiments */ + cts_each_dim[0] = BLOCK_SIZES_ALL; + cts_each_dim[1] = MOTION_MODES; + optimize_cdf_table( + &fc.motion_mode[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob\n" + "default_motion_mode_cdf[BLOCK_SIZES_ALL][CDF_SIZE(MOTION_MODES)]"); + cts_each_dim[0] = BLOCK_SIZES_ALL; + cts_each_dim[1] = 2; + optimize_cdf_table(&fc.obmc[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob " + "default_obmc_cdf[BLOCK_SIZES_ALL][CDF_SIZE(2)]"); + + /* Intra/inter flag */ + cts_each_dim[0] = INTRA_INTER_CONTEXTS; + cts_each_dim[1] = 2; + optimize_cdf_table( + &fc.intra_inter[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob\n" + "default_intra_inter_cdf[INTRA_INTER_CONTEXTS][CDF_SIZE(2)]"); + + /* Single/comp ref flag */ + cts_each_dim[0] = COMP_INTER_CONTEXTS; + cts_each_dim[1] = 2; + optimize_cdf_table( + &fc.comp_inter[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob\n" + "default_comp_inter_cdf[COMP_INTER_CONTEXTS][CDF_SIZE(2)]"); + + /* ext_comp_refs experiment */ + cts_each_dim[0] = COMP_REF_TYPE_CONTEXTS; + cts_each_dim[1] = 2; + optimize_cdf_table( + &fc.comp_ref_type[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob\n" + "default_comp_ref_type_cdf[COMP_REF_TYPE_CONTEXTS][CDF_SIZE(2)]"); + + cts_each_dim[0] = UNI_COMP_REF_CONTEXTS; + cts_each_dim[1] = UNIDIR_COMP_REFS - 1; + cts_each_dim[2] = 2; + optimize_cdf_table(&fc.uni_comp_ref[0][0][0], probsfile, 3, cts_each_dim, + "static const aom_cdf_prob\n" + "default_uni_comp_ref_cdf[UNI_COMP_REF_CONTEXTS][UNIDIR_" + "COMP_REFS - 1][CDF_SIZE(2)]"); + + /* Reference frame (single ref) */ + cts_each_dim[0] = REF_CONTEXTS; + cts_each_dim[1] = SINGLE_REFS - 1; + cts_each_dim[2] = 2; + optimize_cdf_table( + &fc.single_ref[0][0][0], probsfile, 3, cts_each_dim, + "static const aom_cdf_prob\n" + "default_single_ref_cdf[REF_CONTEXTS][SINGLE_REFS - 1][CDF_SIZE(2)]"); + + /* ext_refs experiment */ + cts_each_dim[0] = REF_CONTEXTS; + cts_each_dim[1] = FWD_REFS - 1; + cts_each_dim[2] = 2; + optimize_cdf_table( + &fc.comp_ref[0][0][0], probsfile, 3, cts_each_dim, + "static const aom_cdf_prob\n" + "default_comp_ref_cdf[REF_CONTEXTS][FWD_REFS - 1][CDF_SIZE(2)]"); + + cts_each_dim[0] = REF_CONTEXTS; + cts_each_dim[1] = BWD_REFS - 1; + cts_each_dim[2] = 2; + optimize_cdf_table( + &fc.comp_bwdref[0][0][0], probsfile, 3, cts_each_dim, + "static const aom_cdf_prob\n" + "default_comp_bwdref_cdf[REF_CONTEXTS][BWD_REFS - 1][CDF_SIZE(2)]"); + + /* palette */ + cts_each_dim[0] = PALATTE_BSIZE_CTXS; + cts_each_dim[1] = PALETTE_SIZES; + optimize_cdf_table(&fc.palette_y_size[0][0], probsfile, 2, cts_each_dim, + "const aom_cdf_prob default_palette_y_size_cdf" + "[PALATTE_BSIZE_CTXS][CDF_SIZE(PALETTE_SIZES)]"); + + cts_each_dim[0] = PALATTE_BSIZE_CTXS; + cts_each_dim[1] = PALETTE_SIZES; + optimize_cdf_table(&fc.palette_uv_size[0][0], probsfile, 2, cts_each_dim, + "const aom_cdf_prob default_palette_uv_size_cdf" + "[PALATTE_BSIZE_CTXS][CDF_SIZE(PALETTE_SIZES)]"); + + cts_each_dim[0] = PALATTE_BSIZE_CTXS; + cts_each_dim[1] = PALETTE_Y_MODE_CONTEXTS; + cts_each_dim[2] = 2; + optimize_cdf_table(&fc.palette_y_mode[0][0][0], probsfile, 3, cts_each_dim, + "const aom_cdf_prob default_palette_y_mode_cdf" + "[PALATTE_BSIZE_CTXS][PALETTE_Y_MODE_CONTEXTS]" + "[CDF_SIZE(2)]"); + + cts_each_dim[0] = PALETTE_UV_MODE_CONTEXTS; + cts_each_dim[1] = 2; + optimize_cdf_table(&fc.palette_uv_mode[0][0], probsfile, 2, cts_each_dim, + "const aom_cdf_prob default_palette_uv_mode_cdf" + "[PALETTE_UV_MODE_CONTEXTS][CDF_SIZE(2)]"); + + cts_each_dim[0] = PALETTE_SIZES; + cts_each_dim[1] = PALETTE_COLOR_INDEX_CONTEXTS; + cts_each_dim[2] = PALETTE_COLORS; + int palette_color_indexes_each_ctx[PALETTE_SIZES] = { 2, 3, 4, 5, 6, 7, 8 }; + optimize_cdf_table_var_modes_3d( + &fc.palette_y_color_index[0][0][0], probsfile, 3, cts_each_dim, + palette_color_indexes_each_ctx, + "const aom_cdf_prob default_palette_y_color_index_cdf[PALETTE_SIZES]" + "[PALETTE_COLOR_INDEX_CONTEXTS][CDF_SIZE(PALETTE_COLORS)]"); + + cts_each_dim[0] = PALETTE_SIZES; + cts_each_dim[1] = PALETTE_COLOR_INDEX_CONTEXTS; + cts_each_dim[2] = PALETTE_COLORS; + optimize_cdf_table_var_modes_3d( + &fc.palette_uv_color_index[0][0][0], probsfile, 3, cts_each_dim, + palette_color_indexes_each_ctx, + "const aom_cdf_prob default_palette_uv_color_index_cdf[PALETTE_SIZES]" + "[PALETTE_COLOR_INDEX_CONTEXTS][CDF_SIZE(PALETTE_COLORS)]"); + + /* Transform size */ + cts_each_dim[0] = TXFM_PARTITION_CONTEXTS; + cts_each_dim[1] = 2; + optimize_cdf_table( + &fc.txfm_partition[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob\n" + "default_txfm_partition_cdf[TXFM_PARTITION_CONTEXTS][CDF_SIZE(2)]"); + + /* Skip flag */ + cts_each_dim[0] = SKIP_CONTEXTS; + cts_each_dim[1] = 2; + optimize_cdf_table(&fc.skip_txfm[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob " + "default_skip_txfm_cdfs[SKIP_CONTEXTS][CDF_SIZE(2)]"); + + /* Skip mode flag */ + cts_each_dim[0] = SKIP_MODE_CONTEXTS; + cts_each_dim[1] = 2; + optimize_cdf_table(&fc.skip_mode[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob " + "default_skip_mode_cdfs[SKIP_MODE_CONTEXTS][CDF_SIZE(2)]"); + + /* joint compound flag */ + cts_each_dim[0] = COMP_INDEX_CONTEXTS; + cts_each_dim[1] = 2; + optimize_cdf_table(&fc.compound_index[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob default_compound_idx_cdfs" + "[COMP_INDEX_CONTEXTS][CDF_SIZE(2)]"); + + cts_each_dim[0] = COMP_GROUP_IDX_CONTEXTS; + cts_each_dim[1] = 2; + optimize_cdf_table(&fc.comp_group_idx[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob default_comp_group_idx_cdfs" + "[COMP_GROUP_IDX_CONTEXTS][CDF_SIZE(2)]"); + + /* intrabc */ + cts_each_dim[0] = 2; + optimize_cdf_table( + &fc.intrabc[0], probsfile, 1, cts_each_dim, + "static const aom_cdf_prob default_intrabc_cdf[CDF_SIZE(2)]"); + + /* filter_intra experiment */ + cts_each_dim[0] = FILTER_INTRA_MODES; + optimize_cdf_table( + &fc.filter_intra_mode[0], probsfile, 1, cts_each_dim, + "static const aom_cdf_prob " + "default_filter_intra_mode_cdf[CDF_SIZE(FILTER_INTRA_MODES)]"); + + cts_each_dim[0] = BLOCK_SIZES_ALL; + cts_each_dim[1] = 2; + optimize_cdf_table(&fc.filter_intra[0][0], probsfile, 2, cts_each_dim, + "static const aom_cdf_prob " + "default_filter_intra_cdfs[BLOCK_SIZES_ALL][CDF_SIZE(2)]"); + + /* restoration type */ + cts_each_dim[0] = RESTORE_SWITCHABLE_TYPES; + optimize_cdf_table(&fc.switchable_restore[0], probsfile, 1, cts_each_dim, + "static const aom_cdf_prob default_switchable_restore_cdf" + "[CDF_SIZE(RESTORE_SWITCHABLE_TYPES)]"); + + cts_each_dim[0] = 2; + optimize_cdf_table(&fc.wiener_restore[0], probsfile, 1, cts_each_dim, + "static const aom_cdf_prob default_wiener_restore_cdf" + "[CDF_SIZE(2)]"); + + cts_each_dim[0] = 2; + optimize_cdf_table(&fc.sgrproj_restore[0], probsfile, 1, cts_each_dim, + "static const aom_cdf_prob default_sgrproj_restore_cdf" + "[CDF_SIZE(2)]"); + + /* intra tx size */ + cts_each_dim[0] = MAX_TX_CATS; + cts_each_dim[1] = TX_SIZE_CONTEXTS; + cts_each_dim[2] = MAX_TX_DEPTH + 1; + int intra_tx_sizes_each_ctx[MAX_TX_CATS] = { 2, 3, 3, 3 }; + optimize_cdf_table_var_modes_3d( + &fc.intra_tx_size[0][0][0], probsfile, 3, cts_each_dim, + intra_tx_sizes_each_ctx, + "static const aom_cdf_prob default_tx_size_cdf" + "[MAX_TX_CATS][TX_SIZE_CONTEXTS][CDF_SIZE(MAX_TX_DEPTH + 1)]"); + + /* transform coding */ + cts_each_dim[0] = TOKEN_CDF_Q_CTXS; + cts_each_dim[1] = TX_SIZES; + cts_each_dim[2] = TXB_SKIP_CONTEXTS; + cts_each_dim[3] = 2; + optimize_cdf_table(&fc.txb_skip[0][0][0][0], probsfile, 4, cts_each_dim, + "static const aom_cdf_prob " + "av1_default_txb_skip_cdfs[TOKEN_CDF_Q_CTXS][TX_SIZES]" + "[TXB_SKIP_CONTEXTS][CDF_SIZE(2)]"); + + cts_each_dim[0] = TOKEN_CDF_Q_CTXS; + cts_each_dim[1] = TX_SIZES; + cts_each_dim[2] = PLANE_TYPES; + cts_each_dim[3] = EOB_COEF_CONTEXTS; + cts_each_dim[4] = 2; + optimize_cdf_table( + &fc.eob_extra[0][0][0][0][0], probsfile, 5, cts_each_dim, + "static const aom_cdf_prob av1_default_eob_extra_cdfs " + "[TOKEN_CDF_Q_CTXS][TX_SIZES][PLANE_TYPES][EOB_COEF_CONTEXTS]" + "[CDF_SIZE(2)]"); + + cts_each_dim[0] = TOKEN_CDF_Q_CTXS; + cts_each_dim[1] = PLANE_TYPES; + cts_each_dim[2] = 2; + cts_each_dim[3] = 5; + optimize_cdf_table(&fc.eob_multi16[0][0][0][0], probsfile, 4, cts_each_dim, + "static const aom_cdf_prob av1_default_eob_multi16_cdfs" + "[TOKEN_CDF_Q_CTXS][PLANE_TYPES][2][CDF_SIZE(5)]"); + + cts_each_dim[0] = TOKEN_CDF_Q_CTXS; + cts_each_dim[1] = PLANE_TYPES; + cts_each_dim[2] = 2; + cts_each_dim[3] = 6; + optimize_cdf_table(&fc.eob_multi32[0][0][0][0], probsfile, 4, cts_each_dim, + "static const aom_cdf_prob av1_default_eob_multi32_cdfs" + "[TOKEN_CDF_Q_CTXS][PLANE_TYPES][2][CDF_SIZE(6)]"); + + cts_each_dim[0] = TOKEN_CDF_Q_CTXS; + cts_each_dim[1] = PLANE_TYPES; + cts_each_dim[2] = 2; + cts_each_dim[3] = 7; + optimize_cdf_table(&fc.eob_multi64[0][0][0][0], probsfile, 4, cts_each_dim, + "static const aom_cdf_prob av1_default_eob_multi64_cdfs" + "[TOKEN_CDF_Q_CTXS][PLANE_TYPES][2][CDF_SIZE(7)]"); + + cts_each_dim[0] = TOKEN_CDF_Q_CTXS; + cts_each_dim[1] = PLANE_TYPES; + cts_each_dim[2] = 2; + cts_each_dim[3] = 8; + optimize_cdf_table(&fc.eob_multi128[0][0][0][0], probsfile, 4, cts_each_dim, + "static const aom_cdf_prob av1_default_eob_multi128_cdfs" + "[TOKEN_CDF_Q_CTXS][PLANE_TYPES][2][CDF_SIZE(8)]"); + + cts_each_dim[0] = TOKEN_CDF_Q_CTXS; + cts_each_dim[1] = PLANE_TYPES; + cts_each_dim[2] = 2; + cts_each_dim[3] = 9; + optimize_cdf_table(&fc.eob_multi256[0][0][0][0], probsfile, 4, cts_each_dim, + "static const aom_cdf_prob av1_default_eob_multi256_cdfs" + "[TOKEN_CDF_Q_CTXS][PLANE_TYPES][2][CDF_SIZE(9)]"); + + cts_each_dim[0] = TOKEN_CDF_Q_CTXS; + cts_each_dim[1] = PLANE_TYPES; + cts_each_dim[2] = 2; + cts_each_dim[3] = 10; + optimize_cdf_table(&fc.eob_multi512[0][0][0][0], probsfile, 4, cts_each_dim, + "static const aom_cdf_prob av1_default_eob_multi512_cdfs" + "[TOKEN_CDF_Q_CTXS][PLANE_TYPES][2][CDF_SIZE(10)]"); + + cts_each_dim[0] = TOKEN_CDF_Q_CTXS; + cts_each_dim[1] = PLANE_TYPES; + cts_each_dim[2] = 2; + cts_each_dim[3] = 11; + optimize_cdf_table(&fc.eob_multi1024[0][0][0][0], probsfile, 4, cts_each_dim, + "static const aom_cdf_prob av1_default_eob_multi1024_cdfs" + "[TOKEN_CDF_Q_CTXS][PLANE_TYPES][2][CDF_SIZE(11)]"); + + cts_each_dim[0] = TOKEN_CDF_Q_CTXS; + cts_each_dim[1] = TX_SIZES; + cts_each_dim[2] = PLANE_TYPES; + cts_each_dim[3] = LEVEL_CONTEXTS; + cts_each_dim[4] = BR_CDF_SIZE; + optimize_cdf_table(&fc.coeff_lps_multi[0][0][0][0][0], probsfile, 5, + cts_each_dim, + "static const aom_cdf_prob " + "av1_default_coeff_lps_multi_cdfs[TOKEN_CDF_Q_CTXS]" + "[TX_SIZES][PLANE_TYPES][LEVEL_CONTEXTS]" + "[CDF_SIZE(BR_CDF_SIZE)]"); + + cts_each_dim[0] = TOKEN_CDF_Q_CTXS; + cts_each_dim[1] = TX_SIZES; + cts_each_dim[2] = PLANE_TYPES; + cts_each_dim[3] = SIG_COEF_CONTEXTS; + cts_each_dim[4] = NUM_BASE_LEVELS + 2; + optimize_cdf_table( + &fc.coeff_base_multi[0][0][0][0][0], probsfile, 5, cts_each_dim, + "static const aom_cdf_prob av1_default_coeff_base_multi_cdfs" + "[TOKEN_CDF_Q_CTXS][TX_SIZES][PLANE_TYPES][SIG_COEF_CONTEXTS]" + "[CDF_SIZE(NUM_BASE_LEVELS + 2)]"); + + cts_each_dim[0] = TOKEN_CDF_Q_CTXS; + cts_each_dim[1] = TX_SIZES; + cts_each_dim[2] = PLANE_TYPES; + cts_each_dim[3] = SIG_COEF_CONTEXTS_EOB; + cts_each_dim[4] = NUM_BASE_LEVELS + 1; + optimize_cdf_table( + &fc.coeff_base_eob_multi[0][0][0][0][0], probsfile, 5, cts_each_dim, + "static const aom_cdf_prob av1_default_coeff_base_eob_multi_cdfs" + "[TOKEN_CDF_Q_CTXS][TX_SIZES][PLANE_TYPES][SIG_COEF_CONTEXTS_EOB]" + "[CDF_SIZE(NUM_BASE_LEVELS + 1)]"); + + fclose(statsfile); + fclose(logfile); + fclose(probsfile); + + return 0; +} diff --git a/third_party/aom/tools/auto_refactor/auto_refactor.py b/third_party/aom/tools/auto_refactor/auto_refactor.py new file mode 100644 index 0000000000..dd0d4415f9 --- /dev/null +++ b/third_party/aom/tools/auto_refactor/auto_refactor.py @@ -0,0 +1,919 @@ +# Copyright (c) 2021, Alliance for Open Media. All rights reserved +# +# This source code is subject to the terms of the BSD 2 Clause License and +# the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License +# was not distributed with this source code in the LICENSE file, you can +# obtain it at www.aomedia.org/license/software. If the Alliance for Open +# Media Patent License 1.0 was not distributed with this source code in the +# PATENTS file, you can obtain it at www.aomedia.org/license/patent. +# + +from __future__ import print_function +import sys +import os +import operator +from pycparser import c_parser, c_ast, parse_file +from math import * + +from inspect import currentframe, getframeinfo +from collections import deque + + +def debug_print(frameinfo): + print('******** ERROR:', frameinfo.filename, frameinfo.lineno, '********') + + +class StructItem(): + + def __init__(self, + typedef_name=None, + struct_name=None, + struct_node=None, + is_union=False): + self.typedef_name = typedef_name + self.struct_name = struct_name + self.struct_node = struct_node + self.is_union = is_union + self.child_decl_map = None + + def __str__(self): + return str(self.typedef_name) + ' ' + str(self.struct_name) + ' ' + str( + self.is_union) + + def compute_child_decl_map(self, struct_info): + self.child_decl_map = {} + if self.struct_node != None and self.struct_node.decls != None: + for decl_node in self.struct_node.decls: + if decl_node.name == None: + for sub_decl_node in decl_node.type.decls: + sub_decl_status = parse_decl_node(struct_info, sub_decl_node) + self.child_decl_map[sub_decl_node.name] = sub_decl_status + else: + decl_status = parse_decl_node(struct_info, decl_node) + self.child_decl_map[decl_status.name] = decl_status + + def get_child_decl_status(self, decl_name): + if self.child_decl_map == None: + debug_print(getframeinfo(currentframe())) + print('child_decl_map is None') + return None + if decl_name not in self.child_decl_map: + debug_print(getframeinfo(currentframe())) + print(decl_name, 'does not exist ') + return None + return self.child_decl_map[decl_name] + + +class StructInfo(): + + def __init__(self): + self.struct_name_dic = {} + self.typedef_name_dic = {} + self.enum_value_dic = {} # enum value -> enum_node + self.enum_name_dic = {} # enum name -> enum_node + self.struct_item_list = [] + + def get_struct_by_typedef_name(self, typedef_name): + if typedef_name in self.typedef_name_dic: + return self.typedef_name_dic[typedef_name] + else: + return None + + def get_struct_by_struct_name(self, struct_name): + if struct_name in self.struct_name_dic: + return self.struct_name_dic[struct_name] + else: + debug_print(getframeinfo(currentframe())) + print('Cant find', struct_name) + return None + + def update_struct_item_list(self): + # Collect all struct_items from struct_name_dic and typedef_name_dic + # Compute child_decl_map for each struct item. + for struct_name in self.struct_name_dic.keys(): + struct_item = self.struct_name_dic[struct_name] + struct_item.compute_child_decl_map(self) + self.struct_item_list.append(struct_item) + + for typedef_name in self.typedef_name_dic.keys(): + struct_item = self.typedef_name_dic[typedef_name] + if struct_item.struct_name not in self.struct_name_dic: + struct_item.compute_child_decl_map(self) + self.struct_item_list.append(struct_item) + + def update_enum(self, enum_node): + if enum_node.name != None: + self.enum_name_dic[enum_node.name] = enum_node + + if enum_node.values != None: + enumerator_list = enum_node.values.enumerators + for enumerator in enumerator_list: + self.enum_value_dic[enumerator.name] = enum_node + + def update(self, + typedef_name=None, + struct_name=None, + struct_node=None, + is_union=False): + """T: typedef_name S: struct_name N: struct_node + + T S N + case 1: o o o + typedef struct P { + int u; + } K; + T S N + case 2: o o x + typedef struct P K; + + T S N + case 3: x o o + struct P { + int u; + }; + + T S N + case 4: o x o + typedef struct { + int u; + } K; + """ + struct_item = None + + # Check whether struct_name or typedef_name is already in the dictionary + if struct_name in self.struct_name_dic: + struct_item = self.struct_name_dic[struct_name] + + if typedef_name in self.typedef_name_dic: + struct_item = self.typedef_name_dic[typedef_name] + + if struct_item == None: + struct_item = StructItem(typedef_name, struct_name, struct_node, is_union) + + if struct_node.decls != None: + struct_item.struct_node = struct_node + + if struct_name != None: + self.struct_name_dic[struct_name] = struct_item + + if typedef_name != None: + self.typedef_name_dic[typedef_name] = struct_item + + +class StructDefVisitor(c_ast.NodeVisitor): + + def __init__(self): + self.struct_info = StructInfo() + + def visit_Struct(self, node): + if node.decls != None: + self.struct_info.update(None, node.name, node) + self.generic_visit(node) + + def visit_Union(self, node): + if node.decls != None: + self.struct_info.update(None, node.name, node, True) + self.generic_visit(node) + + def visit_Enum(self, node): + self.struct_info.update_enum(node) + self.generic_visit(node) + + def visit_Typedef(self, node): + if node.type.__class__.__name__ == 'TypeDecl': + typedecl = node.type + if typedecl.type.__class__.__name__ == 'Struct': + struct_node = typedecl.type + typedef_name = node.name + struct_name = struct_node.name + self.struct_info.update(typedef_name, struct_name, struct_node) + elif typedecl.type.__class__.__name__ == 'Union': + union_node = typedecl.type + typedef_name = node.name + union_name = union_node.name + self.struct_info.update(typedef_name, union_name, union_node, True) + # TODO(angiebird): Do we need to deal with enum here? + self.generic_visit(node) + + +def build_struct_info(ast): + v = StructDefVisitor() + v.visit(ast) + struct_info = v.struct_info + struct_info.update_struct_item_list() + return v.struct_info + + +class DeclStatus(): + + def __init__(self, name, struct_item=None, is_ptr_decl=False): + self.name = name + self.struct_item = struct_item + self.is_ptr_decl = is_ptr_decl + + def get_child_decl_status(self, decl_name): + if self.struct_item != None: + return self.struct_item.get_child_decl_status(decl_name) + else: + #TODO(angiebird): 2. Investigage the situation when a struct's definition can't be found. + return None + + def __str__(self): + return str(self.struct_item) + ' ' + str(self.name) + ' ' + str( + self.is_ptr_decl) + + +def peel_ptr_decl(decl_type_node): + """ Remove PtrDecl and ArrayDecl layer """ + is_ptr_decl = False + peeled_decl_type_node = decl_type_node + while peeled_decl_type_node.__class__.__name__ == 'PtrDecl' or peeled_decl_type_node.__class__.__name__ == 'ArrayDecl': + is_ptr_decl = True + peeled_decl_type_node = peeled_decl_type_node.type + return is_ptr_decl, peeled_decl_type_node + + +def parse_peeled_decl_type_node(struct_info, node): + struct_item = None + if node.__class__.__name__ == 'TypeDecl': + if node.type.__class__.__name__ == 'IdentifierType': + identifier_type_node = node.type + typedef_name = identifier_type_node.names[0] + struct_item = struct_info.get_struct_by_typedef_name(typedef_name) + elif node.type.__class__.__name__ == 'Struct': + struct_node = node.type + if struct_node.name != None: + struct_item = struct_info.get_struct_by_struct_name(struct_node.name) + else: + struct_item = StructItem(None, None, struct_node, False) + struct_item.compute_child_decl_map(struct_info) + elif node.type.__class__.__name__ == 'Union': + # TODO(angiebird): Special treatment for Union? + struct_node = node.type + if struct_node.name != None: + struct_item = struct_info.get_struct_by_struct_name(struct_node.name) + else: + struct_item = StructItem(None, None, struct_node, True) + struct_item.compute_child_decl_map(struct_info) + elif node.type.__class__.__name__ == 'Enum': + # TODO(angiebird): Special treatment for Union? + struct_node = node.type + struct_item = None + else: + print('Unrecognized peeled_decl_type_node.type', + node.type.__class__.__name__) + else: + # debug_print(getframeinfo(currentframe())) + # print(node.__class__.__name__) + #TODO(angiebird): Do we need to take care of this part? + pass + + return struct_item + + +def parse_decl_node(struct_info, decl_node): + # struct_item is None if this decl_node is not a struct_item + decl_node_name = decl_node.name + decl_type_node = decl_node.type + is_ptr_decl, peeled_decl_type_node = peel_ptr_decl(decl_type_node) + struct_item = parse_peeled_decl_type_node(struct_info, peeled_decl_type_node) + return DeclStatus(decl_node_name, struct_item, is_ptr_decl) + + +def get_lvalue_lead(lvalue_node): + """return '&' or '*' of lvalue if available""" + if lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '&': + return '&' + elif lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '*': + return '*' + return None + + +def parse_lvalue(lvalue_node): + """get id_chain from lvalue""" + id_chain = parse_lvalue_recursive(lvalue_node, []) + return id_chain + + +def parse_lvalue_recursive(lvalue_node, id_chain): + """cpi->rd->u -> (cpi->rd)->u""" + if lvalue_node.__class__.__name__ == 'ID': + id_chain.append(lvalue_node.name) + id_chain.reverse() + return id_chain + elif lvalue_node.__class__.__name__ == 'StructRef': + id_chain.append(lvalue_node.field.name) + return parse_lvalue_recursive(lvalue_node.name, id_chain) + elif lvalue_node.__class__.__name__ == 'ArrayRef': + return parse_lvalue_recursive(lvalue_node.name, id_chain) + elif lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '&': + return parse_lvalue_recursive(lvalue_node.expr, id_chain) + elif lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '*': + return parse_lvalue_recursive(lvalue_node.expr, id_chain) + else: + return None + + +class FuncDefVisitor(c_ast.NodeVisitor): + func_dictionary = {} + + def visit_FuncDef(self, node): + func_name = node.decl.name + self.func_dictionary[func_name] = node + + +def build_func_dictionary(ast): + v = FuncDefVisitor() + v.visit(ast) + return v.func_dictionary + + +def get_func_start_coord(func_node): + return func_node.coord + + +def find_end_node(node): + node_list = [] + for c in node: + node_list.append(c) + if len(node_list) == 0: + return node + else: + return find_end_node(node_list[-1]) + + +def get_func_end_coord(func_node): + return find_end_node(func_node).coord + + +def get_func_size(func_node): + start_coord = get_func_start_coord(func_node) + end_coord = get_func_end_coord(func_node) + if start_coord.file == end_coord.file: + return end_coord.line - start_coord.line + 1 + else: + return None + + +def save_object(obj, filename): + with open(filename, 'wb') as obj_fp: + pickle.dump(obj, obj_fp, protocol=-1) + + +def load_object(filename): + obj = None + with open(filename, 'rb') as obj_fp: + obj = pickle.load(obj_fp) + return obj + + +def get_av1_ast(gen_ast=False): + # TODO(angiebird): Generalize this path + c_filename = './av1_pp.c' + print('generate ast') + ast = parse_file(c_filename) + #save_object(ast, ast_file) + print('finished generate ast') + return ast + + +def get_func_param_id_map(func_def_node): + param_id_map = {} + func_decl = func_def_node.decl.type + param_list = func_decl.args.params + for decl in param_list: + param_id_map[decl.name] = decl + return param_id_map + + +class IDTreeStack(): + + def __init__(self, global_id_tree): + self.stack = deque() + self.global_id_tree = global_id_tree + + def add_link_node(self, node, link_id_chain): + link_node = self.add_id_node(link_id_chain) + node.link_node = link_node + node.link_id_chain = link_id_chain + + def push_id_tree(self, id_tree=None): + if id_tree == None: + id_tree = IDStatusNode() + self.stack.append(id_tree) + return id_tree + + def pop_id_tree(self): + return self.stack.pop() + + def add_id_seed_node(self, id_seed, decl_status): + return self.stack[-1].add_child(id_seed, decl_status) + + def get_id_seed_node(self, id_seed): + idx = len(self.stack) - 1 + while idx >= 0: + id_node = self.stack[idx].get_child(id_seed) + if id_node != None: + return id_node + idx -= 1 + + id_node = self.global_id_tree.get_child(id_seed) + if id_node != None: + return id_node + return None + + def add_id_node(self, id_chain): + id_seed = id_chain[0] + id_seed_node = self.get_id_seed_node(id_seed) + if id_seed_node == None: + return None + if len(id_chain) == 1: + return id_seed_node + return id_seed_node.add_descendant(id_chain[1:]) + + def get_id_node(self, id_chain): + id_seed = id_chain[0] + id_seed_node = self.get_id_seed_node(id_seed) + if id_seed_node == None: + return None + if len(id_chain) == 1: + return id_seed_node + return id_seed_node.get_descendant(id_chain[1:]) + + def top(self): + return self.stack[-1] + + +class IDStatusNode(): + + def __init__(self, name=None, root=None): + if root is None: + self.root = self + else: + self.root = root + + self.name = name + + self.parent = None + self.children = {} + + self.assign = False + self.last_assign_coord = None + self.refer = False + self.last_refer_coord = None + + self.decl_status = None + + self.link_id_chain = None + self.link_node = None + + self.visit = False + + def set_link_id_chain(self, link_id_chain): + self.set_assign(False) + self.link_id_chain = link_id_chain + self.link_node = self.root.get_descendant(link_id_chain) + + def set_link_node(self, link_node): + self.set_assign(False) + self.link_id_chain = ['*'] + self.link_node = link_node + + def get_link_id_chain(self): + return self.link_id_chain + + def get_concrete_node(self): + if self.visit == True: + # return None when there is a loop + return None + self.visit = True + if self.link_node == None: + self.visit = False + return self + else: + concrete_node = self.link_node.get_concrete_node() + self.visit = False + if concrete_node == None: + return self + return concrete_node + + def set_assign(self, assign, coord=None): + concrete_node = self.get_concrete_node() + concrete_node.assign = assign + concrete_node.last_assign_coord = coord + + def get_assign(self): + concrete_node = self.get_concrete_node() + return concrete_node.assign + + def set_refer(self, refer, coord=None): + concrete_node = self.get_concrete_node() + concrete_node.refer = refer + concrete_node.last_refer_coord = coord + + def get_refer(self): + concrete_node = self.get_concrete_node() + return concrete_node.refer + + def set_parent(self, parent): + concrete_node = self.get_concrete_node() + concrete_node.parent = parent + + def add_child(self, name, decl_status=None): + concrete_node = self.get_concrete_node() + if name not in concrete_node.children: + child_id_node = IDStatusNode(name, concrete_node.root) + concrete_node.children[name] = child_id_node + if decl_status == None: + # Check if the child decl_status can be inferred from its parent's + # decl_status + if self.decl_status != None: + decl_status = self.decl_status.get_child_decl_status(name) + child_id_node.set_decl_status(decl_status) + return concrete_node.children[name] + + def get_child(self, name): + concrete_node = self.get_concrete_node() + if name in concrete_node.children: + return concrete_node.children[name] + else: + return None + + def add_descendant(self, id_chain): + current_node = self.get_concrete_node() + for name in id_chain: + current_node.add_child(name) + parent_node = current_node + current_node = current_node.get_child(name) + current_node.set_parent(parent_node) + return current_node + + def get_descendant(self, id_chain): + current_node = self.get_concrete_node() + for name in id_chain: + current_node = current_node.get_child(name) + if current_node == None: + return None + return current_node + + def get_children(self): + current_node = self.get_concrete_node() + return current_node.children + + def set_decl_status(self, decl_status): + current_node = self.get_concrete_node() + current_node.decl_status = decl_status + + def get_decl_status(self): + current_node = self.get_concrete_node() + return current_node.decl_status + + def __str__(self): + if self.link_id_chain is None: + return str(self.name) + ' a: ' + str(int(self.assign)) + ' r: ' + str( + int(self.refer)) + else: + return str(self.name) + ' -> ' + ' '.join(self.link_id_chain) + + def collect_assign_refer_status(self, + id_chain=None, + assign_ls=None, + refer_ls=None): + if id_chain == None: + id_chain = [] + if assign_ls == None: + assign_ls = [] + if refer_ls == None: + refer_ls = [] + id_chain.append(self.name) + if self.assign: + info_str = ' '.join([ + ' '.join(id_chain[1:]), 'a:', + str(int(self.assign)), 'r:', + str(int(self.refer)), + str(self.last_assign_coord) + ]) + assign_ls.append(info_str) + if self.refer: + info_str = ' '.join([ + ' '.join(id_chain[1:]), 'a:', + str(int(self.assign)), 'r:', + str(int(self.refer)), + str(self.last_refer_coord) + ]) + refer_ls.append(info_str) + for c in self.children: + self.children[c].collect_assign_refer_status(id_chain, assign_ls, + refer_ls) + id_chain.pop() + return assign_ls, refer_ls + + def show(self): + assign_ls, refer_ls = self.collect_assign_refer_status() + print('---- assign ----') + for item in assign_ls: + print(item) + print('---- refer ----') + for item in refer_ls: + print(item) + + +class FuncInOutVisitor(c_ast.NodeVisitor): + + def __init__(self, + func_def_node, + struct_info, + func_dictionary, + keep_body_id_tree=True, + call_param_map=None, + global_id_tree=None, + func_history=None, + unknown=None): + self.func_dictionary = func_dictionary + self.struct_info = struct_info + self.param_id_map = get_func_param_id_map(func_def_node) + self.parent_node = None + self.global_id_tree = global_id_tree + self.body_id_tree = None + self.keep_body_id_tree = keep_body_id_tree + if func_history == None: + self.func_history = {} + else: + self.func_history = func_history + + if unknown == None: + self.unknown = [] + else: + self.unknown = unknown + + self.id_tree_stack = IDTreeStack(global_id_tree) + self.id_tree_stack.push_id_tree() + + #TODO move this part into a function + for param in self.param_id_map: + decl_node = self.param_id_map[param] + decl_status = parse_decl_node(self.struct_info, decl_node) + descendant = self.id_tree_stack.add_id_seed_node(decl_status.name, + decl_status) + if call_param_map is not None and param in call_param_map: + # This is a function call. + # Map the input parameter to the caller's nodes + # TODO(angiebird): Can we use add_link_node here? + descendant.set_link_node(call_param_map[param]) + + def get_id_tree_stack(self): + return self.id_tree_stack + + def generic_visit(self, node): + prev_parent = self.parent_node + self.parent_node = node + for c in node: + self.visit(c) + self.parent_node = prev_parent + + # TODO rename + def add_new_id_tree(self, node): + self.id_tree_stack.push_id_tree() + self.generic_visit(node) + id_tree = self.id_tree_stack.pop_id_tree() + if self.parent_node == None and self.keep_body_id_tree == True: + # this is function body + self.body_id_tree = id_tree + + def visit_For(self, node): + self.add_new_id_tree(node) + + def visit_Compound(self, node): + self.add_new_id_tree(node) + + def visit_Decl(self, node): + if node.type.__class__.__name__ != 'FuncDecl': + decl_status = parse_decl_node(self.struct_info, node) + descendant = self.id_tree_stack.add_id_seed_node(decl_status.name, + decl_status) + if node.init is not None: + init_id_chain = self.process_lvalue(node.init) + if init_id_chain != None: + if decl_status.struct_item is None: + init_descendant = self.id_tree_stack.add_id_node(init_id_chain) + if init_descendant != None: + init_descendant.set_refer(True, node.coord) + else: + self.unknown.append(node) + descendant.set_assign(True, node.coord) + else: + self.id_tree_stack.add_link_node(descendant, init_id_chain) + else: + self.unknown.append(node) + else: + descendant.set_assign(True, node.coord) + self.generic_visit(node) + + def is_lvalue(self, node): + if self.parent_node is None: + # TODO(angiebird): Do every lvalue has parent_node != None? + return False + if self.parent_node.__class__.__name__ == 'StructRef': + return False + if self.parent_node.__class__.__name__ == 'ArrayRef' and node == self.parent_node.name: + # if node == self.parent_node.subscript, the node could be lvalue + return False + if self.parent_node.__class__.__name__ == 'UnaryOp' and self.parent_node.op == '&': + return False + if self.parent_node.__class__.__name__ == 'UnaryOp' and self.parent_node.op == '*': + return False + return True + + def process_lvalue(self, node): + id_chain = parse_lvalue(node) + if id_chain == None: + return id_chain + elif id_chain[0] in self.struct_info.enum_value_dic: + return None + else: + return id_chain + + def process_possible_lvalue(self, node): + if self.is_lvalue(node): + id_chain = self.process_lvalue(node) + lead_char = get_lvalue_lead(node) + # make sure the id is not an enum value + if id_chain == None: + self.unknown.append(node) + return + descendant = self.id_tree_stack.add_id_node(id_chain) + if descendant == None: + self.unknown.append(node) + return + decl_status = descendant.get_decl_status() + if decl_status == None: + descendant.set_assign(True, node.coord) + descendant.set_refer(True, node.coord) + self.unknown.append(node) + return + if self.parent_node.__class__.__name__ == 'Assignment': + if node is self.parent_node.lvalue: + if decl_status.struct_item != None: + if len(id_chain) > 1: + descendant.set_assign(True, node.coord) + elif len(id_chain) == 1: + if lead_char == '*': + descendant.set_assign(True, node.coord) + else: + right_id_chain = self.process_lvalue(self.parent_node.rvalue) + if right_id_chain != None: + self.id_tree_stack.add_link_node(descendant, right_id_chain) + else: + #TODO(angiebird): 1.Find a better way to deal with this case. + descendant.set_assign(True, node.coord) + else: + debug_print(getframeinfo(currentframe())) + else: + descendant.set_assign(True, node.coord) + elif node is self.parent_node.rvalue: + if decl_status.struct_item is None: + descendant.set_refer(True, node.coord) + if lead_char == '&': + descendant.set_assign(True, node.coord) + else: + left_id_chain = self.process_lvalue(self.parent_node.lvalue) + left_lead_char = get_lvalue_lead(self.parent_node.lvalue) + if left_id_chain != None: + if len(left_id_chain) > 1: + descendant.set_refer(True, node.coord) + elif len(left_id_chain) == 1: + if left_lead_char == '*': + descendant.set_refer(True, node.coord) + else: + #TODO(angiebird): Check whether the other node is linked to this node. + pass + else: + self.unknown.append(self.parent_node.lvalue) + debug_print(getframeinfo(currentframe())) + else: + self.unknown.append(self.parent_node.lvalue) + debug_print(getframeinfo(currentframe())) + else: + debug_print(getframeinfo(currentframe())) + elif self.parent_node.__class__.__name__ == 'UnaryOp': + # TODO(angiebird): Consider +=, *=, -=, /= etc + if self.parent_node.op == '--' or self.parent_node.op == '++' or\ + self.parent_node.op == 'p--' or self.parent_node.op == 'p++': + descendant.set_assign(True, node.coord) + descendant.set_refer(True, node.coord) + else: + descendant.set_refer(True, node.coord) + elif self.parent_node.__class__.__name__ == 'Decl': + #The logic is at visit_Decl + pass + elif self.parent_node.__class__.__name__ == 'ExprList': + #The logic is at visit_FuncCall + pass + else: + descendant.set_refer(True, node.coord) + + def visit_ID(self, node): + # If the parent is a FuncCall, this ID is a function name. + if self.parent_node.__class__.__name__ != 'FuncCall': + self.process_possible_lvalue(node) + self.generic_visit(node) + + def visit_StructRef(self, node): + self.process_possible_lvalue(node) + self.generic_visit(node) + + def visit_ArrayRef(self, node): + self.process_possible_lvalue(node) + self.generic_visit(node) + + def visit_UnaryOp(self, node): + if node.op == '&' or node.op == '*': + self.process_possible_lvalue(node) + self.generic_visit(node) + + def visit_FuncCall(self, node): + if node.name.__class__.__name__ == 'ID': + if node.name.name in self.func_dictionary: + if node.name.name not in self.func_history: + self.func_history[node.name.name] = True + func_def_node = self.func_dictionary[node.name.name] + call_param_map = self.process_func_call(node, func_def_node) + + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary, False, + call_param_map, self.global_id_tree, + self.func_history, self.unknown) + visitor.visit(func_def_node.body) + else: + self.unknown.append(node) + self.generic_visit(node) + + def process_func_call(self, func_call_node, func_def_node): + # set up a refer/assign for func parameters + # return call_param_map + call_param_ls = func_call_node.args.exprs + call_param_map = {} + + func_decl = func_def_node.decl.type + decl_param_ls = func_decl.args.params + for param_node, decl_node in zip(call_param_ls, decl_param_ls): + id_chain = self.process_lvalue(param_node) + if id_chain != None: + descendant = self.id_tree_stack.add_id_node(id_chain) + if descendant == None: + self.unknown.append(param_node) + else: + decl_status = descendant.get_decl_status() + if decl_status != None: + if decl_status.struct_item == None: + if decl_status.is_ptr_decl == True: + descendant.set_assign(True, param_node.coord) + descendant.set_refer(True, param_node.coord) + else: + descendant.set_refer(True, param_node.coord) + else: + call_param_map[decl_node.name] = descendant + else: + self.unknown.append(param_node) + else: + self.unknown.append(param_node) + return call_param_map + + +def build_global_id_tree(ast, struct_info): + global_id_tree = IDStatusNode() + for node in ast.ext: + if node.__class__.__name__ == 'Decl': + # id tree is for tracking assign/refer status + # we don't care about function id because they can't be changed + if node.type.__class__.__name__ != 'FuncDecl': + decl_status = parse_decl_node(struct_info, node) + descendant = global_id_tree.add_child(decl_status.name, decl_status) + return global_id_tree + + +class FuncAnalyzer(): + + def __init__(self): + self.ast = get_av1_ast() + self.struct_info = build_struct_info(self.ast) + self.func_dictionary = build_func_dictionary(self.ast) + self.global_id_tree = build_global_id_tree(self.ast, self.struct_info) + + def analyze(self, func_name): + if func_name in self.func_dictionary: + func_def_node = self.func_dictionary[func_name] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary, True, None, + self.global_id_tree) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + root.top().show() + else: + print(func_name, "doesn't exist") + + +if __name__ == '__main__': + fa = FuncAnalyzer() + fa.analyze('tpl_get_satd_cost') + pass diff --git a/third_party/aom/tools/auto_refactor/av1_preprocess.py b/third_party/aom/tools/auto_refactor/av1_preprocess.py new file mode 100644 index 0000000000..ea76912cf1 --- /dev/null +++ b/third_party/aom/tools/auto_refactor/av1_preprocess.py @@ -0,0 +1,113 @@ +# Copyright (c) 2021, Alliance for Open Media. All rights reserved +# +# This source code is subject to the terms of the BSD 2 Clause License and +# the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License +# was not distributed with this source code in the LICENSE file, you can +# obtain it at www.aomedia.org/license/software. If the Alliance for Open +# Media Patent License 1.0 was not distributed with this source code in the +# PATENTS file, you can obtain it at www.aomedia.org/license/patent. +# + +import os +import sys + + +def is_code_file(filename): + return filename.endswith(".c") or filename.endswith(".h") + + +def is_simd_file(filename): + simd_keywords = [ + "avx2", "sse2", "sse3", "ssse3", "sse4", "dspr2", "neon", "msa", "simd", + "x86" + ] + for keyword in simd_keywords: + if filename.find(keyword) >= 0: + return True + return False + + +def get_code_file_list(path, exclude_file_set): + code_file_list = [] + for cur_dir, sub_dir, file_list in os.walk(path): + for filename in file_list: + if is_code_file(filename) and not is_simd_file( + filename) and filename not in exclude_file_set: + file_path = os.path.join(cur_dir, filename) + code_file_list.append(file_path) + return code_file_list + + +def av1_exclude_file_set(): + exclude_file_set = { + "cfl_ppc.c", + "ppc_cpudetect.c", + } + return exclude_file_set + + +def get_av1_pp_command(fake_header_dir, code_file_list): + pre_command = "gcc -w -nostdinc -E -I./ -I../ -I" + fake_header_dir + (" " + "-D'ATTRIBUTE_PACKED='" + " " + "-D'__attribute__(x)='" + " " + "-D'__inline__='" + " " + "-D'float_t=float'" + " " + "-D'DECLARE_ALIGNED(n," + " typ," + " " + "val)=typ" + " val'" + " " + "-D'volatile='" + " " + "-D'AV1_K_MEANS_DIM=2'" + " " + "-D'INLINE='" + " " + "-D'AOM_INLINE='" + " " + "-D'AOM_FORCE_INLINE='" + " " + "-D'inline='" + ) + return pre_command + " " + " ".join(code_file_list) + + +def modify_av1_rtcd(build_dir): + av1_rtcd = os.path.join(build_dir, "config/av1_rtcd.h") + fp = open(av1_rtcd) + string = fp.read() + fp.close() + new_string = string.replace("#ifdef RTCD_C", "#if 0") + fp = open(av1_rtcd, "w") + fp.write(new_string) + fp.close() + + +def preprocess_av1(aom_dir, build_dir, fake_header_dir): + cur_dir = os.getcwd() + output = os.path.join(cur_dir, "av1_pp.c") + path_list = [ + os.path.join(aom_dir, "av1/encoder"), + os.path.join(aom_dir, "av1/common") + ] + code_file_list = [] + for path in path_list: + path = os.path.realpath(path) + code_file_list.extend(get_code_file_list(path, av1_exclude_file_set())) + modify_av1_rtcd(build_dir) + cmd = get_av1_pp_command(fake_header_dir, code_file_list) + " >" + output + os.chdir(build_dir) + os.system(cmd) + os.chdir(cur_dir) + + +if __name__ == "__main__": + aom_dir = sys.argv[1] + build_dir = sys.argv[2] + fake_header_dir = sys.argv[3] + preprocess_av1(aom_dir, build_dir, fake_header_dir) diff --git a/third_party/aom/tools/auto_refactor/c_files/decl_status_code.c b/third_party/aom/tools/auto_refactor/c_files/decl_status_code.c new file mode 100644 index 0000000000..a444553bb1 --- /dev/null +++ b/third_party/aom/tools/auto_refactor/c_files/decl_status_code.c @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2021, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +typedef struct S1 { + int x; +} T1; + +int parse_decl_node_2(void) { int arr[3]; } + +int parse_decl_node_3(void) { int *a; } + +int parse_decl_node_4(void) { T1 t1[3]; } + +int parse_decl_node_5(void) { T1 *t2[3]; } + +int parse_decl_node_6(void) { T1 t3[3][3]; } + +int main(void) { + int a; + T1 t1; + struct S1 s1; + T1 *t2; +} diff --git a/third_party/aom/tools/auto_refactor/c_files/func_in_out.c b/third_party/aom/tools/auto_refactor/c_files/func_in_out.c new file mode 100644 index 0000000000..7f37bbae7e --- /dev/null +++ b/third_party/aom/tools/auto_refactor/c_files/func_in_out.c @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2021, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +typedef struct XD { + int u; + int v; +} XD; + +typedef struct RD { + XD *xd; + int u; + int v; +} RD; + +typedef struct VP9_COMP { + int y; + RD *rd; + RD rd2; + int arr[3]; + union { + int z; + }; + struct { + int w; + }; +} VP9_COMP; + +int sub_func(VP9_COMP *cpi, int b) { + int d; + cpi->y += 1; + cpi->y -= b; + d = cpi->y * 2; + return d; +} + +int func_id_forrest_show(VP9_COMP *cpi, int b) { + int c = 2; + int x = cpi->y + c * 2 + 1; + int y; + RD *rd = cpi->rd; + y = cpi->rd->u; + return x + y; +} + +int func_link_id_chain_1(VP9_COMP *cpi) { + RD *rd = cpi->rd; + rd->u = 0; +} + +int func_link_id_chain_2(VP9_COMP *cpi) { + RD *rd = cpi->rd; + XD *xd = rd->xd; + xd->u = 0; +} + +int func_assign_refer_status_1(VP9_COMP *cpi) { RD *rd = cpi->rd; } + +int func_assign_refer_status_2(VP9_COMP *cpi) { + RD *rd2; + rd2 = cpi->rd; +} + +int func_assign_refer_status_3(VP9_COMP *cpi) { + int a; + a = cpi->y; +} + +int func_assign_refer_status_4(VP9_COMP *cpi) { + int *b; + b = &cpi->y; +} + +int func_assign_refer_status_5(VP9_COMP *cpi) { + RD *rd5; + rd5 = &cpi->rd2; +} + +int func_assign_refer_status_6(VP9_COMP *cpi, VP9_COMP *cpi2) { + cpi->rd = cpi2->rd; +} + +int func_assign_refer_status_7(VP9_COMP *cpi, VP9_COMP *cpi2) { + cpi->arr[3] = 0; +} + +int func_assign_refer_status_8(VP9_COMP *cpi, VP9_COMP *cpi2) { + int x = cpi->arr[3]; +} + +int func_assign_refer_status_9(VP9_COMP *cpi) { + { + RD *rd = cpi->rd; + { rd->u = 0; } + } +} + +int func_assign_refer_status_10(VP9_COMP *cpi) { cpi->arr[cpi->rd->u] = 0; } + +int func_assign_refer_status_11(VP9_COMP *cpi) { + RD *rd11 = &cpi->rd2; + rd11->v = 1; +} + +int func_assign_refer_status_12(VP9_COMP *cpi, VP9_COMP *cpi2) { + *cpi->rd = *cpi2->rd; +} + +int func_assign_refer_status_13(VP9_COMP *cpi) { + cpi->z = 0; + cpi->w = 0; +} + +int func(VP9_COMP *cpi, int x) { + int a; + cpi->y = 4; + a = 3 + cpi->y; + a = a * x; + cpi->y *= 4; + RD *ref_rd = cpi->rd; + ref_rd->u = 0; + cpi->rd2.v = 1; + cpi->rd->v = 1; + RD *ref_rd2 = &cpi->rd2; + RD **ref_rd3 = &(&cpi->rd2); + int b = sub_func(cpi, a); + cpi->rd->v++; + return b; +} + +int func_sub_call_1(VP9_COMP *cpi2, int x) { cpi2->y = 4; } + +int func_call_1(VP9_COMP *cpi, int y) { func_sub_call_1(cpi, y); } + +int func_sub_call_2(VP9_COMP *cpi2, RD *rd, int x) { rd->u = 0; } + +int func_call_2(VP9_COMP *cpi, int y) { func_sub_call_2(cpi, &cpi->rd, y); } + +int func_sub_call_3(VP9_COMP *cpi2, int x) {} + +int func_call_3(VP9_COMP *cpi, int y) { func_sub_call_3(cpi, ++cpi->y); } + +int func_sub_sub_call_4(VP9_COMP *cpi3, XD *xd) { + cpi3->rd.u = 0; + xd->u = 0; +} + +int func_sub_call_4(VP9_COMP *cpi2, RD *rd) { + func_sub_sub_call_4(cpi2, rd->xd); +} + +int func_call_4(VP9_COMP *cpi, int y) { func_sub_call_4(cpi, &cpi->rd); } + +int func_sub_call_5(VP9_COMP *cpi) { + cpi->y = 2; + func_call_5(cpi); +} + +int func_call_5(VP9_COMP *cpi) { func_sub_call_5(cpi); } + +int func_compound_1(VP9_COMP *cpi) { + for (int i = 0; i < 10; ++i) { + cpi->y++; + } +} + +int func_compound_2(VP9_COMP *cpi) { + for (int i = 0; i < cpi->y; ++i) { + cpi->rd->u = i; + } +} + +int func_compound_3(VP9_COMP *cpi) { + int i = 3; + while (i > 0) { + cpi->rd->u = i; + i--; + } +} + +int func_compound_4(VP9_COMP *cpi) { + while (cpi->y-- >= 0) { + } +} + +int func_compound_5(VP9_COMP *cpi) { + do { + } while (cpi->y-- >= 0); +} + +int func_compound_6(VP9_COMP *cpi) { + for (int i = 0; i < 10; ++i) cpi->y--; +} + +int main(void) { + int x; + VP9_COMP cpi; + RD rd; + cpi->rd = rd; + func(&cpi, x); +} diff --git a/third_party/aom/tools/auto_refactor/c_files/global_variable.c b/third_party/aom/tools/auto_refactor/c_files/global_variable.c new file mode 100644 index 0000000000..26d5385e97 --- /dev/null +++ b/third_party/aom/tools/auto_refactor/c_files/global_variable.c @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +extern const int global_a[13]; + +const int global_b = 0; + +typedef struct S1 { + int x; +} T1; + +struct S3 { + int x; +} s3; + +int func_global_1(int *a) { + *a = global_a[3]; + return 0; +} diff --git a/third_party/aom/tools/auto_refactor/c_files/parse_lvalue.c b/third_party/aom/tools/auto_refactor/c_files/parse_lvalue.c new file mode 100644 index 0000000000..fa44d72381 --- /dev/null +++ b/third_party/aom/tools/auto_refactor/c_files/parse_lvalue.c @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +typedef struct RD { + int u; + int v; + int arr[3]; +} RD; + +typedef struct VP9_COMP { + int y; + RD *rd; + RD rd2; + RD rd3[2]; +} VP9_COMP; + +int parse_lvalue_2(VP9_COMP *cpi) { RD *rd2 = &cpi->rd2; } + +int func(VP9_COMP *cpi, int x) { + cpi->rd->u = 0; + + int y; + y = 0; + + cpi->rd2.v = 0; + + cpi->rd->arr[2] = 0; + + cpi->rd3[1]->arr[2] = 0; + + return 0; +} + +int main(void) { + int x = 0; + VP9_COMP cpi; + func(&cpi, x); +} diff --git a/third_party/aom/tools/auto_refactor/c_files/simple_code.c b/third_party/aom/tools/auto_refactor/c_files/simple_code.c new file mode 100644 index 0000000000..902cd1d826 --- /dev/null +++ b/third_party/aom/tools/auto_refactor/c_files/simple_code.c @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +typedef struct S { + int x; + int y; + int z; +} S; + +typedef struct T { + S s; +} T; + +int d(S *s) { + ++s->x; + s->x--; + s->y = s->y + 1; + int *c = &s->x; + S ss; + ss.x = 1; + ss.x += 2; + ss.z *= 2; + return 0; +} +int b(S *s) { + d(s); + return 0; +} +int c(int x) { + if (x) { + c(x - 1); + } else { + S s; + d(&s); + } + return 0; +} +int a(S *s) { + b(s); + c(1); + return 0; +} +int e(void) { + c(0); + return 0; +} +int main(void) { + int p = 3; + S s; + s.x = p + 1; + s.y = 2; + s.z = 3; + a(&s); + T t; + t.s.x = 3; +} diff --git a/third_party/aom/tools/auto_refactor/c_files/struct_code.c b/third_party/aom/tools/auto_refactor/c_files/struct_code.c new file mode 100644 index 0000000000..7f24d41075 --- /dev/null +++ b/third_party/aom/tools/auto_refactor/c_files/struct_code.c @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2021, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +typedef struct S1 { + int x; +} T1; + +struct S3 { + int x; +}; + +typedef struct { + int x; + struct S3 s3; +} T4; + +typedef union U5 { + int x; + double y; +} T5; + +typedef struct S6 { + struct { + int x; + }; + union { + int y; + int z; + }; +} T6; + +typedef struct S7 { + struct { + int x; + } y; + union { + int w; + } z; +} T7; + +int main(void) {} diff --git a/third_party/aom/tools/auto_refactor/test_auto_refactor.py b/third_party/aom/tools/auto_refactor/test_auto_refactor.py new file mode 100644 index 0000000000..6b1e269efa --- /dev/null +++ b/third_party/aom/tools/auto_refactor/test_auto_refactor.py @@ -0,0 +1,675 @@ +#!/usr/bin/env python +# Copyright (c) 2021, Alliance for Open Media. All rights reserved +# +# This source code is subject to the terms of the BSD 2 Clause License and +# the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License +# was not distributed with this source code in the LICENSE file, you can +# obtain it at www.aomedia.org/license/software. If the Alliance for Open +# Media Patent License 1.0 was not distributed with this source code in the +# PATENTS file, you can obtain it at www.aomedia.org/license/patent. +# + +import pprint +import re +import os, sys +import io +import unittest as googletest + +sys.path[0:0] = ['.', '..'] + +from pycparser import c_parser, parse_file +from pycparser.c_ast import * +from pycparser.c_parser import CParser, Coord, ParseError + +from auto_refactor import * + + +def get_c_file_path(filename): + return os.path.join('c_files', filename) + + +class TestStructInfo(googletest.TestCase): + + def setUp(self): + filename = get_c_file_path('struct_code.c') + self.ast = parse_file(filename) + + def test_build_struct_info(self): + struct_info = build_struct_info(self.ast) + typedef_name_dic = struct_info.typedef_name_dic + self.assertEqual('T1' in typedef_name_dic, True) + self.assertEqual('T4' in typedef_name_dic, True) + self.assertEqual('T5' in typedef_name_dic, True) + + struct_name_dic = struct_info.struct_name_dic + struct_name = 'S1' + self.assertEqual(struct_name in struct_name_dic, True) + struct_item = struct_name_dic[struct_name] + self.assertEqual(struct_item.is_union, False) + + struct_name = 'S3' + self.assertEqual(struct_name in struct_name_dic, True) + struct_item = struct_name_dic[struct_name] + self.assertEqual(struct_item.is_union, False) + + struct_name = 'U5' + self.assertEqual(struct_name in struct_name_dic, True) + struct_item = struct_name_dic[struct_name] + self.assertEqual(struct_item.is_union, True) + + self.assertEqual(len(struct_info.struct_item_list), 6) + + def test_get_child_decl_status(self): + struct_info = build_struct_info(self.ast) + struct_item = struct_info.typedef_name_dic['T4'] + + decl_status = struct_item.child_decl_map['x'] + self.assertEqual(decl_status.struct_item, None) + self.assertEqual(decl_status.is_ptr_decl, False) + + decl_status = struct_item.child_decl_map['s3'] + self.assertEqual(decl_status.struct_item.struct_name, 'S3') + self.assertEqual(decl_status.is_ptr_decl, False) + + struct_item = struct_info.typedef_name_dic['T6'] + decl_status = struct_item.child_decl_map['x'] + self.assertEqual(decl_status.struct_item, None) + self.assertEqual(decl_status.is_ptr_decl, False) + + decl_status = struct_item.child_decl_map['y'] + self.assertEqual(decl_status.struct_item, None) + self.assertEqual(decl_status.is_ptr_decl, False) + + decl_status = struct_item.child_decl_map['z'] + self.assertEqual(decl_status.struct_item, None) + self.assertEqual(decl_status.is_ptr_decl, False) + + struct_item = struct_info.typedef_name_dic['T7'] + decl_status = struct_item.child_decl_map['y'] + self.assertEqual('x' in decl_status.struct_item.child_decl_map, True) + + struct_item = struct_info.typedef_name_dic['T7'] + decl_status = struct_item.child_decl_map['z'] + self.assertEqual('w' in decl_status.struct_item.child_decl_map, True) + + +class TestParseLvalue(googletest.TestCase): + + def setUp(self): + filename = get_c_file_path('parse_lvalue.c') + self.ast = parse_file(filename) + self.func_dictionary = build_func_dictionary(self.ast) + + def test_parse_lvalue(self): + func_node = self.func_dictionary['func'] + func_body_items = func_node.body.block_items + id_list = parse_lvalue(func_body_items[0].lvalue) + ref_id_list = ['cpi', 'rd', 'u'] + self.assertEqual(id_list, ref_id_list) + + id_list = parse_lvalue(func_body_items[2].lvalue) + ref_id_list = ['y'] + self.assertEqual(id_list, ref_id_list) + + id_list = parse_lvalue(func_body_items[3].lvalue) + ref_id_list = ['cpi', 'rd2', 'v'] + self.assertEqual(id_list, ref_id_list) + + id_list = parse_lvalue(func_body_items[4].lvalue) + ref_id_list = ['cpi', 'rd', 'arr'] + self.assertEqual(id_list, ref_id_list) + + id_list = parse_lvalue(func_body_items[5].lvalue) + ref_id_list = ['cpi', 'rd3', 'arr'] + self.assertEqual(id_list, ref_id_list) + + def test_parse_lvalue_2(self): + func_node = self.func_dictionary['parse_lvalue_2'] + func_body_items = func_node.body.block_items + id_list = parse_lvalue(func_body_items[0].init) + ref_id_list = ['cpi', 'rd2'] + self.assertEqual(id_list, ref_id_list) + + +class TestIDStatusNode(googletest.TestCase): + + def test_add_descendant(self): + root = IDStatusNode('root') + id_chain1 = ['cpi', 'rd', 'u'] + id_chain2 = ['cpi', 'rd', 'v'] + root.add_descendant(id_chain1) + root.add_descendant(id_chain2) + + ref_children_list1 = ['cpi'] + children_list1 = list(root.children.keys()) + self.assertEqual(children_list1, ref_children_list1) + + ref_children_list2 = ['rd'] + children_list2 = list(root.children['cpi'].children.keys()) + self.assertEqual(children_list2, ref_children_list2) + + ref_children_list3 = ['u', 'v'] + children_list3 = list(root.children['cpi'].children['rd'].children.keys()) + self.assertEqual(children_list3, ref_children_list3) + + def test_get_descendant(self): + root = IDStatusNode('root') + id_chain1 = ['cpi', 'rd', 'u'] + id_chain2 = ['cpi', 'rd', 'v'] + ref_descendant_1 = root.add_descendant(id_chain1) + ref_descendant_2 = root.add_descendant(id_chain2) + + descendant_1 = root.get_descendant(id_chain1) + self.assertEqual(descendant_1 is ref_descendant_1, True) + + descendant_2 = root.get_descendant(id_chain2) + self.assertEqual(descendant_2 is ref_descendant_2, True) + + id_chain3 = ['cpi', 'rd', 'h'] + descendant_3 = root.get_descendant(id_chain3) + self.assertEqual(descendant_3, None) + + +class TestFuncInOut(googletest.TestCase): + + def setUp(self): + c_filename = get_c_file_path('func_in_out.c') + self.ast = parse_file(c_filename) + self.func_dictionary = build_func_dictionary(self.ast) + self.struct_info = build_struct_info(self.ast) + + def test_get_func_param_id_map(self): + func_def_node = self.func_dictionary['func'] + param_id_map = get_func_param_id_map(func_def_node) + ref_param_id_map_keys = ['cpi', 'x'] + self.assertEqual(list(param_id_map.keys()), ref_param_id_map_keys) + + def test_assign_refer_status_1(self): + func_def_node = self.func_dictionary['func_assign_refer_status_1'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + body_id_tree = visitor.body_id_tree + + id_chain = ['rd'] + descendant = body_id_tree.get_descendant(id_chain) + self.assertEqual(descendant.get_assign(), False) + self.assertEqual(descendant.get_refer(), False) + ref_link_id_chain = ['cpi', 'rd'] + self.assertEqual(ref_link_id_chain, descendant.get_link_id_chain()) + + id_chain = ['cpi', 'rd'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), False) + self.assertEqual(descendant.get_refer(), False) + self.assertEqual(None, descendant.get_link_id_chain()) + + def test_assign_refer_status_2(self): + func_def_node = self.func_dictionary['func_assign_refer_status_2'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + body_id_tree = visitor.body_id_tree + + id_chain = ['rd2'] + descendant = body_id_tree.get_descendant(id_chain) + self.assertEqual(descendant.get_assign(), False) + self.assertEqual(descendant.get_refer(), False) + + ref_link_id_chain = ['cpi', 'rd'] + self.assertEqual(ref_link_id_chain, descendant.get_link_id_chain()) + + id_chain = ['cpi', 'rd'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), False) + self.assertEqual(descendant.get_refer(), False) + self.assertEqual(None, descendant.get_link_id_chain()) + + def test_assign_refer_status_3(self): + func_def_node = self.func_dictionary['func_assign_refer_status_3'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + body_id_tree = visitor.body_id_tree + + id_chain = ['a'] + descendant = body_id_tree.get_descendant(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + self.assertEqual(None, descendant.get_link_id_chain()) + + id_chain = ['cpi', 'y'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), False) + self.assertEqual(descendant.get_refer(), True) + self.assertEqual(None, descendant.get_link_id_chain()) + + def test_assign_refer_status_4(self): + func_def_node = self.func_dictionary['func_assign_refer_status_4'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + body_id_tree = visitor.body_id_tree + + id_chain = ['b'] + descendant = body_id_tree.get_descendant(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + self.assertEqual(None, descendant.get_link_id_chain()) + + id_chain = ['cpi', 'y'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), True) + self.assertEqual(None, descendant.get_link_id_chain()) + + def test_assign_refer_status_5(self): + func_def_node = self.func_dictionary['func_assign_refer_status_5'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + body_id_tree = visitor.body_id_tree + + id_chain = ['rd5'] + descendant = body_id_tree.get_descendant(id_chain) + self.assertEqual(descendant.get_assign(), False) + self.assertEqual(descendant.get_refer(), False) + + id_chain = ['cpi', 'rd2'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), False) + self.assertEqual(descendant.get_refer(), False) + self.assertEqual(None, descendant.get_link_id_chain()) + + def test_assign_refer_status_6(self): + func_def_node = self.func_dictionary['func_assign_refer_status_6'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + + id_chain = ['cpi', 'rd'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + self.assertEqual(None, descendant.get_link_id_chain()) + + id_chain = ['cpi2', 'rd'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), False) + self.assertEqual(descendant.get_refer(), True) + self.assertEqual(None, descendant.get_link_id_chain()) + + def test_assign_refer_status_7(self): + func_def_node = self.func_dictionary['func_assign_refer_status_7'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'arr'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + + def test_assign_refer_status_8(self): + func_def_node = self.func_dictionary['func_assign_refer_status_8'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'arr'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), False) + self.assertEqual(descendant.get_refer(), True) + + def test_assign_refer_status_9(self): + func_def_node = self.func_dictionary['func_assign_refer_status_9'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'rd', 'u'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + + def test_assign_refer_status_10(self): + func_def_node = self.func_dictionary['func_assign_refer_status_10'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'rd', 'u'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), False) + self.assertEqual(descendant.get_refer(), True) + + id_chain = ['cpi', 'arr'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + + def test_assign_refer_status_11(self): + func_def_node = self.func_dictionary['func_assign_refer_status_11'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'rd2', 'v'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + + def test_assign_refer_status_12(self): + func_def_node = self.func_dictionary['func_assign_refer_status_12'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'rd'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + + id_chain = ['cpi2', 'rd'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), False) + self.assertEqual(descendant.get_refer(), True) + + def test_assign_refer_status_13(self): + func_def_node = self.func_dictionary['func_assign_refer_status_13'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'z'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + + id_chain = ['cpi', 'w'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + + def test_id_status_forrest_1(self): + func_def_node = self.func_dictionary['func'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack().top() + children_names = set(root.get_children().keys()) + ref_children_names = set(['cpi', 'x']) + self.assertEqual(children_names, ref_children_names) + + root = visitor.body_id_tree + children_names = set(root.get_children().keys()) + ref_children_names = set(['a', 'ref_rd', 'ref_rd2', 'ref_rd3', 'b']) + self.assertEqual(children_names, ref_children_names) + + def test_id_status_forrest_show(self): + func_def_node = self.func_dictionary['func_id_forrest_show'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + visitor.get_id_tree_stack().top().show() + + def test_id_status_forrest_2(self): + func_def_node = self.func_dictionary['func_id_forrest_show'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack().top() + self.assertEqual(root, root.root) + + id_chain = ['cpi', 'rd'] + descendant = root.get_descendant(id_chain) + self.assertEqual(root, descendant.root) + + id_chain = ['b'] + descendant = root.get_descendant(id_chain) + self.assertEqual(root, descendant.root) + + def test_link_id_chain_1(self): + func_def_node = self.func_dictionary['func_link_id_chain_1'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'rd', 'u'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + + def test_link_id_chain_2(self): + func_def_node = self.func_dictionary['func_link_id_chain_2'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'rd', 'xd', 'u'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + + def test_func_call_1(self): + func_def_node = self.func_dictionary['func_call_1'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'y'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + + id_chain = ['y'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), False) + self.assertEqual(descendant.get_refer(), True) + + def test_func_call_2(self): + func_def_node = self.func_dictionary['func_call_2'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'rd', 'u'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + + id_chain = ['cpi', 'rd'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), False) + self.assertEqual(descendant.get_refer(), False) + + def test_func_call_3(self): + func_def_node = self.func_dictionary['func_call_3'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'y'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), True) + + def test_func_call_4(self): + func_def_node = self.func_dictionary['func_call_4'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + + id_chain = ['cpi', 'rd', 'u'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + + id_chain = ['cpi', 'rd', 'xd', 'u'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + + def test_func_call_5(self): + func_def_node = self.func_dictionary['func_call_5'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + + id_chain = ['cpi', 'y'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + + def test_func_compound_1(self): + func_def_node = self.func_dictionary['func_compound_1'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'y'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), True) + + def test_func_compound_2(self): + func_def_node = self.func_dictionary['func_compound_2'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'y'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), False) + self.assertEqual(descendant.get_refer(), True) + + id_chain = ['cpi', 'rd', 'u'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + + def test_func_compound_3(self): + func_def_node = self.func_dictionary['func_compound_3'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + + id_chain = ['cpi', 'rd', 'u'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), False) + + def test_func_compound_4(self): + func_def_node = self.func_dictionary['func_compound_4'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'y'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), True) + + def test_func_compound_5(self): + func_def_node = self.func_dictionary['func_compound_5'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'y'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), True) + + def test_func_compound_6(self): + func_def_node = self.func_dictionary['func_compound_6'] + visitor = FuncInOutVisitor(func_def_node, self.struct_info, + self.func_dictionary) + visitor.visit(func_def_node.body) + root = visitor.get_id_tree_stack() + id_chain = ['cpi', 'y'] + descendant = root.get_id_node(id_chain) + self.assertEqual(descendant.get_assign(), True) + self.assertEqual(descendant.get_refer(), True) + + +class TestDeclStatus(googletest.TestCase): + + def setUp(self): + filename = get_c_file_path('decl_status_code.c') + self.ast = parse_file(filename) + self.func_dictionary = build_func_dictionary(self.ast) + self.struct_info = build_struct_info(self.ast) + + def test_parse_decl_node(self): + func_def_node = self.func_dictionary['main'] + decl_list = func_def_node.body.block_items + decl_status = parse_decl_node(self.struct_info, decl_list[0]) + self.assertEqual(decl_status.name, 'a') + self.assertEqual(decl_status.is_ptr_decl, False) + + decl_status = parse_decl_node(self.struct_info, decl_list[1]) + self.assertEqual(decl_status.name, 't1') + self.assertEqual(decl_status.is_ptr_decl, False) + + decl_status = parse_decl_node(self.struct_info, decl_list[2]) + self.assertEqual(decl_status.name, 's1') + self.assertEqual(decl_status.is_ptr_decl, False) + + decl_status = parse_decl_node(self.struct_info, decl_list[3]) + self.assertEqual(decl_status.name, 't2') + self.assertEqual(decl_status.is_ptr_decl, True) + + def test_parse_decl_node_2(self): + func_def_node = self.func_dictionary['parse_decl_node_2'] + decl_list = func_def_node.body.block_items + decl_status = parse_decl_node(self.struct_info, decl_list[0]) + self.assertEqual(decl_status.name, 'arr') + self.assertEqual(decl_status.is_ptr_decl, True) + self.assertEqual(decl_status.struct_item, None) + + def test_parse_decl_node_3(self): + func_def_node = self.func_dictionary['parse_decl_node_3'] + decl_list = func_def_node.body.block_items + decl_status = parse_decl_node(self.struct_info, decl_list[0]) + self.assertEqual(decl_status.name, 'a') + self.assertEqual(decl_status.is_ptr_decl, True) + self.assertEqual(decl_status.struct_item, None) + + def test_parse_decl_node_4(self): + func_def_node = self.func_dictionary['parse_decl_node_4'] + decl_list = func_def_node.body.block_items + decl_status = parse_decl_node(self.struct_info, decl_list[0]) + self.assertEqual(decl_status.name, 't1') + self.assertEqual(decl_status.is_ptr_decl, True) + self.assertEqual(decl_status.struct_item.typedef_name, 'T1') + self.assertEqual(decl_status.struct_item.struct_name, 'S1') + + def test_parse_decl_node_5(self): + func_def_node = self.func_dictionary['parse_decl_node_5'] + decl_list = func_def_node.body.block_items + decl_status = parse_decl_node(self.struct_info, decl_list[0]) + self.assertEqual(decl_status.name, 't2') + self.assertEqual(decl_status.is_ptr_decl, True) + self.assertEqual(decl_status.struct_item.typedef_name, 'T1') + self.assertEqual(decl_status.struct_item.struct_name, 'S1') + + def test_parse_decl_node_6(self): + func_def_node = self.func_dictionary['parse_decl_node_6'] + decl_list = func_def_node.body.block_items + decl_status = parse_decl_node(self.struct_info, decl_list[0]) + self.assertEqual(decl_status.name, 't3') + self.assertEqual(decl_status.is_ptr_decl, True) + self.assertEqual(decl_status.struct_item.typedef_name, 'T1') + self.assertEqual(decl_status.struct_item.struct_name, 'S1') + + +if __name__ == '__main__': + googletest.main() diff --git a/third_party/aom/tools/cpplint.py b/third_party/aom/tools/cpplint.py new file mode 100755 index 0000000000..e3ebde2f5a --- /dev/null +++ b/third_party/aom/tools/cpplint.py @@ -0,0 +1,6244 @@ +#!/usr/bin/env python +# +# Copyright (c) 2009 Google Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Does google-lint on c++ files. + +The goal of this script is to identify places in the code that *may* +be in non-compliance with google style. It does not attempt to fix +up these problems -- the point is to educate. It does also not +attempt to find all problems, or to ensure that everything it does +find is legitimately a problem. + +In particular, we can get very confused by /* and // inside strings! +We do a small hack, which is to ignore //'s with "'s after them on the +same line, but it is far from perfect (in either direction). +""" + +import codecs +import copy +import getopt +import math # for log +import os +import re +import sre_compile +import string +import sys +import unicodedata +import sysconfig + +try: + xrange # Python 2 +except NameError: + xrange = range # Python 3 + + +_USAGE = """ +Syntax: cpplint.py [--verbose=#] [--output=vs7] [--filter=-x,+y,...] + [--counting=total|toplevel|detailed] [--root=subdir] + [--linelength=digits] [--headers=x,y,...] + [--quiet] + [file] ... + + The style guidelines this tries to follow are those in + https://google-styleguide.googlecode.com/svn/trunk/cppguide.xml + + Every problem is given a confidence score from 1-5, with 5 meaning we are + certain of the problem, and 1 meaning it could be a legitimate construct. + This will miss some errors, and is not a substitute for a code review. + + To suppress false-positive errors of a certain category, add a + 'NOLINT(category)' comment to the line. NOLINT or NOLINT(*) + suppresses errors of all categories on that line. + + The files passed in will be linted; at least one file must be provided. + Default linted extensions are .cc, .cpp, .cu, .cuh and .h. Change the + extensions with the --extensions flag. + + Flags: + + output=vs7 + By default, the output is formatted to ease emacs parsing. Visual Studio + compatible output (vs7) may also be used. Other formats are unsupported. + + verbose=# + Specify a number 0-5 to restrict errors to certain verbosity levels. + + quiet + Don't print anything if no errors are found. + + filter=-x,+y,... + Specify a comma-separated list of category-filters to apply: only + error messages whose category names pass the filters will be printed. + (Category names are printed with the message and look like + "[whitespace/indent]".) Filters are evaluated left to right. + "-FOO" and "FOO" means "do not print categories that start with FOO". + "+FOO" means "do print categories that start with FOO". + + Examples: --filter=-whitespace,+whitespace/braces + --filter=whitespace,runtime/printf,+runtime/printf_format + --filter=-,+build/include_what_you_use + + To see a list of all the categories used in cpplint, pass no arg: + --filter= + + counting=total|toplevel|detailed + The total number of errors found is always printed. If + 'toplevel' is provided, then the count of errors in each of + the top-level categories like 'build' and 'whitespace' will + also be printed. If 'detailed' is provided, then a count + is provided for each category like 'build/class'. + + root=subdir + The root directory used for deriving header guard CPP variable. + By default, the header guard CPP variable is calculated as the relative + path to the directory that contains .git, .hg, or .svn. When this flag + is specified, the relative path is calculated from the specified + directory. If the specified directory does not exist, this flag is + ignored. + + Examples: + Assuming that top/src/.git exists (and cwd=top/src), the header guard + CPP variables for top/src/chrome/browser/ui/browser.h are: + + No flag => CHROME_BROWSER_UI_BROWSER_H_ + --root=chrome => BROWSER_UI_BROWSER_H_ + --root=chrome/browser => UI_BROWSER_H_ + --root=.. => SRC_CHROME_BROWSER_UI_BROWSER_H_ + + linelength=digits + This is the allowed line length for the project. The default value is + 80 characters. + + Examples: + --linelength=120 + + extensions=extension,extension,... + The allowed file extensions that cpplint will check + + Examples: + --extensions=hpp,cpp + + headers=x,y,... + The header extensions that cpplint will treat as .h in checks. Values are + automatically added to --extensions list. + + Examples: + --headers=hpp,hxx + --headers=hpp + + cpplint.py supports per-directory configurations specified in CPPLINT.cfg + files. CPPLINT.cfg file can contain a number of key=value pairs. + Currently the following options are supported: + + set noparent + filter=+filter1,-filter2,... + exclude_files=regex + linelength=80 + root=subdir + headers=x,y,... + + "set noparent" option prevents cpplint from traversing directory tree + upwards looking for more .cfg files in parent directories. This option + is usually placed in the top-level project directory. + + The "filter" option is similar in function to --filter flag. It specifies + message filters in addition to the |_DEFAULT_FILTERS| and those specified + through --filter command-line flag. + + "exclude_files" allows to specify a regular expression to be matched against + a file name. If the expression matches, the file is skipped and not run + through liner. + + "linelength" allows to specify the allowed line length for the project. + + The "root" option is similar in function to the --root flag (see example + above). Paths are relative to the directory of the CPPLINT.cfg. + + The "headers" option is similar in function to the --headers flag + (see example above). + + CPPLINT.cfg has an effect on files in the same directory and all + sub-directories, unless overridden by a nested configuration file. + + Example file: + filter=-build/include_order,+build/include_alpha + exclude_files=.*\.cc + + The above example disables build/include_order warning and enables + build/include_alpha as well as excludes all .cc from being + processed by linter, in the current directory (where the .cfg + file is located) and all sub-directories. +""" + +# We categorize each error message we print. Here are the categories. +# We want an explicit list so we can list them all in cpplint --filter=. +# If you add a new error message with a new category, add it to the list +# here! cpplint_unittest.py should tell you if you forget to do this. +_ERROR_CATEGORIES = [ + 'build/class', + 'build/c++11', + 'build/c++14', + 'build/c++tr1', + 'build/deprecated', + 'build/endif_comment', + 'build/explicit_make_pair', + 'build/forward_decl', + 'build/header_guard', + 'build/include', + 'build/include_alpha', + 'build/include_order', + 'build/include_what_you_use', + 'build/namespaces', + 'build/printf_format', + 'build/storage_class', + 'legal/copyright', + 'readability/alt_tokens', + 'readability/braces', + 'readability/casting', + 'readability/check', + 'readability/constructors', + 'readability/fn_size', + 'readability/inheritance', + 'readability/multiline_comment', + 'readability/multiline_string', + 'readability/namespace', + 'readability/nolint', + 'readability/nul', + 'readability/strings', + 'readability/todo', + 'readability/utf8', + 'runtime/arrays', + 'runtime/casting', + 'runtime/explicit', + 'runtime/int', + 'runtime/init', + 'runtime/invalid_increment', + 'runtime/member_string_references', + 'runtime/memset', + 'runtime/indentation_namespace', + 'runtime/operator', + 'runtime/printf', + 'runtime/printf_format', + 'runtime/references', + 'runtime/string', + 'runtime/threadsafe_fn', + 'runtime/vlog', + 'whitespace/blank_line', + 'whitespace/braces', + 'whitespace/comma', + 'whitespace/comments', + 'whitespace/empty_conditional_body', + 'whitespace/empty_if_body', + 'whitespace/empty_loop_body', + 'whitespace/end_of_line', + 'whitespace/ending_newline', + 'whitespace/forcolon', + 'whitespace/indent', + 'whitespace/line_length', + 'whitespace/newline', + 'whitespace/operators', + 'whitespace/parens', + 'whitespace/semicolon', + 'whitespace/tab', + 'whitespace/todo', + ] + +# These error categories are no longer enforced by cpplint, but for backwards- +# compatibility they may still appear in NOLINT comments. +_LEGACY_ERROR_CATEGORIES = [ + 'readability/streams', + 'readability/function', + ] + +# The default state of the category filter. This is overridden by the --filter= +# flag. By default all errors are on, so only add here categories that should be +# off by default (i.e., categories that must be enabled by the --filter= flags). +# All entries here should start with a '-' or '+', as in the --filter= flag. +_DEFAULT_FILTERS = ['-build/include_alpha'] + +# The default list of categories suppressed for C (not C++) files. +_DEFAULT_C_SUPPRESSED_CATEGORIES = [ + 'readability/casting', + ] + +# The default list of categories suppressed for Linux Kernel files. +_DEFAULT_KERNEL_SUPPRESSED_CATEGORIES = [ + 'whitespace/tab', + ] + +# We used to check for high-bit characters, but after much discussion we +# decided those were OK, as long as they were in UTF-8 and didn't represent +# hard-coded international strings, which belong in a separate i18n file. + +# C++ headers +_CPP_HEADERS = frozenset([ + # Legacy + 'algobase.h', + 'algo.h', + 'alloc.h', + 'builtinbuf.h', + 'bvector.h', + 'complex.h', + 'defalloc.h', + 'deque.h', + 'editbuf.h', + 'fstream.h', + 'function.h', + 'hash_map', + 'hash_map.h', + 'hash_set', + 'hash_set.h', + 'hashtable.h', + 'heap.h', + 'indstream.h', + 'iomanip.h', + 'iostream.h', + 'istream.h', + 'iterator.h', + 'list.h', + 'map.h', + 'multimap.h', + 'multiset.h', + 'ostream.h', + 'pair.h', + 'parsestream.h', + 'pfstream.h', + 'procbuf.h', + 'pthread_alloc', + 'pthread_alloc.h', + 'rope', + 'rope.h', + 'ropeimpl.h', + 'set.h', + 'slist', + 'slist.h', + 'stack.h', + 'stdiostream.h', + 'stl_alloc.h', + 'stl_relops.h', + 'streambuf.h', + 'stream.h', + 'strfile.h', + 'strstream.h', + 'tempbuf.h', + 'tree.h', + 'type_traits.h', + 'vector.h', + # 17.6.1.2 C++ library headers + 'algorithm', + 'array', + 'atomic', + 'bitset', + 'chrono', + 'codecvt', + 'complex', + 'condition_variable', + 'deque', + 'exception', + 'forward_list', + 'fstream', + 'functional', + 'future', + 'initializer_list', + 'iomanip', + 'ios', + 'iosfwd', + 'iostream', + 'istream', + 'iterator', + 'limits', + 'list', + 'locale', + 'map', + 'memory', + 'mutex', + 'new', + 'numeric', + 'ostream', + 'queue', + 'random', + 'ratio', + 'regex', + 'scoped_allocator', + 'set', + 'sstream', + 'stack', + 'stdexcept', + 'streambuf', + 'string', + 'strstream', + 'system_error', + 'thread', + 'tuple', + 'typeindex', + 'typeinfo', + 'type_traits', + 'unordered_map', + 'unordered_set', + 'utility', + 'valarray', + 'vector', + # 17.6.1.2 C++ headers for C library facilities + 'cassert', + 'ccomplex', + 'cctype', + 'cerrno', + 'cfenv', + 'cfloat', + 'cinttypes', + 'ciso646', + 'climits', + 'clocale', + 'cmath', + 'csetjmp', + 'csignal', + 'cstdalign', + 'cstdarg', + 'cstdbool', + 'cstddef', + 'cstdint', + 'cstdio', + 'cstdlib', + 'cstring', + 'ctgmath', + 'ctime', + 'cuchar', + 'cwchar', + 'cwctype', + ]) + +# Type names +_TYPES = re.compile( + r'^(?:' + # [dcl.type.simple] + r'(char(16_t|32_t)?)|wchar_t|' + r'bool|short|int|long|signed|unsigned|float|double|' + # [support.types] + r'(ptrdiff_t|size_t|max_align_t|nullptr_t)|' + # [cstdint.syn] + r'(u?int(_fast|_least)?(8|16|32|64)_t)|' + r'(u?int(max|ptr)_t)|' + r')$') + + +# These headers are excluded from [build/include] and [build/include_order] +# checks: +# - Anything not following google file name conventions (containing an +# uppercase character, such as Python.h or nsStringAPI.h, for example). +# - Lua headers. +_THIRD_PARTY_HEADERS_PATTERN = re.compile( + r'^(?:[^/]*[A-Z][^/]*\.h|lua\.h|lauxlib\.h|lualib\.h)$') + +# Pattern for matching FileInfo.BaseName() against test file name +_TEST_FILE_SUFFIX = r'(_test|_unittest|_regtest)$' + +# Pattern that matches only complete whitespace, possibly across multiple lines. +_EMPTY_CONDITIONAL_BODY_PATTERN = re.compile(r'^\s*$', re.DOTALL) + +# Assertion macros. These are defined in base/logging.h and +# testing/base/public/gunit.h. +_CHECK_MACROS = [ + 'DCHECK', 'CHECK', + 'EXPECT_TRUE', 'ASSERT_TRUE', + 'EXPECT_FALSE', 'ASSERT_FALSE', + ] + +# Replacement macros for CHECK/DCHECK/EXPECT_TRUE/EXPECT_FALSE +_CHECK_REPLACEMENT = dict([(m, {}) for m in _CHECK_MACROS]) + +for op, replacement in [('==', 'EQ'), ('!=', 'NE'), + ('>=', 'GE'), ('>', 'GT'), + ('<=', 'LE'), ('<', 'LT')]: + _CHECK_REPLACEMENT['DCHECK'][op] = 'DCHECK_%s' % replacement + _CHECK_REPLACEMENT['CHECK'][op] = 'CHECK_%s' % replacement + _CHECK_REPLACEMENT['EXPECT_TRUE'][op] = 'EXPECT_%s' % replacement + _CHECK_REPLACEMENT['ASSERT_TRUE'][op] = 'ASSERT_%s' % replacement + +for op, inv_replacement in [('==', 'NE'), ('!=', 'EQ'), + ('>=', 'LT'), ('>', 'LE'), + ('<=', 'GT'), ('<', 'GE')]: + _CHECK_REPLACEMENT['EXPECT_FALSE'][op] = 'EXPECT_%s' % inv_replacement + _CHECK_REPLACEMENT['ASSERT_FALSE'][op] = 'ASSERT_%s' % inv_replacement + +# Alternative tokens and their replacements. For full list, see section 2.5 +# Alternative tokens [lex.digraph] in the C++ standard. +# +# Digraphs (such as '%:') are not included here since it's a mess to +# match those on a word boundary. +_ALT_TOKEN_REPLACEMENT = { + 'and': '&&', + 'bitor': '|', + 'or': '||', + 'xor': '^', + 'compl': '~', + 'bitand': '&', + 'and_eq': '&=', + 'or_eq': '|=', + 'xor_eq': '^=', + 'not': '!', + 'not_eq': '!=' + } + +# Compile regular expression that matches all the above keywords. The "[ =()]" +# bit is meant to avoid matching these keywords outside of boolean expressions. +# +# False positives include C-style multi-line comments and multi-line strings +# but those have always been troublesome for cpplint. +_ALT_TOKEN_REPLACEMENT_PATTERN = re.compile( + r'[ =()](' + ('|'.join(_ALT_TOKEN_REPLACEMENT.keys())) + r')(?=[ (]|$)') + + +# These constants define types of headers for use with +# _IncludeState.CheckNextIncludeOrder(). +_C_SYS_HEADER = 1 +_CPP_SYS_HEADER = 2 +_LIKELY_MY_HEADER = 3 +_POSSIBLE_MY_HEADER = 4 +_OTHER_HEADER = 5 + +# These constants define the current inline assembly state +_NO_ASM = 0 # Outside of inline assembly block +_INSIDE_ASM = 1 # Inside inline assembly block +_END_ASM = 2 # Last line of inline assembly block +_BLOCK_ASM = 3 # The whole block is an inline assembly block + +# Match start of assembly blocks +_MATCH_ASM = re.compile(r'^\s*(?:asm|_asm|__asm|__asm__)' + r'(?:\s+(volatile|__volatile__))?' + r'\s*[{(]') + +# Match strings that indicate we're working on a C (not C++) file. +_SEARCH_C_FILE = re.compile(r'\b(?:LINT_C_FILE|' + r'vim?:\s*.*(\s*|:)filetype=c(\s*|:|$))') + +# Match string that indicates we're working on a Linux Kernel file. +_SEARCH_KERNEL_FILE = re.compile(r'\b(?:LINT_KERNEL_FILE)') + +_regexp_compile_cache = {} + +# {str, set(int)}: a map from error categories to sets of linenumbers +# on which those errors are expected and should be suppressed. +_error_suppressions = {} + +# The root directory used for deriving header guard CPP variable. +# This is set by --root flag. +_root = None +_root_debug = False + +# The allowed line length of files. +# This is set by --linelength flag. +_line_length = 80 + +# The allowed extensions for file names +# This is set by --extensions flag. +_valid_extensions = set(['cc', 'h', 'cpp', 'cu', 'cuh']) + +# Treat all headers starting with 'h' equally: .h, .hpp, .hxx etc. +# This is set by --headers flag. +_hpp_headers = set(['h']) + +# {str, bool}: a map from error categories to booleans which indicate if the +# category should be suppressed for every line. +_global_error_suppressions = {} + +def ProcessHppHeadersOption(val): + global _hpp_headers + try: + _hpp_headers = set(val.split(',')) + # Automatically append to extensions list so it does not have to be set 2 times + _valid_extensions.update(_hpp_headers) + except ValueError: + PrintUsage('Header extensions must be comma separated list.') + +def IsHeaderExtension(file_extension): + return file_extension in _hpp_headers + +def ParseNolintSuppressions(filename, raw_line, linenum, error): + """Updates the global list of line error-suppressions. + + Parses any NOLINT comments on the current line, updating the global + error_suppressions store. Reports an error if the NOLINT comment + was malformed. + + Args: + filename: str, the name of the input file. + raw_line: str, the line of input text, with comments. + linenum: int, the number of the current line. + error: function, an error handler. + """ + matched = Search(r'\bNOLINT(NEXTLINE)?\b(\([^)]+\))?', raw_line) + if matched: + if matched.group(1): + suppressed_line = linenum + 1 + else: + suppressed_line = linenum + category = matched.group(2) + if category in (None, '(*)'): # => "suppress all" + _error_suppressions.setdefault(None, set()).add(suppressed_line) + else: + if category.startswith('(') and category.endswith(')'): + category = category[1:-1] + if category in _ERROR_CATEGORIES: + _error_suppressions.setdefault(category, set()).add(suppressed_line) + elif category not in _LEGACY_ERROR_CATEGORIES: + error(filename, linenum, 'readability/nolint', 5, + 'Unknown NOLINT error category: %s' % category) + + +def ProcessGlobalSuppresions(lines): + """Updates the list of global error suppressions. + + Parses any lint directives in the file that have global effect. + + Args: + lines: An array of strings, each representing a line of the file, with the + last element being empty if the file is terminated with a newline. + """ + for line in lines: + if _SEARCH_C_FILE.search(line): + for category in _DEFAULT_C_SUPPRESSED_CATEGORIES: + _global_error_suppressions[category] = True + if _SEARCH_KERNEL_FILE.search(line): + for category in _DEFAULT_KERNEL_SUPPRESSED_CATEGORIES: + _global_error_suppressions[category] = True + + +def ResetNolintSuppressions(): + """Resets the set of NOLINT suppressions to empty.""" + _error_suppressions.clear() + _global_error_suppressions.clear() + + +def IsErrorSuppressedByNolint(category, linenum): + """Returns true if the specified error category is suppressed on this line. + + Consults the global error_suppressions map populated by + ParseNolintSuppressions/ProcessGlobalSuppresions/ResetNolintSuppressions. + + Args: + category: str, the category of the error. + linenum: int, the current line number. + Returns: + bool, True iff the error should be suppressed due to a NOLINT comment or + global suppression. + """ + return (_global_error_suppressions.get(category, False) or + linenum in _error_suppressions.get(category, set()) or + linenum in _error_suppressions.get(None, set())) + + +def Match(pattern, s): + """Matches the string with the pattern, caching the compiled regexp.""" + # The regexp compilation caching is inlined in both Match and Search for + # performance reasons; factoring it out into a separate function turns out + # to be noticeably expensive. + if pattern not in _regexp_compile_cache: + _regexp_compile_cache[pattern] = sre_compile.compile(pattern) + return _regexp_compile_cache[pattern].match(s) + + +def ReplaceAll(pattern, rep, s): + """Replaces instances of pattern in a string with a replacement. + + The compiled regex is kept in a cache shared by Match and Search. + + Args: + pattern: regex pattern + rep: replacement text + s: search string + + Returns: + string with replacements made (or original string if no replacements) + """ + if pattern not in _regexp_compile_cache: + _regexp_compile_cache[pattern] = sre_compile.compile(pattern) + return _regexp_compile_cache[pattern].sub(rep, s) + + +def Search(pattern, s): + """Searches the string for the pattern, caching the compiled regexp.""" + if pattern not in _regexp_compile_cache: + _regexp_compile_cache[pattern] = sre_compile.compile(pattern) + return _regexp_compile_cache[pattern].search(s) + + +def _IsSourceExtension(s): + """File extension (excluding dot) matches a source file extension.""" + return s in ('c', 'cc', 'cpp', 'cxx') + + +class _IncludeState(object): + """Tracks line numbers for includes, and the order in which includes appear. + + include_list contains list of lists of (header, line number) pairs. + It's a lists of lists rather than just one flat list to make it + easier to update across preprocessor boundaries. + + Call CheckNextIncludeOrder() once for each header in the file, passing + in the type constants defined above. Calls in an illegal order will + raise an _IncludeError with an appropriate error message. + + """ + # self._section will move monotonically through this set. If it ever + # needs to move backwards, CheckNextIncludeOrder will raise an error. + _INITIAL_SECTION = 0 + _MY_H_SECTION = 1 + _C_SECTION = 2 + _CPP_SECTION = 3 + _OTHER_H_SECTION = 4 + + _TYPE_NAMES = { + _C_SYS_HEADER: 'C system header', + _CPP_SYS_HEADER: 'C++ system header', + _LIKELY_MY_HEADER: 'header this file implements', + _POSSIBLE_MY_HEADER: 'header this file may implement', + _OTHER_HEADER: 'other header', + } + _SECTION_NAMES = { + _INITIAL_SECTION: "... nothing. (This can't be an error.)", + _MY_H_SECTION: 'a header this file implements', + _C_SECTION: 'C system header', + _CPP_SECTION: 'C++ system header', + _OTHER_H_SECTION: 'other header', + } + + def __init__(self): + self.include_list = [[]] + self.ResetSection('') + + def FindHeader(self, header): + """Check if a header has already been included. + + Args: + header: header to check. + Returns: + Line number of previous occurrence, or -1 if the header has not + been seen before. + """ + for section_list in self.include_list: + for f in section_list: + if f[0] == header: + return f[1] + return -1 + + def ResetSection(self, directive): + """Reset section checking for preprocessor directive. + + Args: + directive: preprocessor directive (e.g. "if", "else"). + """ + # The name of the current section. + self._section = self._INITIAL_SECTION + # The path of last found header. + self._last_header = '' + + # Update list of includes. Note that we never pop from the + # include list. + if directive in ('if', 'ifdef', 'ifndef'): + self.include_list.append([]) + elif directive in ('else', 'elif'): + self.include_list[-1] = [] + + def SetLastHeader(self, header_path): + self._last_header = header_path + + def CanonicalizeAlphabeticalOrder(self, header_path): + """Returns a path canonicalized for alphabetical comparison. + + - replaces "-" with "_" so they both cmp the same. + - removes '-inl' since we don't require them to be after the main header. + - lowercase everything, just in case. + + Args: + header_path: Path to be canonicalized. + + Returns: + Canonicalized path. + """ + return header_path.replace('-inl.h', '.h').replace('-', '_').lower() + + def IsInAlphabeticalOrder(self, clean_lines, linenum, header_path): + """Check if a header is in alphabetical order with the previous header. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + header_path: Canonicalized header to be checked. + + Returns: + Returns true if the header is in alphabetical order. + """ + # If previous section is different from current section, _last_header will + # be reset to empty string, so it's always less than current header. + # + # If previous line was a blank line, assume that the headers are + # intentionally sorted the way they are. + if (self._last_header > header_path and + Match(r'^\s*#\s*include\b', clean_lines.elided[linenum - 1])): + return False + return True + + def CheckNextIncludeOrder(self, header_type): + """Returns a non-empty error message if the next header is out of order. + + This function also updates the internal state to be ready to check + the next include. + + Args: + header_type: One of the _XXX_HEADER constants defined above. + + Returns: + The empty string if the header is in the right order, or an + error message describing what's wrong. + + """ + error_message = ('Found %s after %s' % + (self._TYPE_NAMES[header_type], + self._SECTION_NAMES[self._section])) + + last_section = self._section + + if header_type == _C_SYS_HEADER: + if self._section <= self._C_SECTION: + self._section = self._C_SECTION + else: + self._last_header = '' + return error_message + elif header_type == _CPP_SYS_HEADER: + if self._section <= self._CPP_SECTION: + self._section = self._CPP_SECTION + else: + self._last_header = '' + return error_message + elif header_type == _LIKELY_MY_HEADER: + if self._section <= self._MY_H_SECTION: + self._section = self._MY_H_SECTION + else: + self._section = self._OTHER_H_SECTION + elif header_type == _POSSIBLE_MY_HEADER: + if self._section <= self._MY_H_SECTION: + self._section = self._MY_H_SECTION + else: + # This will always be the fallback because we're not sure + # enough that the header is associated with this file. + self._section = self._OTHER_H_SECTION + else: + assert header_type == _OTHER_HEADER + self._section = self._OTHER_H_SECTION + + if last_section != self._section: + self._last_header = '' + + return '' + + +class _CppLintState(object): + """Maintains module-wide state..""" + + def __init__(self): + self.verbose_level = 1 # global setting. + self.error_count = 0 # global count of reported errors + # filters to apply when emitting error messages + self.filters = _DEFAULT_FILTERS[:] + # backup of filter list. Used to restore the state after each file. + self._filters_backup = self.filters[:] + self.counting = 'total' # In what way are we counting errors? + self.errors_by_category = {} # string to int dict storing error counts + self.quiet = False # Suppress non-error messagess? + + # output format: + # "emacs" - format that emacs can parse (default) + # "vs7" - format that Microsoft Visual Studio 7 can parse + self.output_format = 'emacs' + + def SetOutputFormat(self, output_format): + """Sets the output format for errors.""" + self.output_format = output_format + + def SetQuiet(self, quiet): + """Sets the module's quiet settings, and returns the previous setting.""" + last_quiet = self.quiet + self.quiet = quiet + return last_quiet + + def SetVerboseLevel(self, level): + """Sets the module's verbosity, and returns the previous setting.""" + last_verbose_level = self.verbose_level + self.verbose_level = level + return last_verbose_level + + def SetCountingStyle(self, counting_style): + """Sets the module's counting options.""" + self.counting = counting_style + + def SetFilters(self, filters): + """Sets the error-message filters. + + These filters are applied when deciding whether to emit a given + error message. + + Args: + filters: A string of comma-separated filters (eg "+whitespace/indent"). + Each filter should start with + or -; else we die. + + Raises: + ValueError: The comma-separated filters did not all start with '+' or '-'. + E.g. "-,+whitespace,-whitespace/indent,whitespace/badfilter" + """ + # Default filters always have less priority than the flag ones. + self.filters = _DEFAULT_FILTERS[:] + self.AddFilters(filters) + + def AddFilters(self, filters): + """ Adds more filters to the existing list of error-message filters. """ + for filt in filters.split(','): + clean_filt = filt.strip() + if clean_filt: + self.filters.append(clean_filt) + for filt in self.filters: + if not (filt.startswith('+') or filt.startswith('-')): + raise ValueError('Every filter in --filters must start with + or -' + ' (%s does not)' % filt) + + def BackupFilters(self): + """ Saves the current filter list to backup storage.""" + self._filters_backup = self.filters[:] + + def RestoreFilters(self): + """ Restores filters previously backed up.""" + self.filters = self._filters_backup[:] + + def ResetErrorCounts(self): + """Sets the module's error statistic back to zero.""" + self.error_count = 0 + self.errors_by_category = {} + + def IncrementErrorCount(self, category): + """Bumps the module's error statistic.""" + self.error_count += 1 + if self.counting in ('toplevel', 'detailed'): + if self.counting != 'detailed': + category = category.split('/')[0] + if category not in self.errors_by_category: + self.errors_by_category[category] = 0 + self.errors_by_category[category] += 1 + + def PrintErrorCounts(self): + """Print a summary of errors by category, and the total.""" + for category, count in self.errors_by_category.iteritems(): + sys.stderr.write('Category \'%s\' errors found: %d\n' % + (category, count)) + sys.stdout.write('Total errors found: %d\n' % self.error_count) + +_cpplint_state = _CppLintState() + + +def _OutputFormat(): + """Gets the module's output format.""" + return _cpplint_state.output_format + + +def _SetOutputFormat(output_format): + """Sets the module's output format.""" + _cpplint_state.SetOutputFormat(output_format) + +def _Quiet(): + """Return's the module's quiet setting.""" + return _cpplint_state.quiet + +def _SetQuiet(quiet): + """Set the module's quiet status, and return previous setting.""" + return _cpplint_state.SetQuiet(quiet) + + +def _VerboseLevel(): + """Returns the module's verbosity setting.""" + return _cpplint_state.verbose_level + + +def _SetVerboseLevel(level): + """Sets the module's verbosity, and returns the previous setting.""" + return _cpplint_state.SetVerboseLevel(level) + + +def _SetCountingStyle(level): + """Sets the module's counting options.""" + _cpplint_state.SetCountingStyle(level) + + +def _Filters(): + """Returns the module's list of output filters, as a list.""" + return _cpplint_state.filters + + +def _SetFilters(filters): + """Sets the module's error-message filters. + + These filters are applied when deciding whether to emit a given + error message. + + Args: + filters: A string of comma-separated filters (eg "whitespace/indent"). + Each filter should start with + or -; else we die. + """ + _cpplint_state.SetFilters(filters) + +def _AddFilters(filters): + """Adds more filter overrides. + + Unlike _SetFilters, this function does not reset the current list of filters + available. + + Args: + filters: A string of comma-separated filters (eg "whitespace/indent"). + Each filter should start with + or -; else we die. + """ + _cpplint_state.AddFilters(filters) + +def _BackupFilters(): + """ Saves the current filter list to backup storage.""" + _cpplint_state.BackupFilters() + +def _RestoreFilters(): + """ Restores filters previously backed up.""" + _cpplint_state.RestoreFilters() + +class _FunctionState(object): + """Tracks current function name and the number of lines in its body.""" + + _NORMAL_TRIGGER = 250 # for --v=0, 500 for --v=1, etc. + _TEST_TRIGGER = 400 # about 50% more than _NORMAL_TRIGGER. + + def __init__(self): + self.in_a_function = False + self.lines_in_function = 0 + self.current_function = '' + + def Begin(self, function_name): + """Start analyzing function body. + + Args: + function_name: The name of the function being tracked. + """ + self.in_a_function = True + self.lines_in_function = 0 + self.current_function = function_name + + def Count(self): + """Count line in current function body.""" + if self.in_a_function: + self.lines_in_function += 1 + + def Check(self, error, filename, linenum): + """Report if too many lines in function body. + + Args: + error: The function to call with any errors found. + filename: The name of the current file. + linenum: The number of the line to check. + """ + if not self.in_a_function: + return + + if Match(r'T(EST|est)', self.current_function): + base_trigger = self._TEST_TRIGGER + else: + base_trigger = self._NORMAL_TRIGGER + trigger = base_trigger * 2**_VerboseLevel() + + if self.lines_in_function > trigger: + error_level = int(math.log(self.lines_in_function / base_trigger, 2)) + # 50 => 0, 100 => 1, 200 => 2, 400 => 3, 800 => 4, 1600 => 5, ... + if error_level > 5: + error_level = 5 + error(filename, linenum, 'readability/fn_size', error_level, + 'Small and focused functions are preferred:' + ' %s has %d non-comment lines' + ' (error triggered by exceeding %d lines).' % ( + self.current_function, self.lines_in_function, trigger)) + + def End(self): + """Stop analyzing function body.""" + self.in_a_function = False + + +class _IncludeError(Exception): + """Indicates a problem with the include order in a file.""" + pass + + +class FileInfo(object): + """Provides utility functions for filenames. + + FileInfo provides easy access to the components of a file's path + relative to the project root. + """ + + def __init__(self, filename): + self._filename = filename + + def FullName(self): + """Make Windows paths like Unix.""" + return os.path.abspath(self._filename).replace('\\', '/') + + def RepositoryName(self): + """FullName after removing the local path to the repository. + + If we have a real absolute path name here we can try to do something smart: + detecting the root of the checkout and truncating /path/to/checkout from + the name so that we get header guards that don't include things like + "C:\Documents and Settings\..." or "/home/username/..." in them and thus + people on different computers who have checked the source out to different + locations won't see bogus errors. + """ + fullname = self.FullName() + + if os.path.exists(fullname): + project_dir = os.path.dirname(fullname) + + if os.path.exists(os.path.join(project_dir, ".svn")): + # If there's a .svn file in the current directory, we recursively look + # up the directory tree for the top of the SVN checkout + root_dir = project_dir + one_up_dir = os.path.dirname(root_dir) + while os.path.exists(os.path.join(one_up_dir, ".svn")): + root_dir = os.path.dirname(root_dir) + one_up_dir = os.path.dirname(one_up_dir) + + prefix = os.path.commonprefix([root_dir, project_dir]) + return fullname[len(prefix) + 1:] + + # Not SVN <= 1.6? Try to find a git, hg, or svn top level directory by + # searching up from the current path. + root_dir = current_dir = os.path.dirname(fullname) + while current_dir != os.path.dirname(current_dir): + if (os.path.exists(os.path.join(current_dir, ".git")) or + os.path.exists(os.path.join(current_dir, ".hg")) or + os.path.exists(os.path.join(current_dir, ".svn"))): + root_dir = current_dir + current_dir = os.path.dirname(current_dir) + + if (os.path.exists(os.path.join(root_dir, ".git")) or + os.path.exists(os.path.join(root_dir, ".hg")) or + os.path.exists(os.path.join(root_dir, ".svn"))): + prefix = os.path.commonprefix([root_dir, project_dir]) + return fullname[len(prefix) + 1:] + + # Don't know what to do; header guard warnings may be wrong... + return fullname + + def Split(self): + """Splits the file into the directory, basename, and extension. + + For 'chrome/browser/browser.cc', Split() would + return ('chrome/browser', 'browser', '.cc') + + Returns: + A tuple of (directory, basename, extension). + """ + + googlename = self.RepositoryName() + project, rest = os.path.split(googlename) + return (project,) + os.path.splitext(rest) + + def BaseName(self): + """File base name - text after the final slash, before the final period.""" + return self.Split()[1] + + def Extension(self): + """File extension - text following the final period.""" + return self.Split()[2] + + def NoExtension(self): + """File has no source file extension.""" + return '/'.join(self.Split()[0:2]) + + def IsSource(self): + """File has a source file extension.""" + return _IsSourceExtension(self.Extension()[1:]) + + +def _ShouldPrintError(category, confidence, linenum): + """If confidence >= verbose, category passes filter and is not suppressed.""" + + # There are three ways we might decide not to print an error message: + # a "NOLINT(category)" comment appears in the source, + # the verbosity level isn't high enough, or the filters filter it out. + if IsErrorSuppressedByNolint(category, linenum): + return False + + if confidence < _cpplint_state.verbose_level: + return False + + is_filtered = False + for one_filter in _Filters(): + if one_filter.startswith('-'): + if category.startswith(one_filter[1:]): + is_filtered = True + elif one_filter.startswith('+'): + if category.startswith(one_filter[1:]): + is_filtered = False + else: + assert False # should have been checked for in SetFilter. + if is_filtered: + return False + + return True + + +def Error(filename, linenum, category, confidence, message): + """Logs the fact we've found a lint error. + + We log where the error was found, and also our confidence in the error, + that is, how certain we are this is a legitimate style regression, and + not a misidentification or a use that's sometimes justified. + + False positives can be suppressed by the use of + "cpplint(category)" comments on the offending line. These are + parsed into _error_suppressions. + + Args: + filename: The name of the file containing the error. + linenum: The number of the line containing the error. + category: A string used to describe the "category" this bug + falls under: "whitespace", say, or "runtime". Categories + may have a hierarchy separated by slashes: "whitespace/indent". + confidence: A number from 1-5 representing a confidence score for + the error, with 5 meaning that we are certain of the problem, + and 1 meaning that it could be a legitimate construct. + message: The error message. + """ + if _ShouldPrintError(category, confidence, linenum): + _cpplint_state.IncrementErrorCount(category) + if _cpplint_state.output_format == 'vs7': + sys.stderr.write('%s(%s): error cpplint: [%s] %s [%d]\n' % ( + filename, linenum, category, message, confidence)) + elif _cpplint_state.output_format == 'eclipse': + sys.stderr.write('%s:%s: warning: %s [%s] [%d]\n' % ( + filename, linenum, message, category, confidence)) + else: + sys.stderr.write('%s:%s: %s [%s] [%d]\n' % ( + filename, linenum, message, category, confidence)) + + +# Matches standard C++ escape sequences per 2.13.2.3 of the C++ standard. +_RE_PATTERN_CLEANSE_LINE_ESCAPES = re.compile( + r'\\([abfnrtv?"\\\']|\d+|x[0-9a-fA-F]+)') +# Match a single C style comment on the same line. +_RE_PATTERN_C_COMMENTS = r'/\*(?:[^*]|\*(?!/))*\*/' +# Matches multi-line C style comments. +# This RE is a little bit more complicated than one might expect, because we +# have to take care of space removals tools so we can handle comments inside +# statements better. +# The current rule is: We only clear spaces from both sides when we're at the +# end of the line. Otherwise, we try to remove spaces from the right side, +# if this doesn't work we try on left side but only if there's a non-character +# on the right. +_RE_PATTERN_CLEANSE_LINE_C_COMMENTS = re.compile( + r'(\s*' + _RE_PATTERN_C_COMMENTS + r'\s*$|' + + _RE_PATTERN_C_COMMENTS + r'\s+|' + + r'\s+' + _RE_PATTERN_C_COMMENTS + r'(?=\W)|' + + _RE_PATTERN_C_COMMENTS + r')') + + +def IsCppString(line): + """Does line terminate so, that the next symbol is in string constant. + + This function does not consider single-line nor multi-line comments. + + Args: + line: is a partial line of code starting from the 0..n. + + Returns: + True, if next character appended to 'line' is inside a + string constant. + """ + + line = line.replace(r'\\', 'XX') # after this, \\" does not match to \" + return ((line.count('"') - line.count(r'\"') - line.count("'\"'")) & 1) == 1 + + +def CleanseRawStrings(raw_lines): + """Removes C++11 raw strings from lines. + + Before: + static const char kData[] = R"( + multi-line string + )"; + + After: + static const char kData[] = "" + (replaced by blank line) + ""; + + Args: + raw_lines: list of raw lines. + + Returns: + list of lines with C++11 raw strings replaced by empty strings. + """ + + delimiter = None + lines_without_raw_strings = [] + for line in raw_lines: + if delimiter: + # Inside a raw string, look for the end + end = line.find(delimiter) + if end >= 0: + # Found the end of the string, match leading space for this + # line and resume copying the original lines, and also insert + # a "" on the last line. + leading_space = Match(r'^(\s*)\S', line) + line = leading_space.group(1) + '""' + line[end + len(delimiter):] + delimiter = None + else: + # Haven't found the end yet, append a blank line. + line = '""' + + # Look for beginning of a raw string, and replace them with + # empty strings. This is done in a loop to handle multiple raw + # strings on the same line. + while delimiter is None: + # Look for beginning of a raw string. + # See 2.14.15 [lex.string] for syntax. + # + # Once we have matched a raw string, we check the prefix of the + # line to make sure that the line is not part of a single line + # comment. It's done this way because we remove raw strings + # before removing comments as opposed to removing comments + # before removing raw strings. This is because there are some + # cpplint checks that requires the comments to be preserved, but + # we don't want to check comments that are inside raw strings. + matched = Match(r'^(.*?)\b(?:R|u8R|uR|UR|LR)"([^\s\\()]*)\((.*)$', line) + if (matched and + not Match(r'^([^\'"]|\'(\\.|[^\'])*\'|"(\\.|[^"])*")*//', + matched.group(1))): + delimiter = ')' + matched.group(2) + '"' + + end = matched.group(3).find(delimiter) + if end >= 0: + # Raw string ended on same line + line = (matched.group(1) + '""' + + matched.group(3)[end + len(delimiter):]) + delimiter = None + else: + # Start of a multi-line raw string + line = matched.group(1) + '""' + else: + break + + lines_without_raw_strings.append(line) + + # TODO(unknown): if delimiter is not None here, we might want to + # emit a warning for unterminated string. + return lines_without_raw_strings + + +def FindNextMultiLineCommentStart(lines, lineix): + """Find the beginning marker for a multiline comment.""" + while lineix < len(lines): + if lines[lineix].strip().startswith('/*'): + # Only return this marker if the comment goes beyond this line + if lines[lineix].strip().find('*/', 2) < 0: + return lineix + lineix += 1 + return len(lines) + + +def FindNextMultiLineCommentEnd(lines, lineix): + """We are inside a comment, find the end marker.""" + while lineix < len(lines): + if lines[lineix].strip().endswith('*/'): + return lineix + lineix += 1 + return len(lines) + + +def RemoveMultiLineCommentsFromRange(lines, begin, end): + """Clears a range of lines for multi-line comments.""" + # Having // comments makes the lines non-empty, so we will not get + # unnecessary blank line warnings later in the code. + for i in range(begin, end): + lines[i] = '/**/' + + +def RemoveMultiLineComments(filename, lines, error): + """Removes multiline (c-style) comments from lines.""" + lineix = 0 + while lineix < len(lines): + lineix_begin = FindNextMultiLineCommentStart(lines, lineix) + if lineix_begin >= len(lines): + return + lineix_end = FindNextMultiLineCommentEnd(lines, lineix_begin) + if lineix_end >= len(lines): + error(filename, lineix_begin + 1, 'readability/multiline_comment', 5, + 'Could not find end of multi-line comment') + return + RemoveMultiLineCommentsFromRange(lines, lineix_begin, lineix_end + 1) + lineix = lineix_end + 1 + + +def CleanseComments(line): + """Removes //-comments and single-line C-style /* */ comments. + + Args: + line: A line of C++ source. + + Returns: + The line with single-line comments removed. + """ + commentpos = line.find('//') + if commentpos != -1 and not IsCppString(line[:commentpos]): + line = line[:commentpos].rstrip() + # get rid of /* ... */ + return _RE_PATTERN_CLEANSE_LINE_C_COMMENTS.sub('', line) + + +class CleansedLines(object): + """Holds 4 copies of all lines with different preprocessing applied to them. + + 1) elided member contains lines without strings and comments. + 2) lines member contains lines without comments. + 3) raw_lines member contains all the lines without processing. + 4) lines_without_raw_strings member is same as raw_lines, but with C++11 raw + strings removed. + All these members are of , and of the same length. + """ + + def __init__(self, lines): + self.elided = [] + self.lines = [] + self.raw_lines = lines + self.num_lines = len(lines) + self.lines_without_raw_strings = CleanseRawStrings(lines) + for linenum in range(len(self.lines_without_raw_strings)): + self.lines.append(CleanseComments( + self.lines_without_raw_strings[linenum])) + elided = self._CollapseStrings(self.lines_without_raw_strings[linenum]) + self.elided.append(CleanseComments(elided)) + + def NumLines(self): + """Returns the number of lines represented.""" + return self.num_lines + + @staticmethod + def _CollapseStrings(elided): + """Collapses strings and chars on a line to simple "" or '' blocks. + + We nix strings first so we're not fooled by text like '"http://"' + + Args: + elided: The line being processed. + + Returns: + The line with collapsed strings. + """ + if _RE_PATTERN_INCLUDE.match(elided): + return elided + + # Remove escaped characters first to make quote/single quote collapsing + # basic. Things that look like escaped characters shouldn't occur + # outside of strings and chars. + elided = _RE_PATTERN_CLEANSE_LINE_ESCAPES.sub('', elided) + + # Replace quoted strings and digit separators. Both single quotes + # and double quotes are processed in the same loop, otherwise + # nested quotes wouldn't work. + collapsed = '' + while True: + # Find the first quote character + match = Match(r'^([^\'"]*)([\'"])(.*)$', elided) + if not match: + collapsed += elided + break + head, quote, tail = match.groups() + + if quote == '"': + # Collapse double quoted strings + second_quote = tail.find('"') + if second_quote >= 0: + collapsed += head + '""' + elided = tail[second_quote + 1:] + else: + # Unmatched double quote, don't bother processing the rest + # of the line since this is probably a multiline string. + collapsed += elided + break + else: + # Found single quote, check nearby text to eliminate digit separators. + # + # There is no special handling for floating point here, because + # the integer/fractional/exponent parts would all be parsed + # correctly as long as there are digits on both sides of the + # separator. So we are fine as long as we don't see something + # like "0.'3" (gcc 4.9.0 will not allow this literal). + if Search(r'\b(?:0[bBxX]?|[1-9])[0-9a-fA-F]*$', head): + match_literal = Match(r'^((?:\'?[0-9a-zA-Z_])*)(.*)$', "'" + tail) + collapsed += head + match_literal.group(1).replace("'", '') + elided = match_literal.group(2) + else: + second_quote = tail.find('\'') + if second_quote >= 0: + collapsed += head + "''" + elided = tail[second_quote + 1:] + else: + # Unmatched single quote + collapsed += elided + break + + return collapsed + + +def FindEndOfExpressionInLine(line, startpos, stack): + """Find the position just after the end of current parenthesized expression. + + Args: + line: a CleansedLines line. + startpos: start searching at this position. + stack: nesting stack at startpos. + + Returns: + On finding matching end: (index just after matching end, None) + On finding an unclosed expression: (-1, None) + Otherwise: (-1, new stack at end of this line) + """ + for i in xrange(startpos, len(line)): + char = line[i] + if char in '([{': + # Found start of parenthesized expression, push to expression stack + stack.append(char) + elif char == '<': + # Found potential start of template argument list + if i > 0 and line[i - 1] == '<': + # Left shift operator + if stack and stack[-1] == '<': + stack.pop() + if not stack: + return (-1, None) + elif i > 0 and Search(r'\boperator\s*$', line[0:i]): + # operator<, don't add to stack + continue + else: + # Tentative start of template argument list + stack.append('<') + elif char in ')]}': + # Found end of parenthesized expression. + # + # If we are currently expecting a matching '>', the pending '<' + # must have been an operator. Remove them from expression stack. + while stack and stack[-1] == '<': + stack.pop() + if not stack: + return (-1, None) + if ((stack[-1] == '(' and char == ')') or + (stack[-1] == '[' and char == ']') or + (stack[-1] == '{' and char == '}')): + stack.pop() + if not stack: + return (i + 1, None) + else: + # Mismatched parentheses + return (-1, None) + elif char == '>': + # Found potential end of template argument list. + + # Ignore "->" and operator functions + if (i > 0 and + (line[i - 1] == '-' or Search(r'\boperator\s*$', line[0:i - 1]))): + continue + + # Pop the stack if there is a matching '<'. Otherwise, ignore + # this '>' since it must be an operator. + if stack: + if stack[-1] == '<': + stack.pop() + if not stack: + return (i + 1, None) + elif char == ';': + # Found something that look like end of statements. If we are currently + # expecting a '>', the matching '<' must have been an operator, since + # template argument list should not contain statements. + while stack and stack[-1] == '<': + stack.pop() + if not stack: + return (-1, None) + + # Did not find end of expression or unbalanced parentheses on this line + return (-1, stack) + + +def CloseExpression(clean_lines, linenum, pos): + """If input points to ( or { or [ or <, finds the position that closes it. + + If lines[linenum][pos] points to a '(' or '{' or '[' or '<', finds the + linenum/pos that correspond to the closing of the expression. + + TODO(unknown): cpplint spends a fair bit of time matching parentheses. + Ideally we would want to index all opening and closing parentheses once + and have CloseExpression be just a simple lookup, but due to preprocessor + tricks, this is not so easy. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + pos: A position on the line. + + Returns: + A tuple (line, linenum, pos) pointer *past* the closing brace, or + (line, len(lines), -1) if we never find a close. Note we ignore + strings and comments when matching; and the line we return is the + 'cleansed' line at linenum. + """ + + line = clean_lines.elided[linenum] + if (line[pos] not in '({[<') or Match(r'<[<=]', line[pos:]): + return (line, clean_lines.NumLines(), -1) + + # Check first line + (end_pos, stack) = FindEndOfExpressionInLine(line, pos, []) + if end_pos > -1: + return (line, linenum, end_pos) + + # Continue scanning forward + while stack and linenum < clean_lines.NumLines() - 1: + linenum += 1 + line = clean_lines.elided[linenum] + (end_pos, stack) = FindEndOfExpressionInLine(line, 0, stack) + if end_pos > -1: + return (line, linenum, end_pos) + + # Did not find end of expression before end of file, give up + return (line, clean_lines.NumLines(), -1) + + +def FindStartOfExpressionInLine(line, endpos, stack): + """Find position at the matching start of current expression. + + This is almost the reverse of FindEndOfExpressionInLine, but note + that the input position and returned position differs by 1. + + Args: + line: a CleansedLines line. + endpos: start searching at this position. + stack: nesting stack at endpos. + + Returns: + On finding matching start: (index at matching start, None) + On finding an unclosed expression: (-1, None) + Otherwise: (-1, new stack at beginning of this line) + """ + i = endpos + while i >= 0: + char = line[i] + if char in ')]}': + # Found end of expression, push to expression stack + stack.append(char) + elif char == '>': + # Found potential end of template argument list. + # + # Ignore it if it's a "->" or ">=" or "operator>" + if (i > 0 and + (line[i - 1] == '-' or + Match(r'\s>=\s', line[i - 1:]) or + Search(r'\boperator\s*$', line[0:i]))): + i -= 1 + else: + stack.append('>') + elif char == '<': + # Found potential start of template argument list + if i > 0 and line[i - 1] == '<': + # Left shift operator + i -= 1 + else: + # If there is a matching '>', we can pop the expression stack. + # Otherwise, ignore this '<' since it must be an operator. + if stack and stack[-1] == '>': + stack.pop() + if not stack: + return (i, None) + elif char in '([{': + # Found start of expression. + # + # If there are any unmatched '>' on the stack, they must be + # operators. Remove those. + while stack and stack[-1] == '>': + stack.pop() + if not stack: + return (-1, None) + if ((char == '(' and stack[-1] == ')') or + (char == '[' and stack[-1] == ']') or + (char == '{' and stack[-1] == '}')): + stack.pop() + if not stack: + return (i, None) + else: + # Mismatched parentheses + return (-1, None) + elif char == ';': + # Found something that look like end of statements. If we are currently + # expecting a '<', the matching '>' must have been an operator, since + # template argument list should not contain statements. + while stack and stack[-1] == '>': + stack.pop() + if not stack: + return (-1, None) + + i -= 1 + + return (-1, stack) + + +def ReverseCloseExpression(clean_lines, linenum, pos): + """If input points to ) or } or ] or >, finds the position that opens it. + + If lines[linenum][pos] points to a ')' or '}' or ']' or '>', finds the + linenum/pos that correspond to the opening of the expression. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + pos: A position on the line. + + Returns: + A tuple (line, linenum, pos) pointer *at* the opening brace, or + (line, 0, -1) if we never find the matching opening brace. Note + we ignore strings and comments when matching; and the line we + return is the 'cleansed' line at linenum. + """ + line = clean_lines.elided[linenum] + if line[pos] not in ')}]>': + return (line, 0, -1) + + # Check last line + (start_pos, stack) = FindStartOfExpressionInLine(line, pos, []) + if start_pos > -1: + return (line, linenum, start_pos) + + # Continue scanning backward + while stack and linenum > 0: + linenum -= 1 + line = clean_lines.elided[linenum] + (start_pos, stack) = FindStartOfExpressionInLine(line, len(line) - 1, stack) + if start_pos > -1: + return (line, linenum, start_pos) + + # Did not find start of expression before beginning of file, give up + return (line, 0, -1) + + +def CheckForCopyright(filename, lines, error): + """Logs an error if no Copyright message appears at the top of the file.""" + + # We'll say it should occur by line 10. Don't forget there's a + # placeholder line at the front. + for line in xrange(1, min(len(lines), 11)): + if re.search(r'Copyright', lines[line], re.I): break + else: # means no copyright line was found + error(filename, 0, 'legal/copyright', 5, + 'No copyright message found. ' + 'You should have a line: "Copyright [year] "') + + +def GetIndentLevel(line): + """Return the number of leading spaces in line. + + Args: + line: A string to check. + + Returns: + An integer count of leading spaces, possibly zero. + """ + indent = Match(r'^( *)\S', line) + if indent: + return len(indent.group(1)) + else: + return 0 + +def PathSplitToList(path): + """Returns the path split into a list by the separator. + + Args: + path: An absolute or relative path (e.g. '/a/b/c/' or '../a') + + Returns: + A list of path components (e.g. ['a', 'b', 'c]). + """ + lst = [] + while True: + (head, tail) = os.path.split(path) + if head == path: # absolute paths end + lst.append(head) + break + if tail == path: # relative paths end + lst.append(tail) + break + + path = head + lst.append(tail) + + lst.reverse() + return lst + +def GetHeaderGuardCPPVariable(filename): + """Returns the CPP variable that should be used as a header guard. + + Args: + filename: The name of a C++ header file. + + Returns: + The CPP variable that should be used as a header guard in the + named file. + + """ + + # Restores original filename in case that cpplint is invoked from Emacs's + # flymake. + filename = re.sub(r'_flymake\.h$', '.h', filename) + filename = re.sub(r'/\.flymake/([^/]*)$', r'/\1', filename) + # Replace 'c++' with 'cpp'. + filename = filename.replace('C++', 'cpp').replace('c++', 'cpp') + + fileinfo = FileInfo(filename) + file_path_from_root = fileinfo.RepositoryName() + + def FixupPathFromRoot(): + if _root_debug: + sys.stderr.write("\n_root fixup, _root = '%s', repository name = '%s'\n" + %(_root, fileinfo.RepositoryName())) + + # Process the file path with the --root flag if it was set. + if not _root: + if _root_debug: + sys.stderr.write("_root unspecified\n") + return file_path_from_root + + def StripListPrefix(lst, prefix): + # f(['x', 'y'], ['w, z']) -> None (not a valid prefix) + if lst[:len(prefix)] != prefix: + return None + # f(['a, 'b', 'c', 'd'], ['a', 'b']) -> ['c', 'd'] + return lst[(len(prefix)):] + + # root behavior: + # --root=subdir , lstrips subdir from the header guard + maybe_path = StripListPrefix(PathSplitToList(file_path_from_root), + PathSplitToList(_root)) + + if _root_debug: + sys.stderr.write(("_root lstrip (maybe_path=%s, file_path_from_root=%s," + + " _root=%s)\n") %(maybe_path, file_path_from_root, _root)) + + if maybe_path: + return os.path.join(*maybe_path) + + # --root=.. , will prepend the outer directory to the header guard + full_path = fileinfo.FullName() + root_abspath = os.path.abspath(_root) + + maybe_path = StripListPrefix(PathSplitToList(full_path), + PathSplitToList(root_abspath)) + + if _root_debug: + sys.stderr.write(("_root prepend (maybe_path=%s, full_path=%s, " + + "root_abspath=%s)\n") %(maybe_path, full_path, root_abspath)) + + if maybe_path: + return os.path.join(*maybe_path) + + if _root_debug: + sys.stderr.write("_root ignore, returning %s\n" %(file_path_from_root)) + + # --root=FAKE_DIR is ignored + return file_path_from_root + + file_path_from_root = FixupPathFromRoot() + return re.sub(r'[^a-zA-Z0-9]', '_', file_path_from_root).upper() + '_' + + +def CheckForHeaderGuard(filename, clean_lines, error): + """Checks that the file contains a header guard. + + Logs an error if no #ifndef header guard is present. For other + headers, checks that the full pathname is used. + + Args: + filename: The name of the C++ header file. + clean_lines: A CleansedLines instance containing the file. + error: The function to call with any errors found. + """ + + # Don't check for header guards if there are error suppression + # comments somewhere in this file. + # + # Because this is silencing a warning for a nonexistent line, we + # only support the very specific NOLINT(build/header_guard) syntax, + # and not the general NOLINT or NOLINT(*) syntax. + raw_lines = clean_lines.lines_without_raw_strings + for i in raw_lines: + if Search(r'//\s*NOLINT\(build/header_guard\)', i): + return + + cppvar = GetHeaderGuardCPPVariable(filename) + + ifndef = '' + ifndef_linenum = 0 + define = '' + endif = '' + endif_linenum = 0 + for linenum, line in enumerate(raw_lines): + linesplit = line.split() + if len(linesplit) >= 2: + # find the first occurrence of #ifndef and #define, save arg + if not ifndef and linesplit[0] == '#ifndef': + # set ifndef to the header guard presented on the #ifndef line. + ifndef = linesplit[1] + ifndef_linenum = linenum + if not define and linesplit[0] == '#define': + define = linesplit[1] + # find the last occurrence of #endif, save entire line + if line.startswith('#endif'): + endif = line + endif_linenum = linenum + + if not ifndef or not define or ifndef != define: + error(filename, 0, 'build/header_guard', 5, + 'No #ifndef header guard found, suggested CPP variable is: %s' % + cppvar) + return + + # The guard should be PATH_FILE_H_, but we also allow PATH_FILE_H__ + # for backward compatibility. + if ifndef != cppvar: + error_level = 0 + if ifndef != cppvar + '_': + error_level = 5 + + ParseNolintSuppressions(filename, raw_lines[ifndef_linenum], ifndef_linenum, + error) + error(filename, ifndef_linenum, 'build/header_guard', error_level, + '#ifndef header guard has wrong style, please use: %s' % cppvar) + + # Check for "//" comments on endif line. + ParseNolintSuppressions(filename, raw_lines[endif_linenum], endif_linenum, + error) + match = Match(r'#endif\s*//\s*' + cppvar + r'(_)?\b', endif) + if match: + if match.group(1) == '_': + # Issue low severity warning for deprecated double trailing underscore + error(filename, endif_linenum, 'build/header_guard', 0, + '#endif line should be "#endif // %s"' % cppvar) + return + + # Didn't find the corresponding "//" comment. If this file does not + # contain any "//" comments at all, it could be that the compiler + # only wants "/**/" comments, look for those instead. + no_single_line_comments = True + for i in xrange(1, len(raw_lines) - 1): + line = raw_lines[i] + if Match(r'^(?:(?:\'(?:\.|[^\'])*\')|(?:"(?:\.|[^"])*")|[^\'"])*//', line): + no_single_line_comments = False + break + + if no_single_line_comments: + match = Match(r'#endif\s*/\*\s*' + cppvar + r'(_)?\s*\*/', endif) + if match: + if match.group(1) == '_': + # Low severity warning for double trailing underscore + error(filename, endif_linenum, 'build/header_guard', 0, + '#endif line should be "#endif /* %s */"' % cppvar) + return + + # Didn't find anything + error(filename, endif_linenum, 'build/header_guard', 5, + '#endif line should be "#endif // %s"' % cppvar) + + +def CheckHeaderFileIncluded(filename, include_state, error): + """Logs an error if a .cc file does not include its header.""" + + # Do not check test files + fileinfo = FileInfo(filename) + if Search(_TEST_FILE_SUFFIX, fileinfo.BaseName()): + return + + headerfile = filename[0:len(filename) - len(fileinfo.Extension())] + '.h' + if not os.path.exists(headerfile): + return + headername = FileInfo(headerfile).RepositoryName() + first_include = 0 + for section_list in include_state.include_list: + for f in section_list: + if headername in f[0] or f[0] in headername: + return + if not first_include: + first_include = f[1] + + error(filename, first_include, 'build/include', 5, + '%s should include its header file %s' % (fileinfo.RepositoryName(), + headername)) + + +def CheckForBadCharacters(filename, lines, error): + """Logs an error for each line containing bad characters. + + Two kinds of bad characters: + + 1. Unicode replacement characters: These indicate that either the file + contained invalid UTF-8 (likely) or Unicode replacement characters (which + it shouldn't). Note that it's possible for this to throw off line + numbering if the invalid UTF-8 occurred adjacent to a newline. + + 2. NUL bytes. These are problematic for some tools. + + Args: + filename: The name of the current file. + lines: An array of strings, each representing a line of the file. + error: The function to call with any errors found. + """ + for linenum, line in enumerate(lines): + if u'\ufffd' in line: + error(filename, linenum, 'readability/utf8', 5, + 'Line contains invalid UTF-8 (or Unicode replacement character).') + if '\0' in line: + error(filename, linenum, 'readability/nul', 5, 'Line contains NUL byte.') + + +def CheckForNewlineAtEOF(filename, lines, error): + """Logs an error if there is no newline char at the end of the file. + + Args: + filename: The name of the current file. + lines: An array of strings, each representing a line of the file. + error: The function to call with any errors found. + """ + + # The array lines() was created by adding two newlines to the + # original file (go figure), then splitting on \n. + # To verify that the file ends in \n, we just have to make sure the + # last-but-two element of lines() exists and is empty. + if len(lines) < 3 or lines[-2]: + error(filename, len(lines) - 2, 'whitespace/ending_newline', 5, + 'Could not find a newline character at the end of the file.') + + +def CheckForMultilineCommentsAndStrings(filename, clean_lines, linenum, error): + """Logs an error if we see /* ... */ or "..." that extend past one line. + + /* ... */ comments are legit inside macros, for one line. + Otherwise, we prefer // comments, so it's ok to warn about the + other. Likewise, it's ok for strings to extend across multiple + lines, as long as a line continuation character (backslash) + terminates each line. Although not currently prohibited by the C++ + style guide, it's ugly and unnecessary. We don't do well with either + in this lint program, so we warn about both. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Remove all \\ (escaped backslashes) from the line. They are OK, and the + # second (escaped) slash may trigger later \" detection erroneously. + line = line.replace('\\\\', '') + + if line.count('/*') > line.count('*/'): + error(filename, linenum, 'readability/multiline_comment', 5, + 'Complex multi-line /*...*/-style comment found. ' + 'Lint may give bogus warnings. ' + 'Consider replacing these with //-style comments, ' + 'with #if 0...#endif, ' + 'or with more clearly structured multi-line comments.') + + if (line.count('"') - line.count('\\"')) % 2: + error(filename, linenum, 'readability/multiline_string', 5, + 'Multi-line string ("...") found. This lint script doesn\'t ' + 'do well with such strings, and may give bogus warnings. ' + 'Use C++11 raw strings or concatenation instead.') + + +# (non-threadsafe name, thread-safe alternative, validation pattern) +# +# The validation pattern is used to eliminate false positives such as: +# _rand(); // false positive due to substring match. +# ->rand(); // some member function rand(). +# ACMRandom rand(seed); // some variable named rand. +# ISAACRandom rand(); // another variable named rand. +# +# Basically we require the return value of these functions to be used +# in some expression context on the same line by matching on some +# operator before the function name. This eliminates constructors and +# member function calls. +_UNSAFE_FUNC_PREFIX = r'(?:[-+*/=%^&|(<]\s*|>\s+)' +_THREADING_LIST = ( + ('asctime(', 'asctime_r(', _UNSAFE_FUNC_PREFIX + r'asctime\([^)]+\)'), + ('ctime(', 'ctime_r(', _UNSAFE_FUNC_PREFIX + r'ctime\([^)]+\)'), + ('getgrgid(', 'getgrgid_r(', _UNSAFE_FUNC_PREFIX + r'getgrgid\([^)]+\)'), + ('getgrnam(', 'getgrnam_r(', _UNSAFE_FUNC_PREFIX + r'getgrnam\([^)]+\)'), + ('getlogin(', 'getlogin_r(', _UNSAFE_FUNC_PREFIX + r'getlogin\(\)'), + ('getpwnam(', 'getpwnam_r(', _UNSAFE_FUNC_PREFIX + r'getpwnam\([^)]+\)'), + ('getpwuid(', 'getpwuid_r(', _UNSAFE_FUNC_PREFIX + r'getpwuid\([^)]+\)'), + ('gmtime(', 'gmtime_r(', _UNSAFE_FUNC_PREFIX + r'gmtime\([^)]+\)'), + ('localtime(', 'localtime_r(', _UNSAFE_FUNC_PREFIX + r'localtime\([^)]+\)'), + ('rand(', 'rand_r(', _UNSAFE_FUNC_PREFIX + r'rand\(\)'), + ('strtok(', 'strtok_r(', + _UNSAFE_FUNC_PREFIX + r'strtok\([^)]+\)'), + ('ttyname(', 'ttyname_r(', _UNSAFE_FUNC_PREFIX + r'ttyname\([^)]+\)'), + ) + + +def CheckPosixThreading(filename, clean_lines, linenum, error): + """Checks for calls to thread-unsafe functions. + + Much code has been originally written without consideration of + multi-threading. Also, engineers are relying on their old experience; + they have learned posix before threading extensions were added. These + tests guide the engineers to use thread-safe functions (when using + posix directly). + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + for single_thread_func, multithread_safe_func, pattern in _THREADING_LIST: + # Additional pattern matching check to confirm that this is the + # function we are looking for + if Search(pattern, line): + error(filename, linenum, 'runtime/threadsafe_fn', 2, + 'Consider using ' + multithread_safe_func + + '...) instead of ' + single_thread_func + + '...) for improved thread safety.') + + +def CheckVlogArguments(filename, clean_lines, linenum, error): + """Checks that VLOG() is only used for defining a logging level. + + For example, VLOG(2) is correct. VLOG(INFO), VLOG(WARNING), VLOG(ERROR), and + VLOG(FATAL) are not. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + if Search(r'\bVLOG\((INFO|ERROR|WARNING|DFATAL|FATAL)\)', line): + error(filename, linenum, 'runtime/vlog', 5, + 'VLOG() should be used with numeric verbosity level. ' + 'Use LOG() if you want symbolic severity levels.') + +# Matches invalid increment: *count++, which moves pointer instead of +# incrementing a value. +_RE_PATTERN_INVALID_INCREMENT = re.compile( + r'^\s*\*\w+(\+\+|--);') + + +def CheckInvalidIncrement(filename, clean_lines, linenum, error): + """Checks for invalid increment *count++. + + For example following function: + void increment_counter(int* count) { + *count++; + } + is invalid, because it effectively does count++, moving pointer, and should + be replaced with ++*count, (*count)++ or *count += 1. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + if _RE_PATTERN_INVALID_INCREMENT.match(line): + error(filename, linenum, 'runtime/invalid_increment', 5, + 'Changing pointer instead of value (or unused value of operator*).') + + +def IsMacroDefinition(clean_lines, linenum): + if Search(r'^#define', clean_lines[linenum]): + return True + + if linenum > 0 and Search(r'\\$', clean_lines[linenum - 1]): + return True + + return False + + +def IsForwardClassDeclaration(clean_lines, linenum): + return Match(r'^\s*(\btemplate\b)*.*class\s+\w+;\s*$', clean_lines[linenum]) + + +class _BlockInfo(object): + """Stores information about a generic block of code.""" + + def __init__(self, linenum, seen_open_brace): + self.starting_linenum = linenum + self.seen_open_brace = seen_open_brace + self.open_parentheses = 0 + self.inline_asm = _NO_ASM + self.check_namespace_indentation = False + + def CheckBegin(self, filename, clean_lines, linenum, error): + """Run checks that applies to text up to the opening brace. + + This is mostly for checking the text after the class identifier + and the "{", usually where the base class is specified. For other + blocks, there isn't much to check, so we always pass. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + pass + + def CheckEnd(self, filename, clean_lines, linenum, error): + """Run checks that applies to text after the closing brace. + + This is mostly used for checking end of namespace comments. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + pass + + def IsBlockInfo(self): + """Returns true if this block is a _BlockInfo. + + This is convenient for verifying that an object is an instance of + a _BlockInfo, but not an instance of any of the derived classes. + + Returns: + True for this class, False for derived classes. + """ + return self.__class__ == _BlockInfo + + +class _ExternCInfo(_BlockInfo): + """Stores information about an 'extern "C"' block.""" + + def __init__(self, linenum): + _BlockInfo.__init__(self, linenum, True) + + +class _ClassInfo(_BlockInfo): + """Stores information about a class.""" + + def __init__(self, name, class_or_struct, clean_lines, linenum): + _BlockInfo.__init__(self, linenum, False) + self.name = name + self.is_derived = False + self.check_namespace_indentation = True + if class_or_struct == 'struct': + self.access = 'public' + self.is_struct = True + else: + self.access = 'private' + self.is_struct = False + + # Remember initial indentation level for this class. Using raw_lines here + # instead of elided to account for leading comments. + self.class_indent = GetIndentLevel(clean_lines.raw_lines[linenum]) + + # Try to find the end of the class. This will be confused by things like: + # class A { + # } *x = { ... + # + # But it's still good enough for CheckSectionSpacing. + self.last_line = 0 + depth = 0 + for i in range(linenum, clean_lines.NumLines()): + line = clean_lines.elided[i] + depth += line.count('{') - line.count('}') + if not depth: + self.last_line = i + break + + def CheckBegin(self, filename, clean_lines, linenum, error): + # Look for a bare ':' + if Search('(^|[^:]):($|[^:])', clean_lines.elided[linenum]): + self.is_derived = True + + def CheckEnd(self, filename, clean_lines, linenum, error): + # If there is a DISALLOW macro, it should appear near the end of + # the class. + seen_last_thing_in_class = False + for i in xrange(linenum - 1, self.starting_linenum, -1): + match = Search( + r'\b(DISALLOW_COPY_AND_ASSIGN|DISALLOW_IMPLICIT_CONSTRUCTORS)\(' + + self.name + r'\)', + clean_lines.elided[i]) + if match: + if seen_last_thing_in_class: + error(filename, i, 'readability/constructors', 3, + match.group(1) + ' should be the last thing in the class') + break + + if not Match(r'^\s*$', clean_lines.elided[i]): + seen_last_thing_in_class = True + + # Check that closing brace is aligned with beginning of the class. + # Only do this if the closing brace is indented by only whitespaces. + # This means we will not check single-line class definitions. + indent = Match(r'^( *)\}', clean_lines.elided[linenum]) + if indent and len(indent.group(1)) != self.class_indent: + if self.is_struct: + parent = 'struct ' + self.name + else: + parent = 'class ' + self.name + error(filename, linenum, 'whitespace/indent', 3, + 'Closing brace should be aligned with beginning of %s' % parent) + + +class _NamespaceInfo(_BlockInfo): + """Stores information about a namespace.""" + + def __init__(self, name, linenum): + _BlockInfo.__init__(self, linenum, False) + self.name = name or '' + self.check_namespace_indentation = True + + def CheckEnd(self, filename, clean_lines, linenum, error): + """Check end of namespace comments.""" + line = clean_lines.raw_lines[linenum] + + # Check how many lines is enclosed in this namespace. Don't issue + # warning for missing namespace comments if there aren't enough + # lines. However, do apply checks if there is already an end of + # namespace comment and it's incorrect. + # + # TODO(unknown): We always want to check end of namespace comments + # if a namespace is large, but sometimes we also want to apply the + # check if a short namespace contained nontrivial things (something + # other than forward declarations). There is currently no logic on + # deciding what these nontrivial things are, so this check is + # triggered by namespace size only, which works most of the time. + if (linenum - self.starting_linenum < 10 + and not Match(r'^\s*};*\s*(//|/\*).*\bnamespace\b', line)): + return + + # Look for matching comment at end of namespace. + # + # Note that we accept C style "/* */" comments for terminating + # namespaces, so that code that terminate namespaces inside + # preprocessor macros can be cpplint clean. + # + # We also accept stuff like "// end of namespace ." with the + # period at the end. + # + # Besides these, we don't accept anything else, otherwise we might + # get false negatives when existing comment is a substring of the + # expected namespace. + if self.name: + # Named namespace + if not Match((r'^\s*};*\s*(//|/\*).*\bnamespace\s+' + + re.escape(self.name) + r'[\*/\.\\\s]*$'), + line): + error(filename, linenum, 'readability/namespace', 5, + 'Namespace should be terminated with "// namespace %s"' % + self.name) + else: + # Anonymous namespace + if not Match(r'^\s*};*\s*(//|/\*).*\bnamespace[\*/\.\\\s]*$', line): + # If "// namespace anonymous" or "// anonymous namespace (more text)", + # mention "// anonymous namespace" as an acceptable form + if Match(r'^\s*}.*\b(namespace anonymous|anonymous namespace)\b', line): + error(filename, linenum, 'readability/namespace', 5, + 'Anonymous namespace should be terminated with "// namespace"' + ' or "// anonymous namespace"') + else: + error(filename, linenum, 'readability/namespace', 5, + 'Anonymous namespace should be terminated with "// namespace"') + + +class _PreprocessorInfo(object): + """Stores checkpoints of nesting stacks when #if/#else is seen.""" + + def __init__(self, stack_before_if): + # The entire nesting stack before #if + self.stack_before_if = stack_before_if + + # The entire nesting stack up to #else + self.stack_before_else = [] + + # Whether we have already seen #else or #elif + self.seen_else = False + + +class NestingState(object): + """Holds states related to parsing braces.""" + + def __init__(self): + # Stack for tracking all braces. An object is pushed whenever we + # see a "{", and popped when we see a "}". Only 3 types of + # objects are possible: + # - _ClassInfo: a class or struct. + # - _NamespaceInfo: a namespace. + # - _BlockInfo: some other type of block. + self.stack = [] + + # Top of the previous stack before each Update(). + # + # Because the nesting_stack is updated at the end of each line, we + # had to do some convoluted checks to find out what is the current + # scope at the beginning of the line. This check is simplified by + # saving the previous top of nesting stack. + # + # We could save the full stack, but we only need the top. Copying + # the full nesting stack would slow down cpplint by ~10%. + self.previous_stack_top = [] + + # Stack of _PreprocessorInfo objects. + self.pp_stack = [] + + def SeenOpenBrace(self): + """Check if we have seen the opening brace for the innermost block. + + Returns: + True if we have seen the opening brace, False if the innermost + block is still expecting an opening brace. + """ + return (not self.stack) or self.stack[-1].seen_open_brace + + def InNamespaceBody(self): + """Check if we are currently one level inside a namespace body. + + Returns: + True if top of the stack is a namespace block, False otherwise. + """ + return self.stack and isinstance(self.stack[-1], _NamespaceInfo) + + def InExternC(self): + """Check if we are currently one level inside an 'extern "C"' block. + + Returns: + True if top of the stack is an extern block, False otherwise. + """ + return self.stack and isinstance(self.stack[-1], _ExternCInfo) + + def InClassDeclaration(self): + """Check if we are currently one level inside a class or struct declaration. + + Returns: + True if top of the stack is a class/struct, False otherwise. + """ + return self.stack and isinstance(self.stack[-1], _ClassInfo) + + def InAsmBlock(self): + """Check if we are currently one level inside an inline ASM block. + + Returns: + True if the top of the stack is a block containing inline ASM. + """ + return self.stack and self.stack[-1].inline_asm != _NO_ASM + + def InTemplateArgumentList(self, clean_lines, linenum, pos): + """Check if current position is inside template argument list. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + pos: position just after the suspected template argument. + Returns: + True if (linenum, pos) is inside template arguments. + """ + while linenum < clean_lines.NumLines(): + # Find the earliest character that might indicate a template argument + line = clean_lines.elided[linenum] + match = Match(r'^[^{};=\[\]\.<>]*(.)', line[pos:]) + if not match: + linenum += 1 + pos = 0 + continue + token = match.group(1) + pos += len(match.group(0)) + + # These things do not look like template argument list: + # class Suspect { + # class Suspect x; } + if token in ('{', '}', ';'): return False + + # These things look like template argument list: + # template + # template + # template + # template + if token in ('>', '=', '[', ']', '.'): return True + + # Check if token is an unmatched '<'. + # If not, move on to the next character. + if token != '<': + pos += 1 + if pos >= len(line): + linenum += 1 + pos = 0 + continue + + # We can't be sure if we just find a single '<', and need to + # find the matching '>'. + (_, end_line, end_pos) = CloseExpression(clean_lines, linenum, pos - 1) + if end_pos < 0: + # Not sure if template argument list or syntax error in file + return False + linenum = end_line + pos = end_pos + return False + + def UpdatePreprocessor(self, line): + """Update preprocessor stack. + + We need to handle preprocessors due to classes like this: + #ifdef SWIG + struct ResultDetailsPageElementExtensionPoint { + #else + struct ResultDetailsPageElementExtensionPoint : public Extension { + #endif + + We make the following assumptions (good enough for most files): + - Preprocessor condition evaluates to true from #if up to first + #else/#elif/#endif. + + - Preprocessor condition evaluates to false from #else/#elif up + to #endif. We still perform lint checks on these lines, but + these do not affect nesting stack. + + Args: + line: current line to check. + """ + if Match(r'^\s*#\s*(if|ifdef|ifndef)\b', line): + # Beginning of #if block, save the nesting stack here. The saved + # stack will allow us to restore the parsing state in the #else case. + self.pp_stack.append(_PreprocessorInfo(copy.deepcopy(self.stack))) + elif Match(r'^\s*#\s*(else|elif)\b', line): + # Beginning of #else block + if self.pp_stack: + if not self.pp_stack[-1].seen_else: + # This is the first #else or #elif block. Remember the + # whole nesting stack up to this point. This is what we + # keep after the #endif. + self.pp_stack[-1].seen_else = True + self.pp_stack[-1].stack_before_else = copy.deepcopy(self.stack) + + # Restore the stack to how it was before the #if + self.stack = copy.deepcopy(self.pp_stack[-1].stack_before_if) + else: + # TODO(unknown): unexpected #else, issue warning? + pass + elif Match(r'^\s*#\s*endif\b', line): + # End of #if or #else blocks. + if self.pp_stack: + # If we saw an #else, we will need to restore the nesting + # stack to its former state before the #else, otherwise we + # will just continue from where we left off. + if self.pp_stack[-1].seen_else: + # Here we can just use a shallow copy since we are the last + # reference to it. + self.stack = self.pp_stack[-1].stack_before_else + # Drop the corresponding #if + self.pp_stack.pop() + else: + # TODO(unknown): unexpected #endif, issue warning? + pass + + # TODO(unknown): Update() is too long, but we will refactor later. + def Update(self, filename, clean_lines, linenum, error): + """Update nesting state with current line. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Remember top of the previous nesting stack. + # + # The stack is always pushed/popped and not modified in place, so + # we can just do a shallow copy instead of copy.deepcopy. Using + # deepcopy would slow down cpplint by ~28%. + if self.stack: + self.previous_stack_top = self.stack[-1] + else: + self.previous_stack_top = None + + # Update pp_stack + self.UpdatePreprocessor(line) + + # Count parentheses. This is to avoid adding struct arguments to + # the nesting stack. + if self.stack: + inner_block = self.stack[-1] + depth_change = line.count('(') - line.count(')') + inner_block.open_parentheses += depth_change + + # Also check if we are starting or ending an inline assembly block. + if inner_block.inline_asm in (_NO_ASM, _END_ASM): + if (depth_change != 0 and + inner_block.open_parentheses == 1 and + _MATCH_ASM.match(line)): + # Enter assembly block + inner_block.inline_asm = _INSIDE_ASM + else: + # Not entering assembly block. If previous line was _END_ASM, + # we will now shift to _NO_ASM state. + inner_block.inline_asm = _NO_ASM + elif (inner_block.inline_asm == _INSIDE_ASM and + inner_block.open_parentheses == 0): + # Exit assembly block + inner_block.inline_asm = _END_ASM + + # Consume namespace declaration at the beginning of the line. Do + # this in a loop so that we catch same line declarations like this: + # namespace proto2 { namespace bridge { class MessageSet; } } + while True: + # Match start of namespace. The "\b\s*" below catches namespace + # declarations even if it weren't followed by a whitespace, this + # is so that we don't confuse our namespace checker. The + # missing spaces will be flagged by CheckSpacing. + namespace_decl_match = Match(r'^\s*namespace\b\s*([:\w]+)?(.*)$', line) + if not namespace_decl_match: + break + + new_namespace = _NamespaceInfo(namespace_decl_match.group(1), linenum) + self.stack.append(new_namespace) + + line = namespace_decl_match.group(2) + if line.find('{') != -1: + new_namespace.seen_open_brace = True + line = line[line.find('{') + 1:] + + # Look for a class declaration in whatever is left of the line + # after parsing namespaces. The regexp accounts for decorated classes + # such as in: + # class LOCKABLE API Object { + # }; + class_decl_match = Match( + r'^(\s*(?:template\s*<[\w\s<>,:]*>\s*)?' + r'(class|struct)\s+(?:[A-Z_]+\s+)*(\w+(?:::\w+)*))' + r'(.*)$', line) + if (class_decl_match and + (not self.stack or self.stack[-1].open_parentheses == 0)): + # We do not want to accept classes that are actually template arguments: + # template , + # template class Ignore3> + # void Function() {}; + # + # To avoid template argument cases, we scan forward and look for + # an unmatched '>'. If we see one, assume we are inside a + # template argument list. + end_declaration = len(class_decl_match.group(1)) + if not self.InTemplateArgumentList(clean_lines, linenum, end_declaration): + self.stack.append(_ClassInfo( + class_decl_match.group(3), class_decl_match.group(2), + clean_lines, linenum)) + line = class_decl_match.group(4) + + # If we have not yet seen the opening brace for the innermost block, + # run checks here. + if not self.SeenOpenBrace(): + self.stack[-1].CheckBegin(filename, clean_lines, linenum, error) + + # Update access control if we are inside a class/struct + if self.stack and isinstance(self.stack[-1], _ClassInfo): + classinfo = self.stack[-1] + access_match = Match( + r'^(.*)\b(public|private|protected|signals)(\s+(?:slots\s*)?)?' + r':(?:[^:]|$)', + line) + if access_match: + classinfo.access = access_match.group(2) + + # Check that access keywords are indented +1 space. Skip this + # check if the keywords are not preceded by whitespaces. + indent = access_match.group(1) + if (len(indent) != classinfo.class_indent + 1 and + Match(r'^\s*$', indent)): + if classinfo.is_struct: + parent = 'struct ' + classinfo.name + else: + parent = 'class ' + classinfo.name + slots = '' + if access_match.group(3): + slots = access_match.group(3) + error(filename, linenum, 'whitespace/indent', 3, + '%s%s: should be indented +1 space inside %s' % ( + access_match.group(2), slots, parent)) + + # Consume braces or semicolons from what's left of the line + while True: + # Match first brace, semicolon, or closed parenthesis. + matched = Match(r'^[^{;)}]*([{;)}])(.*)$', line) + if not matched: + break + + token = matched.group(1) + if token == '{': + # If namespace or class hasn't seen a opening brace yet, mark + # namespace/class head as complete. Push a new block onto the + # stack otherwise. + if not self.SeenOpenBrace(): + self.stack[-1].seen_open_brace = True + elif Match(r'^extern\s*"[^"]*"\s*\{', line): + self.stack.append(_ExternCInfo(linenum)) + else: + self.stack.append(_BlockInfo(linenum, True)) + if _MATCH_ASM.match(line): + self.stack[-1].inline_asm = _BLOCK_ASM + + elif token == ';' or token == ')': + # If we haven't seen an opening brace yet, but we already saw + # a semicolon, this is probably a forward declaration. Pop + # the stack for these. + # + # Similarly, if we haven't seen an opening brace yet, but we + # already saw a closing parenthesis, then these are probably + # function arguments with extra "class" or "struct" keywords. + # Also pop these stack for these. + if not self.SeenOpenBrace(): + self.stack.pop() + else: # token == '}' + # Perform end of block checks and pop the stack. + if self.stack: + self.stack[-1].CheckEnd(filename, clean_lines, linenum, error) + self.stack.pop() + line = matched.group(2) + + def InnermostClass(self): + """Get class info on the top of the stack. + + Returns: + A _ClassInfo object if we are inside a class, or None otherwise. + """ + for i in range(len(self.stack), 0, -1): + classinfo = self.stack[i - 1] + if isinstance(classinfo, _ClassInfo): + return classinfo + return None + + def CheckCompletedBlocks(self, filename, error): + """Checks that all classes and namespaces have been completely parsed. + + Call this when all lines in a file have been processed. + Args: + filename: The name of the current file. + error: The function to call with any errors found. + """ + # Note: This test can result in false positives if #ifdef constructs + # get in the way of brace matching. See the testBuildClass test in + # cpplint_unittest.py for an example of this. + for obj in self.stack: + if isinstance(obj, _ClassInfo): + error(filename, obj.starting_linenum, 'build/class', 5, + 'Failed to find complete declaration of class %s' % + obj.name) + elif isinstance(obj, _NamespaceInfo): + error(filename, obj.starting_linenum, 'build/namespaces', 5, + 'Failed to find complete declaration of namespace %s' % + obj.name) + + +def CheckForNonStandardConstructs(filename, clean_lines, linenum, + nesting_state, error): + r"""Logs an error if we see certain non-ANSI constructs ignored by gcc-2. + + Complain about several constructs which gcc-2 accepts, but which are + not standard C++. Warning about these in lint is one way to ease the + transition to new compilers. + - put storage class first (e.g. "static const" instead of "const static"). + - "%lld" instead of %qd" in printf-type functions. + - "%1$d" is non-standard in printf-type functions. + - "\%" is an undefined character escape sequence. + - text after #endif is not allowed. + - invalid inner-style forward declaration. + - >? and ?= and )\?=?\s*(\w+|[+-]?\d+)(\.\d*)?', + line): + error(filename, linenum, 'build/deprecated', 3, + '>? and ))?' + # r'\s*const\s*' + type_name + '\s*&\s*\w+\s*;' + error(filename, linenum, 'runtime/member_string_references', 2, + 'const string& members are dangerous. It is much better to use ' + 'alternatives, such as pointers or simple constants.') + + # Everything else in this function operates on class declarations. + # Return early if the top of the nesting stack is not a class, or if + # the class head is not completed yet. + classinfo = nesting_state.InnermostClass() + if not classinfo or not classinfo.seen_open_brace: + return + + # The class may have been declared with namespace or classname qualifiers. + # The constructor and destructor will not have those qualifiers. + base_classname = classinfo.name.split('::')[-1] + + # Look for single-argument constructors that aren't marked explicit. + # Technically a valid construct, but against style. + explicit_constructor_match = Match( + r'\s+(?:(?:inline|constexpr)\s+)*(explicit\s+)?' + r'(?:(?:inline|constexpr)\s+)*%s\s*' + r'\(((?:[^()]|\([^()]*\))*)\)' + % re.escape(base_classname), + line) + + if explicit_constructor_match: + is_marked_explicit = explicit_constructor_match.group(1) + + if not explicit_constructor_match.group(2): + constructor_args = [] + else: + constructor_args = explicit_constructor_match.group(2).split(',') + + # collapse arguments so that commas in template parameter lists and function + # argument parameter lists don't split arguments in two + i = 0 + while i < len(constructor_args): + constructor_arg = constructor_args[i] + while (constructor_arg.count('<') > constructor_arg.count('>') or + constructor_arg.count('(') > constructor_arg.count(')')): + constructor_arg += ',' + constructor_args[i + 1] + del constructor_args[i + 1] + constructor_args[i] = constructor_arg + i += 1 + + defaulted_args = [arg for arg in constructor_args if '=' in arg] + noarg_constructor = (not constructor_args or # empty arg list + # 'void' arg specifier + (len(constructor_args) == 1 and + constructor_args[0].strip() == 'void')) + onearg_constructor = ((len(constructor_args) == 1 and # exactly one arg + not noarg_constructor) or + # all but at most one arg defaulted + (len(constructor_args) >= 1 and + not noarg_constructor and + len(defaulted_args) >= len(constructor_args) - 1)) + initializer_list_constructor = bool( + onearg_constructor and + Search(r'\bstd\s*::\s*initializer_list\b', constructor_args[0])) + copy_constructor = bool( + onearg_constructor and + Match(r'(const\s+)?%s(\s*<[^>]*>)?(\s+const)?\s*(?:<\w+>\s*)?&' + % re.escape(base_classname), constructor_args[0].strip())) + + if (not is_marked_explicit and + onearg_constructor and + not initializer_list_constructor and + not copy_constructor): + if defaulted_args: + error(filename, linenum, 'runtime/explicit', 5, + 'Constructors callable with one argument ' + 'should be marked explicit.') + else: + error(filename, linenum, 'runtime/explicit', 5, + 'Single-parameter constructors should be marked explicit.') + elif is_marked_explicit and not onearg_constructor: + if noarg_constructor: + error(filename, linenum, 'runtime/explicit', 5, + 'Zero-parameter constructors should not be marked explicit.') + + +def CheckSpacingForFunctionCall(filename, clean_lines, linenum, error): + """Checks for the correctness of various spacing around function calls. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Since function calls often occur inside if/for/while/switch + # expressions - which have their own, more liberal conventions - we + # first see if we should be looking inside such an expression for a + # function call, to which we can apply more strict standards. + fncall = line # if there's no control flow construct, look at whole line + for pattern in (r'\bif\s*\((.*)\)\s*{', + r'\bfor\s*\((.*)\)\s*{', + r'\bwhile\s*\((.*)\)\s*[{;]', + r'\bswitch\s*\((.*)\)\s*{'): + match = Search(pattern, line) + if match: + fncall = match.group(1) # look inside the parens for function calls + break + + # Except in if/for/while/switch, there should never be space + # immediately inside parens (eg "f( 3, 4 )"). We make an exception + # for nested parens ( (a+b) + c ). Likewise, there should never be + # a space before a ( when it's a function argument. I assume it's a + # function argument when the char before the whitespace is legal in + # a function name (alnum + _) and we're not starting a macro. Also ignore + # pointers and references to arrays and functions coz they're too tricky: + # we use a very simple way to recognize these: + # " (something)(maybe-something)" or + # " (something)(maybe-something," or + # " (something)[something]" + # Note that we assume the contents of [] to be short enough that + # they'll never need to wrap. + if ( # Ignore control structures. + not Search(r'\b(if|for|while|switch|return|new|delete|catch|sizeof)\b', + fncall) and + # Ignore pointers/references to functions. + not Search(r' \([^)]+\)\([^)]*(\)|,$)', fncall) and + # Ignore pointers/references to arrays. + not Search(r' \([^)]+\)\[[^\]]+\]', fncall)): + if Search(r'\w\s*\(\s(?!\s*\\$)', fncall): # a ( used for a fn call + error(filename, linenum, 'whitespace/parens', 4, + 'Extra space after ( in function call') + elif Search(r'\(\s+(?!(\s*\\)|\()', fncall): + error(filename, linenum, 'whitespace/parens', 2, + 'Extra space after (') + if (Search(r'\w\s+\(', fncall) and + not Search(r'_{0,2}asm_{0,2}\s+_{0,2}volatile_{0,2}\s+\(', fncall) and + not Search(r'#\s*define|typedef|using\s+\w+\s*=', fncall) and + not Search(r'\w\s+\((\w+::)*\*\w+\)\(', fncall) and + not Search(r'\bcase\s+\(', fncall)): + # TODO(unknown): Space after an operator function seem to be a common + # error, silence those for now by restricting them to highest verbosity. + if Search(r'\boperator_*\b', line): + error(filename, linenum, 'whitespace/parens', 0, + 'Extra space before ( in function call') + else: + error(filename, linenum, 'whitespace/parens', 4, + 'Extra space before ( in function call') + # If the ) is followed only by a newline or a { + newline, assume it's + # part of a control statement (if/while/etc), and don't complain + if Search(r'[^)]\s+\)\s*[^{\s]', fncall): + # If the closing parenthesis is preceded by only whitespaces, + # try to give a more descriptive error message. + if Search(r'^\s+\)', fncall): + error(filename, linenum, 'whitespace/parens', 2, + 'Closing ) should be moved to the previous line') + else: + error(filename, linenum, 'whitespace/parens', 2, + 'Extra space before )') + + +def IsBlankLine(line): + """Returns true if the given line is blank. + + We consider a line to be blank if the line is empty or consists of + only white spaces. + + Args: + line: A line of a string. + + Returns: + True, if the given line is blank. + """ + return not line or line.isspace() + + +def CheckForNamespaceIndentation(filename, nesting_state, clean_lines, line, + error): + is_namespace_indent_item = ( + len(nesting_state.stack) > 1 and + nesting_state.stack[-1].check_namespace_indentation and + isinstance(nesting_state.previous_stack_top, _NamespaceInfo) and + nesting_state.previous_stack_top == nesting_state.stack[-2]) + + if ShouldCheckNamespaceIndentation(nesting_state, is_namespace_indent_item, + clean_lines.elided, line): + CheckItemIndentationInNamespace(filename, clean_lines.elided, + line, error) + + +def CheckForFunctionLengths(filename, clean_lines, linenum, + function_state, error): + """Reports for long function bodies. + + For an overview why this is done, see: + https://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Write_Short_Functions + + Uses a simplistic algorithm assuming other style guidelines + (especially spacing) are followed. + Only checks unindented functions, so class members are unchecked. + Trivial bodies are unchecked, so constructors with huge initializer lists + may be missed. + Blank/comment lines are not counted so as to avoid encouraging the removal + of vertical space and comments just to get through a lint check. + NOLINT *on the last line of a function* disables this check. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + function_state: Current function name and lines in body so far. + error: The function to call with any errors found. + """ + lines = clean_lines.lines + line = lines[linenum] + joined_line = '' + + starting_func = False + regexp = r'(\w(\w|::|\*|\&|\s)*)\(' # decls * & space::name( ... + match_result = Match(regexp, line) + if match_result: + # If the name is all caps and underscores, figure it's a macro and + # ignore it, unless it's TEST or TEST_F. + function_name = match_result.group(1).split()[-1] + if function_name == 'TEST' or function_name == 'TEST_F' or ( + not Match(r'[A-Z_]+$', function_name)): + starting_func = True + + if starting_func: + body_found = False + for start_linenum in xrange(linenum, clean_lines.NumLines()): + start_line = lines[start_linenum] + joined_line += ' ' + start_line.lstrip() + if Search(r'(;|})', start_line): # Declarations and trivial functions + body_found = True + break # ... ignore + elif Search(r'{', start_line): + body_found = True + function = Search(r'((\w|:)*)\(', line).group(1) + if Match(r'TEST', function): # Handle TEST... macros + parameter_regexp = Search(r'(\(.*\))', joined_line) + if parameter_regexp: # Ignore bad syntax + function += parameter_regexp.group(1) + else: + function += '()' + function_state.Begin(function) + break + if not body_found: + # No body for the function (or evidence of a non-function) was found. + error(filename, linenum, 'readability/fn_size', 5, + 'Lint failed to find start of function body.') + elif Match(r'^\}\s*$', line): # function end + function_state.Check(error, filename, linenum) + function_state.End() + elif not Match(r'^\s*$', line): + function_state.Count() # Count non-blank/non-comment lines. + + +_RE_PATTERN_TODO = re.compile(r'^//(\s*)TODO(\(.+?\))?:?(\s|$)?') + + +def CheckComment(line, filename, linenum, next_line_start, error): + """Checks for common mistakes in comments. + + Args: + line: The line in question. + filename: The name of the current file. + linenum: The number of the line to check. + next_line_start: The first non-whitespace column of the next line. + error: The function to call with any errors found. + """ + commentpos = line.find('//') + if commentpos != -1: + # Check if the // may be in quotes. If so, ignore it + if re.sub(r'\\.', '', line[0:commentpos]).count('"') % 2 == 0: + # Allow one space for new scopes, two spaces otherwise: + if (not (Match(r'^.*{ *//', line) and next_line_start == commentpos) and + ((commentpos >= 1 and + line[commentpos-1] not in string.whitespace) or + (commentpos >= 2 and + line[commentpos-2] not in string.whitespace))): + error(filename, linenum, 'whitespace/comments', 2, + 'At least two spaces is best between code and comments') + + # Checks for common mistakes in TODO comments. + comment = line[commentpos:] + match = _RE_PATTERN_TODO.match(comment) + if match: + # One whitespace is correct; zero whitespace is handled elsewhere. + leading_whitespace = match.group(1) + if len(leading_whitespace) > 1: + error(filename, linenum, 'whitespace/todo', 2, + 'Too many spaces before TODO') + + username = match.group(2) + if not username: + error(filename, linenum, 'readability/todo', 2, + 'Missing username in TODO; it should look like ' + '"// TODO(my_username): Stuff."') + + middle_whitespace = match.group(3) + # Comparisons made explicit for correctness -- pylint: disable=g-explicit-bool-comparison + if middle_whitespace != ' ' and middle_whitespace != '': + error(filename, linenum, 'whitespace/todo', 2, + 'TODO(my_username) should be followed by a space') + + # If the comment contains an alphanumeric character, there + # should be a space somewhere between it and the // unless + # it's a /// or //! Doxygen comment. + if (Match(r'//[^ ]*\w', comment) and + not Match(r'(///|//\!)(\s+|$)', comment)): + error(filename, linenum, 'whitespace/comments', 4, + 'Should have a space between // and comment') + + +def CheckSpacing(filename, clean_lines, linenum, nesting_state, error): + """Checks for the correctness of various spacing issues in the code. + + Things we check for: spaces around operators, spaces after + if/for/while/switch, no spaces around parens in function calls, two + spaces between code and comment, don't start a block with a blank + line, don't end a function with a blank line, don't add a blank line + after public/protected/private, don't have too many blank lines in a row. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: The function to call with any errors found. + """ + + # Don't use "elided" lines here, otherwise we can't check commented lines. + # Don't want to use "raw" either, because we don't want to check inside C++11 + # raw strings, + raw = clean_lines.lines_without_raw_strings + line = raw[linenum] + + # Before nixing comments, check if the line is blank for no good + # reason. This includes the first line after a block is opened, and + # blank lines at the end of a function (ie, right before a line like '}' + # + # Skip all the blank line checks if we are immediately inside a + # namespace body. In other words, don't issue blank line warnings + # for this block: + # namespace { + # + # } + # + # A warning about missing end of namespace comments will be issued instead. + # + # Also skip blank line checks for 'extern "C"' blocks, which are formatted + # like namespaces. + if (IsBlankLine(line) and + not nesting_state.InNamespaceBody() and + not nesting_state.InExternC()): + elided = clean_lines.elided + prev_line = elided[linenum - 1] + prevbrace = prev_line.rfind('{') + # TODO(unknown): Don't complain if line before blank line, and line after, + # both start with alnums and are indented the same amount. + # This ignores whitespace at the start of a namespace block + # because those are not usually indented. + if prevbrace != -1 and prev_line[prevbrace:].find('}') == -1: + # OK, we have a blank line at the start of a code block. Before we + # complain, we check if it is an exception to the rule: The previous + # non-empty line has the parameters of a function header that are indented + # 4 spaces (because they did not fit in a 80 column line when placed on + # the same line as the function name). We also check for the case where + # the previous line is indented 6 spaces, which may happen when the + # initializers of a constructor do not fit into a 80 column line. + exception = False + if Match(r' {6}\w', prev_line): # Initializer list? + # We are looking for the opening column of initializer list, which + # should be indented 4 spaces to cause 6 space indentation afterwards. + search_position = linenum-2 + while (search_position >= 0 + and Match(r' {6}\w', elided[search_position])): + search_position -= 1 + exception = (search_position >= 0 + and elided[search_position][:5] == ' :') + else: + # Search for the function arguments or an initializer list. We use a + # simple heuristic here: If the line is indented 4 spaces; and we have a + # closing paren, without the opening paren, followed by an opening brace + # or colon (for initializer lists) we assume that it is the last line of + # a function header. If we have a colon indented 4 spaces, it is an + # initializer list. + exception = (Match(r' {4}\w[^\(]*\)\s*(const\s*)?(\{\s*$|:)', + prev_line) + or Match(r' {4}:', prev_line)) + + if not exception: + error(filename, linenum, 'whitespace/blank_line', 2, + 'Redundant blank line at the start of a code block ' + 'should be deleted.') + # Ignore blank lines at the end of a block in a long if-else + # chain, like this: + # if (condition1) { + # // Something followed by a blank line + # + # } else if (condition2) { + # // Something else + # } + if linenum + 1 < clean_lines.NumLines(): + next_line = raw[linenum + 1] + if (next_line + and Match(r'\s*}', next_line) + and next_line.find('} else ') == -1): + error(filename, linenum, 'whitespace/blank_line', 3, + 'Redundant blank line at the end of a code block ' + 'should be deleted.') + + matched = Match(r'\s*(public|protected|private):', prev_line) + if matched: + error(filename, linenum, 'whitespace/blank_line', 3, + 'Do not leave a blank line after "%s:"' % matched.group(1)) + + # Next, check comments + next_line_start = 0 + if linenum + 1 < clean_lines.NumLines(): + next_line = raw[linenum + 1] + next_line_start = len(next_line) - len(next_line.lstrip()) + CheckComment(line, filename, linenum, next_line_start, error) + + # get rid of comments and strings + line = clean_lines.elided[linenum] + + # You shouldn't have spaces before your brackets, except maybe after + # 'delete []', 'return []() {};', or 'auto [abc, ...] = ...;'. + if Search(r'\w\s+\[', line) and not Search(r'(?:auto&?|delete|return)\s+\[', line): + error(filename, linenum, 'whitespace/braces', 5, + 'Extra space before [') + + # In range-based for, we wanted spaces before and after the colon, but + # not around "::" tokens that might appear. + if (Search(r'for *\(.*[^:]:[^: ]', line) or + Search(r'for *\(.*[^: ]:[^:]', line)): + error(filename, linenum, 'whitespace/forcolon', 2, + 'Missing space around colon in range-based for loop') + + +def CheckOperatorSpacing(filename, clean_lines, linenum, error): + """Checks for horizontal spacing around operators. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Don't try to do spacing checks for operator methods. Do this by + # replacing the troublesome characters with something else, + # preserving column position for all other characters. + # + # The replacement is done repeatedly to avoid false positives from + # operators that call operators. + while True: + match = Match(r'^(.*\boperator\b)(\S+)(\s*\(.*)$', line) + if match: + line = match.group(1) + ('_' * len(match.group(2))) + match.group(3) + else: + break + + # We allow no-spaces around = within an if: "if ( (a=Foo()) == 0 )". + # Otherwise not. Note we only check for non-spaces on *both* sides; + # sometimes people put non-spaces on one side when aligning ='s among + # many lines (not that this is behavior that I approve of...) + if ((Search(r'[\w.]=', line) or + Search(r'=[\w.]', line)) + and not Search(r'\b(if|while|for) ', line) + # Operators taken from [lex.operators] in C++11 standard. + and not Search(r'(>=|<=|==|!=|&=|\^=|\|=|\+=|\*=|\/=|\%=)', line) + and not Search(r'operator=', line)): + error(filename, linenum, 'whitespace/operators', 4, + 'Missing spaces around =') + + # It's ok not to have spaces around binary operators like + - * /, but if + # there's too little whitespace, we get concerned. It's hard to tell, + # though, so we punt on this one for now. TODO. + + # You should always have whitespace around binary operators. + # + # Check <= and >= first to avoid false positives with < and >, then + # check non-include lines for spacing around < and >. + # + # If the operator is followed by a comma, assume it's be used in a + # macro context and don't do any checks. This avoids false + # positives. + # + # Note that && is not included here. This is because there are too + # many false positives due to RValue references. + match = Search(r'[^<>=!\s](==|!=|<=|>=|\|\|)[^<>=!\s,;\)]', line) + if match: + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around %s' % match.group(1)) + elif not Match(r'#.*include', line): + # Look for < that is not surrounded by spaces. This is only + # triggered if both sides are missing spaces, even though + # technically should should flag if at least one side is missing a + # space. This is done to avoid some false positives with shifts. + match = Match(r'^(.*[^\s<])<[^\s=<,]', line) + if match: + (_, _, end_pos) = CloseExpression( + clean_lines, linenum, len(match.group(1))) + if end_pos <= -1: + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around <') + + # Look for > that is not surrounded by spaces. Similar to the + # above, we only trigger if both sides are missing spaces to avoid + # false positives with shifts. + match = Match(r'^(.*[^-\s>])>[^\s=>,]', line) + if match: + (_, _, start_pos) = ReverseCloseExpression( + clean_lines, linenum, len(match.group(1))) + if start_pos <= -1: + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around >') + + # We allow no-spaces around << when used like this: 10<<20, but + # not otherwise (particularly, not when used as streams) + # + # We also allow operators following an opening parenthesis, since + # those tend to be macros that deal with operators. + match = Search(r'(operator|[^\s(<])(?:L|UL|LL|ULL|l|ul|ll|ull)?<<([^\s,=<])', line) + if (match and not (match.group(1).isdigit() and match.group(2).isdigit()) and + not (match.group(1) == 'operator' and match.group(2) == ';')): + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around <<') + + # We allow no-spaces around >> for almost anything. This is because + # C++11 allows ">>" to close nested templates, which accounts for + # most cases when ">>" is not followed by a space. + # + # We still warn on ">>" followed by alpha character, because that is + # likely due to ">>" being used for right shifts, e.g.: + # value >> alpha + # + # When ">>" is used to close templates, the alphanumeric letter that + # follows would be part of an identifier, and there should still be + # a space separating the template type and the identifier. + # type> alpha + match = Search(r'>>[a-zA-Z_]', line) + if match: + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around >>') + + # There shouldn't be space around unary operators + match = Search(r'(!\s|~\s|[\s]--[\s;]|[\s]\+\+[\s;])', line) + if match: + error(filename, linenum, 'whitespace/operators', 4, + 'Extra space for operator %s' % match.group(1)) + + +def CheckParenthesisSpacing(filename, clean_lines, linenum, error): + """Checks for horizontal spacing around parentheses. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # No spaces after an if, while, switch, or for + match = Search(r' (if\(|for\(|while\(|switch\()', line) + if match: + error(filename, linenum, 'whitespace/parens', 5, + 'Missing space before ( in %s' % match.group(1)) + + # For if/for/while/switch, the left and right parens should be + # consistent about how many spaces are inside the parens, and + # there should either be zero or one spaces inside the parens. + # We don't want: "if ( foo)" or "if ( foo )". + # Exception: "for ( ; foo; bar)" and "for (foo; bar; )" are allowed. + match = Search(r'\b(if|for|while|switch)\s*' + r'\(([ ]*)(.).*[^ ]+([ ]*)\)\s*{\s*$', + line) + if match: + if len(match.group(2)) != len(match.group(4)): + if not (match.group(3) == ';' and + len(match.group(2)) == 1 + len(match.group(4)) or + not match.group(2) and Search(r'\bfor\s*\(.*; \)', line)): + error(filename, linenum, 'whitespace/parens', 5, + 'Mismatching spaces inside () in %s' % match.group(1)) + if len(match.group(2)) not in [0, 1]: + error(filename, linenum, 'whitespace/parens', 5, + 'Should have zero or one spaces inside ( and ) in %s' % + match.group(1)) + + +def CheckCommaSpacing(filename, clean_lines, linenum, error): + """Checks for horizontal spacing near commas and semicolons. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + raw = clean_lines.lines_without_raw_strings + line = clean_lines.elided[linenum] + + # You should always have a space after a comma (either as fn arg or operator) + # + # This does not apply when the non-space character following the + # comma is another comma, since the only time when that happens is + # for empty macro arguments. + # + # We run this check in two passes: first pass on elided lines to + # verify that lines contain missing whitespaces, second pass on raw + # lines to confirm that those missing whitespaces are not due to + # elided comments. + if (Search(r',[^,\s]', ReplaceAll(r'\boperator\s*,\s*\(', 'F(', line)) and + Search(r',[^,\s]', raw[linenum])): + error(filename, linenum, 'whitespace/comma', 3, + 'Missing space after ,') + + # You should always have a space after a semicolon + # except for few corner cases + # TODO(unknown): clarify if 'if (1) { return 1;}' is requires one more + # space after ; + if Search(r';[^\s};\\)/]', line): + error(filename, linenum, 'whitespace/semicolon', 3, + 'Missing space after ;') + + +def _IsType(clean_lines, nesting_state, expr): + """Check if expression looks like a type name, returns true if so. + + Args: + clean_lines: A CleansedLines instance containing the file. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. + expr: The expression to check. + Returns: + True, if token looks like a type. + """ + # Keep only the last token in the expression + last_word = Match(r'^.*(\b\S+)$', expr) + if last_word: + token = last_word.group(1) + else: + token = expr + + # Match native types and stdint types + if _TYPES.match(token): + return True + + # Try a bit harder to match templated types. Walk up the nesting + # stack until we find something that resembles a typename + # declaration for what we are looking for. + typename_pattern = (r'\b(?:typename|class|struct)\s+' + re.escape(token) + + r'\b') + block_index = len(nesting_state.stack) - 1 + while block_index >= 0: + if isinstance(nesting_state.stack[block_index], _NamespaceInfo): + return False + + # Found where the opening brace is. We want to scan from this + # line up to the beginning of the function, minus a few lines. + # template + # class C + # : public ... { // start scanning here + last_line = nesting_state.stack[block_index].starting_linenum + + next_block_start = 0 + if block_index > 0: + next_block_start = nesting_state.stack[block_index - 1].starting_linenum + first_line = last_line + while first_line >= next_block_start: + if clean_lines.elided[first_line].find('template') >= 0: + break + first_line -= 1 + if first_line < next_block_start: + # Didn't find any "template" keyword before reaching the next block, + # there are probably no template things to check for this block + block_index -= 1 + continue + + # Look for typename in the specified range + for i in xrange(first_line, last_line + 1, 1): + if Search(typename_pattern, clean_lines.elided[i]): + return True + block_index -= 1 + + return False + + +def CheckBracesSpacing(filename, clean_lines, linenum, nesting_state, error): + """Checks for horizontal spacing near commas. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Except after an opening paren, or after another opening brace (in case of + # an initializer list, for instance), you should have spaces before your + # braces when they are delimiting blocks, classes, namespaces etc. + # And since you should never have braces at the beginning of a line, + # this is an easy test. Except that braces used for initialization don't + # follow the same rule; we often don't want spaces before those. + match = Match(r'^(.*[^ ({>]){', line) + + if match: + # Try a bit harder to check for brace initialization. This + # happens in one of the following forms: + # Constructor() : initializer_list_{} { ... } + # Constructor{}.MemberFunction() + # Type variable{}; + # FunctionCall(type{}, ...); + # LastArgument(..., type{}); + # LOG(INFO) << type{} << " ..."; + # map_of_type[{...}] = ...; + # ternary = expr ? new type{} : nullptr; + # OuterTemplate{}> + # + # We check for the character following the closing brace, and + # silence the warning if it's one of those listed above, i.e. + # "{.;,)<>]:". + # + # To account for nested initializer list, we allow any number of + # closing braces up to "{;,)<". We can't simply silence the + # warning on first sight of closing brace, because that would + # cause false negatives for things that are not initializer lists. + # Silence this: But not this: + # Outer{ if (...) { + # Inner{...} if (...){ // Missing space before { + # }; } + # + # There is a false negative with this approach if people inserted + # spurious semicolons, e.g. "if (cond){};", but we will catch the + # spurious semicolon with a separate check. + leading_text = match.group(1) + (endline, endlinenum, endpos) = CloseExpression( + clean_lines, linenum, len(match.group(1))) + trailing_text = '' + if endpos > -1: + trailing_text = endline[endpos:] + for offset in xrange(endlinenum + 1, + min(endlinenum + 3, clean_lines.NumLines() - 1)): + trailing_text += clean_lines.elided[offset] + # We also suppress warnings for `uint64_t{expression}` etc., as the style + # guide recommends brace initialization for integral types to avoid + # overflow/truncation. + if (not Match(r'^[\s}]*[{.;,)<>\]:]', trailing_text) + and not _IsType(clean_lines, nesting_state, leading_text)): + error(filename, linenum, 'whitespace/braces', 5, + 'Missing space before {') + + # Make sure '} else {' has spaces. + if Search(r'}else', line): + error(filename, linenum, 'whitespace/braces', 5, + 'Missing space before else') + + # You shouldn't have a space before a semicolon at the end of the line. + # There's a special case for "for" since the style guide allows space before + # the semicolon there. + if Search(r':\s*;\s*$', line): + error(filename, linenum, 'whitespace/semicolon', 5, + 'Semicolon defining empty statement. Use {} instead.') + elif Search(r'^\s*;\s*$', line): + error(filename, linenum, 'whitespace/semicolon', 5, + 'Line contains only semicolon. If this should be an empty statement, ' + 'use {} instead.') + elif (Search(r'\s+;\s*$', line) and + not Search(r'\bfor\b', line)): + error(filename, linenum, 'whitespace/semicolon', 5, + 'Extra space before last semicolon. If this should be an empty ' + 'statement, use {} instead.') + + +def IsDecltype(clean_lines, linenum, column): + """Check if the token ending on (linenum, column) is decltype(). + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: the number of the line to check. + column: end column of the token to check. + Returns: + True if this token is decltype() expression, False otherwise. + """ + (text, _, start_col) = ReverseCloseExpression(clean_lines, linenum, column) + if start_col < 0: + return False + if Search(r'\bdecltype\s*$', text[0:start_col]): + return True + return False + + +def CheckSectionSpacing(filename, clean_lines, class_info, linenum, error): + """Checks for additional blank line issues related to sections. + + Currently the only thing checked here is blank line before protected/private. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + class_info: A _ClassInfo objects. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + # Skip checks if the class is small, where small means 25 lines or less. + # 25 lines seems like a good cutoff since that's the usual height of + # terminals, and any class that can't fit in one screen can't really + # be considered "small". + # + # Also skip checks if we are on the first line. This accounts for + # classes that look like + # class Foo { public: ... }; + # + # If we didn't find the end of the class, last_line would be zero, + # and the check will be skipped by the first condition. + if (class_info.last_line - class_info.starting_linenum <= 24 or + linenum <= class_info.starting_linenum): + return + + matched = Match(r'\s*(public|protected|private):', clean_lines.lines[linenum]) + if matched: + # Issue warning if the line before public/protected/private was + # not a blank line, but don't do this if the previous line contains + # "class" or "struct". This can happen two ways: + # - We are at the beginning of the class. + # - We are forward-declaring an inner class that is semantically + # private, but needed to be public for implementation reasons. + # Also ignores cases where the previous line ends with a backslash as can be + # common when defining classes in C macros. + prev_line = clean_lines.lines[linenum - 1] + if (not IsBlankLine(prev_line) and + not Search(r'\b(class|struct)\b', prev_line) and + not Search(r'\\$', prev_line)): + # Try a bit harder to find the beginning of the class. This is to + # account for multi-line base-specifier lists, e.g.: + # class Derived + # : public Base { + end_class_head = class_info.starting_linenum + for i in range(class_info.starting_linenum, linenum): + if Search(r'\{\s*$', clean_lines.lines[i]): + end_class_head = i + break + if end_class_head < linenum - 1: + error(filename, linenum, 'whitespace/blank_line', 3, + '"%s:" should be preceded by a blank line' % matched.group(1)) + + +def GetPreviousNonBlankLine(clean_lines, linenum): + """Return the most recent non-blank line and its line number. + + Args: + clean_lines: A CleansedLines instance containing the file contents. + linenum: The number of the line to check. + + Returns: + A tuple with two elements. The first element is the contents of the last + non-blank line before the current line, or the empty string if this is the + first non-blank line. The second is the line number of that line, or -1 + if this is the first non-blank line. + """ + + prevlinenum = linenum - 1 + while prevlinenum >= 0: + prevline = clean_lines.elided[prevlinenum] + if not IsBlankLine(prevline): # if not a blank line... + return (prevline, prevlinenum) + prevlinenum -= 1 + return ('', -1) + + +def CheckBraces(filename, clean_lines, linenum, error): + """Looks for misplaced braces (e.g. at the end of line). + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + + line = clean_lines.elided[linenum] # get rid of comments and strings + + if Match(r'\s*{\s*$', line): + # We allow an open brace to start a line in the case where someone is using + # braces in a block to explicitly create a new scope, which is commonly used + # to control the lifetime of stack-allocated variables. Braces are also + # used for brace initializers inside function calls. We don't detect this + # perfectly: we just don't complain if the last non-whitespace character on + # the previous non-blank line is ',', ';', ':', '(', '{', or '}', or if the + # previous line starts a preprocessor block. We also allow a brace on the + # following line if it is part of an array initialization and would not fit + # within the 80 character limit of the preceding line. + prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0] + if (not Search(r'[,;:}{(]\s*$', prevline) and + not Match(r'\s*#', prevline) and + not (GetLineWidth(prevline) > _line_length - 2 and '[]' in prevline)): + error(filename, linenum, 'whitespace/braces', 4, + '{ should almost always be at the end of the previous line') + + # An else clause should be on the same line as the preceding closing brace. + if Match(r'\s*else\b\s*(?:if\b|\{|$)', line): + prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0] + if Match(r'\s*}\s*$', prevline): + error(filename, linenum, 'whitespace/newline', 4, + 'An else should appear on the same line as the preceding }') + + # If braces come on one side of an else, they should be on both. + # However, we have to worry about "else if" that spans multiple lines! + if Search(r'else if\s*\(', line): # could be multi-line if + brace_on_left = bool(Search(r'}\s*else if\s*\(', line)) + # find the ( after the if + pos = line.find('else if') + pos = line.find('(', pos) + if pos > 0: + (endline, _, endpos) = CloseExpression(clean_lines, linenum, pos) + brace_on_right = endline[endpos:].find('{') != -1 + if brace_on_left != brace_on_right: # must be brace after if + error(filename, linenum, 'readability/braces', 5, + 'If an else has a brace on one side, it should have it on both') + elif Search(r'}\s*else[^{]*$', line) or Match(r'[^}]*else\s*{', line): + error(filename, linenum, 'readability/braces', 5, + 'If an else has a brace on one side, it should have it on both') + + # Likewise, an else should never have the else clause on the same line + if Search(r'\belse [^\s{]', line) and not Search(r'\belse if\b', line): + error(filename, linenum, 'whitespace/newline', 4, + 'Else clause should never be on same line as else (use 2 lines)') + + # In the same way, a do/while should never be on one line + if Match(r'\s*do [^\s{]', line): + error(filename, linenum, 'whitespace/newline', 4, + 'do/while clauses should not be on a single line') + + # Check single-line if/else bodies. The style guide says 'curly braces are not + # required for single-line statements'. We additionally allow multi-line, + # single statements, but we reject anything with more than one semicolon in + # it. This means that the first semicolon after the if should be at the end of + # its line, and the line after that should have an indent level equal to or + # lower than the if. We also check for ambiguous if/else nesting without + # braces. + if_else_match = Search(r'\b(if\s*\(|else\b)', line) + if if_else_match and not Match(r'\s*#', line): + if_indent = GetIndentLevel(line) + endline, endlinenum, endpos = line, linenum, if_else_match.end() + if_match = Search(r'\bif\s*\(', line) + if if_match: + # This could be a multiline if condition, so find the end first. + pos = if_match.end() - 1 + (endline, endlinenum, endpos) = CloseExpression(clean_lines, linenum, pos) + # Check for an opening brace, either directly after the if or on the next + # line. If found, this isn't a single-statement conditional. + if (not Match(r'\s*{', endline[endpos:]) + and not (Match(r'\s*$', endline[endpos:]) + and endlinenum < (len(clean_lines.elided) - 1) + and Match(r'\s*{', clean_lines.elided[endlinenum + 1]))): + while (endlinenum < len(clean_lines.elided) + and ';' not in clean_lines.elided[endlinenum][endpos:]): + endlinenum += 1 + endpos = 0 + if endlinenum < len(clean_lines.elided): + endline = clean_lines.elided[endlinenum] + # We allow a mix of whitespace and closing braces (e.g. for one-liner + # methods) and a single \ after the semicolon (for macros) + endpos = endline.find(';') + if not Match(r';[\s}]*(\\?)$', endline[endpos:]): + # Semicolon isn't the last character, there's something trailing. + # Output a warning if the semicolon is not contained inside + # a lambda expression. + if not Match(r'^[^{};]*\[[^\[\]]*\][^{}]*\{[^{}]*\}\s*\)*[;,]\s*$', + endline): + error(filename, linenum, 'readability/braces', 4, + 'If/else bodies with multiple statements require braces') + elif endlinenum < len(clean_lines.elided) - 1: + # Make sure the next line is dedented + next_line = clean_lines.elided[endlinenum + 1] + next_indent = GetIndentLevel(next_line) + # With ambiguous nested if statements, this will error out on the + # if that *doesn't* match the else, regardless of whether it's the + # inner one or outer one. + if (if_match and Match(r'\s*else\b', next_line) + and next_indent != if_indent): + error(filename, linenum, 'readability/braces', 4, + 'Else clause should be indented at the same level as if. ' + 'Ambiguous nested if/else chains require braces.') + elif next_indent > if_indent: + error(filename, linenum, 'readability/braces', 4, + 'If/else bodies with multiple statements require braces') + + +def CheckTrailingSemicolon(filename, clean_lines, linenum, error): + """Looks for redundant trailing semicolon. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + + line = clean_lines.elided[linenum] + + # Block bodies should not be followed by a semicolon. Due to C++11 + # brace initialization, there are more places where semicolons are + # required than not, so we explicitly list the allowed rules rather + # than listing the disallowed ones. These are the places where "};" + # should be replaced by just "}": + # 1. Some flavor of block following closing parenthesis: + # for (;;) {}; + # while (...) {}; + # switch (...) {}; + # Function(...) {}; + # if (...) {}; + # if (...) else if (...) {}; + # + # 2. else block: + # if (...) else {}; + # + # 3. const member function: + # Function(...) const {}; + # + # 4. Block following some statement: + # x = 42; + # {}; + # + # 5. Block at the beginning of a function: + # Function(...) { + # {}; + # } + # + # Note that naively checking for the preceding "{" will also match + # braces inside multi-dimensional arrays, but this is fine since + # that expression will not contain semicolons. + # + # 6. Block following another block: + # while (true) {} + # {}; + # + # 7. End of namespaces: + # namespace {}; + # + # These semicolons seems far more common than other kinds of + # redundant semicolons, possibly due to people converting classes + # to namespaces. For now we do not warn for this case. + # + # Try matching case 1 first. + match = Match(r'^(.*\)\s*)\{', line) + if match: + # Matched closing parenthesis (case 1). Check the token before the + # matching opening parenthesis, and don't warn if it looks like a + # macro. This avoids these false positives: + # - macro that defines a base class + # - multi-line macro that defines a base class + # - macro that defines the whole class-head + # + # But we still issue warnings for macros that we know are safe to + # warn, specifically: + # - TEST, TEST_F, TEST_P, MATCHER, MATCHER_P + # - TYPED_TEST + # - INTERFACE_DEF + # - EXCLUSIVE_LOCKS_REQUIRED, SHARED_LOCKS_REQUIRED, LOCKS_EXCLUDED: + # + # We implement a list of safe macros instead of a list of + # unsafe macros, even though the latter appears less frequently in + # google code and would have been easier to implement. This is because + # the downside for getting the allowed checks wrong means some extra + # semicolons, while the downside for getting disallowed checks wrong + # would result in compile errors. + # + # In addition to macros, we also don't want to warn on + # - Compound literals + # - Lambdas + # - alignas specifier with anonymous structs + # - decltype + closing_brace_pos = match.group(1).rfind(')') + opening_parenthesis = ReverseCloseExpression( + clean_lines, linenum, closing_brace_pos) + if opening_parenthesis[2] > -1: + line_prefix = opening_parenthesis[0][0:opening_parenthesis[2]] + macro = Search(r'\b([A-Z_][A-Z0-9_]*)\s*$', line_prefix) + func = Match(r'^(.*\])\s*$', line_prefix) + if ((macro and + macro.group(1) not in ( + 'TEST', 'TEST_F', 'MATCHER', 'MATCHER_P', 'TYPED_TEST', + 'EXCLUSIVE_LOCKS_REQUIRED', 'SHARED_LOCKS_REQUIRED', + 'LOCKS_EXCLUDED', 'INTERFACE_DEF')) or + (func and not Search(r'\boperator\s*\[\s*\]', func.group(1))) or + Search(r'\b(?:struct|union)\s+alignas\s*$', line_prefix) or + Search(r'\bdecltype$', line_prefix) or + Search(r'\s+=\s*$', line_prefix)): + match = None + if (match and + opening_parenthesis[1] > 1 and + Search(r'\]\s*$', clean_lines.elided[opening_parenthesis[1] - 1])): + # Multi-line lambda-expression + match = None + + else: + # Try matching cases 2-3. + match = Match(r'^(.*(?:else|\)\s*const)\s*)\{', line) + if not match: + # Try matching cases 4-6. These are always matched on separate lines. + # + # Note that we can't simply concatenate the previous line to the + # current line and do a single match, otherwise we may output + # duplicate warnings for the blank line case: + # if (cond) { + # // blank line + # } + prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0] + if prevline and Search(r'[;{}]\s*$', prevline): + match = Match(r'^(\s*)\{', line) + + # Check matching closing brace + if match: + (endline, endlinenum, endpos) = CloseExpression( + clean_lines, linenum, len(match.group(1))) + if endpos > -1 and Match(r'^\s*;', endline[endpos:]): + # Current {} pair is eligible for semicolon check, and we have found + # the redundant semicolon, output warning here. + # + # Note: because we are scanning forward for opening braces, and + # outputting warnings for the matching closing brace, if there are + # nested blocks with trailing semicolons, we will get the error + # messages in reversed order. + + # We need to check the line forward for NOLINT + raw_lines = clean_lines.raw_lines + ParseNolintSuppressions(filename, raw_lines[endlinenum-1], endlinenum-1, + error) + ParseNolintSuppressions(filename, raw_lines[endlinenum], endlinenum, + error) + + error(filename, endlinenum, 'readability/braces', 4, + "You don't need a ; after a }") + + +def CheckEmptyBlockBody(filename, clean_lines, linenum, error): + """Look for empty loop/conditional body with only a single semicolon. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + + # Search for loop keywords at the beginning of the line. Because only + # whitespaces are allowed before the keywords, this will also ignore most + # do-while-loops, since those lines should start with closing brace. + # + # We also check "if" blocks here, since an empty conditional block + # is likely an error. + line = clean_lines.elided[linenum] + matched = Match(r'\s*(for|while|if)\s*\(', line) + if matched: + # Find the end of the conditional expression. + (end_line, end_linenum, end_pos) = CloseExpression( + clean_lines, linenum, line.find('(')) + + # Output warning if what follows the condition expression is a semicolon. + # No warning for all other cases, including whitespace or newline, since we + # have a separate check for semicolons preceded by whitespace. + if end_pos >= 0 and Match(r';', end_line[end_pos:]): + if matched.group(1) == 'if': + error(filename, end_linenum, 'whitespace/empty_conditional_body', 5, + 'Empty conditional bodies should use {}') + else: + error(filename, end_linenum, 'whitespace/empty_loop_body', 5, + 'Empty loop bodies should use {} or continue') + + # Check for if statements that have completely empty bodies (no comments) + # and no else clauses. + if end_pos >= 0 and matched.group(1) == 'if': + # Find the position of the opening { for the if statement. + # Return without logging an error if it has no brackets. + opening_linenum = end_linenum + opening_line_fragment = end_line[end_pos:] + # Loop until EOF or find anything that's not whitespace or opening {. + while not Search(r'^\s*\{', opening_line_fragment): + if Search(r'^(?!\s*$)', opening_line_fragment): + # Conditional has no brackets. + return + opening_linenum += 1 + if opening_linenum == len(clean_lines.elided): + # Couldn't find conditional's opening { or any code before EOF. + return + opening_line_fragment = clean_lines.elided[opening_linenum] + # Set opening_line (opening_line_fragment may not be entire opening line). + opening_line = clean_lines.elided[opening_linenum] + + # Find the position of the closing }. + opening_pos = opening_line_fragment.find('{') + if opening_linenum == end_linenum: + # We need to make opening_pos relative to the start of the entire line. + opening_pos += end_pos + (closing_line, closing_linenum, closing_pos) = CloseExpression( + clean_lines, opening_linenum, opening_pos) + if closing_pos < 0: + return + + # Now construct the body of the conditional. This consists of the portion + # of the opening line after the {, all lines until the closing line, + # and the portion of the closing line before the }. + if (clean_lines.raw_lines[opening_linenum] != + CleanseComments(clean_lines.raw_lines[opening_linenum])): + # Opening line ends with a comment, so conditional isn't empty. + return + if closing_linenum > opening_linenum: + # Opening line after the {. Ignore comments here since we checked above. + body = list(opening_line[opening_pos+1:]) + # All lines until closing line, excluding closing line, with comments. + body.extend(clean_lines.raw_lines[opening_linenum+1:closing_linenum]) + # Closing line before the }. Won't (and can't) have comments. + body.append(clean_lines.elided[closing_linenum][:closing_pos-1]) + body = '\n'.join(body) + else: + # If statement has brackets and fits on a single line. + body = opening_line[opening_pos+1:closing_pos-1] + + # Check if the body is empty + if not _EMPTY_CONDITIONAL_BODY_PATTERN.search(body): + return + # The body is empty. Now make sure there's not an else clause. + current_linenum = closing_linenum + current_line_fragment = closing_line[closing_pos:] + # Loop until EOF or find anything that's not whitespace or else clause. + while Search(r'^\s*$|^(?=\s*else)', current_line_fragment): + if Search(r'^(?=\s*else)', current_line_fragment): + # Found an else clause, so don't log an error. + return + current_linenum += 1 + if current_linenum == len(clean_lines.elided): + break + current_line_fragment = clean_lines.elided[current_linenum] + + # The body is empty and there's no else clause until EOF or other code. + error(filename, end_linenum, 'whitespace/empty_if_body', 4, + ('If statement had no body and no else clause')) + + +def FindCheckMacro(line): + """Find a replaceable CHECK-like macro. + + Args: + line: line to search on. + Returns: + (macro name, start position), or (None, -1) if no replaceable + macro is found. + """ + for macro in _CHECK_MACROS: + i = line.find(macro) + if i >= 0: + # Find opening parenthesis. Do a regular expression match here + # to make sure that we are matching the expected CHECK macro, as + # opposed to some other macro that happens to contain the CHECK + # substring. + matched = Match(r'^(.*\b' + macro + r'\s*)\(', line) + if not matched: + continue + return (macro, len(matched.group(1))) + return (None, -1) + + +def CheckCheck(filename, clean_lines, linenum, error): + """Checks the use of CHECK and EXPECT macros. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + + # Decide the set of replacement macros that should be suggested + lines = clean_lines.elided + (check_macro, start_pos) = FindCheckMacro(lines[linenum]) + if not check_macro: + return + + # Find end of the boolean expression by matching parentheses + (last_line, end_line, end_pos) = CloseExpression( + clean_lines, linenum, start_pos) + if end_pos < 0: + return + + # If the check macro is followed by something other than a + # semicolon, assume users will log their own custom error messages + # and don't suggest any replacements. + if not Match(r'\s*;', last_line[end_pos:]): + return + + if linenum == end_line: + expression = lines[linenum][start_pos + 1:end_pos - 1] + else: + expression = lines[linenum][start_pos + 1:] + for i in xrange(linenum + 1, end_line): + expression += lines[i] + expression += last_line[0:end_pos - 1] + + # Parse expression so that we can take parentheses into account. + # This avoids false positives for inputs like "CHECK((a < 4) == b)", + # which is not replaceable by CHECK_LE. + lhs = '' + rhs = '' + operator = None + while expression: + matched = Match(r'^\s*(<<|<<=|>>|>>=|->\*|->|&&|\|\||' + r'==|!=|>=|>|<=|<|\()(.*)$', expression) + if matched: + token = matched.group(1) + if token == '(': + # Parenthesized operand + expression = matched.group(2) + (end, _) = FindEndOfExpressionInLine(expression, 0, ['(']) + if end < 0: + return # Unmatched parenthesis + lhs += '(' + expression[0:end] + expression = expression[end:] + elif token in ('&&', '||'): + # Logical and/or operators. This means the expression + # contains more than one term, for example: + # CHECK(42 < a && a < b); + # + # These are not replaceable with CHECK_LE, so bail out early. + return + elif token in ('<<', '<<=', '>>', '>>=', '->*', '->'): + # Non-relational operator + lhs += token + expression = matched.group(2) + else: + # Relational operator + operator = token + rhs = matched.group(2) + break + else: + # Unparenthesized operand. Instead of appending to lhs one character + # at a time, we do another regular expression match to consume several + # characters at once if possible. Trivial benchmark shows that this + # is more efficient when the operands are longer than a single + # character, which is generally the case. + matched = Match(r'^([^-=!<>()&|]+)(.*)$', expression) + if not matched: + matched = Match(r'^(\s*\S)(.*)$', expression) + if not matched: + break + lhs += matched.group(1) + expression = matched.group(2) + + # Only apply checks if we got all parts of the boolean expression + if not (lhs and operator and rhs): + return + + # Check that rhs do not contain logical operators. We already know + # that lhs is fine since the loop above parses out && and ||. + if rhs.find('&&') > -1 or rhs.find('||') > -1: + return + + # At least one of the operands must be a constant literal. This is + # to avoid suggesting replacements for unprintable things like + # CHECK(variable != iterator) + # + # The following pattern matches decimal, hex integers, strings, and + # characters (in that order). + lhs = lhs.strip() + rhs = rhs.strip() + match_constant = r'^([-+]?(\d+|0[xX][0-9a-fA-F]+)[lLuU]{0,3}|".*"|\'.*\')$' + if Match(match_constant, lhs) or Match(match_constant, rhs): + # Note: since we know both lhs and rhs, we can provide a more + # descriptive error message like: + # Consider using CHECK_EQ(x, 42) instead of CHECK(x == 42) + # Instead of: + # Consider using CHECK_EQ instead of CHECK(a == b) + # + # We are still keeping the less descriptive message because if lhs + # or rhs gets long, the error message might become unreadable. + error(filename, linenum, 'readability/check', 2, + 'Consider using %s instead of %s(a %s b)' % ( + _CHECK_REPLACEMENT[check_macro][operator], + check_macro, operator)) + + +def CheckAltTokens(filename, clean_lines, linenum, error): + """Check alternative keywords being used in boolean expressions. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Avoid preprocessor lines + if Match(r'^\s*#', line): + return + + # Last ditch effort to avoid multi-line comments. This will not help + # if the comment started before the current line or ended after the + # current line, but it catches most of the false positives. At least, + # it provides a way to workaround this warning for people who use + # multi-line comments in preprocessor macros. + # + # TODO(unknown): remove this once cpplint has better support for + # multi-line comments. + if line.find('/*') >= 0 or line.find('*/') >= 0: + return + + for match in _ALT_TOKEN_REPLACEMENT_PATTERN.finditer(line): + error(filename, linenum, 'readability/alt_tokens', 2, + 'Use operator %s instead of %s' % ( + _ALT_TOKEN_REPLACEMENT[match.group(1)], match.group(1))) + + +def GetLineWidth(line): + """Determines the width of the line in column positions. + + Args: + line: A string, which may be a Unicode string. + + Returns: + The width of the line in column positions, accounting for Unicode + combining characters and wide characters. + """ + if isinstance(line, unicode): + width = 0 + for uc in unicodedata.normalize('NFC', line): + if unicodedata.east_asian_width(uc) in ('W', 'F'): + width += 2 + elif not unicodedata.combining(uc): + # Issue 337 + # https://mail.python.org/pipermail/python-list/2012-August/628809.html + if (sys.version_info.major, sys.version_info.minor) <= (3, 2): + # https://github.com/python/cpython/blob/2.7/Include/unicodeobject.h#L81 + is_wide_build = sysconfig.get_config_var("Py_UNICODE_SIZE") >= 4 + # https://github.com/python/cpython/blob/2.7/Objects/unicodeobject.c#L564 + is_low_surrogate = 0xDC00 <= ord(uc) <= 0xDFFF + if not is_wide_build and is_low_surrogate: + width -= 1 + + width += 1 + return width + else: + return len(line) + + +def CheckStyle(filename, clean_lines, linenum, file_extension, nesting_state, + error): + """Checks rules from the 'C++ style rules' section of cppguide.html. + + Most of these rules are hard to test (naming, comment style), but we + do what we can. In particular we check for 2-space indents, line lengths, + tab usage, spaces inside code, etc. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + file_extension: The extension (without the dot) of the filename. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: The function to call with any errors found. + """ + + # Don't use "elided" lines here, otherwise we can't check commented lines. + # Don't want to use "raw" either, because we don't want to check inside C++11 + # raw strings, + raw_lines = clean_lines.lines_without_raw_strings + line = raw_lines[linenum] + prev = raw_lines[linenum - 1] if linenum > 0 else '' + + if line.find('\t') != -1: + error(filename, linenum, 'whitespace/tab', 1, + 'Tab found; better to use spaces') + + # One or three blank spaces at the beginning of the line is weird; it's + # hard to reconcile that with 2-space indents. + # NOTE: here are the conditions rob pike used for his tests. Mine aren't + # as sophisticated, but it may be worth becoming so: RLENGTH==initial_spaces + # if(RLENGTH > 20) complain = 0; + # if(match($0, " +(error|private|public|protected):")) complain = 0; + # if(match(prev, "&& *$")) complain = 0; + # if(match(prev, "\\|\\| *$")) complain = 0; + # if(match(prev, "[\",=><] *$")) complain = 0; + # if(match($0, " <<")) complain = 0; + # if(match(prev, " +for \\(")) complain = 0; + # if(prevodd && match(prevprev, " +for \\(")) complain = 0; + scope_or_label_pattern = r'\s*\w+\s*:\s*\\?$' + classinfo = nesting_state.InnermostClass() + initial_spaces = 0 + cleansed_line = clean_lines.elided[linenum] + while initial_spaces < len(line) and line[initial_spaces] == ' ': + initial_spaces += 1 + # There are certain situations we allow one space, notably for + # section labels, and also lines containing multi-line raw strings. + # We also don't check for lines that look like continuation lines + # (of lines ending in double quotes, commas, equals, or angle brackets) + # because the rules for how to indent those are non-trivial. + if (not Search(r'[",=><] *$', prev) and + (initial_spaces == 1 or initial_spaces == 3) and + not Match(scope_or_label_pattern, cleansed_line) and + not (clean_lines.raw_lines[linenum] != line and + Match(r'^\s*""', line))): + error(filename, linenum, 'whitespace/indent', 3, + 'Weird number of spaces at line-start. ' + 'Are you using a 2-space indent?') + + if line and line[-1].isspace(): + error(filename, linenum, 'whitespace/end_of_line', 4, + 'Line ends in whitespace. Consider deleting these extra spaces.') + + # Check if the line is a header guard. + is_header_guard = False + if IsHeaderExtension(file_extension): + cppvar = GetHeaderGuardCPPVariable(filename) + if (line.startswith('#ifndef %s' % cppvar) or + line.startswith('#define %s' % cppvar) or + line.startswith('#endif // %s' % cppvar)): + is_header_guard = True + # #include lines and header guards can be long, since there's no clean way to + # split them. + # + # URLs can be long too. It's possible to split these, but it makes them + # harder to cut&paste. + # + # The "$Id:...$" comment may also get very long without it being the + # developers fault. + if (not line.startswith('#include') and not is_header_guard and + not Match(r'^\s*//.*http(s?)://\S*$', line) and + not Match(r'^\s*//\s*[^\s]*$', line) and + not Match(r'^// \$Id:.*#[0-9]+ \$$', line)): + line_width = GetLineWidth(line) + if line_width > _line_length: + error(filename, linenum, 'whitespace/line_length', 2, + 'Lines should be <= %i characters long' % _line_length) + + if (cleansed_line.count(';') > 1 and + # for loops are allowed two ;'s (and may run over two lines). + cleansed_line.find('for') == -1 and + (GetPreviousNonBlankLine(clean_lines, linenum)[0].find('for') == -1 or + GetPreviousNonBlankLine(clean_lines, linenum)[0].find(';') != -1) and + # It's ok to have many commands in a switch case that fits in 1 line + not ((cleansed_line.find('case ') != -1 or + cleansed_line.find('default:') != -1) and + cleansed_line.find('break;') != -1)): + error(filename, linenum, 'whitespace/newline', 0, + 'More than one command on the same line') + + # Some more style checks + CheckBraces(filename, clean_lines, linenum, error) + CheckTrailingSemicolon(filename, clean_lines, linenum, error) + CheckEmptyBlockBody(filename, clean_lines, linenum, error) + CheckSpacing(filename, clean_lines, linenum, nesting_state, error) + CheckOperatorSpacing(filename, clean_lines, linenum, error) + CheckParenthesisSpacing(filename, clean_lines, linenum, error) + CheckCommaSpacing(filename, clean_lines, linenum, error) + CheckBracesSpacing(filename, clean_lines, linenum, nesting_state, error) + CheckSpacingForFunctionCall(filename, clean_lines, linenum, error) + CheckCheck(filename, clean_lines, linenum, error) + CheckAltTokens(filename, clean_lines, linenum, error) + classinfo = nesting_state.InnermostClass() + if classinfo: + CheckSectionSpacing(filename, clean_lines, classinfo, linenum, error) + + +_RE_PATTERN_INCLUDE = re.compile(r'^\s*#\s*include\s*([<"])([^>"]*)[>"].*$') +# Matches the first component of a filename delimited by -s and _s. That is: +# _RE_FIRST_COMPONENT.match('foo').group(0) == 'foo' +# _RE_FIRST_COMPONENT.match('foo.cc').group(0) == 'foo' +# _RE_FIRST_COMPONENT.match('foo-bar_baz.cc').group(0) == 'foo' +# _RE_FIRST_COMPONENT.match('foo_bar-baz.cc').group(0) == 'foo' +_RE_FIRST_COMPONENT = re.compile(r'^[^-_.]+') + + +def _DropCommonSuffixes(filename): + """Drops common suffixes like _test.cc or -inl.h from filename. + + For example: + >>> _DropCommonSuffixes('foo/foo-inl.h') + 'foo/foo' + >>> _DropCommonSuffixes('foo/bar/foo.cc') + 'foo/bar/foo' + >>> _DropCommonSuffixes('foo/foo_internal.h') + 'foo/foo' + >>> _DropCommonSuffixes('foo/foo_unusualinternal.h') + 'foo/foo_unusualinternal' + + Args: + filename: The input filename. + + Returns: + The filename with the common suffix removed. + """ + for suffix in ('test.cc', 'regtest.cc', 'unittest.cc', + 'inl.h', 'impl.h', 'internal.h'): + if (filename.endswith(suffix) and len(filename) > len(suffix) and + filename[-len(suffix) - 1] in ('-', '_')): + return filename[:-len(suffix) - 1] + return os.path.splitext(filename)[0] + + +def _ClassifyInclude(fileinfo, include, is_system): + """Figures out what kind of header 'include' is. + + Args: + fileinfo: The current file cpplint is running over. A FileInfo instance. + include: The path to a #included file. + is_system: True if the #include used <> rather than "". + + Returns: + One of the _XXX_HEADER constants. + + For example: + >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'stdio.h', True) + _C_SYS_HEADER + >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'string', True) + _CPP_SYS_HEADER + >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'foo/foo.h', False) + _LIKELY_MY_HEADER + >>> _ClassifyInclude(FileInfo('foo/foo_unknown_extension.cc'), + ... 'bar/foo_other_ext.h', False) + _POSSIBLE_MY_HEADER + >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'foo/bar.h', False) + _OTHER_HEADER + """ + # This is a list of all standard c++ header files, except + # those already checked for above. + is_cpp_h = include in _CPP_HEADERS + + if is_system: + if is_cpp_h: + return _CPP_SYS_HEADER + else: + return _C_SYS_HEADER + + # If the target file and the include we're checking share a + # basename when we drop common extensions, and the include + # lives in . , then it's likely to be owned by the target file. + target_dir, target_base = ( + os.path.split(_DropCommonSuffixes(fileinfo.RepositoryName()))) + include_dir, include_base = os.path.split(_DropCommonSuffixes(include)) + if target_base == include_base and ( + include_dir == target_dir or + include_dir == os.path.normpath(target_dir + '/../public')): + return _LIKELY_MY_HEADER + + # If the target and include share some initial basename + # component, it's possible the target is implementing the + # include, so it's allowed to be first, but we'll never + # complain if it's not there. + target_first_component = _RE_FIRST_COMPONENT.match(target_base) + include_first_component = _RE_FIRST_COMPONENT.match(include_base) + if (target_first_component and include_first_component and + target_first_component.group(0) == + include_first_component.group(0)): + return _POSSIBLE_MY_HEADER + + return _OTHER_HEADER + + + +def CheckIncludeLine(filename, clean_lines, linenum, include_state, error): + """Check rules that are applicable to #include lines. + + Strings on #include lines are NOT removed from elided line, to make + certain tasks easier. However, to prevent false positives, checks + applicable to #include lines in CheckLanguage must be put here. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + include_state: An _IncludeState instance in which the headers are inserted. + error: The function to call with any errors found. + """ + fileinfo = FileInfo(filename) + line = clean_lines.lines[linenum] + + # "include" should use the new style "foo/bar.h" instead of just "bar.h" + # Only do this check if the included header follows google naming + # conventions. If not, assume that it's a 3rd party API that + # requires special include conventions. + # + # We also make an exception for Lua headers, which follow google + # naming convention but not the include convention. + match = Match(r'#include\s*"([^/]+\.h)"', line) + if match and not _THIRD_PARTY_HEADERS_PATTERN.match(match.group(1)): + error(filename, linenum, 'build/include', 4, + 'Include the directory when naming .h files') + + # we shouldn't include a file more than once. actually, there are a + # handful of instances where doing so is okay, but in general it's + # not. + match = _RE_PATTERN_INCLUDE.search(line) + if match: + include = match.group(2) + is_system = (match.group(1) == '<') + duplicate_line = include_state.FindHeader(include) + if duplicate_line >= 0: + error(filename, linenum, 'build/include', 4, + '"%s" already included at %s:%s' % + (include, filename, duplicate_line)) + elif (include.endswith('.cc') and + os.path.dirname(fileinfo.RepositoryName()) != os.path.dirname(include)): + error(filename, linenum, 'build/include', 4, + 'Do not include .cc files from other packages') + elif not _THIRD_PARTY_HEADERS_PATTERN.match(include): + include_state.include_list[-1].append((include, linenum)) + + # We want to ensure that headers appear in the right order: + # 1) for foo.cc, foo.h (preferred location) + # 2) c system files + # 3) cpp system files + # 4) for foo.cc, foo.h (deprecated location) + # 5) other google headers + # + # We classify each include statement as one of those 5 types + # using a number of techniques. The include_state object keeps + # track of the highest type seen, and complains if we see a + # lower type after that. + error_message = include_state.CheckNextIncludeOrder( + _ClassifyInclude(fileinfo, include, is_system)) + if error_message: + error(filename, linenum, 'build/include_order', 4, + '%s. Should be: %s.h, c system, c++ system, other.' % + (error_message, fileinfo.BaseName())) + canonical_include = include_state.CanonicalizeAlphabeticalOrder(include) + if not include_state.IsInAlphabeticalOrder( + clean_lines, linenum, canonical_include): + error(filename, linenum, 'build/include_alpha', 4, + 'Include "%s" not in alphabetical order' % include) + include_state.SetLastHeader(canonical_include) + + + +def _GetTextInside(text, start_pattern): + r"""Retrieves all the text between matching open and close parentheses. + + Given a string of lines and a regular expression string, retrieve all the text + following the expression and between opening punctuation symbols like + (, [, or {, and the matching close-punctuation symbol. This properly nested + occurrences of the punctuations, so for the text like + printf(a(), b(c())); + a call to _GetTextInside(text, r'printf\(') will return 'a(), b(c())'. + start_pattern must match string having an open punctuation symbol at the end. + + Args: + text: The lines to extract text. Its comments and strings must be elided. + It can be single line and can span multiple lines. + start_pattern: The regexp string indicating where to start extracting + the text. + Returns: + The extracted text. + None if either the opening string or ending punctuation could not be found. + """ + # TODO(unknown): Audit cpplint.py to see what places could be profitably + # rewritten to use _GetTextInside (and use inferior regexp matching today). + + # Give opening punctuations to get the matching close-punctuations. + matching_punctuation = {'(': ')', '{': '}', '[': ']'} + closing_punctuation = set(matching_punctuation.itervalues()) + + # Find the position to start extracting text. + match = re.search(start_pattern, text, re.M) + if not match: # start_pattern not found in text. + return None + start_position = match.end(0) + + assert start_position > 0, ( + 'start_pattern must ends with an opening punctuation.') + assert text[start_position - 1] in matching_punctuation, ( + 'start_pattern must ends with an opening punctuation.') + # Stack of closing punctuations we expect to have in text after position. + punctuation_stack = [matching_punctuation[text[start_position - 1]]] + position = start_position + while punctuation_stack and position < len(text): + if text[position] == punctuation_stack[-1]: + punctuation_stack.pop() + elif text[position] in closing_punctuation: + # A closing punctuation without matching opening punctuations. + return None + elif text[position] in matching_punctuation: + punctuation_stack.append(matching_punctuation[text[position]]) + position += 1 + if punctuation_stack: + # Opening punctuations left without matching close-punctuations. + return None + # punctuations match. + return text[start_position:position - 1] + + +# Patterns for matching call-by-reference parameters. +# +# Supports nested templates up to 2 levels deep using this messy pattern: +# < (?: < (?: < [^<>]* +# > +# | [^<>] )* +# > +# | [^<>] )* +# > +_RE_PATTERN_IDENT = r'[_a-zA-Z]\w*' # =~ [[:alpha:]][[:alnum:]]* +_RE_PATTERN_TYPE = ( + r'(?:const\s+)?(?:typename\s+|class\s+|struct\s+|union\s+|enum\s+)?' + r'(?:\w|' + r'\s*<(?:<(?:<[^<>]*>|[^<>])*>|[^<>])*>|' + r'::)+') +# A call-by-reference parameter ends with '& identifier'. +_RE_PATTERN_REF_PARAM = re.compile( + r'(' + _RE_PATTERN_TYPE + r'(?:\s*(?:\bconst\b|[*]))*\s*' + r'&\s*' + _RE_PATTERN_IDENT + r')\s*(?:=[^,()]+)?[,)]') +# A call-by-const-reference parameter either ends with 'const& identifier' +# or looks like 'const type& identifier' when 'type' is atomic. +_RE_PATTERN_CONST_REF_PARAM = ( + r'(?:.*\s*\bconst\s*&\s*' + _RE_PATTERN_IDENT + + r'|const\s+' + _RE_PATTERN_TYPE + r'\s*&\s*' + _RE_PATTERN_IDENT + r')') +# Stream types. +_RE_PATTERN_REF_STREAM_PARAM = ( + r'(?:.*stream\s*&\s*' + _RE_PATTERN_IDENT + r')') + + +def CheckLanguage(filename, clean_lines, linenum, file_extension, + include_state, nesting_state, error): + """Checks rules from the 'C++ language rules' section of cppguide.html. + + Some of these rules are hard to test (function overloading, using + uint32 inappropriately), but we do the best we can. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + file_extension: The extension (without the dot) of the filename. + include_state: An _IncludeState instance in which the headers are inserted. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: The function to call with any errors found. + """ + # If the line is empty or consists of entirely a comment, no need to + # check it. + line = clean_lines.elided[linenum] + if not line: + return + + match = _RE_PATTERN_INCLUDE.search(line) + if match: + CheckIncludeLine(filename, clean_lines, linenum, include_state, error) + return + + # Reset include state across preprocessor directives. This is meant + # to silence warnings for conditional includes. + match = Match(r'^\s*#\s*(if|ifdef|ifndef|elif|else|endif)\b', line) + if match: + include_state.ResetSection(match.group(1)) + + # Make Windows paths like Unix. + fullname = os.path.abspath(filename).replace('\\', '/') + + # Perform other checks now that we are sure that this is not an include line + CheckCasts(filename, clean_lines, linenum, error) + CheckGlobalStatic(filename, clean_lines, linenum, error) + CheckPrintf(filename, clean_lines, linenum, error) + + if IsHeaderExtension(file_extension): + # TODO(unknown): check that 1-arg constructors are explicit. + # How to tell it's a constructor? + # (handled in CheckForNonStandardConstructs for now) + # TODO(unknown): check that classes declare or disable copy/assign + # (level 1 error) + pass + + # Check if people are using the verboten C basic types. The only exception + # we regularly allow is "unsigned short port" for port. + if Search(r'\bshort port\b', line): + if not Search(r'\bunsigned short port\b', line): + error(filename, linenum, 'runtime/int', 4, + 'Use "unsigned short" for ports, not "short"') + else: + match = Search(r'\b(short|long(?! +double)|long long)\b', line) + if match: + error(filename, linenum, 'runtime/int', 4, + 'Use int16/int64/etc, rather than the C type %s' % match.group(1)) + + # Check if some verboten operator overloading is going on + # TODO(unknown): catch out-of-line unary operator&: + # class X {}; + # int operator&(const X& x) { return 42; } // unary operator& + # The trick is it's hard to tell apart from binary operator&: + # class Y { int operator&(const Y& x) { return 23; } }; // binary operator& + if Search(r'\boperator\s*&\s*\(\s*\)', line): + error(filename, linenum, 'runtime/operator', 4, + 'Unary operator& is dangerous. Do not use it.') + + # Check for suspicious usage of "if" like + # } if (a == b) { + if Search(r'\}\s*if\s*\(', line): + error(filename, linenum, 'readability/braces', 4, + 'Did you mean "else if"? If not, start a new line for "if".') + + # Check for potential format string bugs like printf(foo). + # We constrain the pattern not to pick things like DocidForPrintf(foo). + # Not perfect but it can catch printf(foo.c_str()) and printf(foo->c_str()) + # TODO(unknown): Catch the following case. Need to change the calling + # convention of the whole function to process multiple line to handle it. + # printf( + # boy_this_is_a_really_long_variable_that_cannot_fit_on_the_prev_line); + printf_args = _GetTextInside(line, r'(?i)\b(string)?printf\s*\(') + if printf_args: + match = Match(r'([\w.\->()]+)$', printf_args) + if match and match.group(1) != '__VA_ARGS__': + function_name = re.search(r'\b((?:string)?printf)\s*\(', + line, re.I).group(1) + error(filename, linenum, 'runtime/printf', 4, + 'Potential format string bug. Do %s("%%s", %s) instead.' + % (function_name, match.group(1))) + + # Check for potential memset bugs like memset(buf, sizeof(buf), 0). + match = Search(r'memset\s*\(([^,]*),\s*([^,]*),\s*0\s*\)', line) + if match and not Match(r"^''|-?[0-9]+|0x[0-9A-Fa-f]$", match.group(2)): + error(filename, linenum, 'runtime/memset', 4, + 'Did you mean "memset(%s, 0, %s)"?' + % (match.group(1), match.group(2))) + + if Search(r'\busing namespace\b', line): + error(filename, linenum, 'build/namespaces', 5, + 'Do not use namespace using-directives. ' + 'Use using-declarations instead.') + + # Detect variable-length arrays. + match = Match(r'\s*(.+::)?(\w+) [a-z]\w*\[(.+)];', line) + if (match and match.group(2) != 'return' and match.group(2) != 'delete' and + match.group(3).find(']') == -1): + # Split the size using space and arithmetic operators as delimiters. + # If any of the resulting tokens are not compile time constants then + # report the error. + tokens = re.split(r'\s|\+|\-|\*|\/|<<|>>]', match.group(3)) + is_const = True + skip_next = False + for tok in tokens: + if skip_next: + skip_next = False + continue + + if Search(r'sizeof\(.+\)', tok): continue + if Search(r'arraysize\(\w+\)', tok): continue + + tok = tok.lstrip('(') + tok = tok.rstrip(')') + if not tok: continue + if Match(r'\d+', tok): continue + if Match(r'0[xX][0-9a-fA-F]+', tok): continue + if Match(r'k[A-Z0-9]\w*', tok): continue + if Match(r'(.+::)?k[A-Z0-9]\w*', tok): continue + if Match(r'(.+::)?[A-Z][A-Z0-9_]*', tok): continue + # A catch all for tricky sizeof cases, including 'sizeof expression', + # 'sizeof(*type)', 'sizeof(const type)', 'sizeof(struct StructName)' + # requires skipping the next token because we split on ' ' and '*'. + if tok.startswith('sizeof'): + skip_next = True + continue + is_const = False + break + if not is_const: + error(filename, linenum, 'runtime/arrays', 1, + 'Do not use variable-length arrays. Use an appropriately named ' + "('k' followed by CamelCase) compile-time constant for the size.") + + # Check for use of unnamed namespaces in header files. Registration + # macros are typically OK, so we allow use of "namespace {" on lines + # that end with backslashes. + if (IsHeaderExtension(file_extension) + and Search(r'\bnamespace\s*{', line) + and line[-1] != '\\'): + error(filename, linenum, 'build/namespaces', 4, + 'Do not use unnamed namespaces in header files. See ' + 'https://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Namespaces' + ' for more information.') + + +def CheckGlobalStatic(filename, clean_lines, linenum, error): + """Check for unsafe global or static objects. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Match two lines at a time to support multiline declarations + if linenum + 1 < clean_lines.NumLines() and not Search(r'[;({]', line): + line += clean_lines.elided[linenum + 1].strip() + + # Check for people declaring static/global STL strings at the top level. + # This is dangerous because the C++ language does not guarantee that + # globals with constructors are initialized before the first access, and + # also because globals can be destroyed when some threads are still running. + # TODO(unknown): Generalize this to also find static unique_ptr instances. + # TODO(unknown): File bugs for clang-tidy to find these. + match = Match( + r'((?:|static +)(?:|const +))(?::*std::)?string( +const)? +' + r'([a-zA-Z0-9_:]+)\b(.*)', + line) + + # Remove false positives: + # - String pointers (as opposed to values). + # string *pointer + # const string *pointer + # string const *pointer + # string *const pointer + # + # - Functions and template specializations. + # string Function(... + # string Class::Method(... + # + # - Operators. These are matched separately because operator names + # cross non-word boundaries, and trying to match both operators + # and functions at the same time would decrease accuracy of + # matching identifiers. + # string Class::operator*() + if (match and + not Search(r'\bstring\b(\s+const)?\s*[\*\&]\s*(const\s+)?\w', line) and + not Search(r'\boperator\W', line) and + not Match(r'\s*(<.*>)?(::[a-zA-Z0-9_]+)*\s*\(([^"]|$)', match.group(4))): + if Search(r'\bconst\b', line): + error(filename, linenum, 'runtime/string', 4, + 'For a static/global string constant, use a C style string ' + 'instead: "%schar%s %s[]".' % + (match.group(1), match.group(2) or '', match.group(3))) + else: + error(filename, linenum, 'runtime/string', 4, + 'Static/global string variables are not permitted.') + + if (Search(r'\b([A-Za-z0-9_]*_)\(\1\)', line) or + Search(r'\b([A-Za-z0-9_]*_)\(CHECK_NOTNULL\(\1\)\)', line)): + error(filename, linenum, 'runtime/init', 4, + 'You seem to be initializing a member variable with itself.') + + +def CheckPrintf(filename, clean_lines, linenum, error): + """Check for printf related issues. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # When snprintf is used, the second argument shouldn't be a literal. + match = Search(r'snprintf\s*\(([^,]*),\s*([0-9]*)\s*,', line) + if match and match.group(2) != '0': + # If 2nd arg is zero, snprintf is used to calculate size. + error(filename, linenum, 'runtime/printf', 3, + 'If you can, use sizeof(%s) instead of %s as the 2nd arg ' + 'to snprintf.' % (match.group(1), match.group(2))) + + # Check if some verboten C functions are being used. + if Search(r'\bsprintf\s*\(', line): + error(filename, linenum, 'runtime/printf', 5, + 'Never use sprintf. Use snprintf instead.') + match = Search(r'\b(strcpy|strcat)\s*\(', line) + if match: + error(filename, linenum, 'runtime/printf', 4, + 'Almost always, snprintf is better than %s' % match.group(1)) + + +def IsDerivedFunction(clean_lines, linenum): + """Check if current line contains an inherited function. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + Returns: + True if current line contains a function with "override" + virt-specifier. + """ + # Scan back a few lines for start of current function + for i in xrange(linenum, max(-1, linenum - 10), -1): + match = Match(r'^([^()]*\w+)\(', clean_lines.elided[i]) + if match: + # Look for "override" after the matching closing parenthesis + line, _, closing_paren = CloseExpression( + clean_lines, i, len(match.group(1))) + return (closing_paren >= 0 and + Search(r'\boverride\b', line[closing_paren:])) + return False + + +def IsOutOfLineMethodDefinition(clean_lines, linenum): + """Check if current line contains an out-of-line method definition. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + Returns: + True if current line contains an out-of-line method definition. + """ + # Scan back a few lines for start of current function + for i in xrange(linenum, max(-1, linenum - 10), -1): + if Match(r'^([^()]*\w+)\(', clean_lines.elided[i]): + return Match(r'^[^()]*\w+::\w+\(', clean_lines.elided[i]) is not None + return False + + +def IsInitializerList(clean_lines, linenum): + """Check if current line is inside constructor initializer list. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + Returns: + True if current line appears to be inside constructor initializer + list, False otherwise. + """ + for i in xrange(linenum, 1, -1): + line = clean_lines.elided[i] + if i == linenum: + remove_function_body = Match(r'^(.*)\{\s*$', line) + if remove_function_body: + line = remove_function_body.group(1) + + if Search(r'\s:\s*\w+[({]', line): + # A lone colon tend to indicate the start of a constructor + # initializer list. It could also be a ternary operator, which + # also tend to appear in constructor initializer lists as + # opposed to parameter lists. + return True + if Search(r'\}\s*,\s*$', line): + # A closing brace followed by a comma is probably the end of a + # brace-initialized member in constructor initializer list. + return True + if Search(r'[{};]\s*$', line): + # Found one of the following: + # - A closing brace or semicolon, probably the end of the previous + # function. + # - An opening brace, probably the start of current class or namespace. + # + # Current line is probably not inside an initializer list since + # we saw one of those things without seeing the starting colon. + return False + + # Got to the beginning of the file without seeing the start of + # constructor initializer list. + return False + + +def CheckForNonConstReference(filename, clean_lines, linenum, + nesting_state, error): + """Check for non-const references. + + Separate from CheckLanguage since it scans backwards from current + line, instead of scanning forward. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: The function to call with any errors found. + """ + # Do nothing if there is no '&' on current line. + line = clean_lines.elided[linenum] + if '&' not in line: + return + + # If a function is inherited, current function doesn't have much of + # a choice, so any non-const references should not be blamed on + # derived function. + if IsDerivedFunction(clean_lines, linenum): + return + + # Don't warn on out-of-line method definitions, as we would warn on the + # in-line declaration, if it isn't marked with 'override'. + if IsOutOfLineMethodDefinition(clean_lines, linenum): + return + + # Long type names may be broken across multiple lines, usually in one + # of these forms: + # LongType + # ::LongTypeContinued &identifier + # LongType:: + # LongTypeContinued &identifier + # LongType< + # ...>::LongTypeContinued &identifier + # + # If we detected a type split across two lines, join the previous + # line to current line so that we can match const references + # accordingly. + # + # Note that this only scans back one line, since scanning back + # arbitrary number of lines would be expensive. If you have a type + # that spans more than 2 lines, please use a typedef. + if linenum > 1: + previous = None + if Match(r'\s*::(?:[\w<>]|::)+\s*&\s*\S', line): + # previous_line\n + ::current_line + previous = Search(r'\b((?:const\s*)?(?:[\w<>]|::)+[\w<>])\s*$', + clean_lines.elided[linenum - 1]) + elif Match(r'\s*[a-zA-Z_]([\w<>]|::)+\s*&\s*\S', line): + # previous_line::\n + current_line + previous = Search(r'\b((?:const\s*)?(?:[\w<>]|::)+::)\s*$', + clean_lines.elided[linenum - 1]) + if previous: + line = previous.group(1) + line.lstrip() + else: + # Check for templated parameter that is split across multiple lines + endpos = line.rfind('>') + if endpos > -1: + (_, startline, startpos) = ReverseCloseExpression( + clean_lines, linenum, endpos) + if startpos > -1 and startline < linenum: + # Found the matching < on an earlier line, collect all + # pieces up to current line. + line = '' + for i in xrange(startline, linenum + 1): + line += clean_lines.elided[i].strip() + + # Check for non-const references in function parameters. A single '&' may + # found in the following places: + # inside expression: binary & for bitwise AND + # inside expression: unary & for taking the address of something + # inside declarators: reference parameter + # We will exclude the first two cases by checking that we are not inside a + # function body, including one that was just introduced by a trailing '{'. + # TODO(unknown): Doesn't account for 'catch(Exception& e)' [rare]. + if (nesting_state.previous_stack_top and + not (isinstance(nesting_state.previous_stack_top, _ClassInfo) or + isinstance(nesting_state.previous_stack_top, _NamespaceInfo))): + # Not at toplevel, not within a class, and not within a namespace + return + + # Avoid initializer lists. We only need to scan back from the + # current line for something that starts with ':'. + # + # We don't need to check the current line, since the '&' would + # appear inside the second set of parentheses on the current line as + # opposed to the first set. + if linenum > 0: + for i in xrange(linenum - 1, max(0, linenum - 10), -1): + previous_line = clean_lines.elided[i] + if not Search(r'[),]\s*$', previous_line): + break + if Match(r'^\s*:\s+\S', previous_line): + return + + # Avoid preprocessors + if Search(r'\\\s*$', line): + return + + # Avoid constructor initializer lists + if IsInitializerList(clean_lines, linenum): + return + + # We allow non-const references in a few standard places, like functions + # called "swap()" or iostream operators like "<<" or ">>". Do not check + # those function parameters. + # + # We also accept & in static_assert, which looks like a function but + # it's actually a declaration expression. + allowed_functions = (r'(?:[sS]wap(?:<\w:+>)?|' + r'operator\s*[<>][<>]|' + r'static_assert|COMPILE_ASSERT' + r')\s*\(') + if Search(allowed_functions, line): + return + elif not Search(r'\S+\([^)]*$', line): + # Don't see an allowed function on this line. Actually we + # didn't see any function name on this line, so this is likely a + # multi-line parameter list. Try a bit harder to catch this case. + for i in xrange(2): + if (linenum > i and + Search(allowed_functions, clean_lines.elided[linenum - i - 1])): + return + + decls = ReplaceAll(r'{[^}]*}', ' ', line) # exclude function body + for parameter in re.findall(_RE_PATTERN_REF_PARAM, decls): + if (not Match(_RE_PATTERN_CONST_REF_PARAM, parameter) and + not Match(_RE_PATTERN_REF_STREAM_PARAM, parameter)): + error(filename, linenum, 'runtime/references', 2, + 'Is this a non-const reference? ' + 'If so, make const or use a pointer: ' + + ReplaceAll(' *<', '<', parameter)) + + +def CheckCasts(filename, clean_lines, linenum, error): + """Various cast related checks. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Check to see if they're using an conversion function cast. + # I just try to capture the most common basic types, though there are more. + # Parameterless conversion functions, such as bool(), are allowed as they are + # probably a member operator declaration or default constructor. + match = Search( + r'(\bnew\s+(?:const\s+)?|\S<\s*(?:const\s+)?)?\b' + r'(int|float|double|bool|char|int32|uint32|int64|uint64)' + r'(\([^)].*)', line) + expecting_function = ExpectingFunctionArgs(clean_lines, linenum) + if match and not expecting_function: + matched_type = match.group(2) + + # matched_new_or_template is used to silence two false positives: + # - New operators + # - Template arguments with function types + # + # For template arguments, we match on types immediately following + # an opening bracket without any spaces. This is a fast way to + # silence the common case where the function type is the first + # template argument. False negative with less-than comparison is + # avoided because those operators are usually followed by a space. + # + # function // bracket + no space = false positive + # value < double(42) // bracket + space = true positive + matched_new_or_template = match.group(1) + + # Avoid arrays by looking for brackets that come after the closing + # parenthesis. + if Match(r'\([^()]+\)\s*\[', match.group(3)): + return + + # Other things to ignore: + # - Function pointers + # - Casts to pointer types + # - Placement new + # - Alias declarations + matched_funcptr = match.group(3) + if (matched_new_or_template is None and + not (matched_funcptr and + (Match(r'\((?:[^() ]+::\s*\*\s*)?[^() ]+\)\s*\(', + matched_funcptr) or + matched_funcptr.startswith('(*)'))) and + not Match(r'\s*using\s+\S+\s*=\s*' + matched_type, line) and + not Search(r'new\(\S+\)\s*' + matched_type, line)): + error(filename, linenum, 'readability/casting', 4, + 'Using deprecated casting style. ' + 'Use static_cast<%s>(...) instead' % + matched_type) + + if not expecting_function: + CheckCStyleCast(filename, clean_lines, linenum, 'static_cast', + r'\((int|float|double|bool|char|u?int(16|32|64))\)', error) + + # This doesn't catch all cases. Consider (const char * const)"hello". + # + # (char *) "foo" should always be a const_cast (reinterpret_cast won't + # compile). + if CheckCStyleCast(filename, clean_lines, linenum, 'const_cast', + r'\((char\s?\*+\s?)\)\s*"', error): + pass + else: + # Check pointer casts for other than string constants + CheckCStyleCast(filename, clean_lines, linenum, 'reinterpret_cast', + r'\((\w+\s?\*+\s?)\)', error) + + # In addition, we look for people taking the address of a cast. This + # is dangerous -- casts can assign to temporaries, so the pointer doesn't + # point where you think. + # + # Some non-identifier character is required before the '&' for the + # expression to be recognized as a cast. These are casts: + # expression = &static_cast(temporary()); + # function(&(int*)(temporary())); + # + # This is not a cast: + # reference_type&(int* function_param); + match = Search( + r'(?:[^\w]&\(([^)*][^)]*)\)[\w(])|' + r'(?:[^\w]&(static|dynamic|down|reinterpret)_cast\b)', line) + if match: + # Try a better error message when the & is bound to something + # dereferenced by the casted pointer, as opposed to the casted + # pointer itself. + parenthesis_error = False + match = Match(r'^(.*&(?:static|dynamic|down|reinterpret)_cast\b)<', line) + if match: + _, y1, x1 = CloseExpression(clean_lines, linenum, len(match.group(1))) + if x1 >= 0 and clean_lines.elided[y1][x1] == '(': + _, y2, x2 = CloseExpression(clean_lines, y1, x1) + if x2 >= 0: + extended_line = clean_lines.elided[y2][x2:] + if y2 < clean_lines.NumLines() - 1: + extended_line += clean_lines.elided[y2 + 1] + if Match(r'\s*(?:->|\[)', extended_line): + parenthesis_error = True + + if parenthesis_error: + error(filename, linenum, 'readability/casting', 4, + ('Are you taking an address of something dereferenced ' + 'from a cast? Wrapping the dereferenced expression in ' + 'parentheses will make the binding more obvious')) + else: + error(filename, linenum, 'runtime/casting', 4, + ('Are you taking an address of a cast? ' + 'This is dangerous: could be a temp var. ' + 'Take the address before doing the cast, rather than after')) + + +def CheckCStyleCast(filename, clean_lines, linenum, cast_type, pattern, error): + """Checks for a C-style cast by looking for the pattern. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + cast_type: The string for the C++ cast to recommend. This is either + reinterpret_cast, static_cast, or const_cast, depending. + pattern: The regular expression used to find C-style casts. + error: The function to call with any errors found. + + Returns: + True if an error was emitted. + False otherwise. + """ + line = clean_lines.elided[linenum] + match = Search(pattern, line) + if not match: + return False + + # Exclude lines with keywords that tend to look like casts + context = line[0:match.start(1) - 1] + if Match(r'.*\b(?:sizeof|alignof|alignas|[_A-Z][_A-Z0-9]*)\s*$', context): + return False + + # Try expanding current context to see if we one level of + # parentheses inside a macro. + if linenum > 0: + for i in xrange(linenum - 1, max(0, linenum - 5), -1): + context = clean_lines.elided[i] + context + if Match(r'.*\b[_A-Z][_A-Z0-9]*\s*\((?:\([^()]*\)|[^()])*$', context): + return False + + # operator++(int) and operator--(int) + if context.endswith(' operator++') or context.endswith(' operator--'): + return False + + # A single unnamed argument for a function tends to look like old style cast. + # If we see those, don't issue warnings for deprecated casts. + remainder = line[match.end(0):] + if Match(r'^\s*(?:;|const\b|throw\b|final\b|override\b|[=>{),]|->)', + remainder): + return False + + # At this point, all that should be left is actual casts. + error(filename, linenum, 'readability/casting', 4, + 'Using C-style cast. Use %s<%s>(...) instead' % + (cast_type, match.group(1))) + + return True + + +def ExpectingFunctionArgs(clean_lines, linenum): + """Checks whether where function type arguments are expected. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + + Returns: + True if the line at 'linenum' is inside something that expects arguments + of function types. + """ + line = clean_lines.elided[linenum] + return (Match(r'^\s*MOCK_(CONST_)?METHOD\d+(_T)?\(', line) or + (linenum >= 2 and + (Match(r'^\s*MOCK_(?:CONST_)?METHOD\d+(?:_T)?\((?:\S+,)?\s*$', + clean_lines.elided[linenum - 1]) or + Match(r'^\s*MOCK_(?:CONST_)?METHOD\d+(?:_T)?\(\s*$', + clean_lines.elided[linenum - 2]) or + Search(r'\bstd::m?function\s*\<\s*$', + clean_lines.elided[linenum - 1])))) + + +_HEADERS_CONTAINING_TEMPLATES = ( + ('', ('deque',)), + ('', ('unary_function', 'binary_function', + 'plus', 'minus', 'multiplies', 'divides', 'modulus', + 'negate', + 'equal_to', 'not_equal_to', 'greater', 'less', + 'greater_equal', 'less_equal', + 'logical_and', 'logical_or', 'logical_not', + 'unary_negate', 'not1', 'binary_negate', 'not2', + 'bind1st', 'bind2nd', + 'pointer_to_unary_function', + 'pointer_to_binary_function', + 'ptr_fun', + 'mem_fun_t', 'mem_fun', 'mem_fun1_t', 'mem_fun1_ref_t', + 'mem_fun_ref_t', + 'const_mem_fun_t', 'const_mem_fun1_t', + 'const_mem_fun_ref_t', 'const_mem_fun1_ref_t', + 'mem_fun_ref', + )), + ('', ('numeric_limits',)), + ('', ('list',)), + ('', ('map', 'multimap',)), + ('', ('allocator', 'make_shared', 'make_unique', 'shared_ptr', + 'unique_ptr', 'weak_ptr')), + ('', ('queue', 'priority_queue',)), + ('', ('set', 'multiset',)), + ('', ('stack',)), + ('', ('char_traits', 'basic_string',)), + ('', ('tuple',)), + ('', ('unordered_map', 'unordered_multimap')), + ('', ('unordered_set', 'unordered_multiset')), + ('', ('pair',)), + ('', ('vector',)), + + # gcc extensions. + # Note: std::hash is their hash, ::hash is our hash + ('', ('hash_map', 'hash_multimap',)), + ('', ('hash_set', 'hash_multiset',)), + ('', ('slist',)), + ) + +_HEADERS_MAYBE_TEMPLATES = ( + ('', ('copy', 'max', 'min', 'min_element', 'sort', + 'transform', + )), + ('', ('forward', 'make_pair', 'move', 'swap')), + ) + +_RE_PATTERN_STRING = re.compile(r'\bstring\b') + +_re_pattern_headers_maybe_templates = [] +for _header, _templates in _HEADERS_MAYBE_TEMPLATES: + for _template in _templates: + # Match max(..., ...), max(..., ...), but not foo->max, foo.max or + # type::max(). + _re_pattern_headers_maybe_templates.append( + (re.compile(r'[^>.]\b' + _template + r'(<.*?>)?\([^\)]'), + _template, + _header)) + +# Other scripts may reach in and modify this pattern. +_re_pattern_templates = [] +for _header, _templates in _HEADERS_CONTAINING_TEMPLATES: + for _template in _templates: + _re_pattern_templates.append( + (re.compile(r'(\<|\b)' + _template + r'\s*\<'), + _template + '<>', + _header)) + + +def FilesBelongToSameModule(filename_cc, filename_h): + """Check if these two filenames belong to the same module. + + The concept of a 'module' here is a as follows: + foo.h, foo-inl.h, foo.cc, foo_test.cc and foo_unittest.cc belong to the + same 'module' if they are in the same directory. + some/path/public/xyzzy and some/path/internal/xyzzy are also considered + to belong to the same module here. + + If the filename_cc contains a longer path than the filename_h, for example, + '/absolute/path/to/base/sysinfo.cc', and this file would include + 'base/sysinfo.h', this function also produces the prefix needed to open the + header. This is used by the caller of this function to more robustly open the + header file. We don't have access to the real include paths in this context, + so we need this guesswork here. + + Known bugs: tools/base/bar.cc and base/bar.h belong to the same module + according to this implementation. Because of this, this function gives + some false positives. This should be sufficiently rare in practice. + + Args: + filename_cc: is the path for the .cc file + filename_h: is the path for the header path + + Returns: + Tuple with a bool and a string: + bool: True if filename_cc and filename_h belong to the same module. + string: the additional prefix needed to open the header file. + """ + + fileinfo = FileInfo(filename_cc) + if not fileinfo.IsSource(): + return (False, '') + filename_cc = filename_cc[:-len(fileinfo.Extension())] + matched_test_suffix = Search(_TEST_FILE_SUFFIX, fileinfo.BaseName()) + if matched_test_suffix: + filename_cc = filename_cc[:-len(matched_test_suffix.group(1))] + filename_cc = filename_cc.replace('/public/', '/') + filename_cc = filename_cc.replace('/internal/', '/') + + if not filename_h.endswith('.h'): + return (False, '') + filename_h = filename_h[:-len('.h')] + if filename_h.endswith('-inl'): + filename_h = filename_h[:-len('-inl')] + filename_h = filename_h.replace('/public/', '/') + filename_h = filename_h.replace('/internal/', '/') + + files_belong_to_same_module = filename_cc.endswith(filename_h) + common_path = '' + if files_belong_to_same_module: + common_path = filename_cc[:-len(filename_h)] + return files_belong_to_same_module, common_path + + +def UpdateIncludeState(filename, include_dict, io=codecs): + """Fill up the include_dict with new includes found from the file. + + Args: + filename: the name of the header to read. + include_dict: a dictionary in which the headers are inserted. + io: The io factory to use to read the file. Provided for testability. + + Returns: + True if a header was successfully added. False otherwise. + """ + headerfile = None + try: + headerfile = io.open(filename, 'r', 'utf8', 'replace') + except IOError: + return False + linenum = 0 + for line in headerfile: + linenum += 1 + clean_line = CleanseComments(line) + match = _RE_PATTERN_INCLUDE.search(clean_line) + if match: + include = match.group(2) + include_dict.setdefault(include, linenum) + return True + + +def CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error, + io=codecs): + """Reports for missing stl includes. + + This function will output warnings to make sure you are including the headers + necessary for the stl containers and functions that you use. We only give one + reason to include a header. For example, if you use both equal_to<> and + less<> in a .h file, only one (the latter in the file) of these will be + reported as a reason to include the . + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + include_state: An _IncludeState instance. + error: The function to call with any errors found. + io: The IO factory to use to read the header file. Provided for unittest + injection. + """ + required = {} # A map of header name to linenumber and the template entity. + # Example of required: { '': (1219, 'less<>') } + + for linenum in xrange(clean_lines.NumLines()): + line = clean_lines.elided[linenum] + if not line or line[0] == '#': + continue + + # String is special -- it is a non-templatized type in STL. + matched = _RE_PATTERN_STRING.search(line) + if matched: + # Don't warn about strings in non-STL namespaces: + # (We check only the first match per line; good enough.) + prefix = line[:matched.start()] + if prefix.endswith('std::') or not prefix.endswith('::'): + required[''] = (linenum, 'string') + + for pattern, template, header in _re_pattern_headers_maybe_templates: + if pattern.search(line): + required[header] = (linenum, template) + + # The following function is just a speed up, no semantics are changed. + if not '<' in line: # Reduces the cpu time usage by skipping lines. + continue + + for pattern, template, header in _re_pattern_templates: + matched = pattern.search(line) + if matched: + # Don't warn about IWYU in non-STL namespaces: + # (We check only the first match per line; good enough.) + prefix = line[:matched.start()] + if prefix.endswith('std::') or not prefix.endswith('::'): + required[header] = (linenum, template) + + # The policy is that if you #include something in foo.h you don't need to + # include it again in foo.cc. Here, we will look at possible includes. + # Let's flatten the include_state include_list and copy it into a dictionary. + include_dict = dict([item for sublist in include_state.include_list + for item in sublist]) + + # Did we find the header for this file (if any) and successfully load it? + header_found = False + + # Use the absolute path so that matching works properly. + abs_filename = FileInfo(filename).FullName() + + # For Emacs's flymake. + # If cpplint is invoked from Emacs's flymake, a temporary file is generated + # by flymake and that file name might end with '_flymake.cc'. In that case, + # restore original file name here so that the corresponding header file can be + # found. + # e.g. If the file name is 'foo_flymake.cc', we should search for 'foo.h' + # instead of 'foo_flymake.h' + abs_filename = re.sub(r'_flymake\.cc$', '.cc', abs_filename) + + # include_dict is modified during iteration, so we iterate over a copy of + # the keys. + header_keys = include_dict.keys() + for header in header_keys: + (same_module, common_path) = FilesBelongToSameModule(abs_filename, header) + fullpath = common_path + header + if same_module and UpdateIncludeState(fullpath, include_dict, io): + header_found = True + + # If we can't find the header file for a .cc, assume it's because we don't + # know where to look. In that case we'll give up as we're not sure they + # didn't include it in the .h file. + # TODO(unknown): Do a better job of finding .h files so we are confident that + # not having the .h file means there isn't one. + if filename.endswith('.cc') and not header_found: + return + + # All the lines have been processed, report the errors found. + for required_header_unstripped in required: + template = required[required_header_unstripped][1] + if required_header_unstripped.strip('<>"') not in include_dict: + error(filename, required[required_header_unstripped][0], + 'build/include_what_you_use', 4, + 'Add #include ' + required_header_unstripped + ' for ' + template) + + +_RE_PATTERN_EXPLICIT_MAKEPAIR = re.compile(r'\bmake_pair\s*<') + + +def CheckMakePairUsesDeduction(filename, clean_lines, linenum, error): + """Check that make_pair's template arguments are deduced. + + G++ 4.6 in C++11 mode fails badly if make_pair's template arguments are + specified explicitly, and such use isn't intended in any case. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + match = _RE_PATTERN_EXPLICIT_MAKEPAIR.search(line) + if match: + error(filename, linenum, 'build/explicit_make_pair', + 4, # 4 = high confidence + 'For C++11-compatibility, omit template arguments from make_pair' + ' OR use pair directly OR if appropriate, construct a pair directly') + + +def CheckRedundantVirtual(filename, clean_lines, linenum, error): + """Check if line contains a redundant "virtual" function-specifier. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + # Look for "virtual" on current line. + line = clean_lines.elided[linenum] + virtual = Match(r'^(.*)(\bvirtual\b)(.*)$', line) + if not virtual: return + + # Ignore "virtual" keywords that are near access-specifiers. These + # are only used in class base-specifier and do not apply to member + # functions. + if (Search(r'\b(public|protected|private)\s+$', virtual.group(1)) or + Match(r'^\s+(public|protected|private)\b', virtual.group(3))): + return + + # Ignore the "virtual" keyword from virtual base classes. Usually + # there is a column on the same line in these cases (virtual base + # classes are rare in google3 because multiple inheritance is rare). + if Match(r'^.*[^:]:[^:].*$', line): return + + # Look for the next opening parenthesis. This is the start of the + # parameter list (possibly on the next line shortly after virtual). + # TODO(unknown): doesn't work if there are virtual functions with + # decltype() or other things that use parentheses, but csearch suggests + # that this is rare. + end_col = -1 + end_line = -1 + start_col = len(virtual.group(2)) + for start_line in xrange(linenum, min(linenum + 3, clean_lines.NumLines())): + line = clean_lines.elided[start_line][start_col:] + parameter_list = Match(r'^([^(]*)\(', line) + if parameter_list: + # Match parentheses to find the end of the parameter list + (_, end_line, end_col) = CloseExpression( + clean_lines, start_line, start_col + len(parameter_list.group(1))) + break + start_col = 0 + + if end_col < 0: + return # Couldn't find end of parameter list, give up + + # Look for "override" or "final" after the parameter list + # (possibly on the next few lines). + for i in xrange(end_line, min(end_line + 3, clean_lines.NumLines())): + line = clean_lines.elided[i][end_col:] + match = Search(r'\b(override|final)\b', line) + if match: + error(filename, linenum, 'readability/inheritance', 4, + ('"virtual" is redundant since function is ' + 'already declared as "%s"' % match.group(1))) + + # Set end_col to check whole lines after we are done with the + # first line. + end_col = 0 + if Search(r'[^\w]\s*$', line): + break + + +def CheckRedundantOverrideOrFinal(filename, clean_lines, linenum, error): + """Check if line contains a redundant "override" or "final" virt-specifier. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + # Look for closing parenthesis nearby. We need one to confirm where + # the declarator ends and where the virt-specifier starts to avoid + # false positives. + line = clean_lines.elided[linenum] + declarator_end = line.rfind(')') + if declarator_end >= 0: + fragment = line[declarator_end:] + else: + if linenum > 1 and clean_lines.elided[linenum - 1].rfind(')') >= 0: + fragment = line + else: + return + + # Check that at most one of "override" or "final" is present, not both + if Search(r'\boverride\b', fragment) and Search(r'\bfinal\b', fragment): + error(filename, linenum, 'readability/inheritance', 4, + ('"override" is redundant since function is ' + 'already declared as "final"')) + + + + +# Returns true if we are at a new block, and it is directly +# inside of a namespace. +def IsBlockInNameSpace(nesting_state, is_forward_declaration): + """Checks that the new block is directly in a namespace. + + Args: + nesting_state: The _NestingState object that contains info about our state. + is_forward_declaration: If the class is a forward declared class. + Returns: + Whether or not the new block is directly in a namespace. + """ + if is_forward_declaration: + if len(nesting_state.stack) >= 1 and ( + isinstance(nesting_state.stack[-1], _NamespaceInfo)): + return True + else: + return False + + return (len(nesting_state.stack) > 1 and + nesting_state.stack[-1].check_namespace_indentation and + isinstance(nesting_state.stack[-2], _NamespaceInfo)) + + +def ShouldCheckNamespaceIndentation(nesting_state, is_namespace_indent_item, + raw_lines_no_comments, linenum): + """This method determines if we should apply our namespace indentation check. + + Args: + nesting_state: The current nesting state. + is_namespace_indent_item: If we just put a new class on the stack, True. + If the top of the stack is not a class, or we did not recently + add the class, False. + raw_lines_no_comments: The lines without the comments. + linenum: The current line number we are processing. + + Returns: + True if we should apply our namespace indentation check. Currently, it + only works for classes and namespaces inside of a namespace. + """ + + is_forward_declaration = IsForwardClassDeclaration(raw_lines_no_comments, + linenum) + + if not (is_namespace_indent_item or is_forward_declaration): + return False + + # If we are in a macro, we do not want to check the namespace indentation. + if IsMacroDefinition(raw_lines_no_comments, linenum): + return False + + return IsBlockInNameSpace(nesting_state, is_forward_declaration) + + +# Call this method if the line is directly inside of a namespace. +# If the line above is blank (excluding comments) or the start of +# an inner namespace, it cannot be indented. +def CheckItemIndentationInNamespace(filename, raw_lines_no_comments, linenum, + error): + line = raw_lines_no_comments[linenum] + if Match(r'^\s+', line): + error(filename, linenum, 'runtime/indentation_namespace', 4, + 'Do not indent within a namespace') + + +def ProcessLine(filename, file_extension, clean_lines, line, + include_state, function_state, nesting_state, error, + extra_check_functions=[]): + """Processes a single line in the file. + + Args: + filename: Filename of the file that is being processed. + file_extension: The extension (dot not included) of the file. + clean_lines: An array of strings, each representing a line of the file, + with comments stripped. + line: Number of line being processed. + include_state: An _IncludeState instance in which the headers are inserted. + function_state: A _FunctionState instance which counts function lines, etc. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: A callable to which errors are reported, which takes 4 arguments: + filename, line number, error level, and message + extra_check_functions: An array of additional check functions that will be + run on each source line. Each function takes 4 + arguments: filename, clean_lines, line, error + """ + raw_lines = clean_lines.raw_lines + ParseNolintSuppressions(filename, raw_lines[line], line, error) + nesting_state.Update(filename, clean_lines, line, error) + CheckForNamespaceIndentation(filename, nesting_state, clean_lines, line, + error) + if nesting_state.InAsmBlock(): return + CheckForFunctionLengths(filename, clean_lines, line, function_state, error) + CheckForMultilineCommentsAndStrings(filename, clean_lines, line, error) + CheckStyle(filename, clean_lines, line, file_extension, nesting_state, error) + CheckLanguage(filename, clean_lines, line, file_extension, include_state, + nesting_state, error) + CheckForNonConstReference(filename, clean_lines, line, nesting_state, error) + CheckForNonStandardConstructs(filename, clean_lines, line, + nesting_state, error) + CheckVlogArguments(filename, clean_lines, line, error) + CheckPosixThreading(filename, clean_lines, line, error) + CheckInvalidIncrement(filename, clean_lines, line, error) + CheckMakePairUsesDeduction(filename, clean_lines, line, error) + CheckRedundantVirtual(filename, clean_lines, line, error) + CheckRedundantOverrideOrFinal(filename, clean_lines, line, error) + for check_fn in extra_check_functions: + check_fn(filename, clean_lines, line, error) + +def FlagCxx11Features(filename, clean_lines, linenum, error): + """Flag those c++11 features that we only allow in certain places. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + include = Match(r'\s*#\s*include\s+[<"]([^<"]+)[">]', line) + + # Flag unapproved C++ TR1 headers. + if include and include.group(1).startswith('tr1/'): + error(filename, linenum, 'build/c++tr1', 5, + ('C++ TR1 headers such as <%s> are unapproved.') % include.group(1)) + + # Flag unapproved C++11 headers. + if include and include.group(1) in ('cfenv', + 'condition_variable', + 'fenv.h', + 'future', + 'mutex', + 'thread', + 'chrono', + 'ratio', + 'regex', + 'system_error', + ): + error(filename, linenum, 'build/c++11', 5, + ('<%s> is an unapproved C++11 header.') % include.group(1)) + + # The only place where we need to worry about C++11 keywords and library + # features in preprocessor directives is in macro definitions. + if Match(r'\s*#', line) and not Match(r'\s*#\s*define\b', line): return + + # These are classes and free functions. The classes are always + # mentioned as std::*, but we only catch the free functions if + # they're not found by ADL. They're alphabetical by header. + for top_name in ( + # type_traits + 'alignment_of', + 'aligned_union', + ): + if Search(r'\bstd::%s\b' % top_name, line): + error(filename, linenum, 'build/c++11', 5, + ('std::%s is an unapproved C++11 class or function. Send c-style ' + 'an example of where it would make your code more readable, and ' + 'they may let you use it.') % top_name) + + +def FlagCxx14Features(filename, clean_lines, linenum, error): + """Flag those C++14 features that we restrict. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + include = Match(r'\s*#\s*include\s+[<"]([^<"]+)[">]', line) + + # Flag unapproved C++14 headers. + if include and include.group(1) in ('scoped_allocator', 'shared_mutex'): + error(filename, linenum, 'build/c++14', 5, + ('<%s> is an unapproved C++14 header.') % include.group(1)) + + +def ProcessFileData(filename, file_extension, lines, error, + extra_check_functions=[]): + """Performs lint checks and reports any errors to the given error function. + + Args: + filename: Filename of the file that is being processed. + file_extension: The extension (dot not included) of the file. + lines: An array of strings, each representing a line of the file, with the + last element being empty if the file is terminated with a newline. + error: A callable to which errors are reported, which takes 4 arguments: + filename, line number, error level, and message + extra_check_functions: An array of additional check functions that will be + run on each source line. Each function takes 4 + arguments: filename, clean_lines, line, error + """ + lines = (['// marker so line numbers and indices both start at 1'] + lines + + ['// marker so line numbers end in a known way']) + + include_state = _IncludeState() + function_state = _FunctionState() + nesting_state = NestingState() + + ResetNolintSuppressions() + + CheckForCopyright(filename, lines, error) + ProcessGlobalSuppresions(lines) + RemoveMultiLineComments(filename, lines, error) + clean_lines = CleansedLines(lines) + + if IsHeaderExtension(file_extension): + CheckForHeaderGuard(filename, clean_lines, error) + + for line in xrange(clean_lines.NumLines()): + ProcessLine(filename, file_extension, clean_lines, line, + include_state, function_state, nesting_state, error, + extra_check_functions) + FlagCxx11Features(filename, clean_lines, line, error) + nesting_state.CheckCompletedBlocks(filename, error) + + CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error) + + # Check that the .cc file has included its header if it exists. + if _IsSourceExtension(file_extension): + CheckHeaderFileIncluded(filename, include_state, error) + + # We check here rather than inside ProcessLine so that we see raw + # lines rather than "cleaned" lines. + CheckForBadCharacters(filename, lines, error) + + CheckForNewlineAtEOF(filename, lines, error) + +def ProcessConfigOverrides(filename): + """ Loads the configuration files and processes the config overrides. + + Args: + filename: The name of the file being processed by the linter. + + Returns: + False if the current |filename| should not be processed further. + """ + + abs_filename = os.path.abspath(filename) + cfg_filters = [] + keep_looking = True + while keep_looking: + abs_path, base_name = os.path.split(abs_filename) + if not base_name: + break # Reached the root directory. + + cfg_file = os.path.join(abs_path, "CPPLINT.cfg") + abs_filename = abs_path + if not os.path.isfile(cfg_file): + continue + + try: + with open(cfg_file) as file_handle: + for line in file_handle: + line, _, _ = line.partition('#') # Remove comments. + if not line.strip(): + continue + + name, _, val = line.partition('=') + name = name.strip() + val = val.strip() + if name == 'set noparent': + keep_looking = False + elif name == 'filter': + cfg_filters.append(val) + elif name == 'exclude_files': + # When matching exclude_files pattern, use the base_name of + # the current file name or the directory name we are processing. + # For example, if we are checking for lint errors in /foo/bar/baz.cc + # and we found the .cfg file at /foo/CPPLINT.cfg, then the config + # file's "exclude_files" filter is meant to be checked against "bar" + # and not "baz" nor "bar/baz.cc". + if base_name: + pattern = re.compile(val) + if pattern.match(base_name): + if _cpplint_state.quiet: + # Suppress "Ignoring file" warning when using --quiet. + return False + sys.stderr.write('Ignoring "%s": file excluded by "%s". ' + 'File path component "%s" matches ' + 'pattern "%s"\n' % + (filename, cfg_file, base_name, val)) + return False + elif name == 'linelength': + global _line_length + try: + _line_length = int(val) + except ValueError: + sys.stderr.write('Line length must be numeric.') + elif name == 'root': + global _root + # root directories are specified relative to CPPLINT.cfg dir. + _root = os.path.join(os.path.dirname(cfg_file), val) + elif name == 'headers': + ProcessHppHeadersOption(val) + else: + sys.stderr.write( + 'Invalid configuration option (%s) in file %s\n' % + (name, cfg_file)) + + except IOError: + sys.stderr.write( + "Skipping config file '%s': Can't open for reading\n" % cfg_file) + keep_looking = False + + # Apply all the accumulated filters in reverse order (top-level directory + # config options having the least priority). + for filter in reversed(cfg_filters): + _AddFilters(filter) + + return True + + +def ProcessFile(filename, vlevel, extra_check_functions=[]): + """Does google-lint on a single file. + + Args: + filename: The name of the file to parse. + + vlevel: The level of errors to report. Every error of confidence + >= verbose_level will be reported. 0 is a good default. + + extra_check_functions: An array of additional check functions that will be + run on each source line. Each function takes 4 + arguments: filename, clean_lines, line, error + """ + + _SetVerboseLevel(vlevel) + _BackupFilters() + old_errors = _cpplint_state.error_count + + if not ProcessConfigOverrides(filename): + _RestoreFilters() + return + + lf_lines = [] + crlf_lines = [] + try: + # Support the UNIX convention of using "-" for stdin. Note that + # we are not opening the file with universal newline support + # (which codecs doesn't support anyway), so the resulting lines do + # contain trailing '\r' characters if we are reading a file that + # has CRLF endings. + # If after the split a trailing '\r' is present, it is removed + # below. + if filename == '-': + lines = codecs.StreamReaderWriter(sys.stdin, + codecs.getreader('utf8'), + codecs.getwriter('utf8'), + 'replace').read().split('\n') + else: + lines = codecs.open(filename, 'r', 'utf8', 'replace').read().split('\n') + + # Remove trailing '\r'. + # The -1 accounts for the extra trailing blank line we get from split() + for linenum in range(len(lines) - 1): + if lines[linenum].endswith('\r'): + lines[linenum] = lines[linenum].rstrip('\r') + crlf_lines.append(linenum + 1) + else: + lf_lines.append(linenum + 1) + + except IOError: + sys.stderr.write( + "Skipping input '%s': Can't open for reading\n" % filename) + _RestoreFilters() + return + + # Note, if no dot is found, this will give the entire filename as the ext. + file_extension = filename[filename.rfind('.') + 1:] + + # When reading from stdin, the extension is unknown, so no cpplint tests + # should rely on the extension. + if filename != '-' and file_extension not in _valid_extensions: + sys.stderr.write('Ignoring %s; not a valid file name ' + '(%s)\n' % (filename, ', '.join(_valid_extensions))) + else: + ProcessFileData(filename, file_extension, lines, Error, + extra_check_functions) + + # If end-of-line sequences are a mix of LF and CR-LF, issue + # warnings on the lines with CR. + # + # Don't issue any warnings if all lines are uniformly LF or CR-LF, + # since critique can handle these just fine, and the style guide + # doesn't dictate a particular end of line sequence. + # + # We can't depend on os.linesep to determine what the desired + # end-of-line sequence should be, since that will return the + # server-side end-of-line sequence. + if lf_lines and crlf_lines: + # Warn on every line with CR. An alternative approach might be to + # check whether the file is mostly CRLF or just LF, and warn on the + # minority, we bias toward LF here since most tools prefer LF. + for linenum in crlf_lines: + Error(filename, linenum, 'whitespace/newline', 1, + 'Unexpected \\r (^M) found; better to use only \\n') + + # Suppress printing anything if --quiet was passed unless the error + # count has increased after processing this file. + if not _cpplint_state.quiet or old_errors != _cpplint_state.error_count: + sys.stdout.write('Done processing %s\n' % filename) + _RestoreFilters() + + +def PrintUsage(message): + """Prints a brief usage string and exits, optionally with an error message. + + Args: + message: The optional error message. + """ + sys.stderr.write(_USAGE) + if message: + sys.exit('\nFATAL ERROR: ' + message) + else: + sys.exit(1) + + +def PrintCategories(): + """Prints a list of all the error-categories used by error messages. + + These are the categories used to filter messages via --filter. + """ + sys.stderr.write(''.join(' %s\n' % cat for cat in _ERROR_CATEGORIES)) + sys.exit(0) + + +def ParseArguments(args): + """Parses the command line arguments. + + This may set the output format and verbosity level as side-effects. + + Args: + args: The command line arguments: + + Returns: + The list of filenames to lint. + """ + try: + (opts, filenames) = getopt.getopt(args, '', ['help', 'output=', 'verbose=', + 'counting=', + 'filter=', + 'root=', + 'linelength=', + 'extensions=', + 'headers=', + 'quiet']) + except getopt.GetoptError: + PrintUsage('Invalid arguments.') + + verbosity = _VerboseLevel() + output_format = _OutputFormat() + filters = '' + quiet = _Quiet() + counting_style = '' + + for (opt, val) in opts: + if opt == '--help': + PrintUsage(None) + elif opt == '--output': + if val not in ('emacs', 'vs7', 'eclipse'): + PrintUsage('The only allowed output formats are emacs, vs7 and eclipse.') + output_format = val + elif opt == '--quiet': + quiet = True + elif opt == '--verbose': + verbosity = int(val) + elif opt == '--filter': + filters = val + if not filters: + PrintCategories() + elif opt == '--counting': + if val not in ('total', 'toplevel', 'detailed'): + PrintUsage('Valid counting options are total, toplevel, and detailed') + counting_style = val + elif opt == '--root': + global _root + _root = val + elif opt == '--linelength': + global _line_length + try: + _line_length = int(val) + except ValueError: + PrintUsage('Line length must be digits.') + elif opt == '--extensions': + global _valid_extensions + try: + _valid_extensions = set(val.split(',')) + except ValueError: + PrintUsage('Extensions must be comma separated list.') + elif opt == '--headers': + ProcessHppHeadersOption(val) + + if not filenames: + PrintUsage('No files were specified.') + + _SetOutputFormat(output_format) + _SetQuiet(quiet) + _SetVerboseLevel(verbosity) + _SetFilters(filters) + _SetCountingStyle(counting_style) + + return filenames + + +def main(): + filenames = ParseArguments(sys.argv[1:]) + + # Change stderr to write with replacement characters so we don't die + # if we try to print something containing non-ASCII characters. + sys.stderr = codecs.StreamReaderWriter(sys.stderr, + codecs.getreader('utf8'), + codecs.getwriter('utf8'), + 'replace') + + _cpplint_state.ResetErrorCounts() + for filename in filenames: + ProcessFile(filename, _cpplint_state.verbose_level) + # If --quiet is passed, suppress printing error count unless there are errors. + if not _cpplint_state.quiet or _cpplint_state.error_count > 0: + _cpplint_state.PrintErrorCounts() + + sys.exit(_cpplint_state.error_count > 0) + + +if __name__ == '__main__': + main() diff --git a/third_party/aom/tools/diff.py b/third_party/aom/tools/diff.py new file mode 100644 index 0000000000..7bb6b7fcb4 --- /dev/null +++ b/third_party/aom/tools/diff.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +## +## Copyright (c) 2016, Alliance for Open Media. All rights reserved +## +## This source code is subject to the terms of the BSD 2 Clause License and +## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License +## was not distributed with this source code in the LICENSE file, you can +## obtain it at www.aomedia.org/license/software. If the Alliance for Open +## Media Patent License 1.0 was not distributed with this source code in the +## PATENTS file, you can obtain it at www.aomedia.org/license/patent. +## +"""Classes for representing diff pieces.""" + +__author__ = "jkoleszar@google.com" + +import re + + +class DiffLines(object): + """A container for one half of a diff.""" + + def __init__(self, filename, offset, length): + self.filename = filename + self.offset = offset + self.length = length + self.lines = [] + self.delta_line_nums = [] + + def Append(self, line): + l = len(self.lines) + if line[0] != " ": + self.delta_line_nums.append(self.offset + l) + self.lines.append(line[1:]) + assert l+1 <= self.length + + def Complete(self): + return len(self.lines) == self.length + + def __contains__(self, item): + return item >= self.offset and item <= self.offset + self.length - 1 + + +class DiffHunk(object): + """A container for one diff hunk, consisting of two DiffLines.""" + + def __init__(self, header, file_a, file_b, start_a, len_a, start_b, len_b): + self.header = header + self.left = DiffLines(file_a, start_a, len_a) + self.right = DiffLines(file_b, start_b, len_b) + self.lines = [] + + def Append(self, line): + """Adds a line to the DiffHunk and its DiffLines children.""" + if line[0] == "-": + self.left.Append(line) + elif line[0] == "+": + self.right.Append(line) + elif line[0] == " ": + self.left.Append(line) + self.right.Append(line) + elif line[0] == "\\": + # Ignore newline messages from git diff. + pass + else: + assert False, ("Unrecognized character at start of diff line " + "%r" % line[0]) + self.lines.append(line) + + def Complete(self): + return self.left.Complete() and self.right.Complete() + + def __repr__(self): + return "DiffHunk(%s, %s, len %d)" % ( + self.left.filename, self.right.filename, + max(self.left.length, self.right.length)) + + +def ParseDiffHunks(stream): + """Walk a file-like object, yielding DiffHunks as they're parsed.""" + + file_regex = re.compile(r"(\+\+\+|---) (\S+)") + range_regex = re.compile(r"@@ -(\d+)(,(\d+))? \+(\d+)(,(\d+))?") + hunk = None + while True: + line = stream.readline() + if not line: + break + + if hunk is None: + # Parse file names + diff_file = file_regex.match(line) + if diff_file: + if line.startswith("---"): + a_line = line + a = diff_file.group(2) + continue + if line.startswith("+++"): + b_line = line + b = diff_file.group(2) + continue + + # Parse offset/lengths + diffrange = range_regex.match(line) + if diffrange: + if diffrange.group(2): + start_a = int(diffrange.group(1)) + len_a = int(diffrange.group(3)) + else: + start_a = 1 + len_a = int(diffrange.group(1)) + + if diffrange.group(5): + start_b = int(diffrange.group(4)) + len_b = int(diffrange.group(6)) + else: + start_b = 1 + len_b = int(diffrange.group(4)) + + header = [a_line, b_line, line] + hunk = DiffHunk(header, a, b, start_a, len_a, start_b, len_b) + else: + # Add the current line to the hunk + hunk.Append(line) + + # See if the whole hunk has been parsed. If so, yield it and prepare + # for the next hunk. + if hunk.Complete(): + yield hunk + hunk = None + + # Partial hunks are a parse error + assert hunk is None diff --git a/third_party/aom/tools/dump_obu.cc b/third_party/aom/tools/dump_obu.cc new file mode 100644 index 0000000000..b9ff985c44 --- /dev/null +++ b/third_party/aom/tools/dump_obu.cc @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2017, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include + +#include +#include + +#include "config/aom_config.h" + +#include "common/ivfdec.h" +#include "common/obudec.h" +#include "common/tools_common.h" +#include "common/webmdec.h" +#include "tools/obu_parser.h" + +namespace { + +const size_t kInitialBufferSize = 100 * 1024; + +struct InputContext { + InputContext() = default; + ~InputContext() { free(unit_buffer); } + + void Init() { + memset(avx_ctx, 0, sizeof(*avx_ctx)); + memset(obu_ctx, 0, sizeof(*obu_ctx)); + obu_ctx->avx_ctx = avx_ctx; +#if CONFIG_WEBM_IO + memset(webm_ctx, 0, sizeof(*webm_ctx)); +#endif + } + + AvxInputContext *avx_ctx = nullptr; + ObuDecInputContext *obu_ctx = nullptr; +#if CONFIG_WEBM_IO + WebmInputContext *webm_ctx = nullptr; +#endif + uint8_t *unit_buffer = nullptr; + size_t unit_buffer_size = 0; +}; + +void PrintUsage() { + printf("Libaom OBU dump.\nUsage: dump_obu \n"); +} + +VideoFileType GetFileType(InputContext *ctx) { + // TODO(https://crbug.com/aomedia/1706): webm type does not support reading + // from stdin yet, and file_is_webm is not using the detect buffer when + // determining the type. Therefore it should only be checked when using a file + // and needs to be checked prior to other types. +#if CONFIG_WEBM_IO + if (file_is_webm(ctx->webm_ctx, ctx->avx_ctx)) return FILE_TYPE_WEBM; +#endif + if (file_is_ivf(ctx->avx_ctx)) return FILE_TYPE_IVF; + if (file_is_obu(ctx->obu_ctx)) return FILE_TYPE_OBU; + return FILE_TYPE_RAW; +} + +bool ReadTemporalUnit(InputContext *ctx, size_t *unit_size) { + const VideoFileType file_type = ctx->avx_ctx->file_type; + switch (file_type) { + case FILE_TYPE_IVF: { + if (ivf_read_frame(ctx->avx_ctx, &ctx->unit_buffer, unit_size, + &ctx->unit_buffer_size, NULL)) { + return false; + } + break; + } + case FILE_TYPE_OBU: { + if (obudec_read_temporal_unit(ctx->obu_ctx, &ctx->unit_buffer, unit_size, + &ctx->unit_buffer_size)) { + return false; + } + break; + } +#if CONFIG_WEBM_IO + case FILE_TYPE_WEBM: { + if (webm_read_frame(ctx->webm_ctx, &ctx->unit_buffer, unit_size, + &ctx->unit_buffer_size)) { + return false; + } + break; + } +#endif + default: + // TODO(tomfinegan): Abuse FILE_TYPE_RAW for AV1/OBU elementary streams? + fprintf(stderr, "Error: Unsupported file type.\n"); + return false; + } + + return true; +} + +} // namespace + +int main(int argc, const char *argv[]) { + // TODO(tomfinegan): Could do with some params for verbosity. + if (argc < 2) { + PrintUsage(); + return EXIT_SUCCESS; + } + + const std::string filename = argv[1]; + + using FilePtr = std::unique_ptr; + FilePtr input_file(fopen(filename.c_str(), "rb"), &fclose); + if (input_file.get() == nullptr) { + input_file.release(); + fprintf(stderr, "Error: Cannot open input file.\n"); + return EXIT_FAILURE; + } + + AvxInputContext avx_ctx; + InputContext input_ctx; + input_ctx.avx_ctx = &avx_ctx; + ObuDecInputContext obu_ctx; + input_ctx.obu_ctx = &obu_ctx; +#if CONFIG_WEBM_IO + WebmInputContext webm_ctx; + input_ctx.webm_ctx = &webm_ctx; +#endif + + input_ctx.Init(); + avx_ctx.file = input_file.get(); + avx_ctx.file_type = GetFileType(&input_ctx); + + // Note: the reader utilities will realloc the buffer using realloc() etc. + // Can't have nice things like unique_ptr wrappers with that type of + // behavior underneath the function calls. + input_ctx.unit_buffer = + reinterpret_cast(calloc(kInitialBufferSize, 1)); + if (!input_ctx.unit_buffer) { + fprintf(stderr, "Error: No memory, can't alloc input buffer.\n"); + return EXIT_FAILURE; + } + input_ctx.unit_buffer_size = kInitialBufferSize; + + size_t unit_size = 0; + int unit_number = 0; + int64_t obu_overhead_bytes_total = 0; + while (ReadTemporalUnit(&input_ctx, &unit_size)) { + printf("Temporal unit %d\n", unit_number); + + int obu_overhead_current_unit = 0; + if (!aom_tools::DumpObu(input_ctx.unit_buffer, static_cast(unit_size), + &obu_overhead_current_unit)) { + fprintf(stderr, "Error: Temporal Unit parse failed on unit number %d.\n", + unit_number); + return EXIT_FAILURE; + } + printf(" OBU overhead: %d\n", obu_overhead_current_unit); + ++unit_number; + obu_overhead_bytes_total += obu_overhead_current_unit; + } + + printf("File total OBU overhead: %" PRId64 "\n", obu_overhead_bytes_total); + return EXIT_SUCCESS; +} diff --git a/third_party/aom/tools/frame_size_variation_analyzer.py b/third_party/aom/tools/frame_size_variation_analyzer.py new file mode 100644 index 0000000000..5c02319df1 --- /dev/null +++ b/third_party/aom/tools/frame_size_variation_analyzer.py @@ -0,0 +1,74 @@ +# RTC frame size variation analyzer +# Usage: +# 1. Config with "-DCONFIG_OUTPUT_FRAME_SIZE=1". +# 2. Build aomenc. Encode a file, and generate output file: frame_sizes.csv +# 3. Run: python ./frame_size.py frame_sizes.csv target-bitrate fps +# Where target-bitrate: Bitrate (kbps), and fps is frame per second. +# Example: python ../aom/tools/frame_size_variation_analyzer.py frame_sizes.csv +# 1000 30 + +import numpy as np +import csv +import sys +import matplotlib.pyplot as plt + +# return the moving average +def moving_average(x, w): + return np.convolve(x, np.ones(w), 'valid') / w + +def frame_size_analysis(filename, target_br, fps): + tbr = target_br * 1000 / fps + + with open(filename, 'r') as infile: + raw_data = list(csv.reader(infile, delimiter=',')) + + data = np.array(raw_data).astype(float) + fsize = data[:, 0].astype(float) # frame size + qindex = data[:, 1].astype(float) # qindex + + # Frame bit rate mismatch + mismatch = np.absolute(fsize - np.full(fsize.size, tbr)) + + # Count how many frames are more than 2.5x of frame target bit rate. + tbr_thr = tbr * 2.5 + cnt = 0 + idx = np.arange(fsize.size) + for i in idx: + if fsize[i] > tbr_thr: + cnt = cnt + 1 + + # Use the 15-frame moving window + win = 15 + avg_fsize = moving_average(fsize, win) + win_mismatch = np.absolute(avg_fsize - np.full(avg_fsize.size, tbr)) + + print('[Target frame rate (bit)]:', "%.2f"%tbr) + print('[Average frame rate (bit)]:', "%.2f"%np.average(fsize)) + print('[Frame rate standard deviation]:', "%.2f"%np.std(fsize)) + print('[Max/min frame rate (bit)]:', "%.2f"%np.max(fsize), '/', "%.2f"%np.min(fsize)) + print('[Average frame rate mismatch (bit)]:', "%.2f"%np.average(mismatch)) + print('[Number of frames (frame rate > 2.5x of target frame rate)]:', cnt) + print(' Moving window size:', win) + print('[Moving average frame rate mismatch (bit)]:', "%.2f"%np.average(win_mismatch)) + print('------------------------------') + + figure, axis = plt.subplots(2) + x = np.arange(fsize.size) + axis[0].plot(x, fsize, color='blue') + axis[0].set_title("frame sizes") + axis[1].plot(x, qindex, color='blue') + axis[1].set_title("frame qindex") + plt.tight_layout() + + # Save the plot + plotname = filename + '.png' + plt.savefig(plotname) + plt.show() + +if __name__ == '__main__': + if (len(sys.argv) < 4): + print(sys.argv[0], 'input_file, target_bitrate, fps') + sys.exit() + target_br = int(sys.argv[2]) + fps = int(sys.argv[3]) + frame_size_analysis(sys.argv[1], target_br, fps) diff --git a/third_party/aom/tools/gen_authors.sh b/third_party/aom/tools/gen_authors.sh new file mode 100755 index 0000000000..5def8bc898 --- /dev/null +++ b/third_party/aom/tools/gen_authors.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# Add organization names manually. + +cat <" | sort | uniq | grep -v "corp.google\|clang-format") +EOF diff --git a/third_party/aom/tools/gen_constrained_tokenset.py b/third_party/aom/tools/gen_constrained_tokenset.py new file mode 100755 index 0000000000..f5b0816dbf --- /dev/null +++ b/third_party/aom/tools/gen_constrained_tokenset.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +## +## Copyright (c) 2016, Alliance for Open Media. All rights reserved +## +## This source code is subject to the terms of the BSD 2 Clause License and +## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License +## was not distributed with this source code in the LICENSE file, you can +## obtain it at www.aomedia.org/license/software. If the Alliance for Open +## Media Patent License 1.0 was not distributed with this source code in the +## PATENTS file, you can obtain it at www.aomedia.org/license/patent. +## +"""Generate the probability model for the constrained token set. + +Model obtained from a 2-sided zero-centered distribution derived +from a Pareto distribution. The cdf of the distribution is: +cdf(x) = 0.5 + 0.5 * sgn(x) * [1 - {alpha/(alpha + |x|)} ^ beta] + +For a given beta and a given probability of the 1-node, the alpha +is first solved, and then the {alpha, beta} pair is used to generate +the probabilities for the rest of the nodes. +""" + +import heapq +import sys +import numpy as np +import scipy.optimize +import scipy.stats + + +def cdf_spareto(x, xm, beta): + p = 1 - (xm / (np.abs(x) + xm))**beta + p = 0.5 + 0.5 * np.sign(x) * p + return p + + +def get_spareto(p, beta): + cdf = cdf_spareto + + def func(x): + return ((cdf(1.5, x, beta) - cdf(0.5, x, beta)) / + (1 - cdf(0.5, x, beta)) - p)**2 + + alpha = scipy.optimize.fminbound(func, 1e-12, 10000, xtol=1e-12) + parray = np.zeros(11) + parray[0] = 2 * (cdf(0.5, alpha, beta) - 0.5) + parray[1] = (2 * (cdf(1.5, alpha, beta) - cdf(0.5, alpha, beta))) + parray[2] = (2 * (cdf(2.5, alpha, beta) - cdf(1.5, alpha, beta))) + parray[3] = (2 * (cdf(3.5, alpha, beta) - cdf(2.5, alpha, beta))) + parray[4] = (2 * (cdf(4.5, alpha, beta) - cdf(3.5, alpha, beta))) + parray[5] = (2 * (cdf(6.5, alpha, beta) - cdf(4.5, alpha, beta))) + parray[6] = (2 * (cdf(10.5, alpha, beta) - cdf(6.5, alpha, beta))) + parray[7] = (2 * (cdf(18.5, alpha, beta) - cdf(10.5, alpha, beta))) + parray[8] = (2 * (cdf(34.5, alpha, beta) - cdf(18.5, alpha, beta))) + parray[9] = (2 * (cdf(66.5, alpha, beta) - cdf(34.5, alpha, beta))) + parray[10] = 2 * (1. - cdf(66.5, alpha, beta)) + return parray + + +def quantize_probs(p, save_first_bin, bits): + """Quantize probability precisely. + + Quantize probabilities minimizing dH (Kullback-Leibler divergence) + approximated by: sum (p_i-q_i)^2/p_i. + References: + https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence + https://github.com/JarekDuda/AsymmetricNumeralSystemsToolkit + """ + num_sym = p.size + p = np.clip(p, 1e-16, 1) + L = 2**bits + pL = p * L + ip = 1. / p # inverse probability + q = np.clip(np.round(pL), 1, L + 1 - num_sym) + quant_err = (pL - q)**2 * ip + sgn = np.sign(L - q.sum()) # direction of correction + if sgn != 0: # correction is needed + v = [] # heap of adjustment results (adjustment err, index) of each symbol + for i in range(1 if save_first_bin else 0, num_sym): + q_adj = q[i] + sgn + if q_adj > 0 and q_adj < L: + adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i] + heapq.heappush(v, (adj_err, i)) + while q.sum() != L: + # apply lowest error adjustment + (adj_err, i) = heapq.heappop(v) + quant_err[i] += adj_err + q[i] += sgn + # calculate the cost of adjusting this symbol again + q_adj = q[i] + sgn + if q_adj > 0 and q_adj < L: + adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i] + heapq.heappush(v, (adj_err, i)) + return q + + +def get_quantized_spareto(p, beta, bits, first_token): + parray = get_spareto(p, beta) + parray = parray[1:] / (1 - parray[0]) + # CONFIG_NEW_TOKENSET + if first_token > 1: + parray = parray[1:] / (1 - parray[0]) + qarray = quantize_probs(parray, first_token == 1, bits) + return qarray.astype(np.int) + + +def main(bits=15, first_token=1): + beta = 8 + for q in range(1, 256): + parray = get_quantized_spareto(q / 256., beta, bits, first_token) + assert parray.sum() == 2**bits + print('{', ', '.join('%d' % i for i in parray), '},') + + +if __name__ == '__main__': + if len(sys.argv) > 2: + main(int(sys.argv[1]), int(sys.argv[2])) + elif len(sys.argv) > 1: + main(int(sys.argv[1])) + else: + main() diff --git a/third_party/aom/tools/gop_bitrate/analyze_data.py b/third_party/aom/tools/gop_bitrate/analyze_data.py new file mode 100644 index 0000000000..4e006b9220 --- /dev/null +++ b/third_party/aom/tools/gop_bitrate/analyze_data.py @@ -0,0 +1,18 @@ +with open('experiment.txt', 'r') as file: + lines = file.readlines() + curr_filename = '' + keyframe = 0 + actual_value = 0 + estimate_value = 0 + print('filename, estimated value (b), actual value (b)') + for line in lines: + if line.startswith('input:'): + curr_filename = line[13:].strip() + if line.startswith('estimated'): + estimate_value = float(line[19:].strip()) + if line.startswith('frame:'): + actual_value += float(line[line.find('size')+6:line.find('total')-2]) + if line.startswith('****'): + print(f'{curr_filename}, {estimate_value}, {actual_value}') + estimate_value = 0 + actual_value = 0 diff --git a/third_party/aom/tools/gop_bitrate/encode_all_script.sh b/third_party/aom/tools/gop_bitrate/encode_all_script.sh new file mode 100755 index 0000000000..0689b33138 --- /dev/null +++ b/third_party/aom/tools/gop_bitrate/encode_all_script.sh @@ -0,0 +1,13 @@ +#!/bin/bash +#INPUT=media/cheer_sif.y4m +OUTPUT=test.webm +LIMIT=17 +CPU_USED=3 +CQ_LEVEL=36 + +for input in media/* +do + echo "****" >> experiment.txt + echo "input: $input" >> experiment.txt + ./aomenc --limit=$LIMIT --codec=av1 --cpu-used=$CPU_USED --end-usage=q --cq-level=$CQ_LEVEL --psnr --threads=0 --profile=0 --lag-in-frames=35 --min-q=0 --max-q=63 --auto-alt-ref=1 --passes=2 --kf-max-dist=160 --kf-min-dist=0 --drop-frame=0 --static-thresh=0 --minsection-pct=0 --maxsection-pct=2000 --arnr-maxframes=7 --arnr-strength=5 --sharpness=0 --undershoot-pct=100 --overshoot-pct=100 --frame-parallel=0 --tile-columns=0 -o $OUTPUT $input >> experiment.txt +done diff --git a/third_party/aom/tools/gop_bitrate/python/bitrate_accuracy.py b/third_party/aom/tools/gop_bitrate/python/bitrate_accuracy.py new file mode 100644 index 0000000000..2a5da6a794 --- /dev/null +++ b/third_party/aom/tools/gop_bitrate/python/bitrate_accuracy.py @@ -0,0 +1,185 @@ +import numpy as np + +# Model A only. +# Uses least squares regression to find the solution +# when there is one unknown variable. +def lstsq_solution(A, B): + A_inv = np.linalg.pinv(A) + x = np.matmul(A_inv, B) + return x[0][0] + +# Model B only. +# Uses the pseudoinverse matrix to find the solution +# when there are two unknown variables. +def pinv_solution(A, mv, B): + new_A = np.concatenate((A, mv), axis=1) + new_A_inv = np.linalg.pinv(new_A) + new_x = np.matmul(new_A_inv, B) + print("pinv solution:", new_x[0][0], new_x[1][0]) + return (new_x[0][0], new_x[1][0]) + +# Model A only. +# Finds the coefficient to multiply A by to minimize +# the percentage error between A and B. +def minimize_percentage_error_model_a(A, B): + R = np.divide(A, B) + num = 0 + den = 0 + best_x = 0 + best_error = 100 + for r_i in R: + num += r_i + den += r_i**2 + if den == 0: + return 0 + return (num/den)[0] + +# Model B only. +# Finds the coefficients to multiply to the frame bitrate +# and the motion vector bitrate to minimize the percent error. +def minimize_percentage_error_model_b(r_e, r_m, r_f): + r_ef = np.divide(r_e, r_f) + r_mf = np.divide(r_m, r_f) + sum_ef = np.sum(r_ef) + sum_ef_sq = np.sum(np.square(r_ef)) + sum_mf = np.sum(r_mf) + sum_mf_sq = np.sum(np.square(r_mf)) + sum_ef_mf = np.sum(np.multiply(r_ef, r_mf)) + # Divides x by y. If y is zero, returns 0. + divide = lambda x, y : 0 if y == 0 else x / y + # Set up and solve the matrix equation + A = np.array([[1, divide(sum_ef_mf, sum_ef_sq)],[divide(sum_ef_mf, sum_mf_sq), 1]]) + B = np.array([divide(sum_ef, sum_ef_sq), divide(sum_mf, sum_mf_sq)]) + A_inv = np.linalg.pinv(A) + x = np.matmul(A_inv, B) + return x + +# Model A only. +# Calculates the least squares error between A and B +# using coefficients in X. +def average_lstsq_error(A, B, x): + error = 0 + n = 0 + for i, a in enumerate(A): + a = a[0] + b = B[i][0] + if b == 0: + continue + n += 1 + error += (b - x*a)**2 + if n == 0: + return None + error /= n + return error + +# Model A only. +# Calculates the average percentage error between A and B. +def average_percent_error_model_a(A, B, x): + error = 0 + n = 0 + for i, a in enumerate(A): + a = a[0] + b = B[i][0] + if b == 0: + continue + n += 1 + error_i = (abs(x*a-b)/b)*100 + error += error_i + error /= n + return error + +# Model B only. +# Calculates the average percentage error between A and B. +def average_percent_error_model_b(A, M, B, x): + error = 0 + for i, a in enumerate(A): + a = a[0] + mv = M[i] + b = B[i][0] + if b == 0: + continue + estimate = x[0]*a + estimate += x[1]*mv + error += abs(estimate - b) / b + error *= 100 + error /= A.shape[0] + return error + +def average_squared_error_model_a(A, B, x): + error = 0 + n = 0 + for i, a in enumerate(A): + a = a[0] + b = B[i][0] + if b == 0: + continue + n += 1 + error_i = (1 - x*(a/b))**2 + error += error_i + error /= n + error = error**0.5 + return error * 100 + +def average_squared_error_model_b(A, M, B, x): + error = 0 + n = 0 + for i, a in enumerate(A): + a = a[0] + b = B[i][0] + mv = M[i] + if b == 0: + continue + n += 1 + error_i = 1 - ((x[0]*a + x[1]*mv)/b) + error_i = error_i**2 + error += error_i + error /= n + error = error**0.5 + return error * 100 + +# Traverses the data and prints out one value for +# each update type. +def print_solutions(file_path): + data = np.genfromtxt(file_path, delimiter="\t") + prev_update = 0 + split_list_indices = list() + for i, val in enumerate(data): + if prev_update != val[3]: + split_list_indices.append(i) + prev_update = val[3] + split = np.split(data, split_list_indices) + for array in split: + A, mv, B, update = np.hsplit(array, 4) + z = np.where(B == 0)[0] + r_e = np.delete(A, z, axis=0) + r_m = np.delete(mv, z, axis=0) + r_f = np.delete(B, z, axis=0) + A = r_e + mv = r_m + B = r_f + all_zeros = not A.any() + if all_zeros: + continue + print("update type:", update[0][0]) + x_ls = lstsq_solution(A, B) + x_a = minimize_percentage_error_model_a(A, B) + x_b = minimize_percentage_error_model_b(A, mv, B) + percent_error_a = average_percent_error_model_a(A, B, x_a) + percent_error_b = average_percent_error_model_b(A, mv, B, x_b)[0] + baseline_percent_error_a = average_percent_error_model_a(A, B, 1) + baseline_percent_error_b = average_percent_error_model_b(A, mv, B, [1, 1])[0] + + squared_error_a = average_squared_error_model_a(A, B, x_a) + squared_error_b = average_squared_error_model_b(A, mv, B, x_b)[0] + baseline_squared_error_a = average_squared_error_model_a(A, B, 1) + baseline_squared_error_b = average_squared_error_model_b(A, mv, B, [1, 1])[0] + + print("model,\tframe_coeff,\tmv_coeff,\terror,\tbaseline_error") + print("Model A %_error,\t" + str(x_a) + ",\t" + str(0) + ",\t" + str(percent_error_a) + ",\t" + str(baseline_percent_error_a)) + print("Model A sq_error,\t" + str(x_a) + ",\t" + str(0) + ",\t" + str(squared_error_a) + ",\t" + str(baseline_squared_error_a)) + print("Model B %_error,\t" + str(x_b[0]) + ",\t" + str(x_b[1]) + ",\t" + str(percent_error_b) + ",\t" + str(baseline_percent_error_b)) + print("Model B sq_error,\t" + str(x_b[0]) + ",\t" + str(x_b[1]) + ",\t" + str(squared_error_b) + ",\t" + str(baseline_squared_error_b)) + print() + +if __name__ == "__main__": + print_solutions("data2/all_lowres_target_lt600_data.txt") diff --git a/third_party/aom/tools/inspect-cli.js b/third_party/aom/tools/inspect-cli.js new file mode 100644 index 0000000000..a14c08111a --- /dev/null +++ b/third_party/aom/tools/inspect-cli.js @@ -0,0 +1,39 @@ +/** + * This tool lets you test if the compiled Javascript decoder is functioning properly. You'll + * need to download a SpiderMonkey js-shell to run this script. + * https://archive.mozilla.org/pub/firefox/nightly/latest-mozilla-central/ + * + * Example: + * js-shell inspect-cli.js video.ivf + */ +load("inspect.js"); +var buffer = read(scriptArgs[0], "binary"); +var Module = { + noExitRuntime: true, + noInitialRun: true, + preInit: [], + preRun: [], + postRun: [function () { + printErr(`Loaded Javascript Decoder OK`); + }], + memoryInitializerPrefixURL: "bin/", + arguments: ['input.ivf', 'output.raw'], + on_frame_decoded_json: function (jsonString) { + let json = JSON.parse("[" + Module.UTF8ToString(jsonString) + "null]"); + json.forEach(frame => { + if (frame) { + print(frame.frame); + } + }); + } +}; +DecoderModule(Module); +Module.FS.writeFile("/tmp/input.ivf", buffer, { encoding: "binary" }); +Module._open_file(); +Module._set_layers(0xFFFFFFFF); // Set this to zero if you want to benchmark decoding. +while(true) { + printErr("Decoding Frame ..."); + if (Module._read_frame()) { + break; + } +} diff --git a/third_party/aom/tools/inspect-post.js b/third_party/aom/tools/inspect-post.js new file mode 100644 index 0000000000..31c40bb82c --- /dev/null +++ b/third_party/aom/tools/inspect-post.js @@ -0,0 +1 @@ +Module["FS"] = FS; diff --git a/third_party/aom/tools/intersect-diffs.py b/third_party/aom/tools/intersect-diffs.py new file mode 100755 index 0000000000..960183675d --- /dev/null +++ b/third_party/aom/tools/intersect-diffs.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +## +## Copyright (c) 2016, Alliance for Open Media. All rights reserved +## +## This source code is subject to the terms of the BSD 2 Clause License and +## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License +## was not distributed with this source code in the LICENSE file, you can +## obtain it at www.aomedia.org/license/software. If the Alliance for Open +## Media Patent License 1.0 was not distributed with this source code in the +## PATENTS file, you can obtain it at www.aomedia.org/license/patent. +## +"""Calculates the "intersection" of two unified diffs. + +Given two diffs, A and B, it finds all hunks in B that had non-context lines +in A and prints them to stdout. This is useful to determine the hunks in B that +are relevant to A. The resulting file can be applied with patch(1) on top of A. +""" + +__author__ = "jkoleszar@google.com" + +import sys + +import diff + + +def FormatDiffHunks(hunks): + """Re-serialize a list of DiffHunks.""" + r = [] + last_header = None + for hunk in hunks: + this_header = hunk.header[0:2] + if last_header != this_header: + r.extend(hunk.header) + last_header = this_header + else: + r.extend(hunk.header[2]) + r.extend(hunk.lines) + r.append("\n") + return "".join(r) + + +def ZipHunks(rhs_hunks, lhs_hunks): + """Join two hunk lists on filename.""" + for rhs_hunk in rhs_hunks: + rhs_file = rhs_hunk.right.filename.split("/")[1:] + + for lhs_hunk in lhs_hunks: + lhs_file = lhs_hunk.left.filename.split("/")[1:] + if lhs_file != rhs_file: + continue + yield (rhs_hunk, lhs_hunk) + + +def main(): + old_hunks = [x for x in diff.ParseDiffHunks(open(sys.argv[1], "r"))] + new_hunks = [x for x in diff.ParseDiffHunks(open(sys.argv[2], "r"))] + out_hunks = [] + + # Join the right hand side of the older diff with the left hand side of the + # newer diff. + for old_hunk, new_hunk in ZipHunks(old_hunks, new_hunks): + if new_hunk in out_hunks: + continue + old_lines = old_hunk.right + new_lines = new_hunk.left + + # Determine if this hunk overlaps any non-context line from the other + for i in old_lines.delta_line_nums: + if i in new_lines: + out_hunks.append(new_hunk) + break + + if out_hunks: + print(FormatDiffHunks(out_hunks)) + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/third_party/aom/tools/lint-hunks.py b/third_party/aom/tools/lint-hunks.py new file mode 100755 index 0000000000..8b3af972fc --- /dev/null +++ b/third_party/aom/tools/lint-hunks.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +## +## Copyright (c) 2016, Alliance for Open Media. All rights reserved +## +## This source code is subject to the terms of the BSD 2 Clause License and +## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License +## was not distributed with this source code in the LICENSE file, you can +## obtain it at www.aomedia.org/license/software. If the Alliance for Open +## Media Patent License 1.0 was not distributed with this source code in the +## PATENTS file, you can obtain it at www.aomedia.org/license/patent. +## +"""Performs style checking on each diff hunk.""" +import getopt +import os +import io +import subprocess +import sys + +import diff + + +SHORT_OPTIONS = "h" +LONG_OPTIONS = ["help"] + +TOPLEVEL_CMD = ["git", "rev-parse", "--show-toplevel"] +DIFF_CMD = ["git", "diff"] +DIFF_INDEX_CMD = ["git", "diff-index", "-u", "HEAD", "--"] +SHOW_CMD = ["git", "show"] +CPPLINT_FILTERS = ["-readability/casting"] + + +class Usage(Exception): + pass + + +class SubprocessException(Exception): + def __init__(self, args): + msg = "Failed to execute '%s'"%(" ".join(args)) + super(SubprocessException, self).__init__(msg) + + +class Subprocess(subprocess.Popen): + """Adds the notion of an expected returncode to Popen.""" + + def __init__(self, args, expected_returncode=0, **kwargs): + self._args = args + self._expected_returncode = expected_returncode + super(Subprocess, self).__init__(args, **kwargs) + + def communicate(self, *args, **kwargs): + result = super(Subprocess, self).communicate(*args, **kwargs) + if self._expected_returncode is not None: + try: + ok = self.returncode in self._expected_returncode + except TypeError: + ok = self.returncode == self._expected_returncode + if not ok: + raise SubprocessException(self._args) + return result + + +def main(argv=None): + if argv is None: + argv = sys.argv + try: + try: + opts, args = getopt.getopt(argv[1:], SHORT_OPTIONS, LONG_OPTIONS) + except getopt.error as msg: + raise Usage(msg) + + # process options + for o, _ in opts: + if o in ("-h", "--help"): + print(__doc__) + sys.exit(0) + + if args and len(args) > 1: + print(__doc__) + sys.exit(0) + + # Find the fully qualified path to the root of the tree + tl = Subprocess(TOPLEVEL_CMD, stdout=subprocess.PIPE, text=True) + tl = tl.communicate()[0].strip() + + # See if we're working on the index or not. + if args: + diff_cmd = DIFF_CMD + [args[0] + "^!"] + else: + diff_cmd = DIFF_INDEX_CMD + + # Build the command line to execute cpplint + cpplint_cmd = [os.path.join(tl, "tools", "cpplint.py"), + "--filter=" + ",".join(CPPLINT_FILTERS), + "-"] + + # Get a list of all affected lines + file_affected_line_map = {} + p = Subprocess(diff_cmd, stdout=subprocess.PIPE, text=True) + stdout = p.communicate()[0] + for hunk in diff.ParseDiffHunks(io.StringIO(stdout)): + filename = hunk.right.filename[2:] + if filename not in file_affected_line_map: + file_affected_line_map[filename] = set() + file_affected_line_map[filename].update(hunk.right.delta_line_nums) + + # Run each affected file through cpplint + lint_failed = False + for filename, affected_lines in file_affected_line_map.items(): + if filename.split(".")[-1] not in ("c", "h", "cc"): + continue + if filename.startswith("third_party"): + continue + + if args: + # File contents come from git + show_cmd = SHOW_CMD + [args[0] + ":" + filename] + show = Subprocess(show_cmd, stdout=subprocess.PIPE, text=True) + lint = Subprocess(cpplint_cmd, expected_returncode=(0, 1), + stdin=show.stdout, stderr=subprocess.PIPE, + text=True) + lint_out = lint.communicate()[1] + else: + # File contents come from the working tree + lint = Subprocess(cpplint_cmd, expected_returncode=(0, 1), + stdin=subprocess.PIPE, stderr=subprocess.PIPE, + text=True) + stdin = open(os.path.join(tl, filename)).read() + lint_out = lint.communicate(stdin)[1] + + for line in lint_out.split("\n"): + fields = line.split(":") + if fields[0] != "-": + continue + warning_line_num = int(fields[1]) + if warning_line_num in affected_lines: + print("%s:%d:%s"%(filename, warning_line_num, + ":".join(fields[2:]))) + lint_failed = True + + # Set exit code if any relevant lint errors seen + if lint_failed: + return 1 + + except Usage as err: + print(err, file=sys.stderr) + print("for help use --help", file=sys.stderr) + return 2 + +if __name__ == "__main__": + sys.exit(main()) diff --git a/third_party/aom/tools/obu_parser.cc b/third_party/aom/tools/obu_parser.cc new file mode 100644 index 0000000000..5716b46218 --- /dev/null +++ b/third_party/aom/tools/obu_parser.cc @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2017, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ +#include + +#include +#include + +#include "aom/aom_codec.h" +#include "aom/aom_integer.h" +#include "aom_ports/mem_ops.h" +#include "av1/common/obu_util.h" +#include "tools/obu_parser.h" + +namespace aom_tools { + +// Basic OBU syntax +// 8 bits: Header +// 7 +// forbidden bit +// 6,5,4,3 +// type bits +// 2 +// extension flag bit +// 1 +// has size field bit +// 0 +// reserved bit +const uint32_t kObuForbiddenBitMask = 0x1; +const uint32_t kObuForbiddenBitShift = 7; +const uint32_t kObuTypeBitsMask = 0xF; +const uint32_t kObuTypeBitsShift = 3; +const uint32_t kObuExtensionFlagBitMask = 0x1; +const uint32_t kObuExtensionFlagBitShift = 2; +const uint32_t kObuHasSizeFieldBitMask = 0x1; +const uint32_t kObuHasSizeFieldBitShift = 1; + +// When extension flag bit is set: +// 8 bits: extension header +// 7,6,5 +// temporal ID +// 4,3 +// spatial ID +// 2,1,0 +// reserved bits +const uint32_t kObuExtTemporalIdBitsMask = 0x7; +const uint32_t kObuExtTemporalIdBitsShift = 5; +const uint32_t kObuExtSpatialIdBitsMask = 0x3; +const uint32_t kObuExtSpatialIdBitsShift = 3; + +bool ValidObuType(int obu_type) { + switch (obu_type) { + case OBU_SEQUENCE_HEADER: + case OBU_TEMPORAL_DELIMITER: + case OBU_FRAME_HEADER: + case OBU_TILE_GROUP: + case OBU_METADATA: + case OBU_FRAME: + case OBU_REDUNDANT_FRAME_HEADER: + case OBU_TILE_LIST: + case OBU_PADDING: return true; + } + return false; +} + +bool ParseObuHeader(uint8_t obu_header_byte, ObuHeader *obu_header) { + const int forbidden_bit = + (obu_header_byte >> kObuForbiddenBitShift) & kObuForbiddenBitMask; + if (forbidden_bit) { + fprintf(stderr, "Invalid OBU, forbidden bit set.\n"); + return false; + } + + obu_header->type = static_cast( + (obu_header_byte >> kObuTypeBitsShift) & kObuTypeBitsMask); + if (!ValidObuType(obu_header->type)) { + fprintf(stderr, "Invalid OBU type: %d.\n", obu_header->type); + return false; + } + + obu_header->has_extension = + (obu_header_byte >> kObuExtensionFlagBitShift) & kObuExtensionFlagBitMask; + obu_header->has_size_field = + (obu_header_byte >> kObuHasSizeFieldBitShift) & kObuHasSizeFieldBitMask; + return true; +} + +bool ParseObuExtensionHeader(uint8_t ext_header_byte, ObuHeader *obu_header) { + obu_header->temporal_layer_id = + (ext_header_byte >> kObuExtTemporalIdBitsShift) & + kObuExtTemporalIdBitsMask; + obu_header->spatial_layer_id = + (ext_header_byte >> kObuExtSpatialIdBitsShift) & kObuExtSpatialIdBitsMask; + + return true; +} + +void PrintObuHeader(const ObuHeader *header) { + printf( + " OBU type: %s\n" + " extension: %s\n", + aom_obu_type_to_string(static_cast(header->type)), + header->has_extension ? "yes" : "no"); + if (header->has_extension) { + printf( + " temporal_id: %d\n" + " spatial_id: %d\n", + header->temporal_layer_id, header->spatial_layer_id); + } +} + +bool DumpObu(const uint8_t *data, int length, int *obu_overhead_bytes) { + const int kObuHeaderSizeBytes = 1; + const int kMinimumBytesRequired = 1 + kObuHeaderSizeBytes; + int consumed = 0; + int obu_overhead = 0; + ObuHeader obu_header; + while (consumed < length) { + const int remaining = length - consumed; + if (remaining < kMinimumBytesRequired) { + fprintf(stderr, + "OBU parse error. Did not consume all data, %d bytes remain.\n", + remaining); + return false; + } + + int obu_header_size = 0; + + memset(&obu_header, 0, sizeof(obu_header)); + const uint8_t obu_header_byte = *(data + consumed); + if (!ParseObuHeader(obu_header_byte, &obu_header)) { + fprintf(stderr, "OBU parsing failed at offset %d.\n", consumed); + return false; + } + + ++obu_overhead; + ++obu_header_size; + + if (obu_header.has_extension) { + const uint8_t obu_ext_header_byte = + *(data + consumed + kObuHeaderSizeBytes); + if (!ParseObuExtensionHeader(obu_ext_header_byte, &obu_header)) { + fprintf(stderr, "OBU extension parsing failed at offset %d.\n", + consumed + kObuHeaderSizeBytes); + return false; + } + + ++obu_overhead; + ++obu_header_size; + } + + PrintObuHeader(&obu_header); + + uint64_t obu_size = 0; + size_t length_field_size = 0; + if (aom_uleb_decode(data + consumed + obu_header_size, + remaining - obu_header_size, &obu_size, + &length_field_size) != 0) { + fprintf(stderr, "OBU size parsing failed at offset %d.\n", + consumed + obu_header_size); + return false; + } + int current_obu_length = static_cast(obu_size); + if (obu_header_size + static_cast(length_field_size) + + current_obu_length > + remaining) { + fprintf(stderr, "OBU parsing failed: not enough OBU data.\n"); + return false; + } + consumed += obu_header_size + static_cast(length_field_size) + + current_obu_length; + printf(" length: %d\n", + static_cast(obu_header_size + length_field_size + + current_obu_length)); + } + + if (obu_overhead_bytes != nullptr) *obu_overhead_bytes = obu_overhead; + printf(" TU size: %d\n", consumed); + + return true; +} + +} // namespace aom_tools diff --git a/third_party/aom/tools/obu_parser.h b/third_party/aom/tools/obu_parser.h new file mode 100644 index 0000000000..1d7d2d794b --- /dev/null +++ b/third_party/aom/tools/obu_parser.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2017, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#ifndef AOM_TOOLS_OBU_PARSER_H_ +#define AOM_TOOLS_OBU_PARSER_H_ + +#include + +namespace aom_tools { + +// Print information obtained from OBU(s) in data until data is exhausted or an +// error occurs. Returns true when all data is consumed successfully, and +// optionally reports OBU storage overhead via obu_overhead_bytes when the +// pointer is non-null. +bool DumpObu(const uint8_t *data, int length, int *obu_overhead_bytes); + +} // namespace aom_tools + +#endif // AOM_TOOLS_OBU_PARSER_H_ diff --git a/third_party/aom/tools/ratectrl_log_analyzer/analyze_ratectrl_log.py b/third_party/aom/tools/ratectrl_log_analyzer/analyze_ratectrl_log.py new file mode 100644 index 0000000000..9afb78cbf5 --- /dev/null +++ b/third_party/aom/tools/ratectrl_log_analyzer/analyze_ratectrl_log.py @@ -0,0 +1,154 @@ +#!/usr/bin/python3 +## +## Copyright (c) 2022, Alliance for Open Media. All rights reserved +## +## This source code is subject to the terms of the BSD 2 Clause License and +## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License +## was not distributed with this source code in the LICENSE file, you can +## obtain it at www.aomedia.org/license/software. If the Alliance for Open +## Media Patent License 1.0 was not distributed with this source code in the +## PATENTS file, you can obtain it at www.aomedia.org/license/patent. +## +""" Analyze the log generated by experimental flag CONFIG_RATECTRL_LOG.""" + +import matplotlib.pyplot as plt +import os + + +def get_file_basename(filename): + return filename.split(".")[0] + + +def parse_log(log_file): + data_list = [] + with open(log_file) as fp: + for line in fp: + dic = {} + word_ls = line.split() + i = 0 + while i < len(word_ls): + dic[word_ls[i]] = float(word_ls[i + 1]) + i += 2 + data_list.append(dic) + fp.close() + return data_list + + +def extract_data(data_list, name): + arr = [] + for data in data_list: + arr.append(data[name]) + return arr + + +def visualize_q_indices(exp_summary, exp_list, fig_path=None): + for exp in exp_list: + data = parse_log(exp["log"]) + q_indices = extract_data(data, "q") + plt.title(exp_summary) + plt.xlabel("frame_coding_idx") + plt.ylabel("q_index") + plt.plot(q_indices, marker=".", label=exp["label"]) + plt.legend() + if fig_path: + plt.savefig(fig_path) + else: + plt.show() + plt.clf() + + +def get_rc_type_from_exp_type(exp_type): + if exp_type == "Q_3P": + return "q" + return "vbr" + + +def test_video(exe_name, input, exp_type, level, log=None, limit=150): + basic_cmd = ("--test-decode=warn --threads=0 --profile=0 --min-q=0 --max-q=63" + " --auto-alt-ref=1 --kf-max-dist=160 --kf-min-dist=0 " + "--drop-frame=0 --static-thresh=0 --minsection-pct=0 " + "--maxsection-pct=2000 --arnr-maxframes=7 --arnr-strength=5 " + "--sharpness=0 --undershoot-pct=100 --overshoot-pct=100 " + "--frame-parallel=0 --tile-columns=0 --cpu-used=3 " + "--lag-in-frames=48 --psnr") + rc_type = get_rc_type_from_exp_type(exp_type) + rc_cmd = "--end-usage=" + rc_type + level_cmd = "" + if rc_type == "q": + level_cmd += "--cq-level=" + str(level) + elif rc_type == "vbr": + level_cmd += "--target-bitrate=" + str(level) + limit_cmd = "--limit=" + str(limit) + passes_cmd = "--passes=3 --second-pass-log=second_pass_log" + output_cmd = "-o test.webm" + input_cmd = "~/data/" + input + log_cmd = "" + if log != None: + log_cmd = ">" + log + cmd_ls = [ + exe_name, basic_cmd, rc_cmd, level_cmd, limit_cmd, passes_cmd, output_cmd, + input_cmd, log_cmd + ] + cmd = " ".join(cmd_ls) + os.system(cmd) + + +def gen_ratectrl_log(test_case): + exe = test_case["exe"] + video = test_case["video"] + exp_type = test_case["exp_type"] + level = test_case["level"] + log = test_case["log"] + test_video(exe, video, exp_type, level, log=log, limit=150) + return log + + +def gen_test_case(exp_type, dataset, videoname, level, log_dir=None): + test_case = {} + exe = "./aomenc_bl" + if exp_type == "BA_3P": + exe = "./aomenc_ba" + test_case["exe"] = exe + + video = os.path.join(dataset, videoname) + test_case["video"] = video + test_case["exp_type"] = exp_type + test_case["level"] = level + + video_basename = get_file_basename(videoname) + log = ".".join([dataset, video_basename, exp_type, str(level)]) + if log_dir != None: + log = os.path.join(log_dir, log) + test_case["log"] = log + return test_case + + +def run_ratectrl_exp(exp_config): + fp = open(exp_config) + log_dir = "./lowres_rc_log" + fig_dir = "./lowres_rc_fig" + dataset = "lowres" + for line in fp: + word_ls = line.split() + dataset = word_ls[0] + videoname = word_ls[1] + exp_type_ls = ["VBR_3P", "BA_3P"] + level_ls = [int(v) for v in word_ls[2:4]] + exp_ls = [] + for i in range(len(exp_type_ls)): + exp_type = exp_type_ls[i] + test_case = gen_test_case(exp_type, dataset, videoname, level_ls[i], + log_dir) + log = gen_ratectrl_log(test_case) + exp = {} + exp["log"] = log + exp["label"] = exp_type + exp_ls.append(exp) + video_basename = get_file_basename(videoname) + fig_path = os.path.join(fig_dir, video_basename + ".png") + visualize_q_indices(video_basename, exp_ls, fig_path) + fp.close() + + +if __name__ == "__main__": + run_ratectrl_exp("exp_rc_config") diff --git a/third_party/aom/tools/txfm_analyzer/txfm_gen_code.cc b/third_party/aom/tools/txfm_analyzer/txfm_gen_code.cc new file mode 100644 index 0000000000..7c5400b91a --- /dev/null +++ b/third_party/aom/tools/txfm_analyzer/txfm_gen_code.cc @@ -0,0 +1,580 @@ +/* + * Copyright (c) 2018, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include +#include +#include +#include + +#include "tools/txfm_analyzer/txfm_graph.h" + +typedef enum CODE_TYPE { + CODE_TYPE_C, + CODE_TYPE_SSE2, + CODE_TYPE_SSE4_1 +} CODE_TYPE; + +int get_cos_idx(double value, int mod) { + return round(acos(fabs(value)) / PI * mod); +} + +char *cos_text_arr(double value, int mod, char *text, int size) { + int num = get_cos_idx(value, mod); + if (value < 0) { + snprintf(text, size, "-cospi[%2d]", num); + } else { + snprintf(text, size, " cospi[%2d]", num); + } + + if (num == 0) + printf("v: %f -> %d/%d v==-1 is %d\n", value, num, mod, value == -1); + + return text; +} + +char *cos_text_sse2(double w0, double w1, int mod, char *text, int size) { + int idx0 = get_cos_idx(w0, mod); + int idx1 = get_cos_idx(w1, mod); + char p[] = "p"; + char n[] = "m"; + char *sgn0 = w0 < 0 ? n : p; + char *sgn1 = w1 < 0 ? n : p; + snprintf(text, size, "cospi_%s%02d_%s%02d", sgn0, idx0, sgn1, idx1); + return text; +} + +char *cos_text_sse4_1(double w, int mod, char *text, int size) { + int idx = get_cos_idx(w, mod); + char p[] = "p"; + char n[] = "m"; + char *sgn = w < 0 ? n : p; + snprintf(text, size, "cospi_%s%02d", sgn, idx); + return text; +} + +void node_to_code_c(Node *node, const char *buf0, const char *buf1) { + int cnt = 0; + for (int i = 0; i < 2; i++) { + if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0) cnt++; + } + if (cnt == 2) { + int cnt2 = 0; + printf(" %s[%d] =", buf1, node->nodeIdx); + for (int i = 0; i < 2; i++) { + if (fabs(node->inWeight[i]) == 1) { + cnt2++; + } + } + if (cnt2 == 2) { + printf(" apply_value("); + } + int cnt1 = 0; + for (int i = 0; i < 2; i++) { + if (node->inWeight[i] == 1) { + if (cnt1 > 0) + printf(" + %s[%d]", buf0, node->inNodeIdx[i]); + else + printf(" %s[%d]", buf0, node->inNodeIdx[i]); + cnt1++; + } else if (node->inWeight[i] == -1) { + if (cnt1 > 0) + printf(" - %s[%d]", buf0, node->inNodeIdx[i]); + else + printf("-%s[%d]", buf0, node->inNodeIdx[i]); + cnt1++; + } + } + if (cnt2 == 2) { + printf(", stage_range[stage])"); + } + printf(";\n"); + } else { + char w0[100]; + char w1[100]; + printf( + " %s[%d] = half_btf(%s, %s[%d], %s, %s[%d], " + "cos_bit);\n", + buf1, node->nodeIdx, cos_text_arr(node->inWeight[0], COS_MOD, w0, 100), + buf0, node->inNodeIdx[0], + cos_text_arr(node->inWeight[1], COS_MOD, w1, 100), buf0, + node->inNodeIdx[1]); + } +} + +void gen_code_c(Node *node, int stage_num, int node_num, TYPE_TXFM type) { + char *fun_name = new char[100]; + get_fun_name(fun_name, 100, type, node_num); + + printf("\n"); + printf( + "void av1_%s(const int32_t *input, int32_t *output, int8_t cos_bit, " + "const int8_t* stage_range) " + "{\n", + fun_name); + printf(" assert(output != input);\n"); + printf(" const int32_t size = %d;\n", node_num); + printf(" const int32_t *cospi = cospi_arr(cos_bit);\n"); + printf("\n"); + + printf(" int32_t stage = 0;\n"); + printf(" int32_t *bf0, *bf1;\n"); + printf(" int32_t step[%d];\n", node_num); + + const char *buf0 = "bf0"; + const char *buf1 = "bf1"; + const char *input = "input"; + + int si = 0; + printf("\n"); + printf(" // stage %d;\n", si); + printf(" apply_range(stage, input, %s, size, stage_range[stage]);\n", input); + + si = 1; + printf("\n"); + printf(" // stage %d;\n", si); + printf(" stage++;\n"); + if (si % 2 == (stage_num - 1) % 2) { + printf(" %s = output;\n", buf1); + } else { + printf(" %s = step;\n", buf1); + } + + for (int ni = 0; ni < node_num; ni++) { + int idx = get_idx(si, ni, node_num); + node_to_code_c(node + idx, input, buf1); + } + + printf(" range_check_buf(stage, input, bf1, size, stage_range[stage]);\n"); + + for (int si = 2; si < stage_num; si++) { + printf("\n"); + printf(" // stage %d\n", si); + printf(" stage++;\n"); + if (si % 2 == (stage_num - 1) % 2) { + printf(" %s = step;\n", buf0); + printf(" %s = output;\n", buf1); + } else { + printf(" %s = output;\n", buf0); + printf(" %s = step;\n", buf1); + } + + // computation code + for (int ni = 0; ni < node_num; ni++) { + int idx = get_idx(si, ni, node_num); + node_to_code_c(node + idx, buf0, buf1); + } + + if (si != stage_num - 1) { + printf( + " range_check_buf(stage, input, bf1, size, stage_range[stage]);\n"); + } + } + printf(" apply_range(stage, input, output, size, stage_range[stage]);\n"); + printf("}\n"); +} + +void single_node_to_code_sse2(Node *node, const char *buf0, const char *buf1) { + printf(" %s[%2d] =", buf1, node->nodeIdx); + if (node->inWeight[0] == 1 && node->inWeight[1] == 1) { + printf(" _mm_adds_epi16(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0, + node->inNodeIdx[1]); + } else if (node->inWeight[0] == 1 && node->inWeight[1] == -1) { + printf(" _mm_subs_epi16(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0, + node->inNodeIdx[1]); + } else if (node->inWeight[0] == -1 && node->inWeight[1] == 1) { + printf(" _mm_subs_epi16(%s[%d], %s[%d])", buf0, node->inNodeIdx[1], buf0, + node->inNodeIdx[0]); + } else if (node->inWeight[0] == 1 && node->inWeight[1] == 0) { + printf(" %s[%d]", buf0, node->inNodeIdx[0]); + } else if (node->inWeight[0] == 0 && node->inWeight[1] == 1) { + printf(" %s[%d]", buf0, node->inNodeIdx[1]); + } else if (node->inWeight[0] == -1 && node->inWeight[1] == 0) { + printf(" _mm_subs_epi16(__zero, %s[%d])", buf0, node->inNodeIdx[0]); + } else if (node->inWeight[0] == 0 && node->inWeight[1] == -1) { + printf(" _mm_subs_epi16(__zero, %s[%d])", buf0, node->inNodeIdx[1]); + } + printf(";\n"); +} + +void pair_node_to_code_sse2(Node *node, Node *partnerNode, const char *buf0, + const char *buf1) { + char temp0[100]; + char temp1[100]; + // btf_16_sse2_type0(w0, w1, in0, in1, out0, out1) + if (node->inNodeIdx[0] != partnerNode->inNodeIdx[0]) + printf(" btf_16_sse2(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d]);\n", + cos_text_sse2(node->inWeight[0], node->inWeight[1], COS_MOD, temp0, + 100), + cos_text_sse2(partnerNode->inWeight[1], partnerNode->inWeight[0], + COS_MOD, temp1, 100), + buf0, node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1, + node->nodeIdx, buf1, partnerNode->nodeIdx); + else + printf(" btf_16_sse2(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d]);\n", + cos_text_sse2(node->inWeight[0], node->inWeight[1], COS_MOD, temp0, + 100), + cos_text_sse2(partnerNode->inWeight[0], partnerNode->inWeight[1], + COS_MOD, temp1, 100), + buf0, node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1, + node->nodeIdx, buf1, partnerNode->nodeIdx); +} + +Node *get_partner_node(Node *node) { + int diff = node->inNode[1]->nodeIdx - node->nodeIdx; + return node + diff; +} + +void node_to_code_sse2(Node *node, const char *buf0, const char *buf1) { + int cnt = 0; + int cnt1 = 0; + if (node->visited == 0) { + node->visited = 1; + for (int i = 0; i < 2; i++) { + if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0) cnt++; + if (fabs(node->inWeight[i]) == 1) cnt1++; + } + if (cnt == 2) { + if (cnt1 == 2) { + // has a partner + Node *partnerNode = get_partner_node(node); + partnerNode->visited = 1; + single_node_to_code_sse2(node, buf0, buf1); + single_node_to_code_sse2(partnerNode, buf0, buf1); + } else { + single_node_to_code_sse2(node, buf0, buf1); + } + } else { + Node *partnerNode = get_partner_node(node); + partnerNode->visited = 1; + pair_node_to_code_sse2(node, partnerNode, buf0, buf1); + } + } +} + +void gen_cospi_list_sse2(Node *node, int stage_num, int node_num) { + int visited[65][65][2][2]; + memset(visited, 0, sizeof(visited)); + char text[100]; + char text1[100]; + char text2[100]; + int size = 100; + printf("\n"); + for (int si = 1; si < stage_num; si++) { + for (int ni = 0; ni < node_num; ni++) { + int idx = get_idx(si, ni, node_num); + int cnt = 0; + Node *node0 = node + idx; + if (node0->visited == 0) { + node0->visited = 1; + for (int i = 0; i < 2; i++) { + if (fabs(node0->inWeight[i]) == 1 || fabs(node0->inWeight[i]) == 0) + cnt++; + } + if (cnt != 2) { + { + double w0 = node0->inWeight[0]; + double w1 = node0->inWeight[1]; + int idx0 = get_cos_idx(w0, COS_MOD); + int idx1 = get_cos_idx(w1, COS_MOD); + int sgn0 = w0 < 0 ? 1 : 0; + int sgn1 = w1 < 0 ? 1 : 0; + + if (!visited[idx0][idx1][sgn0][sgn1]) { + visited[idx0][idx1][sgn0][sgn1] = 1; + printf(" __m128i %s = pair_set_epi16(%s, %s);\n", + cos_text_sse2(w0, w1, COS_MOD, text, size), + cos_text_arr(w0, COS_MOD, text1, size), + cos_text_arr(w1, COS_MOD, text2, size)); + } + } + Node *node1 = get_partner_node(node0); + node1->visited = 1; + if (node1->inNode[0]->nodeIdx != node0->inNode[0]->nodeIdx) { + double w0 = node1->inWeight[0]; + double w1 = node1->inWeight[1]; + int idx0 = get_cos_idx(w0, COS_MOD); + int idx1 = get_cos_idx(w1, COS_MOD); + int sgn0 = w0 < 0 ? 1 : 0; + int sgn1 = w1 < 0 ? 1 : 0; + + if (!visited[idx1][idx0][sgn1][sgn0]) { + visited[idx1][idx0][sgn1][sgn0] = 1; + printf(" __m128i %s = pair_set_epi16(%s, %s);\n", + cos_text_sse2(w1, w0, COS_MOD, text, size), + cos_text_arr(w1, COS_MOD, text1, size), + cos_text_arr(w0, COS_MOD, text2, size)); + } + } else { + double w0 = node1->inWeight[0]; + double w1 = node1->inWeight[1]; + int idx0 = get_cos_idx(w0, COS_MOD); + int idx1 = get_cos_idx(w1, COS_MOD); + int sgn0 = w0 < 0 ? 1 : 0; + int sgn1 = w1 < 0 ? 1 : 0; + + if (!visited[idx0][idx1][sgn0][sgn1]) { + visited[idx0][idx1][sgn0][sgn1] = 1; + printf(" __m128i %s = pair_set_epi16(%s, %s);\n", + cos_text_sse2(w0, w1, COS_MOD, text, size), + cos_text_arr(w0, COS_MOD, text1, size), + cos_text_arr(w1, COS_MOD, text2, size)); + } + } + } + } + } + } +} + +void gen_code_sse2(Node *node, int stage_num, int node_num, TYPE_TXFM type) { + char *fun_name = new char[100]; + get_fun_name(fun_name, 100, type, node_num); + + printf("\n"); + printf( + "void %s_sse2(const __m128i *input, __m128i *output, int8_t cos_bit) " + "{\n", + fun_name); + + printf(" const int32_t* cospi = cospi_arr(cos_bit);\n"); + printf(" const __m128i __zero = _mm_setzero_si128();\n"); + printf(" const __m128i __rounding = _mm_set1_epi32(1 << (cos_bit - 1));\n"); + + graph_reset_visited(node, stage_num, node_num); + gen_cospi_list_sse2(node, stage_num, node_num); + graph_reset_visited(node, stage_num, node_num); + for (int si = 1; si < stage_num; si++) { + char in[100]; + char out[100]; + printf("\n"); + printf(" // stage %d\n", si); + if (si == 1) + snprintf(in, 100, "%s", "input"); + else + snprintf(in, 100, "x%d", si - 1); + if (si == stage_num - 1) { + snprintf(out, 100, "%s", "output"); + } else { + snprintf(out, 100, "x%d", si); + printf(" __m128i %s[%d];\n", out, node_num); + } + // computation code + for (int ni = 0; ni < node_num; ni++) { + int idx = get_idx(si, ni, node_num); + node_to_code_sse2(node + idx, in, out); + } + } + + printf("}\n"); +} +void gen_cospi_list_sse4_1(Node *node, int stage_num, int node_num) { + int visited[65][2]; + memset(visited, 0, sizeof(visited)); + char text[100]; + char text1[100]; + int size = 100; + printf("\n"); + for (int si = 1; si < stage_num; si++) { + for (int ni = 0; ni < node_num; ni++) { + int idx = get_idx(si, ni, node_num); + Node *node0 = node + idx; + if (node0->visited == 0) { + int cnt = 0; + node0->visited = 1; + for (int i = 0; i < 2; i++) { + if (fabs(node0->inWeight[i]) == 1 || fabs(node0->inWeight[i]) == 0) + cnt++; + } + if (cnt != 2) { + for (int i = 0; i < 2; i++) { + if (fabs(node0->inWeight[i]) != 1 && + fabs(node0->inWeight[i]) != 0) { + double w = node0->inWeight[i]; + int idx = get_cos_idx(w, COS_MOD); + int sgn = w < 0 ? 1 : 0; + + if (!visited[idx][sgn]) { + visited[idx][sgn] = 1; + printf(" __m128i %s = _mm_set1_epi32(%s);\n", + cos_text_sse4_1(w, COS_MOD, text, size), + cos_text_arr(w, COS_MOD, text1, size)); + } + } + } + Node *node1 = get_partner_node(node0); + node1->visited = 1; + } + } + } + } +} + +void single_node_to_code_sse4_1(Node *node, const char *buf0, + const char *buf1) { + printf(" %s[%2d] =", buf1, node->nodeIdx); + if (node->inWeight[0] == 1 && node->inWeight[1] == 1) { + printf(" _mm_add_epi32(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0, + node->inNodeIdx[1]); + } else if (node->inWeight[0] == 1 && node->inWeight[1] == -1) { + printf(" _mm_sub_epi32(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0, + node->inNodeIdx[1]); + } else if (node->inWeight[0] == -1 && node->inWeight[1] == 1) { + printf(" _mm_sub_epi32(%s[%d], %s[%d])", buf0, node->inNodeIdx[1], buf0, + node->inNodeIdx[0]); + } else if (node->inWeight[0] == 1 && node->inWeight[1] == 0) { + printf(" %s[%d]", buf0, node->inNodeIdx[0]); + } else if (node->inWeight[0] == 0 && node->inWeight[1] == 1) { + printf(" %s[%d]", buf0, node->inNodeIdx[1]); + } else if (node->inWeight[0] == -1 && node->inWeight[1] == 0) { + printf(" _mm_sub_epi32(__zero, %s[%d])", buf0, node->inNodeIdx[0]); + } else if (node->inWeight[0] == 0 && node->inWeight[1] == -1) { + printf(" _mm_sub_epi32(__zero, %s[%d])", buf0, node->inNodeIdx[1]); + } + printf(";\n"); +} + +void pair_node_to_code_sse4_1(Node *node, Node *partnerNode, const char *buf0, + const char *buf1) { + char temp0[100]; + char temp1[100]; + if (node->inWeight[0] * partnerNode->inWeight[0] < 0) { + /* type0 + * cos sin + * sin -cos + */ + // btf_32_sse2_type0(w0, w1, in0, in1, out0, out1) + // out0 = w0*in0 + w1*in1 + // out1 = -w0*in1 + w1*in0 + printf( + " btf_32_type0_sse4_1_new(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d], " + "__rounding, cos_bit);\n", + cos_text_sse4_1(node->inWeight[0], COS_MOD, temp0, 100), + cos_text_sse4_1(node->inWeight[1], COS_MOD, temp1, 100), buf0, + node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1, node->nodeIdx, buf1, + partnerNode->nodeIdx); + } else { + /* type1 + * cos sin + * -sin cos + */ + // btf_32_sse2_type1(w0, w1, in0, in1, out0, out1) + // out0 = w0*in0 + w1*in1 + // out1 = w0*in1 - w1*in0 + printf( + " btf_32_type1_sse4_1_new(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d], " + "__rounding, cos_bit);\n", + cos_text_sse4_1(node->inWeight[0], COS_MOD, temp0, 100), + cos_text_sse4_1(node->inWeight[1], COS_MOD, temp1, 100), buf0, + node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1, node->nodeIdx, buf1, + partnerNode->nodeIdx); + } +} + +void node_to_code_sse4_1(Node *node, const char *buf0, const char *buf1) { + int cnt = 0; + int cnt1 = 0; + if (node->visited == 0) { + node->visited = 1; + for (int i = 0; i < 2; i++) { + if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0) cnt++; + if (fabs(node->inWeight[i]) == 1) cnt1++; + } + if (cnt == 2) { + if (cnt1 == 2) { + // has a partner + Node *partnerNode = get_partner_node(node); + partnerNode->visited = 1; + single_node_to_code_sse4_1(node, buf0, buf1); + single_node_to_code_sse4_1(partnerNode, buf0, buf1); + } else { + single_node_to_code_sse2(node, buf0, buf1); + } + } else { + Node *partnerNode = get_partner_node(node); + partnerNode->visited = 1; + pair_node_to_code_sse4_1(node, partnerNode, buf0, buf1); + } + } +} + +void gen_code_sse4_1(Node *node, int stage_num, int node_num, TYPE_TXFM type) { + char *fun_name = new char[100]; + get_fun_name(fun_name, 100, type, node_num); + + printf("\n"); + printf( + "void %s_sse4_1(const __m128i *input, __m128i *output, int8_t cos_bit) " + "{\n", + fun_name); + + printf(" const int32_t* cospi = cospi_arr(cos_bit);\n"); + printf(" const __m128i __zero = _mm_setzero_si128();\n"); + printf(" const __m128i __rounding = _mm_set1_epi32(1 << (cos_bit - 1));\n"); + + graph_reset_visited(node, stage_num, node_num); + gen_cospi_list_sse4_1(node, stage_num, node_num); + graph_reset_visited(node, stage_num, node_num); + for (int si = 1; si < stage_num; si++) { + char in[100]; + char out[100]; + printf("\n"); + printf(" // stage %d\n", si); + if (si == 1) + snprintf(in, 100, "%s", "input"); + else + snprintf(in, 100, "x%d", si - 1); + if (si == stage_num - 1) { + snprintf(out, 100, "%s", "output"); + } else { + snprintf(out, 100, "x%d", si); + printf(" __m128i %s[%d];\n", out, node_num); + } + // computation code + for (int ni = 0; ni < node_num; ni++) { + int idx = get_idx(si, ni, node_num); + node_to_code_sse4_1(node + idx, in, out); + } + } + + printf("}\n"); +} + +void gen_hybrid_code(CODE_TYPE code_type, TYPE_TXFM txfm_type, int node_num) { + int stage_num = get_hybrid_stage_num(txfm_type, node_num); + + Node *node = new Node[node_num * stage_num]; + init_graph(node, stage_num, node_num); + + gen_hybrid_graph_1d(node, stage_num, node_num, 0, 0, node_num, txfm_type); + + switch (code_type) { + case CODE_TYPE_C: gen_code_c(node, stage_num, node_num, txfm_type); break; + case CODE_TYPE_SSE2: + gen_code_sse2(node, stage_num, node_num, txfm_type); + break; + case CODE_TYPE_SSE4_1: + gen_code_sse4_1(node, stage_num, node_num, txfm_type); + break; + } + + delete[] node; +} + +int main(int argc, char **argv) { + CODE_TYPE code_type = CODE_TYPE_SSE4_1; + for (int txfm_type = TYPE_DCT; txfm_type < TYPE_LAST; txfm_type++) { + for (int node_num = 4; node_num <= 64; node_num *= 2) { + gen_hybrid_code(code_type, (TYPE_TXFM)txfm_type, node_num); + } + } + return 0; +} diff --git a/third_party/aom/tools/txfm_analyzer/txfm_graph.cc b/third_party/aom/tools/txfm_analyzer/txfm_graph.cc new file mode 100644 index 0000000000..a249061008 --- /dev/null +++ b/third_party/aom/tools/txfm_analyzer/txfm_graph.cc @@ -0,0 +1,943 @@ +/* + * Copyright (c) 2018, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include "tools/txfm_analyzer/txfm_graph.h" + +#include +#include +#include + +typedef struct Node Node; + +void get_fun_name(char *str_fun_name, int str_buf_size, const TYPE_TXFM type, + const int txfm_size) { + if (type == TYPE_DCT) + snprintf(str_fun_name, str_buf_size, "fdct%d_new", txfm_size); + else if (type == TYPE_ADST) + snprintf(str_fun_name, str_buf_size, "fadst%d_new", txfm_size); + else if (type == TYPE_IDCT) + snprintf(str_fun_name, str_buf_size, "idct%d_new", txfm_size); + else if (type == TYPE_IADST) + snprintf(str_fun_name, str_buf_size, "iadst%d_new", txfm_size); +} + +void get_txfm_type_name(char *str_fun_name, int str_buf_size, + const TYPE_TXFM type, const int txfm_size) { + if (type == TYPE_DCT) + snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_DCT%d", txfm_size); + else if (type == TYPE_ADST) + snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_ADST%d", txfm_size); + else if (type == TYPE_IDCT) + snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_DCT%d", txfm_size); + else if (type == TYPE_IADST) + snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_ADST%d", txfm_size); +} + +void get_hybrid_2d_type_name(char *buf, int buf_size, const TYPE_TXFM type0, + const TYPE_TXFM type1, const int txfm_size0, + const int txfm_size1) { + if (type0 == TYPE_DCT && type1 == TYPE_DCT) + snprintf(buf, buf_size, "_dct_dct_%dx%d", txfm_size1, txfm_size0); + else if (type0 == TYPE_DCT && type1 == TYPE_ADST) + snprintf(buf, buf_size, "_dct_adst_%dx%d", txfm_size1, txfm_size0); + else if (type0 == TYPE_ADST && type1 == TYPE_ADST) + snprintf(buf, buf_size, "_adst_adst_%dx%d", txfm_size1, txfm_size0); + else if (type0 == TYPE_ADST && type1 == TYPE_DCT) + snprintf(buf, buf_size, "_adst_dct_%dx%d", txfm_size1, txfm_size0); +} + +TYPE_TXFM get_inv_type(TYPE_TXFM type) { + if (type == TYPE_DCT) + return TYPE_IDCT; + else if (type == TYPE_ADST) + return TYPE_IADST; + else if (type == TYPE_IDCT) + return TYPE_DCT; + else if (type == TYPE_IADST) + return TYPE_ADST; + else + return TYPE_LAST; +} + +void reference_dct_1d(double *in, double *out, int size) { + const double kInvSqrt2 = 0.707106781186547524400844362104; + for (int k = 0; k < size; k++) { + out[k] = 0; // initialize out[k] + for (int n = 0; n < size; n++) { + out[k] += in[n] * cos(PI * (2 * n + 1) * k / (2 * size)); + } + if (k == 0) out[k] = out[k] * kInvSqrt2; + } +} + +void reference_dct_2d(double *in, double *out, int size) { + double *tempOut = new double[size * size]; + // dct each row: in -> out + for (int r = 0; r < size; r++) { + reference_dct_1d(in + r * size, out + r * size, size); + } + + for (int r = 0; r < size; r++) { + // out ->tempOut + for (int c = 0; c < size; c++) { + tempOut[r * size + c] = out[c * size + r]; + } + } + for (int r = 0; r < size; r++) { + reference_dct_1d(tempOut + r * size, out + r * size, size); + } + delete[] tempOut; +} + +void reference_adst_1d(double *in, double *out, int size) { + for (int k = 0; k < size; k++) { + out[k] = 0; // initialize out[k] + for (int n = 0; n < size; n++) { + out[k] += in[n] * sin(PI * (2 * n + 1) * (2 * k + 1) / (4 * size)); + } + } +} + +void reference_hybrid_2d(double *in, double *out, int size, int type0, + int type1) { + double *tempOut = new double[size * size]; + // dct each row: in -> out + for (int r = 0; r < size; r++) { + if (type0 == TYPE_DCT) + reference_dct_1d(in + r * size, out + r * size, size); + else + reference_adst_1d(in + r * size, out + r * size, size); + } + + for (int r = 0; r < size; r++) { + // out ->tempOut + for (int c = 0; c < size; c++) { + tempOut[r * size + c] = out[c * size + r]; + } + } + for (int r = 0; r < size; r++) { + if (type1 == TYPE_DCT) + reference_dct_1d(tempOut + r * size, out + r * size, size); + else + reference_adst_1d(tempOut + r * size, out + r * size, size); + } + delete[] tempOut; +} + +void reference_hybrid_2d_new(double *in, double *out, int size0, int size1, + int type0, int type1) { + double *tempOut = new double[size0 * size1]; + // dct each row: in -> out + for (int r = 0; r < size1; r++) { + if (type0 == TYPE_DCT) + reference_dct_1d(in + r * size0, out + r * size0, size0); + else + reference_adst_1d(in + r * size0, out + r * size0, size0); + } + + for (int r = 0; r < size1; r++) { + // out ->tempOut + for (int c = 0; c < size0; c++) { + tempOut[c * size1 + r] = out[r * size0 + c]; + } + } + for (int r = 0; r < size0; r++) { + if (type1 == TYPE_DCT) + reference_dct_1d(tempOut + r * size1, out + r * size1, size1); + else + reference_adst_1d(tempOut + r * size1, out + r * size1, size1); + } + delete[] tempOut; +} + +unsigned int get_max_bit(unsigned int x) { + int max_bit = -1; + while (x) { + x = x >> 1; + max_bit++; + } + return max_bit; +} + +unsigned int bitwise_reverse(unsigned int x, int max_bit) { + x = ((x >> 16) & 0x0000ffff) | ((x & 0x0000ffff) << 16); + x = ((x >> 8) & 0x00ff00ff) | ((x & 0x00ff00ff) << 8); + x = ((x >> 4) & 0x0f0f0f0f) | ((x & 0x0f0f0f0f) << 4); + x = ((x >> 2) & 0x33333333) | ((x & 0x33333333) << 2); + x = ((x >> 1) & 0x55555555) | ((x & 0x55555555) << 1); + x = x >> (31 - max_bit); + return x; +} + +int get_idx(int ri, int ci, int cSize) { return ri * cSize + ci; } + +void add_node(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int in, double w) { + int outIdx = get_idx(stage_idx, node_idx, node_num); + int inIdx = get_idx(stage_idx - 1, in, node_num); + int idx = node[outIdx].inNodeNum; + if (idx < 2) { + node[outIdx].inNode[idx] = &node[inIdx]; + node[outIdx].inNodeIdx[idx] = in; + node[outIdx].inWeight[idx] = w; + idx++; + node[outIdx].inNodeNum = idx; + } else { + printf("Error: inNode is full"); + } +} + +void connect_node(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int in0, double w0, int in1, double w1) { + int outIdx = get_idx(stage_idx, node_idx, node_num); + int inIdx0 = get_idx(stage_idx - 1, in0, node_num); + int inIdx1 = get_idx(stage_idx - 1, in1, node_num); + + int idx = 0; + // if(w0 != 0) { + node[outIdx].inNode[idx] = &node[inIdx0]; + node[outIdx].inNodeIdx[idx] = in0; + node[outIdx].inWeight[idx] = w0; + idx++; + //} + + // if(w1 != 0) { + node[outIdx].inNode[idx] = &node[inIdx1]; + node[outIdx].inNodeIdx[idx] = in1; + node[outIdx].inWeight[idx] = w1; + idx++; + //} + + node[outIdx].inNodeNum = idx; +} + +void propagate(Node *node, int stage_num, int node_num, int stage_idx) { + for (int ni = 0; ni < node_num; ni++) { + int outIdx = get_idx(stage_idx, ni, node_num); + node[outIdx].value = 0; + for (int k = 0; k < node[outIdx].inNodeNum; k++) { + node[outIdx].value += + node[outIdx].inNode[k]->value * node[outIdx].inWeight[k]; + } + } +} + +int64_t round_shift(int64_t value, int bit) { + if (bit > 0) { + if (value < 0) { + return -round_shift(-value, bit); + } else { + return (value + (1 << (bit - 1))) >> bit; + } + } else { + return value << (-bit); + } +} + +void round_shift_array(int32_t *arr, int size, int bit) { + if (bit == 0) { + return; + } else { + for (int i = 0; i < size; i++) { + arr[i] = round_shift(arr[i], bit); + } + } +} + +void graph_reset_visited(Node *node, int stage_num, int node_num) { + for (int si = 0; si < stage_num; si++) { + for (int ni = 0; ni < node_num; ni++) { + int idx = get_idx(si, ni, node_num); + node[idx].visited = 0; + } + } +} + +void estimate_value(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int estimate_bit) { + if (stage_idx > 0) { + int outIdx = get_idx(stage_idx, node_idx, node_num); + int64_t out = 0; + node[outIdx].value = 0; + for (int k = 0; k < node[outIdx].inNodeNum; k++) { + int64_t w = round(node[outIdx].inWeight[k] * (1 << estimate_bit)); + int64_t v = round(node[outIdx].inNode[k]->value); + out += v * w; + } + node[outIdx].value = round_shift(out, estimate_bit); + } +} + +void amplify_value(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int amplify_bit) { + int outIdx = get_idx(stage_idx, node_idx, node_num); + node[outIdx].value = round_shift(round(node[outIdx].value), -amplify_bit); +} + +void propagate_estimate_amlify(Node *node, int stage_num, int node_num, + int stage_idx, int amplify_bit, + int estimate_bit) { + for (int ni = 0; ni < node_num; ni++) { + estimate_value(node, stage_num, node_num, stage_idx, ni, estimate_bit); + amplify_value(node, stage_num, node_num, stage_idx, ni, amplify_bit); + } +} + +void init_graph(Node *node, int stage_num, int node_num) { + for (int si = 0; si < stage_num; si++) { + for (int ni = 0; ni < node_num; ni++) { + int outIdx = get_idx(si, ni, node_num); + node[outIdx].stageIdx = si; + node[outIdx].nodeIdx = ni; + node[outIdx].value = 0; + node[outIdx].inNodeNum = 0; + if (si >= 1) { + connect_node(node, stage_num, node_num, si, ni, ni, 1, ni, 0); + } + } + } +} + +void gen_B_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int N, int star) { + for (int i = 0; i < N / 2; i++) { + int out = node_idx + i; + int in1 = node_idx + N - 1 - i; + if (star == 1) { + connect_node(node, stage_num, node_num, stage_idx + 1, out, out, -1, in1, + 1); + } else { + connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, in1, + 1); + } + } + for (int i = N / 2; i < N; i++) { + int out = node_idx + i; + int in1 = node_idx + N - 1 - i; + if (star == 1) { + connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, in1, + 1); + } else { + connect_node(node, stage_num, node_num, stage_idx + 1, out, out, -1, in1, + 1); + } + } +} + +void gen_P_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int N) { + int max_bit = get_max_bit(N - 1); + for (int i = 0; i < N; i++) { + int out = node_idx + bitwise_reverse(i, max_bit); + int in = node_idx + i; + connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); + } +} + +void gen_type1_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int N) { + int max_bit = get_max_bit(N); + for (int ni = 0; ni < N / 2; ni++) { + int ai = bitwise_reverse(N + ni, max_bit); + int out = node_idx + ni; + int in1 = node_idx + N - ni - 1; + connect_node(node, stage_num, node_num, stage_idx + 1, out, out, + sin(PI * ai / (2 * 2 * N)), in1, cos(PI * ai / (2 * 2 * N))); + } + for (int ni = N / 2; ni < N; ni++) { + int ai = bitwise_reverse(N + ni, max_bit); + int out = node_idx + ni; + int in1 = node_idx + N - ni - 1; + connect_node(node, stage_num, node_num, stage_idx + 1, out, out, + cos(PI * ai / (2 * 2 * N)), in1, -sin(PI * ai / (2 * 2 * N))); + } +} + +void gen_type2_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int N) { + for (int ni = 0; ni < N / 4; ni++) { + int out = node_idx + ni; + connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out, 0); + } + + for (int ni = N / 4; ni < N / 2; ni++) { + int out = node_idx + ni; + int in1 = node_idx + N - ni - 1; + connect_node(node, stage_num, node_num, stage_idx + 1, out, out, + -cos(PI / 4), in1, cos(-PI / 4)); + } + + for (int ni = N / 2; ni < N * 3 / 4; ni++) { + int out = node_idx + ni; + int in1 = node_idx + N - ni - 1; + connect_node(node, stage_num, node_num, stage_idx + 1, out, out, + cos(-PI / 4), in1, cos(PI / 4)); + } + + for (int ni = N * 3 / 4; ni < N; ni++) { + int out = node_idx + ni; + connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out, 0); + } +} + +void gen_type3_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int idx, int N) { + // TODO(angiebird): Simplify and clarify this function + + int i = 2 * N / (1 << (idx / 2)); + int max_bit = + get_max_bit(i / 2) - 1; // the max_bit counts on i/2 instead of N here + int N_over_i = 2 << (idx / 2); + + for (int nj = 0; nj < N / 2; nj += N_over_i) { + int j = nj / (N_over_i); + int kj = bitwise_reverse(i / 4 + j, max_bit); + // printf("kj = %d\n", kj); + + // I_N/2i --- 0 + int offset = nj; + for (int ni = 0; ni < N_over_i / 4; ni++) { + int out = node_idx + offset + ni; + int in = out; + connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); + } + + // -C_Kj/i --- S_Kj/i + offset += N_over_i / 4; + for (int ni = 0; ni < N_over_i / 4; ni++) { + int out = node_idx + offset + ni; + int in0 = out; + double w0 = -cos(kj * PI / i); + int in1 = N - (offset + ni) - 1 + node_idx; + double w1 = sin(kj * PI / i); + connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1, + w1); + } + + // S_kj/i --- -C_Kj/i + offset += N_over_i / 4; + for (int ni = 0; ni < N_over_i / 4; ni++) { + int out = node_idx + offset + ni; + int in0 = out; + double w0 = -sin(kj * PI / i); + int in1 = N - (offset + ni) - 1 + node_idx; + double w1 = -cos(kj * PI / i); + connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1, + w1); + } + + // I_N/2i --- 0 + offset += N_over_i / 4; + for (int ni = 0; ni < N_over_i / 4; ni++) { + int out = node_idx + offset + ni; + int in = out; + connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); + } + } + + for (int nj = N / 2; nj < N; nj += N_over_i) { + int j = nj / N_over_i; + int kj = bitwise_reverse(i / 4 + j, max_bit); + + // I_N/2i --- 0 + int offset = nj; + for (int ni = 0; ni < N_over_i / 4; ni++) { + int out = node_idx + offset + ni; + int in = out; + connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); + } + + // C_kj/i --- -S_Kj/i + offset += N_over_i / 4; + for (int ni = 0; ni < N_over_i / 4; ni++) { + int out = node_idx + offset + ni; + int in0 = out; + double w0 = cos(kj * PI / i); + int in1 = N - (offset + ni) - 1 + node_idx; + double w1 = -sin(kj * PI / i); + connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1, + w1); + } + + // S_kj/i --- C_Kj/i + offset += N_over_i / 4; + for (int ni = 0; ni < N_over_i / 4; ni++) { + int out = node_idx + offset + ni; + int in0 = out; + double w0 = sin(kj * PI / i); + int in1 = N - (offset + ni) - 1 + node_idx; + double w1 = cos(kj * PI / i); + connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1, + w1); + } + + // I_N/2i --- 0 + offset += N_over_i / 4; + for (int ni = 0; ni < N_over_i / 4; ni++) { + int out = node_idx + offset + ni; + int in = out; + connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); + } + } +} + +void gen_type4_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int idx, int N) { + int B_size = 1 << ((idx + 1) / 2); + for (int ni = 0; ni < N; ni += B_size) { + gen_B_graph(node, stage_num, node_num, stage_idx, node_idx + ni, B_size, + (ni / B_size) % 2); + } +} + +void gen_R_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int N) { + int max_idx = 2 * (get_max_bit(N) + 1) - 3; + for (int idx = 0; idx < max_idx; idx++) { + int s = stage_idx + max_idx - idx - 1; + if (idx == 0) { + // type 1 + gen_type1_graph(node, stage_num, node_num, s, node_idx, N); + } else if (idx == max_idx - 1) { + // type 2 + gen_type2_graph(node, stage_num, node_num, s, node_idx, N); + } else if ((idx + 1) % 2 == 0) { + // type 4 + gen_type4_graph(node, stage_num, node_num, s, node_idx, idx, N); + } else if ((idx + 1) % 2 == 1) { + // type 3 + gen_type3_graph(node, stage_num, node_num, s, node_idx, idx, N); + } else { + printf("check gen_R_graph()\n"); + } + } +} + +void gen_DCT_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int N) { + if (N > 2) { + gen_B_graph(node, stage_num, node_num, stage_idx, node_idx, N, 0); + gen_DCT_graph(node, stage_num, node_num, stage_idx + 1, node_idx, N / 2); + gen_R_graph(node, stage_num, node_num, stage_idx + 1, node_idx + N / 2, + N / 2); + } else { + // generate dct_2 + connect_node(node, stage_num, node_num, stage_idx + 1, node_idx, node_idx, + cos(PI / 4), node_idx + 1, cos(PI / 4)); + connect_node(node, stage_num, node_num, stage_idx + 1, node_idx + 1, + node_idx + 1, -cos(PI / 4), node_idx, cos(PI / 4)); + } +} + +int get_dct_stage_num(int size) { return 2 * get_max_bit(size); } + +void gen_DCT_graph_1d(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int dct_node_num) { + gen_DCT_graph(node, stage_num, node_num, stage_idx, node_idx, dct_node_num); + int dct_stage_num = get_dct_stage_num(dct_node_num); + gen_P_graph(node, stage_num, node_num, stage_idx + dct_stage_num - 2, + node_idx, dct_node_num); +} + +void gen_adst_B_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_idx) { + int size = 1 << (adst_idx + 1); + for (int ni = 0; ni < size / 2; ni++) { + int nOut = node_idx + ni; + int nIn = nOut + size / 2; + connect_node(node, stage_num, node_num, stage_idx + 1, nOut, nOut, 1, nIn, + 1); + // printf("nOut: %d nIn: %d\n", nOut, nIn); + } + for (int ni = size / 2; ni < size; ni++) { + int nOut = node_idx + ni; + int nIn = nOut - size / 2; + connect_node(node, stage_num, node_num, stage_idx + 1, nOut, nOut, -1, nIn, + 1); + // printf("ndctOut: %d nIn: %d\n", nOut, nIn); + } +} + +void gen_adst_U_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_idx, int adst_node_num) { + int size = 1 << (adst_idx + 1); + for (int ni = 0; ni < adst_node_num; ni += size) { + gen_adst_B_graph(node, stage_num, node_num, stage_idx, node_idx + ni, + adst_idx); + } +} + +void gen_adst_T_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, double freq) { + connect_node(node, stage_num, node_num, stage_idx + 1, node_idx, node_idx, + cos(freq * PI), node_idx + 1, sin(freq * PI)); + connect_node(node, stage_num, node_num, stage_idx + 1, node_idx + 1, + node_idx + 1, -cos(freq * PI), node_idx, sin(freq * PI)); +} + +void gen_adst_E_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_idx) { + int size = 1 << (adst_idx); + for (int i = 0; i < size / 2; i++) { + int ni = i * 2; + double fi = (1 + 4 * i) * 1.0 / (1 << (adst_idx + 1)); + gen_adst_T_graph(node, stage_num, node_num, stage_idx, node_idx + ni, fi); + } +} + +void gen_adst_V_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_idx, int adst_node_num) { + int size = 1 << (adst_idx); + for (int i = 0; i < adst_node_num / size; i++) { + if (i % 2 == 1) { + int ni = i * size; + gen_adst_E_graph(node, stage_num, node_num, stage_idx, node_idx + ni, + adst_idx); + } + } +} +void gen_adst_VJ_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_node_num) { + for (int i = 0; i < adst_node_num / 2; i++) { + int ni = i * 2; + double fi = (1 + 4 * i) * 1.0 / (4 * adst_node_num); + gen_adst_T_graph(node, stage_num, node_num, stage_idx, node_idx + ni, fi); + } +} +void gen_adst_Q_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_node_num) { + // reverse order when idx is 1, 3, 5, 7 ... + // example of adst_node_num = 8: + // 0 1 2 3 4 5 6 7 + // --> 0 7 2 5 4 3 6 1 + for (int ni = 0; ni < adst_node_num; ni++) { + if (ni % 2 == 0) { + int out = node_idx + ni; + connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out, + 0); + } else { + int out = node_idx + ni; + int in = node_idx + adst_node_num - ni; + connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); + } + } +} +void gen_adst_Ibar_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_node_num) { + // reverse order + // 0 1 2 3 --> 3 2 1 0 + for (int ni = 0; ni < adst_node_num; ni++) { + int out = node_idx + ni; + int in = node_idx + adst_node_num - ni - 1; + connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); + } +} + +int get_Q_out2in(int adst_node_num, int out) { + int in; + if (out % 2 == 0) { + in = out; + } else { + in = adst_node_num - out; + } + return in; +} + +int get_Ibar_out2in(int adst_node_num, int out) { + return adst_node_num - out - 1; +} + +void gen_adst_IbarQ_graph(Node *node, int stage_num, int node_num, + int stage_idx, int node_idx, int adst_node_num) { + // in -> Ibar -> Q -> out + for (int ni = 0; ni < adst_node_num; ni++) { + int out = node_idx + ni; + int in = node_idx + + get_Ibar_out2in(adst_node_num, get_Q_out2in(adst_node_num, ni)); + connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); + } +} + +void gen_adst_D_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_node_num) { + // reverse order + for (int ni = 0; ni < adst_node_num; ni++) { + int out = node_idx + ni; + int in = out; + if (ni % 2 == 0) { + connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); + } else { + connect_node(node, stage_num, node_num, stage_idx + 1, out, in, -1, in, + 0); + } + } +} + +int get_hadamard_idx(int x, int adst_node_num) { + int max_bit = get_max_bit(adst_node_num - 1); + x = bitwise_reverse(x, max_bit); + + // gray code + int c = x & 1; + int p = x & 1; + int y = c; + + for (int i = 1; i <= max_bit; i++) { + p = c; + c = (x >> i) & 1; + y += (c ^ p) << i; + } + return y; +} + +void gen_adst_Ht_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_node_num) { + for (int ni = 0; ni < adst_node_num; ni++) { + int out = node_idx + ni; + int in = node_idx + get_hadamard_idx(ni, adst_node_num); + connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); + } +} + +void gen_adst_HtD_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_node_num) { + for (int ni = 0; ni < adst_node_num; ni++) { + int out = node_idx + ni; + int in = node_idx + get_hadamard_idx(ni, adst_node_num); + double inW; + if (ni % 2 == 0) + inW = 1; + else + inW = -1; + connect_node(node, stage_num, node_num, stage_idx + 1, out, in, inW, in, 0); + } +} + +int get_adst_stage_num(int adst_node_num) { + return 2 * get_max_bit(adst_node_num) + 2; +} + +int gen_iadst_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_node_num) { + int max_bit = get_max_bit(adst_node_num); + int si = 0; + gen_adst_IbarQ_graph(node, stage_num, node_num, stage_idx + si, node_idx, + adst_node_num); + si++; + gen_adst_VJ_graph(node, stage_num, node_num, stage_idx + si, node_idx, + adst_node_num); + si++; + for (int adst_idx = max_bit - 1; adst_idx >= 1; adst_idx--) { + gen_adst_U_graph(node, stage_num, node_num, stage_idx + si, node_idx, + adst_idx, adst_node_num); + si++; + gen_adst_V_graph(node, stage_num, node_num, stage_idx + si, node_idx, + adst_idx, adst_node_num); + si++; + } + gen_adst_HtD_graph(node, stage_num, node_num, stage_idx + si, node_idx, + adst_node_num); + si++; + return si + 1; +} + +int gen_adst_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_node_num) { + int hybrid_stage_num = get_hybrid_stage_num(TYPE_ADST, adst_node_num); + // generate a adst tempNode + Node *tempNode = new Node[hybrid_stage_num * adst_node_num]; + init_graph(tempNode, hybrid_stage_num, adst_node_num); + int si = gen_iadst_graph(tempNode, hybrid_stage_num, adst_node_num, 0, 0, + adst_node_num); + + // tempNode's inverse graph to node[stage_idx][node_idx] + gen_inv_graph(tempNode, hybrid_stage_num, adst_node_num, node, stage_num, + node_num, stage_idx, node_idx); + delete[] tempNode; + return si; +} + +void connect_layer_2d(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int dct_node_num) { + for (int first = 0; first < dct_node_num; first++) { + for (int second = 0; second < dct_node_num; second++) { + // int sIn = stage_idx; + int sOut = stage_idx + 1; + int nIn = node_idx + first * dct_node_num + second; + int nOut = node_idx + second * dct_node_num + first; + + // printf("sIn: %d nIn: %d sOut: %d nOut: %d\n", sIn, nIn, sOut, nOut); + + connect_node(node, stage_num, node_num, sOut, nOut, nIn, 1, nIn, 0); + } + } +} + +void connect_layer_2d_new(Node *node, int stage_num, int node_num, + int stage_idx, int node_idx, int dct_node_num0, + int dct_node_num1) { + for (int i = 0; i < dct_node_num1; i++) { + for (int j = 0; j < dct_node_num0; j++) { + // int sIn = stage_idx; + int sOut = stage_idx + 1; + int nIn = node_idx + i * dct_node_num0 + j; + int nOut = node_idx + j * dct_node_num1 + i; + + // printf("sIn: %d nIn: %d sOut: %d nOut: %d\n", sIn, nIn, sOut, nOut); + + connect_node(node, stage_num, node_num, sOut, nOut, nIn, 1, nIn, 0); + } + } +} + +void gen_DCT_graph_2d(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int dct_node_num) { + int dct_stage_num = get_dct_stage_num(dct_node_num); + // put 2 layers of dct_node_num DCTs on the graph + for (int ni = 0; ni < dct_node_num; ni++) { + gen_DCT_graph_1d(node, stage_num, node_num, stage_idx, + node_idx + ni * dct_node_num, dct_node_num); + gen_DCT_graph_1d(node, stage_num, node_num, stage_idx + dct_stage_num, + node_idx + ni * dct_node_num, dct_node_num); + } + // connect first layer and second layer + connect_layer_2d(node, stage_num, node_num, stage_idx + dct_stage_num - 1, + node_idx, dct_node_num); +} + +int get_hybrid_stage_num(int type, int hybrid_node_num) { + if (type == TYPE_DCT || type == TYPE_IDCT) { + return get_dct_stage_num(hybrid_node_num); + } else if (type == TYPE_ADST || type == TYPE_IADST) { + return get_adst_stage_num(hybrid_node_num); + } + return 0; +} + +int get_hybrid_2d_stage_num(int type0, int type1, int hybrid_node_num) { + int stage_num = 0; + stage_num += get_hybrid_stage_num(type0, hybrid_node_num); + stage_num += get_hybrid_stage_num(type1, hybrid_node_num); + return stage_num; +} + +int get_hybrid_2d_stage_num_new(int type0, int type1, int hybrid_node_num0, + int hybrid_node_num1) { + int stage_num = 0; + stage_num += get_hybrid_stage_num(type0, hybrid_node_num0); + stage_num += get_hybrid_stage_num(type1, hybrid_node_num1); + return stage_num; +} + +int get_hybrid_amplify_factor(int type, int hybrid_node_num) { + return get_max_bit(hybrid_node_num) - 1; +} + +void gen_hybrid_graph_1d(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int hybrid_node_num, int type) { + if (type == TYPE_DCT) { + gen_DCT_graph_1d(node, stage_num, node_num, stage_idx, node_idx, + hybrid_node_num); + } else if (type == TYPE_ADST) { + gen_adst_graph(node, stage_num, node_num, stage_idx, node_idx, + hybrid_node_num); + } else if (type == TYPE_IDCT) { + int hybrid_stage_num = get_hybrid_stage_num(type, hybrid_node_num); + // generate a dct tempNode + Node *tempNode = new Node[hybrid_stage_num * hybrid_node_num]; + init_graph(tempNode, hybrid_stage_num, hybrid_node_num); + gen_DCT_graph_1d(tempNode, hybrid_stage_num, hybrid_node_num, 0, 0, + hybrid_node_num); + + // tempNode's inverse graph to node[stage_idx][node_idx] + gen_inv_graph(tempNode, hybrid_stage_num, hybrid_node_num, node, stage_num, + node_num, stage_idx, node_idx); + delete[] tempNode; + } else if (type == TYPE_IADST) { + int hybrid_stage_num = get_hybrid_stage_num(type, hybrid_node_num); + // generate a adst tempNode + Node *tempNode = new Node[hybrid_stage_num * hybrid_node_num]; + init_graph(tempNode, hybrid_stage_num, hybrid_node_num); + gen_adst_graph(tempNode, hybrid_stage_num, hybrid_node_num, 0, 0, + hybrid_node_num); + + // tempNode's inverse graph to node[stage_idx][node_idx] + gen_inv_graph(tempNode, hybrid_stage_num, hybrid_node_num, node, stage_num, + node_num, stage_idx, node_idx); + delete[] tempNode; + } +} + +void gen_hybrid_graph_2d(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int hybrid_node_num, int type0, + int type1) { + int hybrid_stage_num = get_hybrid_stage_num(type0, hybrid_node_num); + + for (int ni = 0; ni < hybrid_node_num; ni++) { + gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx, + node_idx + ni * hybrid_node_num, hybrid_node_num, + type0); + gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx + hybrid_stage_num, + node_idx + ni * hybrid_node_num, hybrid_node_num, + type1); + } + + // connect first layer and second layer + connect_layer_2d(node, stage_num, node_num, stage_idx + hybrid_stage_num - 1, + node_idx, hybrid_node_num); +} + +void gen_hybrid_graph_2d_new(Node *node, int stage_num, int node_num, + int stage_idx, int node_idx, int hybrid_node_num0, + int hybrid_node_num1, int type0, int type1) { + int hybrid_stage_num0 = get_hybrid_stage_num(type0, hybrid_node_num0); + + for (int ni = 0; ni < hybrid_node_num1; ni++) { + gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx, + node_idx + ni * hybrid_node_num0, hybrid_node_num0, + type0); + } + for (int ni = 0; ni < hybrid_node_num0; ni++) { + gen_hybrid_graph_1d( + node, stage_num, node_num, stage_idx + hybrid_stage_num0, + node_idx + ni * hybrid_node_num1, hybrid_node_num1, type1); + } + + // connect first layer and second layer + connect_layer_2d_new(node, stage_num, node_num, + stage_idx + hybrid_stage_num0 - 1, node_idx, + hybrid_node_num0, hybrid_node_num1); +} + +void gen_inv_graph(Node *node, int stage_num, int node_num, Node *invNode, + int inv_stage_num, int inv_node_num, int inv_stage_idx, + int inv_node_idx) { + // clean up inNodeNum in invNode because of add_node + for (int si = 1 + inv_stage_idx; si < inv_stage_idx + stage_num; si++) { + for (int ni = inv_node_idx; ni < inv_node_idx + node_num; ni++) { + int idx = get_idx(si, ni, inv_node_num); + invNode[idx].inNodeNum = 0; + } + } + // generate inverse graph of node on invNode + for (int si = 1; si < stage_num; si++) { + for (int ni = 0; ni < node_num; ni++) { + int invSi = stage_num - si; + int idx = get_idx(si, ni, node_num); + for (int k = 0; k < node[idx].inNodeNum; k++) { + int invNi = node[idx].inNodeIdx[k]; + add_node(invNode, inv_stage_num, inv_node_num, invSi + inv_stage_idx, + invNi + inv_node_idx, ni + inv_node_idx, + node[idx].inWeight[k]); + } + } + } +} diff --git a/third_party/aom/tools/txfm_analyzer/txfm_graph.h b/third_party/aom/tools/txfm_analyzer/txfm_graph.h new file mode 100644 index 0000000000..8dc36146dd --- /dev/null +++ b/third_party/aom/tools/txfm_analyzer/txfm_graph.h @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2018, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#ifndef AOM_TOOLS_TXFM_ANALYZER_TXFM_GRAPH_H_ +#define AOM_TOOLS_TXFM_ANALYZER_TXFM_GRAPH_H_ + +struct Node { + Node *inNode[2]; + int inNodeNum; + int inNodeIdx[2]; + double inWeight[2]; + double value; + int nodeIdx; + int stageIdx; + int visited; +}; + +#define STAGENUM (10) +#define NODENUM (32) +#define COS_MOD (128) + +typedef enum { + TYPE_DCT = 0, + TYPE_ADST, + TYPE_IDCT, + TYPE_IADST, + TYPE_LAST +} TYPE_TXFM; + +TYPE_TXFM get_inv_type(TYPE_TXFM type); +void get_fun_name(char *str_fun_name, int str_buf_size, const TYPE_TXFM type, + const int txfm_size); + +void get_txfm_type_name(char *str_fun_name, int str_buf_size, + const TYPE_TXFM type, const int txfm_size); +void get_hybrid_2d_type_name(char *buf, int buf_size, const TYPE_TXFM type0, + const TYPE_TXFM type1, const int txfm_size0, + const int txfm_size1); +unsigned int get_max_bit(unsigned int x); +unsigned int bitwise_reverse(unsigned int x, int max_bit); +int get_idx(int ri, int ci, int cSize); + +int get_dct_stage_num(int size); +void reference_dct_1d(double *in, double *out, int size); +void reference_dct_2d(double *in, double *out, int size); +void connect_node(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int in0, double w0, int in1, double w1); +void propagate(Node *node, int stage_num, int node_num, int stage); +void init_graph(Node *node, int stage_num, int node_num); +void graph_reset_visited(Node *node, int stage_num, int node_num); +void gen_B_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int N, int star); +void gen_P_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int N); + +void gen_type1_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int N); +void gen_type2_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int N); +void gen_type3_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int idx, int N); +void gen_type4_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int idx, int N); + +void gen_R_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int N); + +void gen_DCT_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int N); + +void gen_DCT_graph_1d(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int dct_node_num); +void connect_layer_2d(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int dct_node_num); + +void gen_DCT_graph_2d(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int dct_node_num); + +void gen_adst_B_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_idx); + +void gen_adst_U_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_idx, int adst_node_num); +void gen_adst_T_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, double freq); + +void gen_adst_E_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_idx); + +void gen_adst_V_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_idx, int adst_node_num); + +void gen_adst_VJ_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_node_num); +void gen_adst_Q_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_node_num); +void gen_adst_Ibar_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_node_num); + +void gen_adst_D_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_node_num); + +int get_hadamard_idx(int x, int adst_node_num); +void gen_adst_Ht_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_node_num); + +int gen_adst_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_node_num); +int gen_iadst_graph(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int adst_node_num); +void reference_adst_1d(double *in, double *out, int size); + +int get_adst_stage_num(int adst_node_num); +int get_hybrid_stage_num(int type, int hybrid_node_num); +int get_hybrid_2d_stage_num(int type0, int type1, int hybrid_node_num); +int get_hybrid_2d_stage_num_new(int type0, int type1, int hybrid_node_num0, + int hybrid_node_num1); +int get_hybrid_amplify_factor(int type, int hybrid_node_num); +void gen_hybrid_graph_1d(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int hybrid_node_num, int type); +void gen_hybrid_graph_2d(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int hybrid_node_num, int type0, + int type1); +void gen_hybrid_graph_2d_new(Node *node, int stage_num, int node_num, + int stage_idx, int node_idx, int hybrid_node_num0, + int hybrid_node_num1, int type0, int type1); + +void reference_hybrid_2d(double *in, double *out, int size, int type0, + int type1); + +void reference_hybrid_2d_new(double *in, double *out, int size0, int size1, + int type0, int type1); +void reference_adst_dct_2d(double *in, double *out, int size); + +void gen_code(Node *node, int stage_num, int node_num, TYPE_TXFM type); + +void gen_inv_graph(Node *node, int stage_num, int node_num, Node *invNode, + int inv_stage_num, int inv_node_num, int inv_stage_idx, + int inv_node_idx); + +TYPE_TXFM hybrid_char_to_int(char ctype); + +int64_t round_shift(int64_t value, int bit); +void round_shift_array(int32_t *arr, int size, int bit); +void estimate_value(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int estimate_bit); +void amplify_value(Node *node, int stage_num, int node_num, int stage_idx, + int node_idx, int estimate_bit); +void propagate_estimate_amlify(Node *node, int stage_num, int node_num, + int stage_idx, int amplify_bit, + int estimate_bit); +#endif // AOM_TOOLS_TXFM_ANALYZER_TXFM_GRAPH_H_ diff --git a/third_party/aom/tools/wrap-commit-msg.py b/third_party/aom/tools/wrap-commit-msg.py new file mode 100755 index 0000000000..c51ed093d3 --- /dev/null +++ b/third_party/aom/tools/wrap-commit-msg.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +## +## Copyright (c) 2016, Alliance for Open Media. All rights reserved +## +## This source code is subject to the terms of the BSD 2 Clause License and +## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License +## was not distributed with this source code in the LICENSE file, you can +## obtain it at www.aomedia.org/license/software. If the Alliance for Open +## Media Patent License 1.0 was not distributed with this source code in the +## PATENTS file, you can obtain it at www.aomedia.org/license/patent. +## +"""Wraps paragraphs of text, preserving manual formatting + +This is like fold(1), but has the special convention of not modifying lines +that start with whitespace. This allows you to intersperse blocks with +special formatting, like code blocks, with written prose. The prose will +be wordwrapped, and the manual formatting will be preserved. + + * This won't handle the case of a bulleted (or ordered) list specially, so + manual wrapping must be done. + +Occasionally it's useful to put something with explicit formatting that +doesn't look at all like a block of text inline. + + indicator = has_leading_whitespace(line); + if (indicator) + preserve_formatting(line); + +The intent is that this docstring would make it through the transform +and still be legible and presented as it is in the source. If additional +cases are handled, update this doc to describe the effect. +""" + +__author__ = "jkoleszar@google.com" +import textwrap +import sys + +def wrap(text): + if text: + return textwrap.fill(text, break_long_words=False) + '\n' + return "" + + +def main(fileobj): + text = "" + output = "" + while True: + line = fileobj.readline() + if not line: + break + + if line.lstrip() == line: + text += line + else: + output += wrap(text) + text="" + output += line + output += wrap(text) + + # Replace the file or write to stdout. + if fileobj == sys.stdin: + fileobj = sys.stdout + else: + fileobj.seek(0) + fileobj.truncate(0) + fileobj.write(output) + +if __name__ == "__main__": + if len(sys.argv) > 1: + main(open(sys.argv[1], "r+")) + else: + main(sys.stdin) -- cgit v1.2.3