summaryrefslogtreecommitdiffstats
path: root/third_party/aom/tools
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 00:47:55 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 00:47:55 +0000
commit26a029d407be480d791972afb5975cf62c9360a6 (patch)
treef435a8308119effd964b339f76abb83a57c29483 /third_party/aom/tools
parentInitial commit. (diff)
downloadfirefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz
firefox-26a029d407be480d791972afb5975cf62c9360a6.zip
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/aom/tools')
-rw-r--r--third_party/aom/tools/aggregate_entropy_stats.py39
-rw-r--r--third_party/aom/tools/aom_entropy_optimizer.c761
-rw-r--r--third_party/aom/tools/auto_refactor/auto_refactor.py919
-rw-r--r--third_party/aom/tools/auto_refactor/av1_preprocess.py113
-rw-r--r--third_party/aom/tools/auto_refactor/c_files/decl_status_code.c31
-rw-r--r--third_party/aom/tools/auto_refactor/c_files/func_in_out.c208
-rw-r--r--third_party/aom/tools/auto_refactor/c_files/global_variable.c27
-rw-r--r--third_party/aom/tools/auto_refactor/c_files/parse_lvalue.c46
-rw-r--r--third_party/aom/tools/auto_refactor/c_files/simple_code.c64
-rw-r--r--third_party/aom/tools/auto_refactor/c_files/struct_code.c49
-rw-r--r--third_party/aom/tools/auto_refactor/test_auto_refactor.py675
-rwxr-xr-xthird_party/aom/tools/cpplint.py6244
-rw-r--r--third_party/aom/tools/diff.py132
-rw-r--r--third_party/aom/tools/dump_obu.cc168
-rw-r--r--third_party/aom/tools/frame_size_variation_analyzer.py74
-rwxr-xr-xthird_party/aom/tools/gen_authors.sh10
-rwxr-xr-xthird_party/aom/tools/gen_constrained_tokenset.py120
-rw-r--r--third_party/aom/tools/gop_bitrate/analyze_data.py18
-rwxr-xr-xthird_party/aom/tools/gop_bitrate/encode_all_script.sh13
-rw-r--r--third_party/aom/tools/gop_bitrate/python/bitrate_accuracy.py185
-rw-r--r--third_party/aom/tools/inspect-cli.js39
-rw-r--r--third_party/aom/tools/inspect-post.js1
-rwxr-xr-xthird_party/aom/tools/intersect-diffs.py78
-rwxr-xr-xthird_party/aom/tools/lint-hunks.py150
-rw-r--r--third_party/aom/tools/obu_parser.cc190
-rw-r--r--third_party/aom/tools/obu_parser.h27
-rw-r--r--third_party/aom/tools/ratectrl_log_analyzer/analyze_ratectrl_log.py154
-rw-r--r--third_party/aom/tools/txfm_analyzer/txfm_gen_code.cc580
-rw-r--r--third_party/aom/tools/txfm_analyzer/txfm_graph.cc943
-rw-r--r--third_party/aom/tools/txfm_analyzer/txfm_graph.h160
-rwxr-xr-xthird_party/aom/tools/wrap-commit-msg.py72
31 files changed, 12290 insertions, 0 deletions
diff --git a/third_party/aom/tools/aggregate_entropy_stats.py b/third_party/aom/tools/aggregate_entropy_stats.py
new file mode 100644
index 0000000000..0311681f2d
--- /dev/null
+++ b/third_party/aom/tools/aggregate_entropy_stats.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python3
+## Copyright (c) 2017, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 2 Clause License and
+## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+## was not distributed with this source code in the LICENSE file, you can
+## obtain it at www.aomedia.org/license/software. If the Alliance for Open
+## Media Patent License 1.0 was not distributed with this source code in the
+## PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+##
+"""Aggregate multiple entropy stats output which is written in 32-bit int.
+
+python ./aggregate_entropy_stats.py [dir of stats files] [keyword of filenames]
+ [filename of final stats]
+"""
+
+__author__ = "yuec@google.com"
+
+import os
+import sys
+import numpy as np
+
+def main():
+ dir = sys.argv[1]
+ sum = []
+ for fn in os.listdir(dir):
+ if sys.argv[2] in fn:
+ stats = np.fromfile(dir + fn, dtype=np.int32)
+ if len(sum) == 0:
+ sum = stats
+ else:
+ sum = np.add(sum, stats)
+ if len(sum) == 0:
+ print("No stats file is found. Double-check directory and keyword?")
+ else:
+ sum.tofile(dir+sys.argv[3])
+
+if __name__ == '__main__':
+ main()
diff --git a/third_party/aom/tools/aom_entropy_optimizer.c b/third_party/aom/tools/aom_entropy_optimizer.c
new file mode 100644
index 0000000000..fa7bf7ea9e
--- /dev/null
+++ b/third_party/aom/tools/aom_entropy_optimizer.c
@@ -0,0 +1,761 @@
+/*
+ * Copyright (c) 2017, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+// This tool is a gadget for offline probability training.
+// A binary executable aom_entropy_optimizer will be generated in tools/. It
+// parses a binary file consisting of counts written in the format of
+// FRAME_COUNTS in entropymode.h, and computes optimized probability tables
+// and CDF tables, which will be written to a new c file optimized_probs.c
+// according to format in the codebase.
+//
+// Command line: ./aom_entropy_optimizer [directory of the count file]
+//
+// The input file can either be generated by encoding a single clip by
+// turning on entropy_stats experiment, or be collected at a larger scale at
+// which a python script which will be provided soon can be used to aggregate
+// multiple stats output.
+
+#include <assert.h>
+#include <stdio.h>
+
+#include "config/aom_config.h"
+
+#include "av1/encoder/encoder.h"
+
+#define SPACES_PER_TAB 2
+#define CDF_MAX_SIZE 16
+
+typedef unsigned int aom_count_type;
+// A log file recording parsed counts
+static FILE *logfile; // TODO(yuec): make it a command line option
+
+static void counts_to_cdf(const aom_count_type *counts, aom_cdf_prob *cdf,
+ int modes) {
+ int64_t csum[CDF_MAX_SIZE];
+ assert(modes <= CDF_MAX_SIZE);
+
+ csum[0] = counts[0] + 1;
+ for (int i = 1; i < modes; ++i) csum[i] = counts[i] + 1 + csum[i - 1];
+
+ for (int i = 0; i < modes; ++i) fprintf(logfile, "%d ", counts[i]);
+ fprintf(logfile, "\n");
+
+ int64_t sum = csum[modes - 1];
+ const int64_t round_shift = sum >> 1;
+ for (int i = 0; i < modes; ++i) {
+ cdf[i] = (csum[i] * CDF_PROB_TOP + round_shift) / sum;
+ cdf[i] = AOMMIN(cdf[i], CDF_PROB_TOP - (modes - 1 + i) * 4);
+ cdf[i] = (i == 0) ? AOMMAX(cdf[i], 4) : AOMMAX(cdf[i], cdf[i - 1] + 4);
+ }
+}
+
+static int parse_counts_for_cdf_opt(aom_count_type **ct_ptr,
+ FILE *const probsfile, int tabs,
+ int dim_of_cts, int *cts_each_dim) {
+ if (dim_of_cts < 1) {
+ fprintf(stderr, "The dimension of a counts vector should be at least 1!\n");
+ return 1;
+ }
+ const int total_modes = cts_each_dim[0];
+ if (dim_of_cts == 1) {
+ assert(total_modes <= CDF_MAX_SIZE);
+ aom_cdf_prob cdfs[CDF_MAX_SIZE];
+ aom_count_type *counts1d = *ct_ptr;
+
+ counts_to_cdf(counts1d, cdfs, total_modes);
+ (*ct_ptr) += total_modes;
+
+ if (tabs > 0) fprintf(probsfile, "%*c", tabs * SPACES_PER_TAB, ' ');
+ fprintf(probsfile, "AOM_CDF%d(", total_modes);
+ for (int k = 0; k < total_modes - 1; ++k) {
+ fprintf(probsfile, "%d", cdfs[k]);
+ if (k < total_modes - 2) fprintf(probsfile, ", ");
+ }
+ fprintf(probsfile, ")");
+ } else {
+ for (int k = 0; k < total_modes; ++k) {
+ int tabs_next_level;
+
+ if (dim_of_cts == 2)
+ fprintf(probsfile, "%*c{ ", tabs * SPACES_PER_TAB, ' ');
+ else
+ fprintf(probsfile, "%*c{\n", tabs * SPACES_PER_TAB, ' ');
+ tabs_next_level = dim_of_cts == 2 ? 0 : tabs + 1;
+
+ if (parse_counts_for_cdf_opt(ct_ptr, probsfile, tabs_next_level,
+ dim_of_cts - 1, cts_each_dim + 1)) {
+ return 1;
+ }
+
+ if (dim_of_cts == 2) {
+ if (k == total_modes - 1)
+ fprintf(probsfile, " }\n");
+ else
+ fprintf(probsfile, " },\n");
+ } else {
+ if (k == total_modes - 1)
+ fprintf(probsfile, "%*c}\n", tabs * SPACES_PER_TAB, ' ');
+ else
+ fprintf(probsfile, "%*c},\n", tabs * SPACES_PER_TAB, ' ');
+ }
+ }
+ }
+ return 0;
+}
+
+static void optimize_cdf_table(aom_count_type *counts, FILE *const probsfile,
+ int dim_of_cts, int *cts_each_dim,
+ char *prefix) {
+ aom_count_type *ct_ptr = counts;
+
+ fprintf(probsfile, "%s = {\n", prefix);
+ fprintf(logfile, "%s\n", prefix);
+ if (parse_counts_for_cdf_opt(&ct_ptr, probsfile, 1, dim_of_cts,
+ cts_each_dim)) {
+ fprintf(probsfile, "Optimizer failed!\n");
+ }
+ fprintf(probsfile, "};\n\n");
+ fprintf(logfile, "============================\n");
+}
+
+static void optimize_uv_mode(aom_count_type *counts, FILE *const probsfile,
+ int dim_of_cts, int *cts_each_dim, char *prefix) {
+ aom_count_type *ct_ptr = counts;
+
+ fprintf(probsfile, "%s = {\n", prefix);
+ fprintf(probsfile, "%*c{\n", SPACES_PER_TAB, ' ');
+ fprintf(logfile, "%s\n", prefix);
+ cts_each_dim[2] = UV_INTRA_MODES - 1;
+ for (int k = 0; k < cts_each_dim[1]; ++k) {
+ fprintf(probsfile, "%*c{ ", 2 * SPACES_PER_TAB, ' ');
+ parse_counts_for_cdf_opt(&ct_ptr, probsfile, 0, dim_of_cts - 2,
+ cts_each_dim + 2);
+ if (k + 1 == cts_each_dim[1]) {
+ fprintf(probsfile, " }\n");
+ } else {
+ fprintf(probsfile, " },\n");
+ }
+ ++ct_ptr;
+ }
+ fprintf(probsfile, "%*c},\n", SPACES_PER_TAB, ' ');
+ fprintf(probsfile, "%*c{\n", SPACES_PER_TAB, ' ');
+ cts_each_dim[2] = UV_INTRA_MODES;
+ parse_counts_for_cdf_opt(&ct_ptr, probsfile, 2, dim_of_cts - 1,
+ cts_each_dim + 1);
+ fprintf(probsfile, "%*c}\n", SPACES_PER_TAB, ' ');
+ fprintf(probsfile, "};\n\n");
+ fprintf(logfile, "============================\n");
+}
+
+static void optimize_cdf_table_var_modes_2d(aom_count_type *counts,
+ FILE *const probsfile,
+ int dim_of_cts, int *cts_each_dim,
+ int *modes_each_ctx, char *prefix) {
+ aom_count_type *ct_ptr = counts;
+
+ assert(dim_of_cts == 2);
+ (void)dim_of_cts;
+
+ fprintf(probsfile, "%s = {\n", prefix);
+ fprintf(logfile, "%s\n", prefix);
+
+ for (int d0_idx = 0; d0_idx < cts_each_dim[0]; ++d0_idx) {
+ int num_of_modes = modes_each_ctx[d0_idx];
+
+ if (num_of_modes > 0) {
+ fprintf(probsfile, "%*c{ ", SPACES_PER_TAB, ' ');
+ parse_counts_for_cdf_opt(&ct_ptr, probsfile, 0, 1, &num_of_modes);
+ ct_ptr += cts_each_dim[1] - num_of_modes;
+ fprintf(probsfile, " },\n");
+ } else {
+ fprintf(probsfile, "%*c{ 0 },\n", SPACES_PER_TAB, ' ');
+ fprintf(logfile, "dummy cdf, no need to optimize\n");
+ ct_ptr += cts_each_dim[1];
+ }
+ }
+ fprintf(probsfile, "};\n\n");
+ fprintf(logfile, "============================\n");
+}
+
+static void optimize_cdf_table_var_modes_3d(aom_count_type *counts,
+ FILE *const probsfile,
+ int dim_of_cts, int *cts_each_dim,
+ int *modes_each_ctx, char *prefix) {
+ aom_count_type *ct_ptr = counts;
+
+ assert(dim_of_cts == 3);
+ (void)dim_of_cts;
+
+ fprintf(probsfile, "%s = {\n", prefix);
+ fprintf(logfile, "%s\n", prefix);
+
+ for (int d0_idx = 0; d0_idx < cts_each_dim[0]; ++d0_idx) {
+ fprintf(probsfile, "%*c{\n", SPACES_PER_TAB, ' ');
+ for (int d1_idx = 0; d1_idx < cts_each_dim[1]; ++d1_idx) {
+ int num_of_modes = modes_each_ctx[d0_idx];
+
+ if (num_of_modes > 0) {
+ fprintf(probsfile, "%*c{ ", 2 * SPACES_PER_TAB, ' ');
+ parse_counts_for_cdf_opt(&ct_ptr, probsfile, 0, 1, &num_of_modes);
+ ct_ptr += cts_each_dim[2] - num_of_modes;
+ fprintf(probsfile, " },\n");
+ } else {
+ fprintf(probsfile, "%*c{ 0 },\n", 2 * SPACES_PER_TAB, ' ');
+ fprintf(logfile, "dummy cdf, no need to optimize\n");
+ ct_ptr += cts_each_dim[2];
+ }
+ }
+ fprintf(probsfile, "%*c},\n", SPACES_PER_TAB, ' ');
+ }
+ fprintf(probsfile, "};\n\n");
+ fprintf(logfile, "============================\n");
+}
+
+static void optimize_cdf_table_var_modes_4d(aom_count_type *counts,
+ FILE *const probsfile,
+ int dim_of_cts, int *cts_each_dim,
+ int *modes_each_ctx, char *prefix) {
+ aom_count_type *ct_ptr = counts;
+
+ assert(dim_of_cts == 4);
+ (void)dim_of_cts;
+
+ fprintf(probsfile, "%s = {\n", prefix);
+ fprintf(logfile, "%s\n", prefix);
+
+ for (int d0_idx = 0; d0_idx < cts_each_dim[0]; ++d0_idx) {
+ fprintf(probsfile, "%*c{\n", SPACES_PER_TAB, ' ');
+ for (int d1_idx = 0; d1_idx < cts_each_dim[1]; ++d1_idx) {
+ fprintf(probsfile, "%*c{\n", 2 * SPACES_PER_TAB, ' ');
+ for (int d2_idx = 0; d2_idx < cts_each_dim[2]; ++d2_idx) {
+ int num_of_modes = modes_each_ctx[d0_idx];
+
+ if (num_of_modes > 0) {
+ fprintf(probsfile, "%*c{ ", 3 * SPACES_PER_TAB, ' ');
+ parse_counts_for_cdf_opt(&ct_ptr, probsfile, 0, 1, &num_of_modes);
+ ct_ptr += cts_each_dim[3] - num_of_modes;
+ fprintf(probsfile, " },\n");
+ } else {
+ fprintf(probsfile, "%*c{ 0 },\n", 3 * SPACES_PER_TAB, ' ');
+ fprintf(logfile, "dummy cdf, no need to optimize\n");
+ ct_ptr += cts_each_dim[3];
+ }
+ }
+ fprintf(probsfile, "%*c},\n", 2 * SPACES_PER_TAB, ' ');
+ }
+ fprintf(probsfile, "%*c},\n", SPACES_PER_TAB, ' ');
+ }
+ fprintf(probsfile, "};\n\n");
+ fprintf(logfile, "============================\n");
+}
+
+int main(int argc, const char **argv) {
+ if (argc < 2) {
+ fprintf(stderr, "Please specify the input stats file!\n");
+ exit(EXIT_FAILURE);
+ }
+
+ FILE *const statsfile = fopen(argv[1], "rb");
+ if (statsfile == NULL) {
+ fprintf(stderr, "Failed to open input file!\n");
+ exit(EXIT_FAILURE);
+ }
+
+ FRAME_COUNTS fc;
+ const size_t bytes = fread(&fc, sizeof(FRAME_COUNTS), 1, statsfile);
+ if (!bytes) {
+ fclose(statsfile);
+ return 1;
+ }
+
+ FILE *const probsfile = fopen("optimized_probs.c", "w");
+ if (probsfile == NULL) {
+ fprintf(stderr,
+ "Failed to create output file for optimized entropy tables!\n");
+ exit(EXIT_FAILURE);
+ }
+
+ logfile = fopen("aom_entropy_optimizer_parsed_counts.log", "w");
+ if (logfile == NULL) {
+ fprintf(stderr, "Failed to create log file for parsed counts!\n");
+ exit(EXIT_FAILURE);
+ }
+
+ int cts_each_dim[10];
+
+ /* Intra mode (keyframe luma) */
+ cts_each_dim[0] = KF_MODE_CONTEXTS;
+ cts_each_dim[1] = KF_MODE_CONTEXTS;
+ cts_each_dim[2] = INTRA_MODES;
+ optimize_cdf_table(&fc.kf_y_mode[0][0][0], probsfile, 3, cts_each_dim,
+ "const aom_cdf_prob\n"
+ "default_kf_y_mode_cdf[KF_MODE_CONTEXTS][KF_MODE_CONTEXTS]"
+ "[CDF_SIZE(INTRA_MODES)]");
+
+ cts_each_dim[0] = DIRECTIONAL_MODES;
+ cts_each_dim[1] = 2 * MAX_ANGLE_DELTA + 1;
+ optimize_cdf_table(&fc.angle_delta[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob default_angle_delta_cdf"
+ "[DIRECTIONAL_MODES][CDF_SIZE(2 * MAX_ANGLE_DELTA + 1)]");
+
+ /* Intra mode (non-keyframe luma) */
+ cts_each_dim[0] = BLOCK_SIZE_GROUPS;
+ cts_each_dim[1] = INTRA_MODES;
+ optimize_cdf_table(
+ &fc.y_mode[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_if_y_mode_cdf[BLOCK_SIZE_GROUPS][CDF_SIZE(INTRA_MODES)]");
+
+ /* Intra mode (chroma) */
+ cts_each_dim[0] = CFL_ALLOWED_TYPES;
+ cts_each_dim[1] = INTRA_MODES;
+ cts_each_dim[2] = UV_INTRA_MODES;
+ optimize_uv_mode(&fc.uv_mode[0][0][0], probsfile, 3, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_uv_mode_cdf[CFL_ALLOWED_TYPES][INTRA_MODES]"
+ "[CDF_SIZE(UV_INTRA_MODES)]");
+
+ /* block partition */
+ cts_each_dim[0] = PARTITION_CONTEXTS;
+ cts_each_dim[1] = EXT_PARTITION_TYPES;
+ int part_types_each_ctx[PARTITION_CONTEXTS] = { 4, 4, 4, 4, 10, 10, 10,
+ 10, 10, 10, 10, 10, 10, 10,
+ 10, 10, 8, 8, 8, 8 };
+ optimize_cdf_table_var_modes_2d(
+ &fc.partition[0][0], probsfile, 2, cts_each_dim, part_types_each_ctx,
+ "static const aom_cdf_prob default_partition_cdf[PARTITION_CONTEXTS]"
+ "[CDF_SIZE(EXT_PARTITION_TYPES)]");
+
+ /* tx type */
+ cts_each_dim[0] = EXT_TX_SETS_INTRA;
+ cts_each_dim[1] = EXT_TX_SIZES;
+ cts_each_dim[2] = INTRA_MODES;
+ cts_each_dim[3] = TX_TYPES;
+ int intra_ext_tx_types_each_ctx[EXT_TX_SETS_INTRA] = { 0, 7, 5 };
+ optimize_cdf_table_var_modes_4d(
+ &fc.intra_ext_tx[0][0][0][0], probsfile, 4, cts_each_dim,
+ intra_ext_tx_types_each_ctx,
+ "static const aom_cdf_prob default_intra_ext_tx_cdf[EXT_TX_SETS_INTRA]"
+ "[EXT_TX_SIZES][INTRA_MODES][CDF_SIZE(TX_TYPES)]");
+
+ cts_each_dim[0] = EXT_TX_SETS_INTER;
+ cts_each_dim[1] = EXT_TX_SIZES;
+ cts_each_dim[2] = TX_TYPES;
+ int inter_ext_tx_types_each_ctx[EXT_TX_SETS_INTER] = { 0, 16, 12, 2 };
+ optimize_cdf_table_var_modes_3d(
+ &fc.inter_ext_tx[0][0][0], probsfile, 3, cts_each_dim,
+ inter_ext_tx_types_each_ctx,
+ "static const aom_cdf_prob default_inter_ext_tx_cdf[EXT_TX_SETS_INTER]"
+ "[EXT_TX_SIZES][CDF_SIZE(TX_TYPES)]");
+
+ /* Chroma from Luma */
+ cts_each_dim[0] = CFL_JOINT_SIGNS;
+ optimize_cdf_table(&fc.cfl_sign[0], probsfile, 1, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_cfl_sign_cdf[CDF_SIZE(CFL_JOINT_SIGNS)]");
+ cts_each_dim[0] = CFL_ALPHA_CONTEXTS;
+ cts_each_dim[1] = CFL_ALPHABET_SIZE;
+ optimize_cdf_table(&fc.cfl_alpha[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_cfl_alpha_cdf[CFL_ALPHA_CONTEXTS]"
+ "[CDF_SIZE(CFL_ALPHABET_SIZE)]");
+
+ /* Interpolation filter */
+ cts_each_dim[0] = SWITCHABLE_FILTER_CONTEXTS;
+ cts_each_dim[1] = SWITCHABLE_FILTERS;
+ optimize_cdf_table(&fc.switchable_interp[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_switchable_interp_cdf[SWITCHABLE_FILTER_CONTEXTS]"
+ "[CDF_SIZE(SWITCHABLE_FILTERS)]");
+
+ /* Motion vector referencing */
+ cts_each_dim[0] = NEWMV_MODE_CONTEXTS;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(&fc.newmv_mode[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob "
+ "default_newmv_cdf[NEWMV_MODE_CONTEXTS][CDF_SIZE(2)]");
+
+ cts_each_dim[0] = GLOBALMV_MODE_CONTEXTS;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(&fc.zeromv_mode[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob "
+ "default_zeromv_cdf[GLOBALMV_MODE_CONTEXTS][CDF_SIZE(2)]");
+
+ cts_each_dim[0] = REFMV_MODE_CONTEXTS;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(&fc.refmv_mode[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob "
+ "default_refmv_cdf[REFMV_MODE_CONTEXTS][CDF_SIZE(2)]");
+
+ cts_each_dim[0] = DRL_MODE_CONTEXTS;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(&fc.drl_mode[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob "
+ "default_drl_cdf[DRL_MODE_CONTEXTS][CDF_SIZE(2)]");
+
+ /* ext_inter experiment */
+ /* New compound mode */
+ cts_each_dim[0] = INTER_MODE_CONTEXTS;
+ cts_each_dim[1] = INTER_COMPOUND_MODES;
+ optimize_cdf_table(&fc.inter_compound_mode[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_inter_compound_mode_cdf[INTER_MODE_CONTEXTS][CDF_"
+ "SIZE(INTER_COMPOUND_MODES)]");
+
+ /* Interintra */
+ cts_each_dim[0] = BLOCK_SIZE_GROUPS;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(&fc.interintra[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob "
+ "default_interintra_cdf[BLOCK_SIZE_GROUPS][CDF_SIZE(2)]");
+
+ cts_each_dim[0] = BLOCK_SIZE_GROUPS;
+ cts_each_dim[1] = INTERINTRA_MODES;
+ optimize_cdf_table(&fc.interintra_mode[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_interintra_mode_cdf[BLOCK_SIZE_GROUPS][CDF_SIZE("
+ "INTERINTRA_MODES)]");
+
+ cts_each_dim[0] = BLOCK_SIZES_ALL;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(
+ &fc.wedge_interintra[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_wedge_interintra_cdf[BLOCK_SIZES_ALL][CDF_SIZE(2)]");
+
+ /* Compound type */
+ cts_each_dim[0] = BLOCK_SIZES_ALL;
+ cts_each_dim[1] = COMPOUND_TYPES - 1;
+ optimize_cdf_table(&fc.compound_type[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob default_compound_type_cdf"
+ "[BLOCK_SIZES_ALL][CDF_SIZE(COMPOUND_TYPES - 1)]");
+
+ cts_each_dim[0] = BLOCK_SIZES_ALL;
+ cts_each_dim[1] = 16;
+ optimize_cdf_table(&fc.wedge_idx[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob "
+ "default_wedge_idx_cdf[BLOCK_SIZES_ALL][CDF_SIZE(16)]");
+
+ /* motion_var and warped_motion experiments */
+ cts_each_dim[0] = BLOCK_SIZES_ALL;
+ cts_each_dim[1] = MOTION_MODES;
+ optimize_cdf_table(
+ &fc.motion_mode[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_motion_mode_cdf[BLOCK_SIZES_ALL][CDF_SIZE(MOTION_MODES)]");
+ cts_each_dim[0] = BLOCK_SIZES_ALL;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(&fc.obmc[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob "
+ "default_obmc_cdf[BLOCK_SIZES_ALL][CDF_SIZE(2)]");
+
+ /* Intra/inter flag */
+ cts_each_dim[0] = INTRA_INTER_CONTEXTS;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(
+ &fc.intra_inter[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_intra_inter_cdf[INTRA_INTER_CONTEXTS][CDF_SIZE(2)]");
+
+ /* Single/comp ref flag */
+ cts_each_dim[0] = COMP_INTER_CONTEXTS;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(
+ &fc.comp_inter[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_comp_inter_cdf[COMP_INTER_CONTEXTS][CDF_SIZE(2)]");
+
+ /* ext_comp_refs experiment */
+ cts_each_dim[0] = COMP_REF_TYPE_CONTEXTS;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(
+ &fc.comp_ref_type[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_comp_ref_type_cdf[COMP_REF_TYPE_CONTEXTS][CDF_SIZE(2)]");
+
+ cts_each_dim[0] = UNI_COMP_REF_CONTEXTS;
+ cts_each_dim[1] = UNIDIR_COMP_REFS - 1;
+ cts_each_dim[2] = 2;
+ optimize_cdf_table(&fc.uni_comp_ref[0][0][0], probsfile, 3, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_uni_comp_ref_cdf[UNI_COMP_REF_CONTEXTS][UNIDIR_"
+ "COMP_REFS - 1][CDF_SIZE(2)]");
+
+ /* Reference frame (single ref) */
+ cts_each_dim[0] = REF_CONTEXTS;
+ cts_each_dim[1] = SINGLE_REFS - 1;
+ cts_each_dim[2] = 2;
+ optimize_cdf_table(
+ &fc.single_ref[0][0][0], probsfile, 3, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_single_ref_cdf[REF_CONTEXTS][SINGLE_REFS - 1][CDF_SIZE(2)]");
+
+ /* ext_refs experiment */
+ cts_each_dim[0] = REF_CONTEXTS;
+ cts_each_dim[1] = FWD_REFS - 1;
+ cts_each_dim[2] = 2;
+ optimize_cdf_table(
+ &fc.comp_ref[0][0][0], probsfile, 3, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_comp_ref_cdf[REF_CONTEXTS][FWD_REFS - 1][CDF_SIZE(2)]");
+
+ cts_each_dim[0] = REF_CONTEXTS;
+ cts_each_dim[1] = BWD_REFS - 1;
+ cts_each_dim[2] = 2;
+ optimize_cdf_table(
+ &fc.comp_bwdref[0][0][0], probsfile, 3, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_comp_bwdref_cdf[REF_CONTEXTS][BWD_REFS - 1][CDF_SIZE(2)]");
+
+ /* palette */
+ cts_each_dim[0] = PALATTE_BSIZE_CTXS;
+ cts_each_dim[1] = PALETTE_SIZES;
+ optimize_cdf_table(&fc.palette_y_size[0][0], probsfile, 2, cts_each_dim,
+ "const aom_cdf_prob default_palette_y_size_cdf"
+ "[PALATTE_BSIZE_CTXS][CDF_SIZE(PALETTE_SIZES)]");
+
+ cts_each_dim[0] = PALATTE_BSIZE_CTXS;
+ cts_each_dim[1] = PALETTE_SIZES;
+ optimize_cdf_table(&fc.palette_uv_size[0][0], probsfile, 2, cts_each_dim,
+ "const aom_cdf_prob default_palette_uv_size_cdf"
+ "[PALATTE_BSIZE_CTXS][CDF_SIZE(PALETTE_SIZES)]");
+
+ cts_each_dim[0] = PALATTE_BSIZE_CTXS;
+ cts_each_dim[1] = PALETTE_Y_MODE_CONTEXTS;
+ cts_each_dim[2] = 2;
+ optimize_cdf_table(&fc.palette_y_mode[0][0][0], probsfile, 3, cts_each_dim,
+ "const aom_cdf_prob default_palette_y_mode_cdf"
+ "[PALATTE_BSIZE_CTXS][PALETTE_Y_MODE_CONTEXTS]"
+ "[CDF_SIZE(2)]");
+
+ cts_each_dim[0] = PALETTE_UV_MODE_CONTEXTS;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(&fc.palette_uv_mode[0][0], probsfile, 2, cts_each_dim,
+ "const aom_cdf_prob default_palette_uv_mode_cdf"
+ "[PALETTE_UV_MODE_CONTEXTS][CDF_SIZE(2)]");
+
+ cts_each_dim[0] = PALETTE_SIZES;
+ cts_each_dim[1] = PALETTE_COLOR_INDEX_CONTEXTS;
+ cts_each_dim[2] = PALETTE_COLORS;
+ int palette_color_indexes_each_ctx[PALETTE_SIZES] = { 2, 3, 4, 5, 6, 7, 8 };
+ optimize_cdf_table_var_modes_3d(
+ &fc.palette_y_color_index[0][0][0], probsfile, 3, cts_each_dim,
+ palette_color_indexes_each_ctx,
+ "const aom_cdf_prob default_palette_y_color_index_cdf[PALETTE_SIZES]"
+ "[PALETTE_COLOR_INDEX_CONTEXTS][CDF_SIZE(PALETTE_COLORS)]");
+
+ cts_each_dim[0] = PALETTE_SIZES;
+ cts_each_dim[1] = PALETTE_COLOR_INDEX_CONTEXTS;
+ cts_each_dim[2] = PALETTE_COLORS;
+ optimize_cdf_table_var_modes_3d(
+ &fc.palette_uv_color_index[0][0][0], probsfile, 3, cts_each_dim,
+ palette_color_indexes_each_ctx,
+ "const aom_cdf_prob default_palette_uv_color_index_cdf[PALETTE_SIZES]"
+ "[PALETTE_COLOR_INDEX_CONTEXTS][CDF_SIZE(PALETTE_COLORS)]");
+
+ /* Transform size */
+ cts_each_dim[0] = TXFM_PARTITION_CONTEXTS;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(
+ &fc.txfm_partition[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob\n"
+ "default_txfm_partition_cdf[TXFM_PARTITION_CONTEXTS][CDF_SIZE(2)]");
+
+ /* Skip flag */
+ cts_each_dim[0] = SKIP_CONTEXTS;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(&fc.skip_txfm[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob "
+ "default_skip_txfm_cdfs[SKIP_CONTEXTS][CDF_SIZE(2)]");
+
+ /* Skip mode flag */
+ cts_each_dim[0] = SKIP_MODE_CONTEXTS;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(&fc.skip_mode[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob "
+ "default_skip_mode_cdfs[SKIP_MODE_CONTEXTS][CDF_SIZE(2)]");
+
+ /* joint compound flag */
+ cts_each_dim[0] = COMP_INDEX_CONTEXTS;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(&fc.compound_index[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob default_compound_idx_cdfs"
+ "[COMP_INDEX_CONTEXTS][CDF_SIZE(2)]");
+
+ cts_each_dim[0] = COMP_GROUP_IDX_CONTEXTS;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(&fc.comp_group_idx[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob default_comp_group_idx_cdfs"
+ "[COMP_GROUP_IDX_CONTEXTS][CDF_SIZE(2)]");
+
+ /* intrabc */
+ cts_each_dim[0] = 2;
+ optimize_cdf_table(
+ &fc.intrabc[0], probsfile, 1, cts_each_dim,
+ "static const aom_cdf_prob default_intrabc_cdf[CDF_SIZE(2)]");
+
+ /* filter_intra experiment */
+ cts_each_dim[0] = FILTER_INTRA_MODES;
+ optimize_cdf_table(
+ &fc.filter_intra_mode[0], probsfile, 1, cts_each_dim,
+ "static const aom_cdf_prob "
+ "default_filter_intra_mode_cdf[CDF_SIZE(FILTER_INTRA_MODES)]");
+
+ cts_each_dim[0] = BLOCK_SIZES_ALL;
+ cts_each_dim[1] = 2;
+ optimize_cdf_table(&fc.filter_intra[0][0], probsfile, 2, cts_each_dim,
+ "static const aom_cdf_prob "
+ "default_filter_intra_cdfs[BLOCK_SIZES_ALL][CDF_SIZE(2)]");
+
+ /* restoration type */
+ cts_each_dim[0] = RESTORE_SWITCHABLE_TYPES;
+ optimize_cdf_table(&fc.switchable_restore[0], probsfile, 1, cts_each_dim,
+ "static const aom_cdf_prob default_switchable_restore_cdf"
+ "[CDF_SIZE(RESTORE_SWITCHABLE_TYPES)]");
+
+ cts_each_dim[0] = 2;
+ optimize_cdf_table(&fc.wiener_restore[0], probsfile, 1, cts_each_dim,
+ "static const aom_cdf_prob default_wiener_restore_cdf"
+ "[CDF_SIZE(2)]");
+
+ cts_each_dim[0] = 2;
+ optimize_cdf_table(&fc.sgrproj_restore[0], probsfile, 1, cts_each_dim,
+ "static const aom_cdf_prob default_sgrproj_restore_cdf"
+ "[CDF_SIZE(2)]");
+
+ /* intra tx size */
+ cts_each_dim[0] = MAX_TX_CATS;
+ cts_each_dim[1] = TX_SIZE_CONTEXTS;
+ cts_each_dim[2] = MAX_TX_DEPTH + 1;
+ int intra_tx_sizes_each_ctx[MAX_TX_CATS] = { 2, 3, 3, 3 };
+ optimize_cdf_table_var_modes_3d(
+ &fc.intra_tx_size[0][0][0], probsfile, 3, cts_each_dim,
+ intra_tx_sizes_each_ctx,
+ "static const aom_cdf_prob default_tx_size_cdf"
+ "[MAX_TX_CATS][TX_SIZE_CONTEXTS][CDF_SIZE(MAX_TX_DEPTH + 1)]");
+
+ /* transform coding */
+ cts_each_dim[0] = TOKEN_CDF_Q_CTXS;
+ cts_each_dim[1] = TX_SIZES;
+ cts_each_dim[2] = TXB_SKIP_CONTEXTS;
+ cts_each_dim[3] = 2;
+ optimize_cdf_table(&fc.txb_skip[0][0][0][0], probsfile, 4, cts_each_dim,
+ "static const aom_cdf_prob "
+ "av1_default_txb_skip_cdfs[TOKEN_CDF_Q_CTXS][TX_SIZES]"
+ "[TXB_SKIP_CONTEXTS][CDF_SIZE(2)]");
+
+ cts_each_dim[0] = TOKEN_CDF_Q_CTXS;
+ cts_each_dim[1] = TX_SIZES;
+ cts_each_dim[2] = PLANE_TYPES;
+ cts_each_dim[3] = EOB_COEF_CONTEXTS;
+ cts_each_dim[4] = 2;
+ optimize_cdf_table(
+ &fc.eob_extra[0][0][0][0][0], probsfile, 5, cts_each_dim,
+ "static const aom_cdf_prob av1_default_eob_extra_cdfs "
+ "[TOKEN_CDF_Q_CTXS][TX_SIZES][PLANE_TYPES][EOB_COEF_CONTEXTS]"
+ "[CDF_SIZE(2)]");
+
+ cts_each_dim[0] = TOKEN_CDF_Q_CTXS;
+ cts_each_dim[1] = PLANE_TYPES;
+ cts_each_dim[2] = 2;
+ cts_each_dim[3] = 5;
+ optimize_cdf_table(&fc.eob_multi16[0][0][0][0], probsfile, 4, cts_each_dim,
+ "static const aom_cdf_prob av1_default_eob_multi16_cdfs"
+ "[TOKEN_CDF_Q_CTXS][PLANE_TYPES][2][CDF_SIZE(5)]");
+
+ cts_each_dim[0] = TOKEN_CDF_Q_CTXS;
+ cts_each_dim[1] = PLANE_TYPES;
+ cts_each_dim[2] = 2;
+ cts_each_dim[3] = 6;
+ optimize_cdf_table(&fc.eob_multi32[0][0][0][0], probsfile, 4, cts_each_dim,
+ "static const aom_cdf_prob av1_default_eob_multi32_cdfs"
+ "[TOKEN_CDF_Q_CTXS][PLANE_TYPES][2][CDF_SIZE(6)]");
+
+ cts_each_dim[0] = TOKEN_CDF_Q_CTXS;
+ cts_each_dim[1] = PLANE_TYPES;
+ cts_each_dim[2] = 2;
+ cts_each_dim[3] = 7;
+ optimize_cdf_table(&fc.eob_multi64[0][0][0][0], probsfile, 4, cts_each_dim,
+ "static const aom_cdf_prob av1_default_eob_multi64_cdfs"
+ "[TOKEN_CDF_Q_CTXS][PLANE_TYPES][2][CDF_SIZE(7)]");
+
+ cts_each_dim[0] = TOKEN_CDF_Q_CTXS;
+ cts_each_dim[1] = PLANE_TYPES;
+ cts_each_dim[2] = 2;
+ cts_each_dim[3] = 8;
+ optimize_cdf_table(&fc.eob_multi128[0][0][0][0], probsfile, 4, cts_each_dim,
+ "static const aom_cdf_prob av1_default_eob_multi128_cdfs"
+ "[TOKEN_CDF_Q_CTXS][PLANE_TYPES][2][CDF_SIZE(8)]");
+
+ cts_each_dim[0] = TOKEN_CDF_Q_CTXS;
+ cts_each_dim[1] = PLANE_TYPES;
+ cts_each_dim[2] = 2;
+ cts_each_dim[3] = 9;
+ optimize_cdf_table(&fc.eob_multi256[0][0][0][0], probsfile, 4, cts_each_dim,
+ "static const aom_cdf_prob av1_default_eob_multi256_cdfs"
+ "[TOKEN_CDF_Q_CTXS][PLANE_TYPES][2][CDF_SIZE(9)]");
+
+ cts_each_dim[0] = TOKEN_CDF_Q_CTXS;
+ cts_each_dim[1] = PLANE_TYPES;
+ cts_each_dim[2] = 2;
+ cts_each_dim[3] = 10;
+ optimize_cdf_table(&fc.eob_multi512[0][0][0][0], probsfile, 4, cts_each_dim,
+ "static const aom_cdf_prob av1_default_eob_multi512_cdfs"
+ "[TOKEN_CDF_Q_CTXS][PLANE_TYPES][2][CDF_SIZE(10)]");
+
+ cts_each_dim[0] = TOKEN_CDF_Q_CTXS;
+ cts_each_dim[1] = PLANE_TYPES;
+ cts_each_dim[2] = 2;
+ cts_each_dim[3] = 11;
+ optimize_cdf_table(&fc.eob_multi1024[0][0][0][0], probsfile, 4, cts_each_dim,
+ "static const aom_cdf_prob av1_default_eob_multi1024_cdfs"
+ "[TOKEN_CDF_Q_CTXS][PLANE_TYPES][2][CDF_SIZE(11)]");
+
+ cts_each_dim[0] = TOKEN_CDF_Q_CTXS;
+ cts_each_dim[1] = TX_SIZES;
+ cts_each_dim[2] = PLANE_TYPES;
+ cts_each_dim[3] = LEVEL_CONTEXTS;
+ cts_each_dim[4] = BR_CDF_SIZE;
+ optimize_cdf_table(&fc.coeff_lps_multi[0][0][0][0][0], probsfile, 5,
+ cts_each_dim,
+ "static const aom_cdf_prob "
+ "av1_default_coeff_lps_multi_cdfs[TOKEN_CDF_Q_CTXS]"
+ "[TX_SIZES][PLANE_TYPES][LEVEL_CONTEXTS]"
+ "[CDF_SIZE(BR_CDF_SIZE)]");
+
+ cts_each_dim[0] = TOKEN_CDF_Q_CTXS;
+ cts_each_dim[1] = TX_SIZES;
+ cts_each_dim[2] = PLANE_TYPES;
+ cts_each_dim[3] = SIG_COEF_CONTEXTS;
+ cts_each_dim[4] = NUM_BASE_LEVELS + 2;
+ optimize_cdf_table(
+ &fc.coeff_base_multi[0][0][0][0][0], probsfile, 5, cts_each_dim,
+ "static const aom_cdf_prob av1_default_coeff_base_multi_cdfs"
+ "[TOKEN_CDF_Q_CTXS][TX_SIZES][PLANE_TYPES][SIG_COEF_CONTEXTS]"
+ "[CDF_SIZE(NUM_BASE_LEVELS + 2)]");
+
+ cts_each_dim[0] = TOKEN_CDF_Q_CTXS;
+ cts_each_dim[1] = TX_SIZES;
+ cts_each_dim[2] = PLANE_TYPES;
+ cts_each_dim[3] = SIG_COEF_CONTEXTS_EOB;
+ cts_each_dim[4] = NUM_BASE_LEVELS + 1;
+ optimize_cdf_table(
+ &fc.coeff_base_eob_multi[0][0][0][0][0], probsfile, 5, cts_each_dim,
+ "static const aom_cdf_prob av1_default_coeff_base_eob_multi_cdfs"
+ "[TOKEN_CDF_Q_CTXS][TX_SIZES][PLANE_TYPES][SIG_COEF_CONTEXTS_EOB]"
+ "[CDF_SIZE(NUM_BASE_LEVELS + 1)]");
+
+ fclose(statsfile);
+ fclose(logfile);
+ fclose(probsfile);
+
+ return 0;
+}
diff --git a/third_party/aom/tools/auto_refactor/auto_refactor.py b/third_party/aom/tools/auto_refactor/auto_refactor.py
new file mode 100644
index 0000000000..dd0d4415f9
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/auto_refactor.py
@@ -0,0 +1,919 @@
+# Copyright (c) 2021, Alliance for Open Media. All rights reserved
+#
+# This source code is subject to the terms of the BSD 2 Clause License and
+# the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+# was not distributed with this source code in the LICENSE file, you can
+# obtain it at www.aomedia.org/license/software. If the Alliance for Open
+# Media Patent License 1.0 was not distributed with this source code in the
+# PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+#
+
+from __future__ import print_function
+import sys
+import os
+import operator
+from pycparser import c_parser, c_ast, parse_file
+from math import *
+
+from inspect import currentframe, getframeinfo
+from collections import deque
+
+
+def debug_print(frameinfo):
+ print('******** ERROR:', frameinfo.filename, frameinfo.lineno, '********')
+
+
+class StructItem():
+
+ def __init__(self,
+ typedef_name=None,
+ struct_name=None,
+ struct_node=None,
+ is_union=False):
+ self.typedef_name = typedef_name
+ self.struct_name = struct_name
+ self.struct_node = struct_node
+ self.is_union = is_union
+ self.child_decl_map = None
+
+ def __str__(self):
+ return str(self.typedef_name) + ' ' + str(self.struct_name) + ' ' + str(
+ self.is_union)
+
+ def compute_child_decl_map(self, struct_info):
+ self.child_decl_map = {}
+ if self.struct_node != None and self.struct_node.decls != None:
+ for decl_node in self.struct_node.decls:
+ if decl_node.name == None:
+ for sub_decl_node in decl_node.type.decls:
+ sub_decl_status = parse_decl_node(struct_info, sub_decl_node)
+ self.child_decl_map[sub_decl_node.name] = sub_decl_status
+ else:
+ decl_status = parse_decl_node(struct_info, decl_node)
+ self.child_decl_map[decl_status.name] = decl_status
+
+ def get_child_decl_status(self, decl_name):
+ if self.child_decl_map == None:
+ debug_print(getframeinfo(currentframe()))
+ print('child_decl_map is None')
+ return None
+ if decl_name not in self.child_decl_map:
+ debug_print(getframeinfo(currentframe()))
+ print(decl_name, 'does not exist ')
+ return None
+ return self.child_decl_map[decl_name]
+
+
+class StructInfo():
+
+ def __init__(self):
+ self.struct_name_dic = {}
+ self.typedef_name_dic = {}
+ self.enum_value_dic = {} # enum value -> enum_node
+ self.enum_name_dic = {} # enum name -> enum_node
+ self.struct_item_list = []
+
+ def get_struct_by_typedef_name(self, typedef_name):
+ if typedef_name in self.typedef_name_dic:
+ return self.typedef_name_dic[typedef_name]
+ else:
+ return None
+
+ def get_struct_by_struct_name(self, struct_name):
+ if struct_name in self.struct_name_dic:
+ return self.struct_name_dic[struct_name]
+ else:
+ debug_print(getframeinfo(currentframe()))
+ print('Cant find', struct_name)
+ return None
+
+ def update_struct_item_list(self):
+ # Collect all struct_items from struct_name_dic and typedef_name_dic
+ # Compute child_decl_map for each struct item.
+ for struct_name in self.struct_name_dic.keys():
+ struct_item = self.struct_name_dic[struct_name]
+ struct_item.compute_child_decl_map(self)
+ self.struct_item_list.append(struct_item)
+
+ for typedef_name in self.typedef_name_dic.keys():
+ struct_item = self.typedef_name_dic[typedef_name]
+ if struct_item.struct_name not in self.struct_name_dic:
+ struct_item.compute_child_decl_map(self)
+ self.struct_item_list.append(struct_item)
+
+ def update_enum(self, enum_node):
+ if enum_node.name != None:
+ self.enum_name_dic[enum_node.name] = enum_node
+
+ if enum_node.values != None:
+ enumerator_list = enum_node.values.enumerators
+ for enumerator in enumerator_list:
+ self.enum_value_dic[enumerator.name] = enum_node
+
+ def update(self,
+ typedef_name=None,
+ struct_name=None,
+ struct_node=None,
+ is_union=False):
+ """T: typedef_name S: struct_name N: struct_node
+
+ T S N
+ case 1: o o o
+ typedef struct P {
+ int u;
+ } K;
+ T S N
+ case 2: o o x
+ typedef struct P K;
+
+ T S N
+ case 3: x o o
+ struct P {
+ int u;
+ };
+
+ T S N
+ case 4: o x o
+ typedef struct {
+ int u;
+ } K;
+ """
+ struct_item = None
+
+ # Check whether struct_name or typedef_name is already in the dictionary
+ if struct_name in self.struct_name_dic:
+ struct_item = self.struct_name_dic[struct_name]
+
+ if typedef_name in self.typedef_name_dic:
+ struct_item = self.typedef_name_dic[typedef_name]
+
+ if struct_item == None:
+ struct_item = StructItem(typedef_name, struct_name, struct_node, is_union)
+
+ if struct_node.decls != None:
+ struct_item.struct_node = struct_node
+
+ if struct_name != None:
+ self.struct_name_dic[struct_name] = struct_item
+
+ if typedef_name != None:
+ self.typedef_name_dic[typedef_name] = struct_item
+
+
+class StructDefVisitor(c_ast.NodeVisitor):
+
+ def __init__(self):
+ self.struct_info = StructInfo()
+
+ def visit_Struct(self, node):
+ if node.decls != None:
+ self.struct_info.update(None, node.name, node)
+ self.generic_visit(node)
+
+ def visit_Union(self, node):
+ if node.decls != None:
+ self.struct_info.update(None, node.name, node, True)
+ self.generic_visit(node)
+
+ def visit_Enum(self, node):
+ self.struct_info.update_enum(node)
+ self.generic_visit(node)
+
+ def visit_Typedef(self, node):
+ if node.type.__class__.__name__ == 'TypeDecl':
+ typedecl = node.type
+ if typedecl.type.__class__.__name__ == 'Struct':
+ struct_node = typedecl.type
+ typedef_name = node.name
+ struct_name = struct_node.name
+ self.struct_info.update(typedef_name, struct_name, struct_node)
+ elif typedecl.type.__class__.__name__ == 'Union':
+ union_node = typedecl.type
+ typedef_name = node.name
+ union_name = union_node.name
+ self.struct_info.update(typedef_name, union_name, union_node, True)
+ # TODO(angiebird): Do we need to deal with enum here?
+ self.generic_visit(node)
+
+
+def build_struct_info(ast):
+ v = StructDefVisitor()
+ v.visit(ast)
+ struct_info = v.struct_info
+ struct_info.update_struct_item_list()
+ return v.struct_info
+
+
+class DeclStatus():
+
+ def __init__(self, name, struct_item=None, is_ptr_decl=False):
+ self.name = name
+ self.struct_item = struct_item
+ self.is_ptr_decl = is_ptr_decl
+
+ def get_child_decl_status(self, decl_name):
+ if self.struct_item != None:
+ return self.struct_item.get_child_decl_status(decl_name)
+ else:
+ #TODO(angiebird): 2. Investigage the situation when a struct's definition can't be found.
+ return None
+
+ def __str__(self):
+ return str(self.struct_item) + ' ' + str(self.name) + ' ' + str(
+ self.is_ptr_decl)
+
+
+def peel_ptr_decl(decl_type_node):
+ """ Remove PtrDecl and ArrayDecl layer """
+ is_ptr_decl = False
+ peeled_decl_type_node = decl_type_node
+ while peeled_decl_type_node.__class__.__name__ == 'PtrDecl' or peeled_decl_type_node.__class__.__name__ == 'ArrayDecl':
+ is_ptr_decl = True
+ peeled_decl_type_node = peeled_decl_type_node.type
+ return is_ptr_decl, peeled_decl_type_node
+
+
+def parse_peeled_decl_type_node(struct_info, node):
+ struct_item = None
+ if node.__class__.__name__ == 'TypeDecl':
+ if node.type.__class__.__name__ == 'IdentifierType':
+ identifier_type_node = node.type
+ typedef_name = identifier_type_node.names[0]
+ struct_item = struct_info.get_struct_by_typedef_name(typedef_name)
+ elif node.type.__class__.__name__ == 'Struct':
+ struct_node = node.type
+ if struct_node.name != None:
+ struct_item = struct_info.get_struct_by_struct_name(struct_node.name)
+ else:
+ struct_item = StructItem(None, None, struct_node, False)
+ struct_item.compute_child_decl_map(struct_info)
+ elif node.type.__class__.__name__ == 'Union':
+ # TODO(angiebird): Special treatment for Union?
+ struct_node = node.type
+ if struct_node.name != None:
+ struct_item = struct_info.get_struct_by_struct_name(struct_node.name)
+ else:
+ struct_item = StructItem(None, None, struct_node, True)
+ struct_item.compute_child_decl_map(struct_info)
+ elif node.type.__class__.__name__ == 'Enum':
+ # TODO(angiebird): Special treatment for Union?
+ struct_node = node.type
+ struct_item = None
+ else:
+ print('Unrecognized peeled_decl_type_node.type',
+ node.type.__class__.__name__)
+ else:
+ # debug_print(getframeinfo(currentframe()))
+ # print(node.__class__.__name__)
+ #TODO(angiebird): Do we need to take care of this part?
+ pass
+
+ return struct_item
+
+
+def parse_decl_node(struct_info, decl_node):
+ # struct_item is None if this decl_node is not a struct_item
+ decl_node_name = decl_node.name
+ decl_type_node = decl_node.type
+ is_ptr_decl, peeled_decl_type_node = peel_ptr_decl(decl_type_node)
+ struct_item = parse_peeled_decl_type_node(struct_info, peeled_decl_type_node)
+ return DeclStatus(decl_node_name, struct_item, is_ptr_decl)
+
+
+def get_lvalue_lead(lvalue_node):
+ """return '&' or '*' of lvalue if available"""
+ if lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '&':
+ return '&'
+ elif lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '*':
+ return '*'
+ return None
+
+
+def parse_lvalue(lvalue_node):
+ """get id_chain from lvalue"""
+ id_chain = parse_lvalue_recursive(lvalue_node, [])
+ return id_chain
+
+
+def parse_lvalue_recursive(lvalue_node, id_chain):
+ """cpi->rd->u -> (cpi->rd)->u"""
+ if lvalue_node.__class__.__name__ == 'ID':
+ id_chain.append(lvalue_node.name)
+ id_chain.reverse()
+ return id_chain
+ elif lvalue_node.__class__.__name__ == 'StructRef':
+ id_chain.append(lvalue_node.field.name)
+ return parse_lvalue_recursive(lvalue_node.name, id_chain)
+ elif lvalue_node.__class__.__name__ == 'ArrayRef':
+ return parse_lvalue_recursive(lvalue_node.name, id_chain)
+ elif lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '&':
+ return parse_lvalue_recursive(lvalue_node.expr, id_chain)
+ elif lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '*':
+ return parse_lvalue_recursive(lvalue_node.expr, id_chain)
+ else:
+ return None
+
+
+class FuncDefVisitor(c_ast.NodeVisitor):
+ func_dictionary = {}
+
+ def visit_FuncDef(self, node):
+ func_name = node.decl.name
+ self.func_dictionary[func_name] = node
+
+
+def build_func_dictionary(ast):
+ v = FuncDefVisitor()
+ v.visit(ast)
+ return v.func_dictionary
+
+
+def get_func_start_coord(func_node):
+ return func_node.coord
+
+
+def find_end_node(node):
+ node_list = []
+ for c in node:
+ node_list.append(c)
+ if len(node_list) == 0:
+ return node
+ else:
+ return find_end_node(node_list[-1])
+
+
+def get_func_end_coord(func_node):
+ return find_end_node(func_node).coord
+
+
+def get_func_size(func_node):
+ start_coord = get_func_start_coord(func_node)
+ end_coord = get_func_end_coord(func_node)
+ if start_coord.file == end_coord.file:
+ return end_coord.line - start_coord.line + 1
+ else:
+ return None
+
+
+def save_object(obj, filename):
+ with open(filename, 'wb') as obj_fp:
+ pickle.dump(obj, obj_fp, protocol=-1)
+
+
+def load_object(filename):
+ obj = None
+ with open(filename, 'rb') as obj_fp:
+ obj = pickle.load(obj_fp)
+ return obj
+
+
+def get_av1_ast(gen_ast=False):
+ # TODO(angiebird): Generalize this path
+ c_filename = './av1_pp.c'
+ print('generate ast')
+ ast = parse_file(c_filename)
+ #save_object(ast, ast_file)
+ print('finished generate ast')
+ return ast
+
+
+def get_func_param_id_map(func_def_node):
+ param_id_map = {}
+ func_decl = func_def_node.decl.type
+ param_list = func_decl.args.params
+ for decl in param_list:
+ param_id_map[decl.name] = decl
+ return param_id_map
+
+
+class IDTreeStack():
+
+ def __init__(self, global_id_tree):
+ self.stack = deque()
+ self.global_id_tree = global_id_tree
+
+ def add_link_node(self, node, link_id_chain):
+ link_node = self.add_id_node(link_id_chain)
+ node.link_node = link_node
+ node.link_id_chain = link_id_chain
+
+ def push_id_tree(self, id_tree=None):
+ if id_tree == None:
+ id_tree = IDStatusNode()
+ self.stack.append(id_tree)
+ return id_tree
+
+ def pop_id_tree(self):
+ return self.stack.pop()
+
+ def add_id_seed_node(self, id_seed, decl_status):
+ return self.stack[-1].add_child(id_seed, decl_status)
+
+ def get_id_seed_node(self, id_seed):
+ idx = len(self.stack) - 1
+ while idx >= 0:
+ id_node = self.stack[idx].get_child(id_seed)
+ if id_node != None:
+ return id_node
+ idx -= 1
+
+ id_node = self.global_id_tree.get_child(id_seed)
+ if id_node != None:
+ return id_node
+ return None
+
+ def add_id_node(self, id_chain):
+ id_seed = id_chain[0]
+ id_seed_node = self.get_id_seed_node(id_seed)
+ if id_seed_node == None:
+ return None
+ if len(id_chain) == 1:
+ return id_seed_node
+ return id_seed_node.add_descendant(id_chain[1:])
+
+ def get_id_node(self, id_chain):
+ id_seed = id_chain[0]
+ id_seed_node = self.get_id_seed_node(id_seed)
+ if id_seed_node == None:
+ return None
+ if len(id_chain) == 1:
+ return id_seed_node
+ return id_seed_node.get_descendant(id_chain[1:])
+
+ def top(self):
+ return self.stack[-1]
+
+
+class IDStatusNode():
+
+ def __init__(self, name=None, root=None):
+ if root is None:
+ self.root = self
+ else:
+ self.root = root
+
+ self.name = name
+
+ self.parent = None
+ self.children = {}
+
+ self.assign = False
+ self.last_assign_coord = None
+ self.refer = False
+ self.last_refer_coord = None
+
+ self.decl_status = None
+
+ self.link_id_chain = None
+ self.link_node = None
+
+ self.visit = False
+
+ def set_link_id_chain(self, link_id_chain):
+ self.set_assign(False)
+ self.link_id_chain = link_id_chain
+ self.link_node = self.root.get_descendant(link_id_chain)
+
+ def set_link_node(self, link_node):
+ self.set_assign(False)
+ self.link_id_chain = ['*']
+ self.link_node = link_node
+
+ def get_link_id_chain(self):
+ return self.link_id_chain
+
+ def get_concrete_node(self):
+ if self.visit == True:
+ # return None when there is a loop
+ return None
+ self.visit = True
+ if self.link_node == None:
+ self.visit = False
+ return self
+ else:
+ concrete_node = self.link_node.get_concrete_node()
+ self.visit = False
+ if concrete_node == None:
+ return self
+ return concrete_node
+
+ def set_assign(self, assign, coord=None):
+ concrete_node = self.get_concrete_node()
+ concrete_node.assign = assign
+ concrete_node.last_assign_coord = coord
+
+ def get_assign(self):
+ concrete_node = self.get_concrete_node()
+ return concrete_node.assign
+
+ def set_refer(self, refer, coord=None):
+ concrete_node = self.get_concrete_node()
+ concrete_node.refer = refer
+ concrete_node.last_refer_coord = coord
+
+ def get_refer(self):
+ concrete_node = self.get_concrete_node()
+ return concrete_node.refer
+
+ def set_parent(self, parent):
+ concrete_node = self.get_concrete_node()
+ concrete_node.parent = parent
+
+ def add_child(self, name, decl_status=None):
+ concrete_node = self.get_concrete_node()
+ if name not in concrete_node.children:
+ child_id_node = IDStatusNode(name, concrete_node.root)
+ concrete_node.children[name] = child_id_node
+ if decl_status == None:
+ # Check if the child decl_status can be inferred from its parent's
+ # decl_status
+ if self.decl_status != None:
+ decl_status = self.decl_status.get_child_decl_status(name)
+ child_id_node.set_decl_status(decl_status)
+ return concrete_node.children[name]
+
+ def get_child(self, name):
+ concrete_node = self.get_concrete_node()
+ if name in concrete_node.children:
+ return concrete_node.children[name]
+ else:
+ return None
+
+ def add_descendant(self, id_chain):
+ current_node = self.get_concrete_node()
+ for name in id_chain:
+ current_node.add_child(name)
+ parent_node = current_node
+ current_node = current_node.get_child(name)
+ current_node.set_parent(parent_node)
+ return current_node
+
+ def get_descendant(self, id_chain):
+ current_node = self.get_concrete_node()
+ for name in id_chain:
+ current_node = current_node.get_child(name)
+ if current_node == None:
+ return None
+ return current_node
+
+ def get_children(self):
+ current_node = self.get_concrete_node()
+ return current_node.children
+
+ def set_decl_status(self, decl_status):
+ current_node = self.get_concrete_node()
+ current_node.decl_status = decl_status
+
+ def get_decl_status(self):
+ current_node = self.get_concrete_node()
+ return current_node.decl_status
+
+ def __str__(self):
+ if self.link_id_chain is None:
+ return str(self.name) + ' a: ' + str(int(self.assign)) + ' r: ' + str(
+ int(self.refer))
+ else:
+ return str(self.name) + ' -> ' + ' '.join(self.link_id_chain)
+
+ def collect_assign_refer_status(self,
+ id_chain=None,
+ assign_ls=None,
+ refer_ls=None):
+ if id_chain == None:
+ id_chain = []
+ if assign_ls == None:
+ assign_ls = []
+ if refer_ls == None:
+ refer_ls = []
+ id_chain.append(self.name)
+ if self.assign:
+ info_str = ' '.join([
+ ' '.join(id_chain[1:]), 'a:',
+ str(int(self.assign)), 'r:',
+ str(int(self.refer)),
+ str(self.last_assign_coord)
+ ])
+ assign_ls.append(info_str)
+ if self.refer:
+ info_str = ' '.join([
+ ' '.join(id_chain[1:]), 'a:',
+ str(int(self.assign)), 'r:',
+ str(int(self.refer)),
+ str(self.last_refer_coord)
+ ])
+ refer_ls.append(info_str)
+ for c in self.children:
+ self.children[c].collect_assign_refer_status(id_chain, assign_ls,
+ refer_ls)
+ id_chain.pop()
+ return assign_ls, refer_ls
+
+ def show(self):
+ assign_ls, refer_ls = self.collect_assign_refer_status()
+ print('---- assign ----')
+ for item in assign_ls:
+ print(item)
+ print('---- refer ----')
+ for item in refer_ls:
+ print(item)
+
+
+class FuncInOutVisitor(c_ast.NodeVisitor):
+
+ def __init__(self,
+ func_def_node,
+ struct_info,
+ func_dictionary,
+ keep_body_id_tree=True,
+ call_param_map=None,
+ global_id_tree=None,
+ func_history=None,
+ unknown=None):
+ self.func_dictionary = func_dictionary
+ self.struct_info = struct_info
+ self.param_id_map = get_func_param_id_map(func_def_node)
+ self.parent_node = None
+ self.global_id_tree = global_id_tree
+ self.body_id_tree = None
+ self.keep_body_id_tree = keep_body_id_tree
+ if func_history == None:
+ self.func_history = {}
+ else:
+ self.func_history = func_history
+
+ if unknown == None:
+ self.unknown = []
+ else:
+ self.unknown = unknown
+
+ self.id_tree_stack = IDTreeStack(global_id_tree)
+ self.id_tree_stack.push_id_tree()
+
+ #TODO move this part into a function
+ for param in self.param_id_map:
+ decl_node = self.param_id_map[param]
+ decl_status = parse_decl_node(self.struct_info, decl_node)
+ descendant = self.id_tree_stack.add_id_seed_node(decl_status.name,
+ decl_status)
+ if call_param_map is not None and param in call_param_map:
+ # This is a function call.
+ # Map the input parameter to the caller's nodes
+ # TODO(angiebird): Can we use add_link_node here?
+ descendant.set_link_node(call_param_map[param])
+
+ def get_id_tree_stack(self):
+ return self.id_tree_stack
+
+ def generic_visit(self, node):
+ prev_parent = self.parent_node
+ self.parent_node = node
+ for c in node:
+ self.visit(c)
+ self.parent_node = prev_parent
+
+ # TODO rename
+ def add_new_id_tree(self, node):
+ self.id_tree_stack.push_id_tree()
+ self.generic_visit(node)
+ id_tree = self.id_tree_stack.pop_id_tree()
+ if self.parent_node == None and self.keep_body_id_tree == True:
+ # this is function body
+ self.body_id_tree = id_tree
+
+ def visit_For(self, node):
+ self.add_new_id_tree(node)
+
+ def visit_Compound(self, node):
+ self.add_new_id_tree(node)
+
+ def visit_Decl(self, node):
+ if node.type.__class__.__name__ != 'FuncDecl':
+ decl_status = parse_decl_node(self.struct_info, node)
+ descendant = self.id_tree_stack.add_id_seed_node(decl_status.name,
+ decl_status)
+ if node.init is not None:
+ init_id_chain = self.process_lvalue(node.init)
+ if init_id_chain != None:
+ if decl_status.struct_item is None:
+ init_descendant = self.id_tree_stack.add_id_node(init_id_chain)
+ if init_descendant != None:
+ init_descendant.set_refer(True, node.coord)
+ else:
+ self.unknown.append(node)
+ descendant.set_assign(True, node.coord)
+ else:
+ self.id_tree_stack.add_link_node(descendant, init_id_chain)
+ else:
+ self.unknown.append(node)
+ else:
+ descendant.set_assign(True, node.coord)
+ self.generic_visit(node)
+
+ def is_lvalue(self, node):
+ if self.parent_node is None:
+ # TODO(angiebird): Do every lvalue has parent_node != None?
+ return False
+ if self.parent_node.__class__.__name__ == 'StructRef':
+ return False
+ if self.parent_node.__class__.__name__ == 'ArrayRef' and node == self.parent_node.name:
+ # if node == self.parent_node.subscript, the node could be lvalue
+ return False
+ if self.parent_node.__class__.__name__ == 'UnaryOp' and self.parent_node.op == '&':
+ return False
+ if self.parent_node.__class__.__name__ == 'UnaryOp' and self.parent_node.op == '*':
+ return False
+ return True
+
+ def process_lvalue(self, node):
+ id_chain = parse_lvalue(node)
+ if id_chain == None:
+ return id_chain
+ elif id_chain[0] in self.struct_info.enum_value_dic:
+ return None
+ else:
+ return id_chain
+
+ def process_possible_lvalue(self, node):
+ if self.is_lvalue(node):
+ id_chain = self.process_lvalue(node)
+ lead_char = get_lvalue_lead(node)
+ # make sure the id is not an enum value
+ if id_chain == None:
+ self.unknown.append(node)
+ return
+ descendant = self.id_tree_stack.add_id_node(id_chain)
+ if descendant == None:
+ self.unknown.append(node)
+ return
+ decl_status = descendant.get_decl_status()
+ if decl_status == None:
+ descendant.set_assign(True, node.coord)
+ descendant.set_refer(True, node.coord)
+ self.unknown.append(node)
+ return
+ if self.parent_node.__class__.__name__ == 'Assignment':
+ if node is self.parent_node.lvalue:
+ if decl_status.struct_item != None:
+ if len(id_chain) > 1:
+ descendant.set_assign(True, node.coord)
+ elif len(id_chain) == 1:
+ if lead_char == '*':
+ descendant.set_assign(True, node.coord)
+ else:
+ right_id_chain = self.process_lvalue(self.parent_node.rvalue)
+ if right_id_chain != None:
+ self.id_tree_stack.add_link_node(descendant, right_id_chain)
+ else:
+ #TODO(angiebird): 1.Find a better way to deal with this case.
+ descendant.set_assign(True, node.coord)
+ else:
+ debug_print(getframeinfo(currentframe()))
+ else:
+ descendant.set_assign(True, node.coord)
+ elif node is self.parent_node.rvalue:
+ if decl_status.struct_item is None:
+ descendant.set_refer(True, node.coord)
+ if lead_char == '&':
+ descendant.set_assign(True, node.coord)
+ else:
+ left_id_chain = self.process_lvalue(self.parent_node.lvalue)
+ left_lead_char = get_lvalue_lead(self.parent_node.lvalue)
+ if left_id_chain != None:
+ if len(left_id_chain) > 1:
+ descendant.set_refer(True, node.coord)
+ elif len(left_id_chain) == 1:
+ if left_lead_char == '*':
+ descendant.set_refer(True, node.coord)
+ else:
+ #TODO(angiebird): Check whether the other node is linked to this node.
+ pass
+ else:
+ self.unknown.append(self.parent_node.lvalue)
+ debug_print(getframeinfo(currentframe()))
+ else:
+ self.unknown.append(self.parent_node.lvalue)
+ debug_print(getframeinfo(currentframe()))
+ else:
+ debug_print(getframeinfo(currentframe()))
+ elif self.parent_node.__class__.__name__ == 'UnaryOp':
+ # TODO(angiebird): Consider +=, *=, -=, /= etc
+ if self.parent_node.op == '--' or self.parent_node.op == '++' or\
+ self.parent_node.op == 'p--' or self.parent_node.op == 'p++':
+ descendant.set_assign(True, node.coord)
+ descendant.set_refer(True, node.coord)
+ else:
+ descendant.set_refer(True, node.coord)
+ elif self.parent_node.__class__.__name__ == 'Decl':
+ #The logic is at visit_Decl
+ pass
+ elif self.parent_node.__class__.__name__ == 'ExprList':
+ #The logic is at visit_FuncCall
+ pass
+ else:
+ descendant.set_refer(True, node.coord)
+
+ def visit_ID(self, node):
+ # If the parent is a FuncCall, this ID is a function name.
+ if self.parent_node.__class__.__name__ != 'FuncCall':
+ self.process_possible_lvalue(node)
+ self.generic_visit(node)
+
+ def visit_StructRef(self, node):
+ self.process_possible_lvalue(node)
+ self.generic_visit(node)
+
+ def visit_ArrayRef(self, node):
+ self.process_possible_lvalue(node)
+ self.generic_visit(node)
+
+ def visit_UnaryOp(self, node):
+ if node.op == '&' or node.op == '*':
+ self.process_possible_lvalue(node)
+ self.generic_visit(node)
+
+ def visit_FuncCall(self, node):
+ if node.name.__class__.__name__ == 'ID':
+ if node.name.name in self.func_dictionary:
+ if node.name.name not in self.func_history:
+ self.func_history[node.name.name] = True
+ func_def_node = self.func_dictionary[node.name.name]
+ call_param_map = self.process_func_call(node, func_def_node)
+
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary, False,
+ call_param_map, self.global_id_tree,
+ self.func_history, self.unknown)
+ visitor.visit(func_def_node.body)
+ else:
+ self.unknown.append(node)
+ self.generic_visit(node)
+
+ def process_func_call(self, func_call_node, func_def_node):
+ # set up a refer/assign for func parameters
+ # return call_param_map
+ call_param_ls = func_call_node.args.exprs
+ call_param_map = {}
+
+ func_decl = func_def_node.decl.type
+ decl_param_ls = func_decl.args.params
+ for param_node, decl_node in zip(call_param_ls, decl_param_ls):
+ id_chain = self.process_lvalue(param_node)
+ if id_chain != None:
+ descendant = self.id_tree_stack.add_id_node(id_chain)
+ if descendant == None:
+ self.unknown.append(param_node)
+ else:
+ decl_status = descendant.get_decl_status()
+ if decl_status != None:
+ if decl_status.struct_item == None:
+ if decl_status.is_ptr_decl == True:
+ descendant.set_assign(True, param_node.coord)
+ descendant.set_refer(True, param_node.coord)
+ else:
+ descendant.set_refer(True, param_node.coord)
+ else:
+ call_param_map[decl_node.name] = descendant
+ else:
+ self.unknown.append(param_node)
+ else:
+ self.unknown.append(param_node)
+ return call_param_map
+
+
+def build_global_id_tree(ast, struct_info):
+ global_id_tree = IDStatusNode()
+ for node in ast.ext:
+ if node.__class__.__name__ == 'Decl':
+ # id tree is for tracking assign/refer status
+ # we don't care about function id because they can't be changed
+ if node.type.__class__.__name__ != 'FuncDecl':
+ decl_status = parse_decl_node(struct_info, node)
+ descendant = global_id_tree.add_child(decl_status.name, decl_status)
+ return global_id_tree
+
+
+class FuncAnalyzer():
+
+ def __init__(self):
+ self.ast = get_av1_ast()
+ self.struct_info = build_struct_info(self.ast)
+ self.func_dictionary = build_func_dictionary(self.ast)
+ self.global_id_tree = build_global_id_tree(self.ast, self.struct_info)
+
+ def analyze(self, func_name):
+ if func_name in self.func_dictionary:
+ func_def_node = self.func_dictionary[func_name]
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary, True, None,
+ self.global_id_tree)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ root.top().show()
+ else:
+ print(func_name, "doesn't exist")
+
+
+if __name__ == '__main__':
+ fa = FuncAnalyzer()
+ fa.analyze('tpl_get_satd_cost')
+ pass
diff --git a/third_party/aom/tools/auto_refactor/av1_preprocess.py b/third_party/aom/tools/auto_refactor/av1_preprocess.py
new file mode 100644
index 0000000000..ea76912cf1
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/av1_preprocess.py
@@ -0,0 +1,113 @@
+# Copyright (c) 2021, Alliance for Open Media. All rights reserved
+#
+# This source code is subject to the terms of the BSD 2 Clause License and
+# the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+# was not distributed with this source code in the LICENSE file, you can
+# obtain it at www.aomedia.org/license/software. If the Alliance for Open
+# Media Patent License 1.0 was not distributed with this source code in the
+# PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+#
+
+import os
+import sys
+
+
+def is_code_file(filename):
+ return filename.endswith(".c") or filename.endswith(".h")
+
+
+def is_simd_file(filename):
+ simd_keywords = [
+ "avx2", "sse2", "sse3", "ssse3", "sse4", "dspr2", "neon", "msa", "simd",
+ "x86"
+ ]
+ for keyword in simd_keywords:
+ if filename.find(keyword) >= 0:
+ return True
+ return False
+
+
+def get_code_file_list(path, exclude_file_set):
+ code_file_list = []
+ for cur_dir, sub_dir, file_list in os.walk(path):
+ for filename in file_list:
+ if is_code_file(filename) and not is_simd_file(
+ filename) and filename not in exclude_file_set:
+ file_path = os.path.join(cur_dir, filename)
+ code_file_list.append(file_path)
+ return code_file_list
+
+
+def av1_exclude_file_set():
+ exclude_file_set = {
+ "cfl_ppc.c",
+ "ppc_cpudetect.c",
+ }
+ return exclude_file_set
+
+
+def get_av1_pp_command(fake_header_dir, code_file_list):
+ pre_command = "gcc -w -nostdinc -E -I./ -I../ -I" + fake_header_dir + (" "
+ "-D'ATTRIBUTE_PACKED='"
+ " "
+ "-D'__attribute__(x)='"
+ " "
+ "-D'__inline__='"
+ " "
+ "-D'float_t=float'"
+ " "
+ "-D'DECLARE_ALIGNED(n,"
+ " typ,"
+ " "
+ "val)=typ"
+ " val'"
+ " "
+ "-D'volatile='"
+ " "
+ "-D'AV1_K_MEANS_DIM=2'"
+ " "
+ "-D'INLINE='"
+ " "
+ "-D'AOM_INLINE='"
+ " "
+ "-D'AOM_FORCE_INLINE='"
+ " "
+ "-D'inline='"
+ )
+ return pre_command + " " + " ".join(code_file_list)
+
+
+def modify_av1_rtcd(build_dir):
+ av1_rtcd = os.path.join(build_dir, "config/av1_rtcd.h")
+ fp = open(av1_rtcd)
+ string = fp.read()
+ fp.close()
+ new_string = string.replace("#ifdef RTCD_C", "#if 0")
+ fp = open(av1_rtcd, "w")
+ fp.write(new_string)
+ fp.close()
+
+
+def preprocess_av1(aom_dir, build_dir, fake_header_dir):
+ cur_dir = os.getcwd()
+ output = os.path.join(cur_dir, "av1_pp.c")
+ path_list = [
+ os.path.join(aom_dir, "av1/encoder"),
+ os.path.join(aom_dir, "av1/common")
+ ]
+ code_file_list = []
+ for path in path_list:
+ path = os.path.realpath(path)
+ code_file_list.extend(get_code_file_list(path, av1_exclude_file_set()))
+ modify_av1_rtcd(build_dir)
+ cmd = get_av1_pp_command(fake_header_dir, code_file_list) + " >" + output
+ os.chdir(build_dir)
+ os.system(cmd)
+ os.chdir(cur_dir)
+
+
+if __name__ == "__main__":
+ aom_dir = sys.argv[1]
+ build_dir = sys.argv[2]
+ fake_header_dir = sys.argv[3]
+ preprocess_av1(aom_dir, build_dir, fake_header_dir)
diff --git a/third_party/aom/tools/auto_refactor/c_files/decl_status_code.c b/third_party/aom/tools/auto_refactor/c_files/decl_status_code.c
new file mode 100644
index 0000000000..a444553bb1
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/c_files/decl_status_code.c
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+typedef struct S1 {
+ int x;
+} T1;
+
+int parse_decl_node_2(void) { int arr[3]; }
+
+int parse_decl_node_3(void) { int *a; }
+
+int parse_decl_node_4(void) { T1 t1[3]; }
+
+int parse_decl_node_5(void) { T1 *t2[3]; }
+
+int parse_decl_node_6(void) { T1 t3[3][3]; }
+
+int main(void) {
+ int a;
+ T1 t1;
+ struct S1 s1;
+ T1 *t2;
+}
diff --git a/third_party/aom/tools/auto_refactor/c_files/func_in_out.c b/third_party/aom/tools/auto_refactor/c_files/func_in_out.c
new file mode 100644
index 0000000000..7f37bbae7e
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/c_files/func_in_out.c
@@ -0,0 +1,208 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+typedef struct XD {
+ int u;
+ int v;
+} XD;
+
+typedef struct RD {
+ XD *xd;
+ int u;
+ int v;
+} RD;
+
+typedef struct VP9_COMP {
+ int y;
+ RD *rd;
+ RD rd2;
+ int arr[3];
+ union {
+ int z;
+ };
+ struct {
+ int w;
+ };
+} VP9_COMP;
+
+int sub_func(VP9_COMP *cpi, int b) {
+ int d;
+ cpi->y += 1;
+ cpi->y -= b;
+ d = cpi->y * 2;
+ return d;
+}
+
+int func_id_forrest_show(VP9_COMP *cpi, int b) {
+ int c = 2;
+ int x = cpi->y + c * 2 + 1;
+ int y;
+ RD *rd = cpi->rd;
+ y = cpi->rd->u;
+ return x + y;
+}
+
+int func_link_id_chain_1(VP9_COMP *cpi) {
+ RD *rd = cpi->rd;
+ rd->u = 0;
+}
+
+int func_link_id_chain_2(VP9_COMP *cpi) {
+ RD *rd = cpi->rd;
+ XD *xd = rd->xd;
+ xd->u = 0;
+}
+
+int func_assign_refer_status_1(VP9_COMP *cpi) { RD *rd = cpi->rd; }
+
+int func_assign_refer_status_2(VP9_COMP *cpi) {
+ RD *rd2;
+ rd2 = cpi->rd;
+}
+
+int func_assign_refer_status_3(VP9_COMP *cpi) {
+ int a;
+ a = cpi->y;
+}
+
+int func_assign_refer_status_4(VP9_COMP *cpi) {
+ int *b;
+ b = &cpi->y;
+}
+
+int func_assign_refer_status_5(VP9_COMP *cpi) {
+ RD *rd5;
+ rd5 = &cpi->rd2;
+}
+
+int func_assign_refer_status_6(VP9_COMP *cpi, VP9_COMP *cpi2) {
+ cpi->rd = cpi2->rd;
+}
+
+int func_assign_refer_status_7(VP9_COMP *cpi, VP9_COMP *cpi2) {
+ cpi->arr[3] = 0;
+}
+
+int func_assign_refer_status_8(VP9_COMP *cpi, VP9_COMP *cpi2) {
+ int x = cpi->arr[3];
+}
+
+int func_assign_refer_status_9(VP9_COMP *cpi) {
+ {
+ RD *rd = cpi->rd;
+ { rd->u = 0; }
+ }
+}
+
+int func_assign_refer_status_10(VP9_COMP *cpi) { cpi->arr[cpi->rd->u] = 0; }
+
+int func_assign_refer_status_11(VP9_COMP *cpi) {
+ RD *rd11 = &cpi->rd2;
+ rd11->v = 1;
+}
+
+int func_assign_refer_status_12(VP9_COMP *cpi, VP9_COMP *cpi2) {
+ *cpi->rd = *cpi2->rd;
+}
+
+int func_assign_refer_status_13(VP9_COMP *cpi) {
+ cpi->z = 0;
+ cpi->w = 0;
+}
+
+int func(VP9_COMP *cpi, int x) {
+ int a;
+ cpi->y = 4;
+ a = 3 + cpi->y;
+ a = a * x;
+ cpi->y *= 4;
+ RD *ref_rd = cpi->rd;
+ ref_rd->u = 0;
+ cpi->rd2.v = 1;
+ cpi->rd->v = 1;
+ RD *ref_rd2 = &cpi->rd2;
+ RD **ref_rd3 = &(&cpi->rd2);
+ int b = sub_func(cpi, a);
+ cpi->rd->v++;
+ return b;
+}
+
+int func_sub_call_1(VP9_COMP *cpi2, int x) { cpi2->y = 4; }
+
+int func_call_1(VP9_COMP *cpi, int y) { func_sub_call_1(cpi, y); }
+
+int func_sub_call_2(VP9_COMP *cpi2, RD *rd, int x) { rd->u = 0; }
+
+int func_call_2(VP9_COMP *cpi, int y) { func_sub_call_2(cpi, &cpi->rd, y); }
+
+int func_sub_call_3(VP9_COMP *cpi2, int x) {}
+
+int func_call_3(VP9_COMP *cpi, int y) { func_sub_call_3(cpi, ++cpi->y); }
+
+int func_sub_sub_call_4(VP9_COMP *cpi3, XD *xd) {
+ cpi3->rd.u = 0;
+ xd->u = 0;
+}
+
+int func_sub_call_4(VP9_COMP *cpi2, RD *rd) {
+ func_sub_sub_call_4(cpi2, rd->xd);
+}
+
+int func_call_4(VP9_COMP *cpi, int y) { func_sub_call_4(cpi, &cpi->rd); }
+
+int func_sub_call_5(VP9_COMP *cpi) {
+ cpi->y = 2;
+ func_call_5(cpi);
+}
+
+int func_call_5(VP9_COMP *cpi) { func_sub_call_5(cpi); }
+
+int func_compound_1(VP9_COMP *cpi) {
+ for (int i = 0; i < 10; ++i) {
+ cpi->y++;
+ }
+}
+
+int func_compound_2(VP9_COMP *cpi) {
+ for (int i = 0; i < cpi->y; ++i) {
+ cpi->rd->u = i;
+ }
+}
+
+int func_compound_3(VP9_COMP *cpi) {
+ int i = 3;
+ while (i > 0) {
+ cpi->rd->u = i;
+ i--;
+ }
+}
+
+int func_compound_4(VP9_COMP *cpi) {
+ while (cpi->y-- >= 0) {
+ }
+}
+
+int func_compound_5(VP9_COMP *cpi) {
+ do {
+ } while (cpi->y-- >= 0);
+}
+
+int func_compound_6(VP9_COMP *cpi) {
+ for (int i = 0; i < 10; ++i) cpi->y--;
+}
+
+int main(void) {
+ int x;
+ VP9_COMP cpi;
+ RD rd;
+ cpi->rd = rd;
+ func(&cpi, x);
+}
diff --git a/third_party/aom/tools/auto_refactor/c_files/global_variable.c b/third_party/aom/tools/auto_refactor/c_files/global_variable.c
new file mode 100644
index 0000000000..26d5385e97
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/c_files/global_variable.c
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+extern const int global_a[13];
+
+const int global_b = 0;
+
+typedef struct S1 {
+ int x;
+} T1;
+
+struct S3 {
+ int x;
+} s3;
+
+int func_global_1(int *a) {
+ *a = global_a[3];
+ return 0;
+}
diff --git a/third_party/aom/tools/auto_refactor/c_files/parse_lvalue.c b/third_party/aom/tools/auto_refactor/c_files/parse_lvalue.c
new file mode 100644
index 0000000000..fa44d72381
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/c_files/parse_lvalue.c
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+typedef struct RD {
+ int u;
+ int v;
+ int arr[3];
+} RD;
+
+typedef struct VP9_COMP {
+ int y;
+ RD *rd;
+ RD rd2;
+ RD rd3[2];
+} VP9_COMP;
+
+int parse_lvalue_2(VP9_COMP *cpi) { RD *rd2 = &cpi->rd2; }
+
+int func(VP9_COMP *cpi, int x) {
+ cpi->rd->u = 0;
+
+ int y;
+ y = 0;
+
+ cpi->rd2.v = 0;
+
+ cpi->rd->arr[2] = 0;
+
+ cpi->rd3[1]->arr[2] = 0;
+
+ return 0;
+}
+
+int main(void) {
+ int x = 0;
+ VP9_COMP cpi;
+ func(&cpi, x);
+}
diff --git a/third_party/aom/tools/auto_refactor/c_files/simple_code.c b/third_party/aom/tools/auto_refactor/c_files/simple_code.c
new file mode 100644
index 0000000000..902cd1d826
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/c_files/simple_code.c
@@ -0,0 +1,64 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+typedef struct S {
+ int x;
+ int y;
+ int z;
+} S;
+
+typedef struct T {
+ S s;
+} T;
+
+int d(S *s) {
+ ++s->x;
+ s->x--;
+ s->y = s->y + 1;
+ int *c = &s->x;
+ S ss;
+ ss.x = 1;
+ ss.x += 2;
+ ss.z *= 2;
+ return 0;
+}
+int b(S *s) {
+ d(s);
+ return 0;
+}
+int c(int x) {
+ if (x) {
+ c(x - 1);
+ } else {
+ S s;
+ d(&s);
+ }
+ return 0;
+}
+int a(S *s) {
+ b(s);
+ c(1);
+ return 0;
+}
+int e(void) {
+ c(0);
+ return 0;
+}
+int main(void) {
+ int p = 3;
+ S s;
+ s.x = p + 1;
+ s.y = 2;
+ s.z = 3;
+ a(&s);
+ T t;
+ t.s.x = 3;
+}
diff --git a/third_party/aom/tools/auto_refactor/c_files/struct_code.c b/third_party/aom/tools/auto_refactor/c_files/struct_code.c
new file mode 100644
index 0000000000..7f24d41075
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/c_files/struct_code.c
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+typedef struct S1 {
+ int x;
+} T1;
+
+struct S3 {
+ int x;
+};
+
+typedef struct {
+ int x;
+ struct S3 s3;
+} T4;
+
+typedef union U5 {
+ int x;
+ double y;
+} T5;
+
+typedef struct S6 {
+ struct {
+ int x;
+ };
+ union {
+ int y;
+ int z;
+ };
+} T6;
+
+typedef struct S7 {
+ struct {
+ int x;
+ } y;
+ union {
+ int w;
+ } z;
+} T7;
+
+int main(void) {}
diff --git a/third_party/aom/tools/auto_refactor/test_auto_refactor.py b/third_party/aom/tools/auto_refactor/test_auto_refactor.py
new file mode 100644
index 0000000000..6b1e269efa
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/test_auto_refactor.py
@@ -0,0 +1,675 @@
+#!/usr/bin/env python
+# Copyright (c) 2021, Alliance for Open Media. All rights reserved
+#
+# This source code is subject to the terms of the BSD 2 Clause License and
+# the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+# was not distributed with this source code in the LICENSE file, you can
+# obtain it at www.aomedia.org/license/software. If the Alliance for Open
+# Media Patent License 1.0 was not distributed with this source code in the
+# PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+#
+
+import pprint
+import re
+import os, sys
+import io
+import unittest as googletest
+
+sys.path[0:0] = ['.', '..']
+
+from pycparser import c_parser, parse_file
+from pycparser.c_ast import *
+from pycparser.c_parser import CParser, Coord, ParseError
+
+from auto_refactor import *
+
+
+def get_c_file_path(filename):
+ return os.path.join('c_files', filename)
+
+
+class TestStructInfo(googletest.TestCase):
+
+ def setUp(self):
+ filename = get_c_file_path('struct_code.c')
+ self.ast = parse_file(filename)
+
+ def test_build_struct_info(self):
+ struct_info = build_struct_info(self.ast)
+ typedef_name_dic = struct_info.typedef_name_dic
+ self.assertEqual('T1' in typedef_name_dic, True)
+ self.assertEqual('T4' in typedef_name_dic, True)
+ self.assertEqual('T5' in typedef_name_dic, True)
+
+ struct_name_dic = struct_info.struct_name_dic
+ struct_name = 'S1'
+ self.assertEqual(struct_name in struct_name_dic, True)
+ struct_item = struct_name_dic[struct_name]
+ self.assertEqual(struct_item.is_union, False)
+
+ struct_name = 'S3'
+ self.assertEqual(struct_name in struct_name_dic, True)
+ struct_item = struct_name_dic[struct_name]
+ self.assertEqual(struct_item.is_union, False)
+
+ struct_name = 'U5'
+ self.assertEqual(struct_name in struct_name_dic, True)
+ struct_item = struct_name_dic[struct_name]
+ self.assertEqual(struct_item.is_union, True)
+
+ self.assertEqual(len(struct_info.struct_item_list), 6)
+
+ def test_get_child_decl_status(self):
+ struct_info = build_struct_info(self.ast)
+ struct_item = struct_info.typedef_name_dic['T4']
+
+ decl_status = struct_item.child_decl_map['x']
+ self.assertEqual(decl_status.struct_item, None)
+ self.assertEqual(decl_status.is_ptr_decl, False)
+
+ decl_status = struct_item.child_decl_map['s3']
+ self.assertEqual(decl_status.struct_item.struct_name, 'S3')
+ self.assertEqual(decl_status.is_ptr_decl, False)
+
+ struct_item = struct_info.typedef_name_dic['T6']
+ decl_status = struct_item.child_decl_map['x']
+ self.assertEqual(decl_status.struct_item, None)
+ self.assertEqual(decl_status.is_ptr_decl, False)
+
+ decl_status = struct_item.child_decl_map['y']
+ self.assertEqual(decl_status.struct_item, None)
+ self.assertEqual(decl_status.is_ptr_decl, False)
+
+ decl_status = struct_item.child_decl_map['z']
+ self.assertEqual(decl_status.struct_item, None)
+ self.assertEqual(decl_status.is_ptr_decl, False)
+
+ struct_item = struct_info.typedef_name_dic['T7']
+ decl_status = struct_item.child_decl_map['y']
+ self.assertEqual('x' in decl_status.struct_item.child_decl_map, True)
+
+ struct_item = struct_info.typedef_name_dic['T7']
+ decl_status = struct_item.child_decl_map['z']
+ self.assertEqual('w' in decl_status.struct_item.child_decl_map, True)
+
+
+class TestParseLvalue(googletest.TestCase):
+
+ def setUp(self):
+ filename = get_c_file_path('parse_lvalue.c')
+ self.ast = parse_file(filename)
+ self.func_dictionary = build_func_dictionary(self.ast)
+
+ def test_parse_lvalue(self):
+ func_node = self.func_dictionary['func']
+ func_body_items = func_node.body.block_items
+ id_list = parse_lvalue(func_body_items[0].lvalue)
+ ref_id_list = ['cpi', 'rd', 'u']
+ self.assertEqual(id_list, ref_id_list)
+
+ id_list = parse_lvalue(func_body_items[2].lvalue)
+ ref_id_list = ['y']
+ self.assertEqual(id_list, ref_id_list)
+
+ id_list = parse_lvalue(func_body_items[3].lvalue)
+ ref_id_list = ['cpi', 'rd2', 'v']
+ self.assertEqual(id_list, ref_id_list)
+
+ id_list = parse_lvalue(func_body_items[4].lvalue)
+ ref_id_list = ['cpi', 'rd', 'arr']
+ self.assertEqual(id_list, ref_id_list)
+
+ id_list = parse_lvalue(func_body_items[5].lvalue)
+ ref_id_list = ['cpi', 'rd3', 'arr']
+ self.assertEqual(id_list, ref_id_list)
+
+ def test_parse_lvalue_2(self):
+ func_node = self.func_dictionary['parse_lvalue_2']
+ func_body_items = func_node.body.block_items
+ id_list = parse_lvalue(func_body_items[0].init)
+ ref_id_list = ['cpi', 'rd2']
+ self.assertEqual(id_list, ref_id_list)
+
+
+class TestIDStatusNode(googletest.TestCase):
+
+ def test_add_descendant(self):
+ root = IDStatusNode('root')
+ id_chain1 = ['cpi', 'rd', 'u']
+ id_chain2 = ['cpi', 'rd', 'v']
+ root.add_descendant(id_chain1)
+ root.add_descendant(id_chain2)
+
+ ref_children_list1 = ['cpi']
+ children_list1 = list(root.children.keys())
+ self.assertEqual(children_list1, ref_children_list1)
+
+ ref_children_list2 = ['rd']
+ children_list2 = list(root.children['cpi'].children.keys())
+ self.assertEqual(children_list2, ref_children_list2)
+
+ ref_children_list3 = ['u', 'v']
+ children_list3 = list(root.children['cpi'].children['rd'].children.keys())
+ self.assertEqual(children_list3, ref_children_list3)
+
+ def test_get_descendant(self):
+ root = IDStatusNode('root')
+ id_chain1 = ['cpi', 'rd', 'u']
+ id_chain2 = ['cpi', 'rd', 'v']
+ ref_descendant_1 = root.add_descendant(id_chain1)
+ ref_descendant_2 = root.add_descendant(id_chain2)
+
+ descendant_1 = root.get_descendant(id_chain1)
+ self.assertEqual(descendant_1 is ref_descendant_1, True)
+
+ descendant_2 = root.get_descendant(id_chain2)
+ self.assertEqual(descendant_2 is ref_descendant_2, True)
+
+ id_chain3 = ['cpi', 'rd', 'h']
+ descendant_3 = root.get_descendant(id_chain3)
+ self.assertEqual(descendant_3, None)
+
+
+class TestFuncInOut(googletest.TestCase):
+
+ def setUp(self):
+ c_filename = get_c_file_path('func_in_out.c')
+ self.ast = parse_file(c_filename)
+ self.func_dictionary = build_func_dictionary(self.ast)
+ self.struct_info = build_struct_info(self.ast)
+
+ def test_get_func_param_id_map(self):
+ func_def_node = self.func_dictionary['func']
+ param_id_map = get_func_param_id_map(func_def_node)
+ ref_param_id_map_keys = ['cpi', 'x']
+ self.assertEqual(list(param_id_map.keys()), ref_param_id_map_keys)
+
+ def test_assign_refer_status_1(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_1']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ body_id_tree = visitor.body_id_tree
+
+ id_chain = ['rd']
+ descendant = body_id_tree.get_descendant(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), False)
+ ref_link_id_chain = ['cpi', 'rd']
+ self.assertEqual(ref_link_id_chain, descendant.get_link_id_chain())
+
+ id_chain = ['cpi', 'rd']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), False)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ def test_assign_refer_status_2(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_2']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ body_id_tree = visitor.body_id_tree
+
+ id_chain = ['rd2']
+ descendant = body_id_tree.get_descendant(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), False)
+
+ ref_link_id_chain = ['cpi', 'rd']
+ self.assertEqual(ref_link_id_chain, descendant.get_link_id_chain())
+
+ id_chain = ['cpi', 'rd']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), False)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ def test_assign_refer_status_3(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_3']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ body_id_tree = visitor.body_id_tree
+
+ id_chain = ['a']
+ descendant = body_id_tree.get_descendant(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), True)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ def test_assign_refer_status_4(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_4']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ body_id_tree = visitor.body_id_tree
+
+ id_chain = ['b']
+ descendant = body_id_tree.get_descendant(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), True)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ def test_assign_refer_status_5(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_5']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ body_id_tree = visitor.body_id_tree
+
+ id_chain = ['rd5']
+ descendant = body_id_tree.get_descendant(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), False)
+
+ id_chain = ['cpi', 'rd2']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), False)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ def test_assign_refer_status_6(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_6']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+
+ id_chain = ['cpi', 'rd']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ id_chain = ['cpi2', 'rd']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), True)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ def test_assign_refer_status_7(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_7']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'arr']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_assign_refer_status_8(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_8']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'arr']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), True)
+
+ def test_assign_refer_status_9(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_9']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'rd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_assign_refer_status_10(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_10']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'rd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), True)
+
+ id_chain = ['cpi', 'arr']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_assign_refer_status_11(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_11']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'rd2', 'v']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_assign_refer_status_12(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_12']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'rd']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ id_chain = ['cpi2', 'rd']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), True)
+
+ def test_assign_refer_status_13(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_13']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'z']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ id_chain = ['cpi', 'w']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_id_status_forrest_1(self):
+ func_def_node = self.func_dictionary['func']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack().top()
+ children_names = set(root.get_children().keys())
+ ref_children_names = set(['cpi', 'x'])
+ self.assertEqual(children_names, ref_children_names)
+
+ root = visitor.body_id_tree
+ children_names = set(root.get_children().keys())
+ ref_children_names = set(['a', 'ref_rd', 'ref_rd2', 'ref_rd3', 'b'])
+ self.assertEqual(children_names, ref_children_names)
+
+ def test_id_status_forrest_show(self):
+ func_def_node = self.func_dictionary['func_id_forrest_show']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ visitor.get_id_tree_stack().top().show()
+
+ def test_id_status_forrest_2(self):
+ func_def_node = self.func_dictionary['func_id_forrest_show']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack().top()
+ self.assertEqual(root, root.root)
+
+ id_chain = ['cpi', 'rd']
+ descendant = root.get_descendant(id_chain)
+ self.assertEqual(root, descendant.root)
+
+ id_chain = ['b']
+ descendant = root.get_descendant(id_chain)
+ self.assertEqual(root, descendant.root)
+
+ def test_link_id_chain_1(self):
+ func_def_node = self.func_dictionary['func_link_id_chain_1']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'rd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+
+ def test_link_id_chain_2(self):
+ func_def_node = self.func_dictionary['func_link_id_chain_2']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'rd', 'xd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+
+ def test_func_call_1(self):
+ func_def_node = self.func_dictionary['func_call_1']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ id_chain = ['y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), True)
+
+ def test_func_call_2(self):
+ func_def_node = self.func_dictionary['func_call_2']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'rd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ id_chain = ['cpi', 'rd']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_func_call_3(self):
+ func_def_node = self.func_dictionary['func_call_3']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), True)
+
+ def test_func_call_4(self):
+ func_def_node = self.func_dictionary['func_call_4']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+
+ id_chain = ['cpi', 'rd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ id_chain = ['cpi', 'rd', 'xd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_func_call_5(self):
+ func_def_node = self.func_dictionary['func_call_5']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_func_compound_1(self):
+ func_def_node = self.func_dictionary['func_compound_1']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), True)
+
+ def test_func_compound_2(self):
+ func_def_node = self.func_dictionary['func_compound_2']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), True)
+
+ id_chain = ['cpi', 'rd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_func_compound_3(self):
+ func_def_node = self.func_dictionary['func_compound_3']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+
+ id_chain = ['cpi', 'rd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_func_compound_4(self):
+ func_def_node = self.func_dictionary['func_compound_4']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), True)
+
+ def test_func_compound_5(self):
+ func_def_node = self.func_dictionary['func_compound_5']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), True)
+
+ def test_func_compound_6(self):
+ func_def_node = self.func_dictionary['func_compound_6']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), True)
+
+
+class TestDeclStatus(googletest.TestCase):
+
+ def setUp(self):
+ filename = get_c_file_path('decl_status_code.c')
+ self.ast = parse_file(filename)
+ self.func_dictionary = build_func_dictionary(self.ast)
+ self.struct_info = build_struct_info(self.ast)
+
+ def test_parse_decl_node(self):
+ func_def_node = self.func_dictionary['main']
+ decl_list = func_def_node.body.block_items
+ decl_status = parse_decl_node(self.struct_info, decl_list[0])
+ self.assertEqual(decl_status.name, 'a')
+ self.assertEqual(decl_status.is_ptr_decl, False)
+
+ decl_status = parse_decl_node(self.struct_info, decl_list[1])
+ self.assertEqual(decl_status.name, 't1')
+ self.assertEqual(decl_status.is_ptr_decl, False)
+
+ decl_status = parse_decl_node(self.struct_info, decl_list[2])
+ self.assertEqual(decl_status.name, 's1')
+ self.assertEqual(decl_status.is_ptr_decl, False)
+
+ decl_status = parse_decl_node(self.struct_info, decl_list[3])
+ self.assertEqual(decl_status.name, 't2')
+ self.assertEqual(decl_status.is_ptr_decl, True)
+
+ def test_parse_decl_node_2(self):
+ func_def_node = self.func_dictionary['parse_decl_node_2']
+ decl_list = func_def_node.body.block_items
+ decl_status = parse_decl_node(self.struct_info, decl_list[0])
+ self.assertEqual(decl_status.name, 'arr')
+ self.assertEqual(decl_status.is_ptr_decl, True)
+ self.assertEqual(decl_status.struct_item, None)
+
+ def test_parse_decl_node_3(self):
+ func_def_node = self.func_dictionary['parse_decl_node_3']
+ decl_list = func_def_node.body.block_items
+ decl_status = parse_decl_node(self.struct_info, decl_list[0])
+ self.assertEqual(decl_status.name, 'a')
+ self.assertEqual(decl_status.is_ptr_decl, True)
+ self.assertEqual(decl_status.struct_item, None)
+
+ def test_parse_decl_node_4(self):
+ func_def_node = self.func_dictionary['parse_decl_node_4']
+ decl_list = func_def_node.body.block_items
+ decl_status = parse_decl_node(self.struct_info, decl_list[0])
+ self.assertEqual(decl_status.name, 't1')
+ self.assertEqual(decl_status.is_ptr_decl, True)
+ self.assertEqual(decl_status.struct_item.typedef_name, 'T1')
+ self.assertEqual(decl_status.struct_item.struct_name, 'S1')
+
+ def test_parse_decl_node_5(self):
+ func_def_node = self.func_dictionary['parse_decl_node_5']
+ decl_list = func_def_node.body.block_items
+ decl_status = parse_decl_node(self.struct_info, decl_list[0])
+ self.assertEqual(decl_status.name, 't2')
+ self.assertEqual(decl_status.is_ptr_decl, True)
+ self.assertEqual(decl_status.struct_item.typedef_name, 'T1')
+ self.assertEqual(decl_status.struct_item.struct_name, 'S1')
+
+ def test_parse_decl_node_6(self):
+ func_def_node = self.func_dictionary['parse_decl_node_6']
+ decl_list = func_def_node.body.block_items
+ decl_status = parse_decl_node(self.struct_info, decl_list[0])
+ self.assertEqual(decl_status.name, 't3')
+ self.assertEqual(decl_status.is_ptr_decl, True)
+ self.assertEqual(decl_status.struct_item.typedef_name, 'T1')
+ self.assertEqual(decl_status.struct_item.struct_name, 'S1')
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/third_party/aom/tools/cpplint.py b/third_party/aom/tools/cpplint.py
new file mode 100755
index 0000000000..e3ebde2f5a
--- /dev/null
+++ b/third_party/aom/tools/cpplint.py
@@ -0,0 +1,6244 @@
+#!/usr/bin/env python
+#
+# Copyright (c) 2009 Google Inc. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Does google-lint on c++ files.
+
+The goal of this script is to identify places in the code that *may*
+be in non-compliance with google style. It does not attempt to fix
+up these problems -- the point is to educate. It does also not
+attempt to find all problems, or to ensure that everything it does
+find is legitimately a problem.
+
+In particular, we can get very confused by /* and // inside strings!
+We do a small hack, which is to ignore //'s with "'s after them on the
+same line, but it is far from perfect (in either direction).
+"""
+
+import codecs
+import copy
+import getopt
+import math # for log
+import os
+import re
+import sre_compile
+import string
+import sys
+import unicodedata
+import sysconfig
+
+try:
+ xrange # Python 2
+except NameError:
+ xrange = range # Python 3
+
+
+_USAGE = """
+Syntax: cpplint.py [--verbose=#] [--output=vs7] [--filter=-x,+y,...]
+ [--counting=total|toplevel|detailed] [--root=subdir]
+ [--linelength=digits] [--headers=x,y,...]
+ [--quiet]
+ <file> [file] ...
+
+ The style guidelines this tries to follow are those in
+ https://google-styleguide.googlecode.com/svn/trunk/cppguide.xml
+
+ Every problem is given a confidence score from 1-5, with 5 meaning we are
+ certain of the problem, and 1 meaning it could be a legitimate construct.
+ This will miss some errors, and is not a substitute for a code review.
+
+ To suppress false-positive errors of a certain category, add a
+ 'NOLINT(category)' comment to the line. NOLINT or NOLINT(*)
+ suppresses errors of all categories on that line.
+
+ The files passed in will be linted; at least one file must be provided.
+ Default linted extensions are .cc, .cpp, .cu, .cuh and .h. Change the
+ extensions with the --extensions flag.
+
+ Flags:
+
+ output=vs7
+ By default, the output is formatted to ease emacs parsing. Visual Studio
+ compatible output (vs7) may also be used. Other formats are unsupported.
+
+ verbose=#
+ Specify a number 0-5 to restrict errors to certain verbosity levels.
+
+ quiet
+ Don't print anything if no errors are found.
+
+ filter=-x,+y,...
+ Specify a comma-separated list of category-filters to apply: only
+ error messages whose category names pass the filters will be printed.
+ (Category names are printed with the message and look like
+ "[whitespace/indent]".) Filters are evaluated left to right.
+ "-FOO" and "FOO" means "do not print categories that start with FOO".
+ "+FOO" means "do print categories that start with FOO".
+
+ Examples: --filter=-whitespace,+whitespace/braces
+ --filter=whitespace,runtime/printf,+runtime/printf_format
+ --filter=-,+build/include_what_you_use
+
+ To see a list of all the categories used in cpplint, pass no arg:
+ --filter=
+
+ counting=total|toplevel|detailed
+ The total number of errors found is always printed. If
+ 'toplevel' is provided, then the count of errors in each of
+ the top-level categories like 'build' and 'whitespace' will
+ also be printed. If 'detailed' is provided, then a count
+ is provided for each category like 'build/class'.
+
+ root=subdir
+ The root directory used for deriving header guard CPP variable.
+ By default, the header guard CPP variable is calculated as the relative
+ path to the directory that contains .git, .hg, or .svn. When this flag
+ is specified, the relative path is calculated from the specified
+ directory. If the specified directory does not exist, this flag is
+ ignored.
+
+ Examples:
+ Assuming that top/src/.git exists (and cwd=top/src), the header guard
+ CPP variables for top/src/chrome/browser/ui/browser.h are:
+
+ No flag => CHROME_BROWSER_UI_BROWSER_H_
+ --root=chrome => BROWSER_UI_BROWSER_H_
+ --root=chrome/browser => UI_BROWSER_H_
+ --root=.. => SRC_CHROME_BROWSER_UI_BROWSER_H_
+
+ linelength=digits
+ This is the allowed line length for the project. The default value is
+ 80 characters.
+
+ Examples:
+ --linelength=120
+
+ extensions=extension,extension,...
+ The allowed file extensions that cpplint will check
+
+ Examples:
+ --extensions=hpp,cpp
+
+ headers=x,y,...
+ The header extensions that cpplint will treat as .h in checks. Values are
+ automatically added to --extensions list.
+
+ Examples:
+ --headers=hpp,hxx
+ --headers=hpp
+
+ cpplint.py supports per-directory configurations specified in CPPLINT.cfg
+ files. CPPLINT.cfg file can contain a number of key=value pairs.
+ Currently the following options are supported:
+
+ set noparent
+ filter=+filter1,-filter2,...
+ exclude_files=regex
+ linelength=80
+ root=subdir
+ headers=x,y,...
+
+ "set noparent" option prevents cpplint from traversing directory tree
+ upwards looking for more .cfg files in parent directories. This option
+ is usually placed in the top-level project directory.
+
+ The "filter" option is similar in function to --filter flag. It specifies
+ message filters in addition to the |_DEFAULT_FILTERS| and those specified
+ through --filter command-line flag.
+
+ "exclude_files" allows to specify a regular expression to be matched against
+ a file name. If the expression matches, the file is skipped and not run
+ through liner.
+
+ "linelength" allows to specify the allowed line length for the project.
+
+ The "root" option is similar in function to the --root flag (see example
+ above). Paths are relative to the directory of the CPPLINT.cfg.
+
+ The "headers" option is similar in function to the --headers flag
+ (see example above).
+
+ CPPLINT.cfg has an effect on files in the same directory and all
+ sub-directories, unless overridden by a nested configuration file.
+
+ Example file:
+ filter=-build/include_order,+build/include_alpha
+ exclude_files=.*\.cc
+
+ The above example disables build/include_order warning and enables
+ build/include_alpha as well as excludes all .cc from being
+ processed by linter, in the current directory (where the .cfg
+ file is located) and all sub-directories.
+"""
+
+# We categorize each error message we print. Here are the categories.
+# We want an explicit list so we can list them all in cpplint --filter=.
+# If you add a new error message with a new category, add it to the list
+# here! cpplint_unittest.py should tell you if you forget to do this.
+_ERROR_CATEGORIES = [
+ 'build/class',
+ 'build/c++11',
+ 'build/c++14',
+ 'build/c++tr1',
+ 'build/deprecated',
+ 'build/endif_comment',
+ 'build/explicit_make_pair',
+ 'build/forward_decl',
+ 'build/header_guard',
+ 'build/include',
+ 'build/include_alpha',
+ 'build/include_order',
+ 'build/include_what_you_use',
+ 'build/namespaces',
+ 'build/printf_format',
+ 'build/storage_class',
+ 'legal/copyright',
+ 'readability/alt_tokens',
+ 'readability/braces',
+ 'readability/casting',
+ 'readability/check',
+ 'readability/constructors',
+ 'readability/fn_size',
+ 'readability/inheritance',
+ 'readability/multiline_comment',
+ 'readability/multiline_string',
+ 'readability/namespace',
+ 'readability/nolint',
+ 'readability/nul',
+ 'readability/strings',
+ 'readability/todo',
+ 'readability/utf8',
+ 'runtime/arrays',
+ 'runtime/casting',
+ 'runtime/explicit',
+ 'runtime/int',
+ 'runtime/init',
+ 'runtime/invalid_increment',
+ 'runtime/member_string_references',
+ 'runtime/memset',
+ 'runtime/indentation_namespace',
+ 'runtime/operator',
+ 'runtime/printf',
+ 'runtime/printf_format',
+ 'runtime/references',
+ 'runtime/string',
+ 'runtime/threadsafe_fn',
+ 'runtime/vlog',
+ 'whitespace/blank_line',
+ 'whitespace/braces',
+ 'whitespace/comma',
+ 'whitespace/comments',
+ 'whitespace/empty_conditional_body',
+ 'whitespace/empty_if_body',
+ 'whitespace/empty_loop_body',
+ 'whitespace/end_of_line',
+ 'whitespace/ending_newline',
+ 'whitespace/forcolon',
+ 'whitespace/indent',
+ 'whitespace/line_length',
+ 'whitespace/newline',
+ 'whitespace/operators',
+ 'whitespace/parens',
+ 'whitespace/semicolon',
+ 'whitespace/tab',
+ 'whitespace/todo',
+ ]
+
+# These error categories are no longer enforced by cpplint, but for backwards-
+# compatibility they may still appear in NOLINT comments.
+_LEGACY_ERROR_CATEGORIES = [
+ 'readability/streams',
+ 'readability/function',
+ ]
+
+# The default state of the category filter. This is overridden by the --filter=
+# flag. By default all errors are on, so only add here categories that should be
+# off by default (i.e., categories that must be enabled by the --filter= flags).
+# All entries here should start with a '-' or '+', as in the --filter= flag.
+_DEFAULT_FILTERS = ['-build/include_alpha']
+
+# The default list of categories suppressed for C (not C++) files.
+_DEFAULT_C_SUPPRESSED_CATEGORIES = [
+ 'readability/casting',
+ ]
+
+# The default list of categories suppressed for Linux Kernel files.
+_DEFAULT_KERNEL_SUPPRESSED_CATEGORIES = [
+ 'whitespace/tab',
+ ]
+
+# We used to check for high-bit characters, but after much discussion we
+# decided those were OK, as long as they were in UTF-8 and didn't represent
+# hard-coded international strings, which belong in a separate i18n file.
+
+# C++ headers
+_CPP_HEADERS = frozenset([
+ # Legacy
+ 'algobase.h',
+ 'algo.h',
+ 'alloc.h',
+ 'builtinbuf.h',
+ 'bvector.h',
+ 'complex.h',
+ 'defalloc.h',
+ 'deque.h',
+ 'editbuf.h',
+ 'fstream.h',
+ 'function.h',
+ 'hash_map',
+ 'hash_map.h',
+ 'hash_set',
+ 'hash_set.h',
+ 'hashtable.h',
+ 'heap.h',
+ 'indstream.h',
+ 'iomanip.h',
+ 'iostream.h',
+ 'istream.h',
+ 'iterator.h',
+ 'list.h',
+ 'map.h',
+ 'multimap.h',
+ 'multiset.h',
+ 'ostream.h',
+ 'pair.h',
+ 'parsestream.h',
+ 'pfstream.h',
+ 'procbuf.h',
+ 'pthread_alloc',
+ 'pthread_alloc.h',
+ 'rope',
+ 'rope.h',
+ 'ropeimpl.h',
+ 'set.h',
+ 'slist',
+ 'slist.h',
+ 'stack.h',
+ 'stdiostream.h',
+ 'stl_alloc.h',
+ 'stl_relops.h',
+ 'streambuf.h',
+ 'stream.h',
+ 'strfile.h',
+ 'strstream.h',
+ 'tempbuf.h',
+ 'tree.h',
+ 'type_traits.h',
+ 'vector.h',
+ # 17.6.1.2 C++ library headers
+ 'algorithm',
+ 'array',
+ 'atomic',
+ 'bitset',
+ 'chrono',
+ 'codecvt',
+ 'complex',
+ 'condition_variable',
+ 'deque',
+ 'exception',
+ 'forward_list',
+ 'fstream',
+ 'functional',
+ 'future',
+ 'initializer_list',
+ 'iomanip',
+ 'ios',
+ 'iosfwd',
+ 'iostream',
+ 'istream',
+ 'iterator',
+ 'limits',
+ 'list',
+ 'locale',
+ 'map',
+ 'memory',
+ 'mutex',
+ 'new',
+ 'numeric',
+ 'ostream',
+ 'queue',
+ 'random',
+ 'ratio',
+ 'regex',
+ 'scoped_allocator',
+ 'set',
+ 'sstream',
+ 'stack',
+ 'stdexcept',
+ 'streambuf',
+ 'string',
+ 'strstream',
+ 'system_error',
+ 'thread',
+ 'tuple',
+ 'typeindex',
+ 'typeinfo',
+ 'type_traits',
+ 'unordered_map',
+ 'unordered_set',
+ 'utility',
+ 'valarray',
+ 'vector',
+ # 17.6.1.2 C++ headers for C library facilities
+ 'cassert',
+ 'ccomplex',
+ 'cctype',
+ 'cerrno',
+ 'cfenv',
+ 'cfloat',
+ 'cinttypes',
+ 'ciso646',
+ 'climits',
+ 'clocale',
+ 'cmath',
+ 'csetjmp',
+ 'csignal',
+ 'cstdalign',
+ 'cstdarg',
+ 'cstdbool',
+ 'cstddef',
+ 'cstdint',
+ 'cstdio',
+ 'cstdlib',
+ 'cstring',
+ 'ctgmath',
+ 'ctime',
+ 'cuchar',
+ 'cwchar',
+ 'cwctype',
+ ])
+
+# Type names
+_TYPES = re.compile(
+ r'^(?:'
+ # [dcl.type.simple]
+ r'(char(16_t|32_t)?)|wchar_t|'
+ r'bool|short|int|long|signed|unsigned|float|double|'
+ # [support.types]
+ r'(ptrdiff_t|size_t|max_align_t|nullptr_t)|'
+ # [cstdint.syn]
+ r'(u?int(_fast|_least)?(8|16|32|64)_t)|'
+ r'(u?int(max|ptr)_t)|'
+ r')$')
+
+
+# These headers are excluded from [build/include] and [build/include_order]
+# checks:
+# - Anything not following google file name conventions (containing an
+# uppercase character, such as Python.h or nsStringAPI.h, for example).
+# - Lua headers.
+_THIRD_PARTY_HEADERS_PATTERN = re.compile(
+ r'^(?:[^/]*[A-Z][^/]*\.h|lua\.h|lauxlib\.h|lualib\.h)$')
+
+# Pattern for matching FileInfo.BaseName() against test file name
+_TEST_FILE_SUFFIX = r'(_test|_unittest|_regtest)$'
+
+# Pattern that matches only complete whitespace, possibly across multiple lines.
+_EMPTY_CONDITIONAL_BODY_PATTERN = re.compile(r'^\s*$', re.DOTALL)
+
+# Assertion macros. These are defined in base/logging.h and
+# testing/base/public/gunit.h.
+_CHECK_MACROS = [
+ 'DCHECK', 'CHECK',
+ 'EXPECT_TRUE', 'ASSERT_TRUE',
+ 'EXPECT_FALSE', 'ASSERT_FALSE',
+ ]
+
+# Replacement macros for CHECK/DCHECK/EXPECT_TRUE/EXPECT_FALSE
+_CHECK_REPLACEMENT = dict([(m, {}) for m in _CHECK_MACROS])
+
+for op, replacement in [('==', 'EQ'), ('!=', 'NE'),
+ ('>=', 'GE'), ('>', 'GT'),
+ ('<=', 'LE'), ('<', 'LT')]:
+ _CHECK_REPLACEMENT['DCHECK'][op] = 'DCHECK_%s' % replacement
+ _CHECK_REPLACEMENT['CHECK'][op] = 'CHECK_%s' % replacement
+ _CHECK_REPLACEMENT['EXPECT_TRUE'][op] = 'EXPECT_%s' % replacement
+ _CHECK_REPLACEMENT['ASSERT_TRUE'][op] = 'ASSERT_%s' % replacement
+
+for op, inv_replacement in [('==', 'NE'), ('!=', 'EQ'),
+ ('>=', 'LT'), ('>', 'LE'),
+ ('<=', 'GT'), ('<', 'GE')]:
+ _CHECK_REPLACEMENT['EXPECT_FALSE'][op] = 'EXPECT_%s' % inv_replacement
+ _CHECK_REPLACEMENT['ASSERT_FALSE'][op] = 'ASSERT_%s' % inv_replacement
+
+# Alternative tokens and their replacements. For full list, see section 2.5
+# Alternative tokens [lex.digraph] in the C++ standard.
+#
+# Digraphs (such as '%:') are not included here since it's a mess to
+# match those on a word boundary.
+_ALT_TOKEN_REPLACEMENT = {
+ 'and': '&&',
+ 'bitor': '|',
+ 'or': '||',
+ 'xor': '^',
+ 'compl': '~',
+ 'bitand': '&',
+ 'and_eq': '&=',
+ 'or_eq': '|=',
+ 'xor_eq': '^=',
+ 'not': '!',
+ 'not_eq': '!='
+ }
+
+# Compile regular expression that matches all the above keywords. The "[ =()]"
+# bit is meant to avoid matching these keywords outside of boolean expressions.
+#
+# False positives include C-style multi-line comments and multi-line strings
+# but those have always been troublesome for cpplint.
+_ALT_TOKEN_REPLACEMENT_PATTERN = re.compile(
+ r'[ =()](' + ('|'.join(_ALT_TOKEN_REPLACEMENT.keys())) + r')(?=[ (]|$)')
+
+
+# These constants define types of headers for use with
+# _IncludeState.CheckNextIncludeOrder().
+_C_SYS_HEADER = 1
+_CPP_SYS_HEADER = 2
+_LIKELY_MY_HEADER = 3
+_POSSIBLE_MY_HEADER = 4
+_OTHER_HEADER = 5
+
+# These constants define the current inline assembly state
+_NO_ASM = 0 # Outside of inline assembly block
+_INSIDE_ASM = 1 # Inside inline assembly block
+_END_ASM = 2 # Last line of inline assembly block
+_BLOCK_ASM = 3 # The whole block is an inline assembly block
+
+# Match start of assembly blocks
+_MATCH_ASM = re.compile(r'^\s*(?:asm|_asm|__asm|__asm__)'
+ r'(?:\s+(volatile|__volatile__))?'
+ r'\s*[{(]')
+
+# Match strings that indicate we're working on a C (not C++) file.
+_SEARCH_C_FILE = re.compile(r'\b(?:LINT_C_FILE|'
+ r'vim?:\s*.*(\s*|:)filetype=c(\s*|:|$))')
+
+# Match string that indicates we're working on a Linux Kernel file.
+_SEARCH_KERNEL_FILE = re.compile(r'\b(?:LINT_KERNEL_FILE)')
+
+_regexp_compile_cache = {}
+
+# {str, set(int)}: a map from error categories to sets of linenumbers
+# on which those errors are expected and should be suppressed.
+_error_suppressions = {}
+
+# The root directory used for deriving header guard CPP variable.
+# This is set by --root flag.
+_root = None
+_root_debug = False
+
+# The allowed line length of files.
+# This is set by --linelength flag.
+_line_length = 80
+
+# The allowed extensions for file names
+# This is set by --extensions flag.
+_valid_extensions = set(['cc', 'h', 'cpp', 'cu', 'cuh'])
+
+# Treat all headers starting with 'h' equally: .h, .hpp, .hxx etc.
+# This is set by --headers flag.
+_hpp_headers = set(['h'])
+
+# {str, bool}: a map from error categories to booleans which indicate if the
+# category should be suppressed for every line.
+_global_error_suppressions = {}
+
+def ProcessHppHeadersOption(val):
+ global _hpp_headers
+ try:
+ _hpp_headers = set(val.split(','))
+ # Automatically append to extensions list so it does not have to be set 2 times
+ _valid_extensions.update(_hpp_headers)
+ except ValueError:
+ PrintUsage('Header extensions must be comma separated list.')
+
+def IsHeaderExtension(file_extension):
+ return file_extension in _hpp_headers
+
+def ParseNolintSuppressions(filename, raw_line, linenum, error):
+ """Updates the global list of line error-suppressions.
+
+ Parses any NOLINT comments on the current line, updating the global
+ error_suppressions store. Reports an error if the NOLINT comment
+ was malformed.
+
+ Args:
+ filename: str, the name of the input file.
+ raw_line: str, the line of input text, with comments.
+ linenum: int, the number of the current line.
+ error: function, an error handler.
+ """
+ matched = Search(r'\bNOLINT(NEXTLINE)?\b(\([^)]+\))?', raw_line)
+ if matched:
+ if matched.group(1):
+ suppressed_line = linenum + 1
+ else:
+ suppressed_line = linenum
+ category = matched.group(2)
+ if category in (None, '(*)'): # => "suppress all"
+ _error_suppressions.setdefault(None, set()).add(suppressed_line)
+ else:
+ if category.startswith('(') and category.endswith(')'):
+ category = category[1:-1]
+ if category in _ERROR_CATEGORIES:
+ _error_suppressions.setdefault(category, set()).add(suppressed_line)
+ elif category not in _LEGACY_ERROR_CATEGORIES:
+ error(filename, linenum, 'readability/nolint', 5,
+ 'Unknown NOLINT error category: %s' % category)
+
+
+def ProcessGlobalSuppresions(lines):
+ """Updates the list of global error suppressions.
+
+ Parses any lint directives in the file that have global effect.
+
+ Args:
+ lines: An array of strings, each representing a line of the file, with the
+ last element being empty if the file is terminated with a newline.
+ """
+ for line in lines:
+ if _SEARCH_C_FILE.search(line):
+ for category in _DEFAULT_C_SUPPRESSED_CATEGORIES:
+ _global_error_suppressions[category] = True
+ if _SEARCH_KERNEL_FILE.search(line):
+ for category in _DEFAULT_KERNEL_SUPPRESSED_CATEGORIES:
+ _global_error_suppressions[category] = True
+
+
+def ResetNolintSuppressions():
+ """Resets the set of NOLINT suppressions to empty."""
+ _error_suppressions.clear()
+ _global_error_suppressions.clear()
+
+
+def IsErrorSuppressedByNolint(category, linenum):
+ """Returns true if the specified error category is suppressed on this line.
+
+ Consults the global error_suppressions map populated by
+ ParseNolintSuppressions/ProcessGlobalSuppresions/ResetNolintSuppressions.
+
+ Args:
+ category: str, the category of the error.
+ linenum: int, the current line number.
+ Returns:
+ bool, True iff the error should be suppressed due to a NOLINT comment or
+ global suppression.
+ """
+ return (_global_error_suppressions.get(category, False) or
+ linenum in _error_suppressions.get(category, set()) or
+ linenum in _error_suppressions.get(None, set()))
+
+
+def Match(pattern, s):
+ """Matches the string with the pattern, caching the compiled regexp."""
+ # The regexp compilation caching is inlined in both Match and Search for
+ # performance reasons; factoring it out into a separate function turns out
+ # to be noticeably expensive.
+ if pattern not in _regexp_compile_cache:
+ _regexp_compile_cache[pattern] = sre_compile.compile(pattern)
+ return _regexp_compile_cache[pattern].match(s)
+
+
+def ReplaceAll(pattern, rep, s):
+ """Replaces instances of pattern in a string with a replacement.
+
+ The compiled regex is kept in a cache shared by Match and Search.
+
+ Args:
+ pattern: regex pattern
+ rep: replacement text
+ s: search string
+
+ Returns:
+ string with replacements made (or original string if no replacements)
+ """
+ if pattern not in _regexp_compile_cache:
+ _regexp_compile_cache[pattern] = sre_compile.compile(pattern)
+ return _regexp_compile_cache[pattern].sub(rep, s)
+
+
+def Search(pattern, s):
+ """Searches the string for the pattern, caching the compiled regexp."""
+ if pattern not in _regexp_compile_cache:
+ _regexp_compile_cache[pattern] = sre_compile.compile(pattern)
+ return _regexp_compile_cache[pattern].search(s)
+
+
+def _IsSourceExtension(s):
+ """File extension (excluding dot) matches a source file extension."""
+ return s in ('c', 'cc', 'cpp', 'cxx')
+
+
+class _IncludeState(object):
+ """Tracks line numbers for includes, and the order in which includes appear.
+
+ include_list contains list of lists of (header, line number) pairs.
+ It's a lists of lists rather than just one flat list to make it
+ easier to update across preprocessor boundaries.
+
+ Call CheckNextIncludeOrder() once for each header in the file, passing
+ in the type constants defined above. Calls in an illegal order will
+ raise an _IncludeError with an appropriate error message.
+
+ """
+ # self._section will move monotonically through this set. If it ever
+ # needs to move backwards, CheckNextIncludeOrder will raise an error.
+ _INITIAL_SECTION = 0
+ _MY_H_SECTION = 1
+ _C_SECTION = 2
+ _CPP_SECTION = 3
+ _OTHER_H_SECTION = 4
+
+ _TYPE_NAMES = {
+ _C_SYS_HEADER: 'C system header',
+ _CPP_SYS_HEADER: 'C++ system header',
+ _LIKELY_MY_HEADER: 'header this file implements',
+ _POSSIBLE_MY_HEADER: 'header this file may implement',
+ _OTHER_HEADER: 'other header',
+ }
+ _SECTION_NAMES = {
+ _INITIAL_SECTION: "... nothing. (This can't be an error.)",
+ _MY_H_SECTION: 'a header this file implements',
+ _C_SECTION: 'C system header',
+ _CPP_SECTION: 'C++ system header',
+ _OTHER_H_SECTION: 'other header',
+ }
+
+ def __init__(self):
+ self.include_list = [[]]
+ self.ResetSection('')
+
+ def FindHeader(self, header):
+ """Check if a header has already been included.
+
+ Args:
+ header: header to check.
+ Returns:
+ Line number of previous occurrence, or -1 if the header has not
+ been seen before.
+ """
+ for section_list in self.include_list:
+ for f in section_list:
+ if f[0] == header:
+ return f[1]
+ return -1
+
+ def ResetSection(self, directive):
+ """Reset section checking for preprocessor directive.
+
+ Args:
+ directive: preprocessor directive (e.g. "if", "else").
+ """
+ # The name of the current section.
+ self._section = self._INITIAL_SECTION
+ # The path of last found header.
+ self._last_header = ''
+
+ # Update list of includes. Note that we never pop from the
+ # include list.
+ if directive in ('if', 'ifdef', 'ifndef'):
+ self.include_list.append([])
+ elif directive in ('else', 'elif'):
+ self.include_list[-1] = []
+
+ def SetLastHeader(self, header_path):
+ self._last_header = header_path
+
+ def CanonicalizeAlphabeticalOrder(self, header_path):
+ """Returns a path canonicalized for alphabetical comparison.
+
+ - replaces "-" with "_" so they both cmp the same.
+ - removes '-inl' since we don't require them to be after the main header.
+ - lowercase everything, just in case.
+
+ Args:
+ header_path: Path to be canonicalized.
+
+ Returns:
+ Canonicalized path.
+ """
+ return header_path.replace('-inl.h', '.h').replace('-', '_').lower()
+
+ def IsInAlphabeticalOrder(self, clean_lines, linenum, header_path):
+ """Check if a header is in alphabetical order with the previous header.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ header_path: Canonicalized header to be checked.
+
+ Returns:
+ Returns true if the header is in alphabetical order.
+ """
+ # If previous section is different from current section, _last_header will
+ # be reset to empty string, so it's always less than current header.
+ #
+ # If previous line was a blank line, assume that the headers are
+ # intentionally sorted the way they are.
+ if (self._last_header > header_path and
+ Match(r'^\s*#\s*include\b', clean_lines.elided[linenum - 1])):
+ return False
+ return True
+
+ def CheckNextIncludeOrder(self, header_type):
+ """Returns a non-empty error message if the next header is out of order.
+
+ This function also updates the internal state to be ready to check
+ the next include.
+
+ Args:
+ header_type: One of the _XXX_HEADER constants defined above.
+
+ Returns:
+ The empty string if the header is in the right order, or an
+ error message describing what's wrong.
+
+ """
+ error_message = ('Found %s after %s' %
+ (self._TYPE_NAMES[header_type],
+ self._SECTION_NAMES[self._section]))
+
+ last_section = self._section
+
+ if header_type == _C_SYS_HEADER:
+ if self._section <= self._C_SECTION:
+ self._section = self._C_SECTION
+ else:
+ self._last_header = ''
+ return error_message
+ elif header_type == _CPP_SYS_HEADER:
+ if self._section <= self._CPP_SECTION:
+ self._section = self._CPP_SECTION
+ else:
+ self._last_header = ''
+ return error_message
+ elif header_type == _LIKELY_MY_HEADER:
+ if self._section <= self._MY_H_SECTION:
+ self._section = self._MY_H_SECTION
+ else:
+ self._section = self._OTHER_H_SECTION
+ elif header_type == _POSSIBLE_MY_HEADER:
+ if self._section <= self._MY_H_SECTION:
+ self._section = self._MY_H_SECTION
+ else:
+ # This will always be the fallback because we're not sure
+ # enough that the header is associated with this file.
+ self._section = self._OTHER_H_SECTION
+ else:
+ assert header_type == _OTHER_HEADER
+ self._section = self._OTHER_H_SECTION
+
+ if last_section != self._section:
+ self._last_header = ''
+
+ return ''
+
+
+class _CppLintState(object):
+ """Maintains module-wide state.."""
+
+ def __init__(self):
+ self.verbose_level = 1 # global setting.
+ self.error_count = 0 # global count of reported errors
+ # filters to apply when emitting error messages
+ self.filters = _DEFAULT_FILTERS[:]
+ # backup of filter list. Used to restore the state after each file.
+ self._filters_backup = self.filters[:]
+ self.counting = 'total' # In what way are we counting errors?
+ self.errors_by_category = {} # string to int dict storing error counts
+ self.quiet = False # Suppress non-error messagess?
+
+ # output format:
+ # "emacs" - format that emacs can parse (default)
+ # "vs7" - format that Microsoft Visual Studio 7 can parse
+ self.output_format = 'emacs'
+
+ def SetOutputFormat(self, output_format):
+ """Sets the output format for errors."""
+ self.output_format = output_format
+
+ def SetQuiet(self, quiet):
+ """Sets the module's quiet settings, and returns the previous setting."""
+ last_quiet = self.quiet
+ self.quiet = quiet
+ return last_quiet
+
+ def SetVerboseLevel(self, level):
+ """Sets the module's verbosity, and returns the previous setting."""
+ last_verbose_level = self.verbose_level
+ self.verbose_level = level
+ return last_verbose_level
+
+ def SetCountingStyle(self, counting_style):
+ """Sets the module's counting options."""
+ self.counting = counting_style
+
+ def SetFilters(self, filters):
+ """Sets the error-message filters.
+
+ These filters are applied when deciding whether to emit a given
+ error message.
+
+ Args:
+ filters: A string of comma-separated filters (eg "+whitespace/indent").
+ Each filter should start with + or -; else we die.
+
+ Raises:
+ ValueError: The comma-separated filters did not all start with '+' or '-'.
+ E.g. "-,+whitespace,-whitespace/indent,whitespace/badfilter"
+ """
+ # Default filters always have less priority than the flag ones.
+ self.filters = _DEFAULT_FILTERS[:]
+ self.AddFilters(filters)
+
+ def AddFilters(self, filters):
+ """ Adds more filters to the existing list of error-message filters. """
+ for filt in filters.split(','):
+ clean_filt = filt.strip()
+ if clean_filt:
+ self.filters.append(clean_filt)
+ for filt in self.filters:
+ if not (filt.startswith('+') or filt.startswith('-')):
+ raise ValueError('Every filter in --filters must start with + or -'
+ ' (%s does not)' % filt)
+
+ def BackupFilters(self):
+ """ Saves the current filter list to backup storage."""
+ self._filters_backup = self.filters[:]
+
+ def RestoreFilters(self):
+ """ Restores filters previously backed up."""
+ self.filters = self._filters_backup[:]
+
+ def ResetErrorCounts(self):
+ """Sets the module's error statistic back to zero."""
+ self.error_count = 0
+ self.errors_by_category = {}
+
+ def IncrementErrorCount(self, category):
+ """Bumps the module's error statistic."""
+ self.error_count += 1
+ if self.counting in ('toplevel', 'detailed'):
+ if self.counting != 'detailed':
+ category = category.split('/')[0]
+ if category not in self.errors_by_category:
+ self.errors_by_category[category] = 0
+ self.errors_by_category[category] += 1
+
+ def PrintErrorCounts(self):
+ """Print a summary of errors by category, and the total."""
+ for category, count in self.errors_by_category.iteritems():
+ sys.stderr.write('Category \'%s\' errors found: %d\n' %
+ (category, count))
+ sys.stdout.write('Total errors found: %d\n' % self.error_count)
+
+_cpplint_state = _CppLintState()
+
+
+def _OutputFormat():
+ """Gets the module's output format."""
+ return _cpplint_state.output_format
+
+
+def _SetOutputFormat(output_format):
+ """Sets the module's output format."""
+ _cpplint_state.SetOutputFormat(output_format)
+
+def _Quiet():
+ """Return's the module's quiet setting."""
+ return _cpplint_state.quiet
+
+def _SetQuiet(quiet):
+ """Set the module's quiet status, and return previous setting."""
+ return _cpplint_state.SetQuiet(quiet)
+
+
+def _VerboseLevel():
+ """Returns the module's verbosity setting."""
+ return _cpplint_state.verbose_level
+
+
+def _SetVerboseLevel(level):
+ """Sets the module's verbosity, and returns the previous setting."""
+ return _cpplint_state.SetVerboseLevel(level)
+
+
+def _SetCountingStyle(level):
+ """Sets the module's counting options."""
+ _cpplint_state.SetCountingStyle(level)
+
+
+def _Filters():
+ """Returns the module's list of output filters, as a list."""
+ return _cpplint_state.filters
+
+
+def _SetFilters(filters):
+ """Sets the module's error-message filters.
+
+ These filters are applied when deciding whether to emit a given
+ error message.
+
+ Args:
+ filters: A string of comma-separated filters (eg "whitespace/indent").
+ Each filter should start with + or -; else we die.
+ """
+ _cpplint_state.SetFilters(filters)
+
+def _AddFilters(filters):
+ """Adds more filter overrides.
+
+ Unlike _SetFilters, this function does not reset the current list of filters
+ available.
+
+ Args:
+ filters: A string of comma-separated filters (eg "whitespace/indent").
+ Each filter should start with + or -; else we die.
+ """
+ _cpplint_state.AddFilters(filters)
+
+def _BackupFilters():
+ """ Saves the current filter list to backup storage."""
+ _cpplint_state.BackupFilters()
+
+def _RestoreFilters():
+ """ Restores filters previously backed up."""
+ _cpplint_state.RestoreFilters()
+
+class _FunctionState(object):
+ """Tracks current function name and the number of lines in its body."""
+
+ _NORMAL_TRIGGER = 250 # for --v=0, 500 for --v=1, etc.
+ _TEST_TRIGGER = 400 # about 50% more than _NORMAL_TRIGGER.
+
+ def __init__(self):
+ self.in_a_function = False
+ self.lines_in_function = 0
+ self.current_function = ''
+
+ def Begin(self, function_name):
+ """Start analyzing function body.
+
+ Args:
+ function_name: The name of the function being tracked.
+ """
+ self.in_a_function = True
+ self.lines_in_function = 0
+ self.current_function = function_name
+
+ def Count(self):
+ """Count line in current function body."""
+ if self.in_a_function:
+ self.lines_in_function += 1
+
+ def Check(self, error, filename, linenum):
+ """Report if too many lines in function body.
+
+ Args:
+ error: The function to call with any errors found.
+ filename: The name of the current file.
+ linenum: The number of the line to check.
+ """
+ if not self.in_a_function:
+ return
+
+ if Match(r'T(EST|est)', self.current_function):
+ base_trigger = self._TEST_TRIGGER
+ else:
+ base_trigger = self._NORMAL_TRIGGER
+ trigger = base_trigger * 2**_VerboseLevel()
+
+ if self.lines_in_function > trigger:
+ error_level = int(math.log(self.lines_in_function / base_trigger, 2))
+ # 50 => 0, 100 => 1, 200 => 2, 400 => 3, 800 => 4, 1600 => 5, ...
+ if error_level > 5:
+ error_level = 5
+ error(filename, linenum, 'readability/fn_size', error_level,
+ 'Small and focused functions are preferred:'
+ ' %s has %d non-comment lines'
+ ' (error triggered by exceeding %d lines).' % (
+ self.current_function, self.lines_in_function, trigger))
+
+ def End(self):
+ """Stop analyzing function body."""
+ self.in_a_function = False
+
+
+class _IncludeError(Exception):
+ """Indicates a problem with the include order in a file."""
+ pass
+
+
+class FileInfo(object):
+ """Provides utility functions for filenames.
+
+ FileInfo provides easy access to the components of a file's path
+ relative to the project root.
+ """
+
+ def __init__(self, filename):
+ self._filename = filename
+
+ def FullName(self):
+ """Make Windows paths like Unix."""
+ return os.path.abspath(self._filename).replace('\\', '/')
+
+ def RepositoryName(self):
+ """FullName after removing the local path to the repository.
+
+ If we have a real absolute path name here we can try to do something smart:
+ detecting the root of the checkout and truncating /path/to/checkout from
+ the name so that we get header guards that don't include things like
+ "C:\Documents and Settings\..." or "/home/username/..." in them and thus
+ people on different computers who have checked the source out to different
+ locations won't see bogus errors.
+ """
+ fullname = self.FullName()
+
+ if os.path.exists(fullname):
+ project_dir = os.path.dirname(fullname)
+
+ if os.path.exists(os.path.join(project_dir, ".svn")):
+ # If there's a .svn file in the current directory, we recursively look
+ # up the directory tree for the top of the SVN checkout
+ root_dir = project_dir
+ one_up_dir = os.path.dirname(root_dir)
+ while os.path.exists(os.path.join(one_up_dir, ".svn")):
+ root_dir = os.path.dirname(root_dir)
+ one_up_dir = os.path.dirname(one_up_dir)
+
+ prefix = os.path.commonprefix([root_dir, project_dir])
+ return fullname[len(prefix) + 1:]
+
+ # Not SVN <= 1.6? Try to find a git, hg, or svn top level directory by
+ # searching up from the current path.
+ root_dir = current_dir = os.path.dirname(fullname)
+ while current_dir != os.path.dirname(current_dir):
+ if (os.path.exists(os.path.join(current_dir, ".git")) or
+ os.path.exists(os.path.join(current_dir, ".hg")) or
+ os.path.exists(os.path.join(current_dir, ".svn"))):
+ root_dir = current_dir
+ current_dir = os.path.dirname(current_dir)
+
+ if (os.path.exists(os.path.join(root_dir, ".git")) or
+ os.path.exists(os.path.join(root_dir, ".hg")) or
+ os.path.exists(os.path.join(root_dir, ".svn"))):
+ prefix = os.path.commonprefix([root_dir, project_dir])
+ return fullname[len(prefix) + 1:]
+
+ # Don't know what to do; header guard warnings may be wrong...
+ return fullname
+
+ def Split(self):
+ """Splits the file into the directory, basename, and extension.
+
+ For 'chrome/browser/browser.cc', Split() would
+ return ('chrome/browser', 'browser', '.cc')
+
+ Returns:
+ A tuple of (directory, basename, extension).
+ """
+
+ googlename = self.RepositoryName()
+ project, rest = os.path.split(googlename)
+ return (project,) + os.path.splitext(rest)
+
+ def BaseName(self):
+ """File base name - text after the final slash, before the final period."""
+ return self.Split()[1]
+
+ def Extension(self):
+ """File extension - text following the final period."""
+ return self.Split()[2]
+
+ def NoExtension(self):
+ """File has no source file extension."""
+ return '/'.join(self.Split()[0:2])
+
+ def IsSource(self):
+ """File has a source file extension."""
+ return _IsSourceExtension(self.Extension()[1:])
+
+
+def _ShouldPrintError(category, confidence, linenum):
+ """If confidence >= verbose, category passes filter and is not suppressed."""
+
+ # There are three ways we might decide not to print an error message:
+ # a "NOLINT(category)" comment appears in the source,
+ # the verbosity level isn't high enough, or the filters filter it out.
+ if IsErrorSuppressedByNolint(category, linenum):
+ return False
+
+ if confidence < _cpplint_state.verbose_level:
+ return False
+
+ is_filtered = False
+ for one_filter in _Filters():
+ if one_filter.startswith('-'):
+ if category.startswith(one_filter[1:]):
+ is_filtered = True
+ elif one_filter.startswith('+'):
+ if category.startswith(one_filter[1:]):
+ is_filtered = False
+ else:
+ assert False # should have been checked for in SetFilter.
+ if is_filtered:
+ return False
+
+ return True
+
+
+def Error(filename, linenum, category, confidence, message):
+ """Logs the fact we've found a lint error.
+
+ We log where the error was found, and also our confidence in the error,
+ that is, how certain we are this is a legitimate style regression, and
+ not a misidentification or a use that's sometimes justified.
+
+ False positives can be suppressed by the use of
+ "cpplint(category)" comments on the offending line. These are
+ parsed into _error_suppressions.
+
+ Args:
+ filename: The name of the file containing the error.
+ linenum: The number of the line containing the error.
+ category: A string used to describe the "category" this bug
+ falls under: "whitespace", say, or "runtime". Categories
+ may have a hierarchy separated by slashes: "whitespace/indent".
+ confidence: A number from 1-5 representing a confidence score for
+ the error, with 5 meaning that we are certain of the problem,
+ and 1 meaning that it could be a legitimate construct.
+ message: The error message.
+ """
+ if _ShouldPrintError(category, confidence, linenum):
+ _cpplint_state.IncrementErrorCount(category)
+ if _cpplint_state.output_format == 'vs7':
+ sys.stderr.write('%s(%s): error cpplint: [%s] %s [%d]\n' % (
+ filename, linenum, category, message, confidence))
+ elif _cpplint_state.output_format == 'eclipse':
+ sys.stderr.write('%s:%s: warning: %s [%s] [%d]\n' % (
+ filename, linenum, message, category, confidence))
+ else:
+ sys.stderr.write('%s:%s: %s [%s] [%d]\n' % (
+ filename, linenum, message, category, confidence))
+
+
+# Matches standard C++ escape sequences per 2.13.2.3 of the C++ standard.
+_RE_PATTERN_CLEANSE_LINE_ESCAPES = re.compile(
+ r'\\([abfnrtv?"\\\']|\d+|x[0-9a-fA-F]+)')
+# Match a single C style comment on the same line.
+_RE_PATTERN_C_COMMENTS = r'/\*(?:[^*]|\*(?!/))*\*/'
+# Matches multi-line C style comments.
+# This RE is a little bit more complicated than one might expect, because we
+# have to take care of space removals tools so we can handle comments inside
+# statements better.
+# The current rule is: We only clear spaces from both sides when we're at the
+# end of the line. Otherwise, we try to remove spaces from the right side,
+# if this doesn't work we try on left side but only if there's a non-character
+# on the right.
+_RE_PATTERN_CLEANSE_LINE_C_COMMENTS = re.compile(
+ r'(\s*' + _RE_PATTERN_C_COMMENTS + r'\s*$|' +
+ _RE_PATTERN_C_COMMENTS + r'\s+|' +
+ r'\s+' + _RE_PATTERN_C_COMMENTS + r'(?=\W)|' +
+ _RE_PATTERN_C_COMMENTS + r')')
+
+
+def IsCppString(line):
+ """Does line terminate so, that the next symbol is in string constant.
+
+ This function does not consider single-line nor multi-line comments.
+
+ Args:
+ line: is a partial line of code starting from the 0..n.
+
+ Returns:
+ True, if next character appended to 'line' is inside a
+ string constant.
+ """
+
+ line = line.replace(r'\\', 'XX') # after this, \\" does not match to \"
+ return ((line.count('"') - line.count(r'\"') - line.count("'\"'")) & 1) == 1
+
+
+def CleanseRawStrings(raw_lines):
+ """Removes C++11 raw strings from lines.
+
+ Before:
+ static const char kData[] = R"(
+ multi-line string
+ )";
+
+ After:
+ static const char kData[] = ""
+ (replaced by blank line)
+ "";
+
+ Args:
+ raw_lines: list of raw lines.
+
+ Returns:
+ list of lines with C++11 raw strings replaced by empty strings.
+ """
+
+ delimiter = None
+ lines_without_raw_strings = []
+ for line in raw_lines:
+ if delimiter:
+ # Inside a raw string, look for the end
+ end = line.find(delimiter)
+ if end >= 0:
+ # Found the end of the string, match leading space for this
+ # line and resume copying the original lines, and also insert
+ # a "" on the last line.
+ leading_space = Match(r'^(\s*)\S', line)
+ line = leading_space.group(1) + '""' + line[end + len(delimiter):]
+ delimiter = None
+ else:
+ # Haven't found the end yet, append a blank line.
+ line = '""'
+
+ # Look for beginning of a raw string, and replace them with
+ # empty strings. This is done in a loop to handle multiple raw
+ # strings on the same line.
+ while delimiter is None:
+ # Look for beginning of a raw string.
+ # See 2.14.15 [lex.string] for syntax.
+ #
+ # Once we have matched a raw string, we check the prefix of the
+ # line to make sure that the line is not part of a single line
+ # comment. It's done this way because we remove raw strings
+ # before removing comments as opposed to removing comments
+ # before removing raw strings. This is because there are some
+ # cpplint checks that requires the comments to be preserved, but
+ # we don't want to check comments that are inside raw strings.
+ matched = Match(r'^(.*?)\b(?:R|u8R|uR|UR|LR)"([^\s\\()]*)\((.*)$', line)
+ if (matched and
+ not Match(r'^([^\'"]|\'(\\.|[^\'])*\'|"(\\.|[^"])*")*//',
+ matched.group(1))):
+ delimiter = ')' + matched.group(2) + '"'
+
+ end = matched.group(3).find(delimiter)
+ if end >= 0:
+ # Raw string ended on same line
+ line = (matched.group(1) + '""' +
+ matched.group(3)[end + len(delimiter):])
+ delimiter = None
+ else:
+ # Start of a multi-line raw string
+ line = matched.group(1) + '""'
+ else:
+ break
+
+ lines_without_raw_strings.append(line)
+
+ # TODO(unknown): if delimiter is not None here, we might want to
+ # emit a warning for unterminated string.
+ return lines_without_raw_strings
+
+
+def FindNextMultiLineCommentStart(lines, lineix):
+ """Find the beginning marker for a multiline comment."""
+ while lineix < len(lines):
+ if lines[lineix].strip().startswith('/*'):
+ # Only return this marker if the comment goes beyond this line
+ if lines[lineix].strip().find('*/', 2) < 0:
+ return lineix
+ lineix += 1
+ return len(lines)
+
+
+def FindNextMultiLineCommentEnd(lines, lineix):
+ """We are inside a comment, find the end marker."""
+ while lineix < len(lines):
+ if lines[lineix].strip().endswith('*/'):
+ return lineix
+ lineix += 1
+ return len(lines)
+
+
+def RemoveMultiLineCommentsFromRange(lines, begin, end):
+ """Clears a range of lines for multi-line comments."""
+ # Having // <empty> comments makes the lines non-empty, so we will not get
+ # unnecessary blank line warnings later in the code.
+ for i in range(begin, end):
+ lines[i] = '/**/'
+
+
+def RemoveMultiLineComments(filename, lines, error):
+ """Removes multiline (c-style) comments from lines."""
+ lineix = 0
+ while lineix < len(lines):
+ lineix_begin = FindNextMultiLineCommentStart(lines, lineix)
+ if lineix_begin >= len(lines):
+ return
+ lineix_end = FindNextMultiLineCommentEnd(lines, lineix_begin)
+ if lineix_end >= len(lines):
+ error(filename, lineix_begin + 1, 'readability/multiline_comment', 5,
+ 'Could not find end of multi-line comment')
+ return
+ RemoveMultiLineCommentsFromRange(lines, lineix_begin, lineix_end + 1)
+ lineix = lineix_end + 1
+
+
+def CleanseComments(line):
+ """Removes //-comments and single-line C-style /* */ comments.
+
+ Args:
+ line: A line of C++ source.
+
+ Returns:
+ The line with single-line comments removed.
+ """
+ commentpos = line.find('//')
+ if commentpos != -1 and not IsCppString(line[:commentpos]):
+ line = line[:commentpos].rstrip()
+ # get rid of /* ... */
+ return _RE_PATTERN_CLEANSE_LINE_C_COMMENTS.sub('', line)
+
+
+class CleansedLines(object):
+ """Holds 4 copies of all lines with different preprocessing applied to them.
+
+ 1) elided member contains lines without strings and comments.
+ 2) lines member contains lines without comments.
+ 3) raw_lines member contains all the lines without processing.
+ 4) lines_without_raw_strings member is same as raw_lines, but with C++11 raw
+ strings removed.
+ All these members are of <type 'list'>, and of the same length.
+ """
+
+ def __init__(self, lines):
+ self.elided = []
+ self.lines = []
+ self.raw_lines = lines
+ self.num_lines = len(lines)
+ self.lines_without_raw_strings = CleanseRawStrings(lines)
+ for linenum in range(len(self.lines_without_raw_strings)):
+ self.lines.append(CleanseComments(
+ self.lines_without_raw_strings[linenum]))
+ elided = self._CollapseStrings(self.lines_without_raw_strings[linenum])
+ self.elided.append(CleanseComments(elided))
+
+ def NumLines(self):
+ """Returns the number of lines represented."""
+ return self.num_lines
+
+ @staticmethod
+ def _CollapseStrings(elided):
+ """Collapses strings and chars on a line to simple "" or '' blocks.
+
+ We nix strings first so we're not fooled by text like '"http://"'
+
+ Args:
+ elided: The line being processed.
+
+ Returns:
+ The line with collapsed strings.
+ """
+ if _RE_PATTERN_INCLUDE.match(elided):
+ return elided
+
+ # Remove escaped characters first to make quote/single quote collapsing
+ # basic. Things that look like escaped characters shouldn't occur
+ # outside of strings and chars.
+ elided = _RE_PATTERN_CLEANSE_LINE_ESCAPES.sub('', elided)
+
+ # Replace quoted strings and digit separators. Both single quotes
+ # and double quotes are processed in the same loop, otherwise
+ # nested quotes wouldn't work.
+ collapsed = ''
+ while True:
+ # Find the first quote character
+ match = Match(r'^([^\'"]*)([\'"])(.*)$', elided)
+ if not match:
+ collapsed += elided
+ break
+ head, quote, tail = match.groups()
+
+ if quote == '"':
+ # Collapse double quoted strings
+ second_quote = tail.find('"')
+ if second_quote >= 0:
+ collapsed += head + '""'
+ elided = tail[second_quote + 1:]
+ else:
+ # Unmatched double quote, don't bother processing the rest
+ # of the line since this is probably a multiline string.
+ collapsed += elided
+ break
+ else:
+ # Found single quote, check nearby text to eliminate digit separators.
+ #
+ # There is no special handling for floating point here, because
+ # the integer/fractional/exponent parts would all be parsed
+ # correctly as long as there are digits on both sides of the
+ # separator. So we are fine as long as we don't see something
+ # like "0.'3" (gcc 4.9.0 will not allow this literal).
+ if Search(r'\b(?:0[bBxX]?|[1-9])[0-9a-fA-F]*$', head):
+ match_literal = Match(r'^((?:\'?[0-9a-zA-Z_])*)(.*)$', "'" + tail)
+ collapsed += head + match_literal.group(1).replace("'", '')
+ elided = match_literal.group(2)
+ else:
+ second_quote = tail.find('\'')
+ if second_quote >= 0:
+ collapsed += head + "''"
+ elided = tail[second_quote + 1:]
+ else:
+ # Unmatched single quote
+ collapsed += elided
+ break
+
+ return collapsed
+
+
+def FindEndOfExpressionInLine(line, startpos, stack):
+ """Find the position just after the end of current parenthesized expression.
+
+ Args:
+ line: a CleansedLines line.
+ startpos: start searching at this position.
+ stack: nesting stack at startpos.
+
+ Returns:
+ On finding matching end: (index just after matching end, None)
+ On finding an unclosed expression: (-1, None)
+ Otherwise: (-1, new stack at end of this line)
+ """
+ for i in xrange(startpos, len(line)):
+ char = line[i]
+ if char in '([{':
+ # Found start of parenthesized expression, push to expression stack
+ stack.append(char)
+ elif char == '<':
+ # Found potential start of template argument list
+ if i > 0 and line[i - 1] == '<':
+ # Left shift operator
+ if stack and stack[-1] == '<':
+ stack.pop()
+ if not stack:
+ return (-1, None)
+ elif i > 0 and Search(r'\boperator\s*$', line[0:i]):
+ # operator<, don't add to stack
+ continue
+ else:
+ # Tentative start of template argument list
+ stack.append('<')
+ elif char in ')]}':
+ # Found end of parenthesized expression.
+ #
+ # If we are currently expecting a matching '>', the pending '<'
+ # must have been an operator. Remove them from expression stack.
+ while stack and stack[-1] == '<':
+ stack.pop()
+ if not stack:
+ return (-1, None)
+ if ((stack[-1] == '(' and char == ')') or
+ (stack[-1] == '[' and char == ']') or
+ (stack[-1] == '{' and char == '}')):
+ stack.pop()
+ if not stack:
+ return (i + 1, None)
+ else:
+ # Mismatched parentheses
+ return (-1, None)
+ elif char == '>':
+ # Found potential end of template argument list.
+
+ # Ignore "->" and operator functions
+ if (i > 0 and
+ (line[i - 1] == '-' or Search(r'\boperator\s*$', line[0:i - 1]))):
+ continue
+
+ # Pop the stack if there is a matching '<'. Otherwise, ignore
+ # this '>' since it must be an operator.
+ if stack:
+ if stack[-1] == '<':
+ stack.pop()
+ if not stack:
+ return (i + 1, None)
+ elif char == ';':
+ # Found something that look like end of statements. If we are currently
+ # expecting a '>', the matching '<' must have been an operator, since
+ # template argument list should not contain statements.
+ while stack and stack[-1] == '<':
+ stack.pop()
+ if not stack:
+ return (-1, None)
+
+ # Did not find end of expression or unbalanced parentheses on this line
+ return (-1, stack)
+
+
+def CloseExpression(clean_lines, linenum, pos):
+ """If input points to ( or { or [ or <, finds the position that closes it.
+
+ If lines[linenum][pos] points to a '(' or '{' or '[' or '<', finds the
+ linenum/pos that correspond to the closing of the expression.
+
+ TODO(unknown): cpplint spends a fair bit of time matching parentheses.
+ Ideally we would want to index all opening and closing parentheses once
+ and have CloseExpression be just a simple lookup, but due to preprocessor
+ tricks, this is not so easy.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ pos: A position on the line.
+
+ Returns:
+ A tuple (line, linenum, pos) pointer *past* the closing brace, or
+ (line, len(lines), -1) if we never find a close. Note we ignore
+ strings and comments when matching; and the line we return is the
+ 'cleansed' line at linenum.
+ """
+
+ line = clean_lines.elided[linenum]
+ if (line[pos] not in '({[<') or Match(r'<[<=]', line[pos:]):
+ return (line, clean_lines.NumLines(), -1)
+
+ # Check first line
+ (end_pos, stack) = FindEndOfExpressionInLine(line, pos, [])
+ if end_pos > -1:
+ return (line, linenum, end_pos)
+
+ # Continue scanning forward
+ while stack and linenum < clean_lines.NumLines() - 1:
+ linenum += 1
+ line = clean_lines.elided[linenum]
+ (end_pos, stack) = FindEndOfExpressionInLine(line, 0, stack)
+ if end_pos > -1:
+ return (line, linenum, end_pos)
+
+ # Did not find end of expression before end of file, give up
+ return (line, clean_lines.NumLines(), -1)
+
+
+def FindStartOfExpressionInLine(line, endpos, stack):
+ """Find position at the matching start of current expression.
+
+ This is almost the reverse of FindEndOfExpressionInLine, but note
+ that the input position and returned position differs by 1.
+
+ Args:
+ line: a CleansedLines line.
+ endpos: start searching at this position.
+ stack: nesting stack at endpos.
+
+ Returns:
+ On finding matching start: (index at matching start, None)
+ On finding an unclosed expression: (-1, None)
+ Otherwise: (-1, new stack at beginning of this line)
+ """
+ i = endpos
+ while i >= 0:
+ char = line[i]
+ if char in ')]}':
+ # Found end of expression, push to expression stack
+ stack.append(char)
+ elif char == '>':
+ # Found potential end of template argument list.
+ #
+ # Ignore it if it's a "->" or ">=" or "operator>"
+ if (i > 0 and
+ (line[i - 1] == '-' or
+ Match(r'\s>=\s', line[i - 1:]) or
+ Search(r'\boperator\s*$', line[0:i]))):
+ i -= 1
+ else:
+ stack.append('>')
+ elif char == '<':
+ # Found potential start of template argument list
+ if i > 0 and line[i - 1] == '<':
+ # Left shift operator
+ i -= 1
+ else:
+ # If there is a matching '>', we can pop the expression stack.
+ # Otherwise, ignore this '<' since it must be an operator.
+ if stack and stack[-1] == '>':
+ stack.pop()
+ if not stack:
+ return (i, None)
+ elif char in '([{':
+ # Found start of expression.
+ #
+ # If there are any unmatched '>' on the stack, they must be
+ # operators. Remove those.
+ while stack and stack[-1] == '>':
+ stack.pop()
+ if not stack:
+ return (-1, None)
+ if ((char == '(' and stack[-1] == ')') or
+ (char == '[' and stack[-1] == ']') or
+ (char == '{' and stack[-1] == '}')):
+ stack.pop()
+ if not stack:
+ return (i, None)
+ else:
+ # Mismatched parentheses
+ return (-1, None)
+ elif char == ';':
+ # Found something that look like end of statements. If we are currently
+ # expecting a '<', the matching '>' must have been an operator, since
+ # template argument list should not contain statements.
+ while stack and stack[-1] == '>':
+ stack.pop()
+ if not stack:
+ return (-1, None)
+
+ i -= 1
+
+ return (-1, stack)
+
+
+def ReverseCloseExpression(clean_lines, linenum, pos):
+ """If input points to ) or } or ] or >, finds the position that opens it.
+
+ If lines[linenum][pos] points to a ')' or '}' or ']' or '>', finds the
+ linenum/pos that correspond to the opening of the expression.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ pos: A position on the line.
+
+ Returns:
+ A tuple (line, linenum, pos) pointer *at* the opening brace, or
+ (line, 0, -1) if we never find the matching opening brace. Note
+ we ignore strings and comments when matching; and the line we
+ return is the 'cleansed' line at linenum.
+ """
+ line = clean_lines.elided[linenum]
+ if line[pos] not in ')}]>':
+ return (line, 0, -1)
+
+ # Check last line
+ (start_pos, stack) = FindStartOfExpressionInLine(line, pos, [])
+ if start_pos > -1:
+ return (line, linenum, start_pos)
+
+ # Continue scanning backward
+ while stack and linenum > 0:
+ linenum -= 1
+ line = clean_lines.elided[linenum]
+ (start_pos, stack) = FindStartOfExpressionInLine(line, len(line) - 1, stack)
+ if start_pos > -1:
+ return (line, linenum, start_pos)
+
+ # Did not find start of expression before beginning of file, give up
+ return (line, 0, -1)
+
+
+def CheckForCopyright(filename, lines, error):
+ """Logs an error if no Copyright message appears at the top of the file."""
+
+ # We'll say it should occur by line 10. Don't forget there's a
+ # placeholder line at the front.
+ for line in xrange(1, min(len(lines), 11)):
+ if re.search(r'Copyright', lines[line], re.I): break
+ else: # means no copyright line was found
+ error(filename, 0, 'legal/copyright', 5,
+ 'No copyright message found. '
+ 'You should have a line: "Copyright [year] <Copyright Owner>"')
+
+
+def GetIndentLevel(line):
+ """Return the number of leading spaces in line.
+
+ Args:
+ line: A string to check.
+
+ Returns:
+ An integer count of leading spaces, possibly zero.
+ """
+ indent = Match(r'^( *)\S', line)
+ if indent:
+ return len(indent.group(1))
+ else:
+ return 0
+
+def PathSplitToList(path):
+ """Returns the path split into a list by the separator.
+
+ Args:
+ path: An absolute or relative path (e.g. '/a/b/c/' or '../a')
+
+ Returns:
+ A list of path components (e.g. ['a', 'b', 'c]).
+ """
+ lst = []
+ while True:
+ (head, tail) = os.path.split(path)
+ if head == path: # absolute paths end
+ lst.append(head)
+ break
+ if tail == path: # relative paths end
+ lst.append(tail)
+ break
+
+ path = head
+ lst.append(tail)
+
+ lst.reverse()
+ return lst
+
+def GetHeaderGuardCPPVariable(filename):
+ """Returns the CPP variable that should be used as a header guard.
+
+ Args:
+ filename: The name of a C++ header file.
+
+ Returns:
+ The CPP variable that should be used as a header guard in the
+ named file.
+
+ """
+
+ # Restores original filename in case that cpplint is invoked from Emacs's
+ # flymake.
+ filename = re.sub(r'_flymake\.h$', '.h', filename)
+ filename = re.sub(r'/\.flymake/([^/]*)$', r'/\1', filename)
+ # Replace 'c++' with 'cpp'.
+ filename = filename.replace('C++', 'cpp').replace('c++', 'cpp')
+
+ fileinfo = FileInfo(filename)
+ file_path_from_root = fileinfo.RepositoryName()
+
+ def FixupPathFromRoot():
+ if _root_debug:
+ sys.stderr.write("\n_root fixup, _root = '%s', repository name = '%s'\n"
+ %(_root, fileinfo.RepositoryName()))
+
+ # Process the file path with the --root flag if it was set.
+ if not _root:
+ if _root_debug:
+ sys.stderr.write("_root unspecified\n")
+ return file_path_from_root
+
+ def StripListPrefix(lst, prefix):
+ # f(['x', 'y'], ['w, z']) -> None (not a valid prefix)
+ if lst[:len(prefix)] != prefix:
+ return None
+ # f(['a, 'b', 'c', 'd'], ['a', 'b']) -> ['c', 'd']
+ return lst[(len(prefix)):]
+
+ # root behavior:
+ # --root=subdir , lstrips subdir from the header guard
+ maybe_path = StripListPrefix(PathSplitToList(file_path_from_root),
+ PathSplitToList(_root))
+
+ if _root_debug:
+ sys.stderr.write(("_root lstrip (maybe_path=%s, file_path_from_root=%s," +
+ " _root=%s)\n") %(maybe_path, file_path_from_root, _root))
+
+ if maybe_path:
+ return os.path.join(*maybe_path)
+
+ # --root=.. , will prepend the outer directory to the header guard
+ full_path = fileinfo.FullName()
+ root_abspath = os.path.abspath(_root)
+
+ maybe_path = StripListPrefix(PathSplitToList(full_path),
+ PathSplitToList(root_abspath))
+
+ if _root_debug:
+ sys.stderr.write(("_root prepend (maybe_path=%s, full_path=%s, " +
+ "root_abspath=%s)\n") %(maybe_path, full_path, root_abspath))
+
+ if maybe_path:
+ return os.path.join(*maybe_path)
+
+ if _root_debug:
+ sys.stderr.write("_root ignore, returning %s\n" %(file_path_from_root))
+
+ # --root=FAKE_DIR is ignored
+ return file_path_from_root
+
+ file_path_from_root = FixupPathFromRoot()
+ return re.sub(r'[^a-zA-Z0-9]', '_', file_path_from_root).upper() + '_'
+
+
+def CheckForHeaderGuard(filename, clean_lines, error):
+ """Checks that the file contains a header guard.
+
+ Logs an error if no #ifndef header guard is present. For other
+ headers, checks that the full pathname is used.
+
+ Args:
+ filename: The name of the C++ header file.
+ clean_lines: A CleansedLines instance containing the file.
+ error: The function to call with any errors found.
+ """
+
+ # Don't check for header guards if there are error suppression
+ # comments somewhere in this file.
+ #
+ # Because this is silencing a warning for a nonexistent line, we
+ # only support the very specific NOLINT(build/header_guard) syntax,
+ # and not the general NOLINT or NOLINT(*) syntax.
+ raw_lines = clean_lines.lines_without_raw_strings
+ for i in raw_lines:
+ if Search(r'//\s*NOLINT\(build/header_guard\)', i):
+ return
+
+ cppvar = GetHeaderGuardCPPVariable(filename)
+
+ ifndef = ''
+ ifndef_linenum = 0
+ define = ''
+ endif = ''
+ endif_linenum = 0
+ for linenum, line in enumerate(raw_lines):
+ linesplit = line.split()
+ if len(linesplit) >= 2:
+ # find the first occurrence of #ifndef and #define, save arg
+ if not ifndef and linesplit[0] == '#ifndef':
+ # set ifndef to the header guard presented on the #ifndef line.
+ ifndef = linesplit[1]
+ ifndef_linenum = linenum
+ if not define and linesplit[0] == '#define':
+ define = linesplit[1]
+ # find the last occurrence of #endif, save entire line
+ if line.startswith('#endif'):
+ endif = line
+ endif_linenum = linenum
+
+ if not ifndef or not define or ifndef != define:
+ error(filename, 0, 'build/header_guard', 5,
+ 'No #ifndef header guard found, suggested CPP variable is: %s' %
+ cppvar)
+ return
+
+ # The guard should be PATH_FILE_H_, but we also allow PATH_FILE_H__
+ # for backward compatibility.
+ if ifndef != cppvar:
+ error_level = 0
+ if ifndef != cppvar + '_':
+ error_level = 5
+
+ ParseNolintSuppressions(filename, raw_lines[ifndef_linenum], ifndef_linenum,
+ error)
+ error(filename, ifndef_linenum, 'build/header_guard', error_level,
+ '#ifndef header guard has wrong style, please use: %s' % cppvar)
+
+ # Check for "//" comments on endif line.
+ ParseNolintSuppressions(filename, raw_lines[endif_linenum], endif_linenum,
+ error)
+ match = Match(r'#endif\s*//\s*' + cppvar + r'(_)?\b', endif)
+ if match:
+ if match.group(1) == '_':
+ # Issue low severity warning for deprecated double trailing underscore
+ error(filename, endif_linenum, 'build/header_guard', 0,
+ '#endif line should be "#endif // %s"' % cppvar)
+ return
+
+ # Didn't find the corresponding "//" comment. If this file does not
+ # contain any "//" comments at all, it could be that the compiler
+ # only wants "/**/" comments, look for those instead.
+ no_single_line_comments = True
+ for i in xrange(1, len(raw_lines) - 1):
+ line = raw_lines[i]
+ if Match(r'^(?:(?:\'(?:\.|[^\'])*\')|(?:"(?:\.|[^"])*")|[^\'"])*//', line):
+ no_single_line_comments = False
+ break
+
+ if no_single_line_comments:
+ match = Match(r'#endif\s*/\*\s*' + cppvar + r'(_)?\s*\*/', endif)
+ if match:
+ if match.group(1) == '_':
+ # Low severity warning for double trailing underscore
+ error(filename, endif_linenum, 'build/header_guard', 0,
+ '#endif line should be "#endif /* %s */"' % cppvar)
+ return
+
+ # Didn't find anything
+ error(filename, endif_linenum, 'build/header_guard', 5,
+ '#endif line should be "#endif // %s"' % cppvar)
+
+
+def CheckHeaderFileIncluded(filename, include_state, error):
+ """Logs an error if a .cc file does not include its header."""
+
+ # Do not check test files
+ fileinfo = FileInfo(filename)
+ if Search(_TEST_FILE_SUFFIX, fileinfo.BaseName()):
+ return
+
+ headerfile = filename[0:len(filename) - len(fileinfo.Extension())] + '.h'
+ if not os.path.exists(headerfile):
+ return
+ headername = FileInfo(headerfile).RepositoryName()
+ first_include = 0
+ for section_list in include_state.include_list:
+ for f in section_list:
+ if headername in f[0] or f[0] in headername:
+ return
+ if not first_include:
+ first_include = f[1]
+
+ error(filename, first_include, 'build/include', 5,
+ '%s should include its header file %s' % (fileinfo.RepositoryName(),
+ headername))
+
+
+def CheckForBadCharacters(filename, lines, error):
+ """Logs an error for each line containing bad characters.
+
+ Two kinds of bad characters:
+
+ 1. Unicode replacement characters: These indicate that either the file
+ contained invalid UTF-8 (likely) or Unicode replacement characters (which
+ it shouldn't). Note that it's possible for this to throw off line
+ numbering if the invalid UTF-8 occurred adjacent to a newline.
+
+ 2. NUL bytes. These are problematic for some tools.
+
+ Args:
+ filename: The name of the current file.
+ lines: An array of strings, each representing a line of the file.
+ error: The function to call with any errors found.
+ """
+ for linenum, line in enumerate(lines):
+ if u'\ufffd' in line:
+ error(filename, linenum, 'readability/utf8', 5,
+ 'Line contains invalid UTF-8 (or Unicode replacement character).')
+ if '\0' in line:
+ error(filename, linenum, 'readability/nul', 5, 'Line contains NUL byte.')
+
+
+def CheckForNewlineAtEOF(filename, lines, error):
+ """Logs an error if there is no newline char at the end of the file.
+
+ Args:
+ filename: The name of the current file.
+ lines: An array of strings, each representing a line of the file.
+ error: The function to call with any errors found.
+ """
+
+ # The array lines() was created by adding two newlines to the
+ # original file (go figure), then splitting on \n.
+ # To verify that the file ends in \n, we just have to make sure the
+ # last-but-two element of lines() exists and is empty.
+ if len(lines) < 3 or lines[-2]:
+ error(filename, len(lines) - 2, 'whitespace/ending_newline', 5,
+ 'Could not find a newline character at the end of the file.')
+
+
+def CheckForMultilineCommentsAndStrings(filename, clean_lines, linenum, error):
+ """Logs an error if we see /* ... */ or "..." that extend past one line.
+
+ /* ... */ comments are legit inside macros, for one line.
+ Otherwise, we prefer // comments, so it's ok to warn about the
+ other. Likewise, it's ok for strings to extend across multiple
+ lines, as long as a line continuation character (backslash)
+ terminates each line. Although not currently prohibited by the C++
+ style guide, it's ugly and unnecessary. We don't do well with either
+ in this lint program, so we warn about both.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # Remove all \\ (escaped backslashes) from the line. They are OK, and the
+ # second (escaped) slash may trigger later \" detection erroneously.
+ line = line.replace('\\\\', '')
+
+ if line.count('/*') > line.count('*/'):
+ error(filename, linenum, 'readability/multiline_comment', 5,
+ 'Complex multi-line /*...*/-style comment found. '
+ 'Lint may give bogus warnings. '
+ 'Consider replacing these with //-style comments, '
+ 'with #if 0...#endif, '
+ 'or with more clearly structured multi-line comments.')
+
+ if (line.count('"') - line.count('\\"')) % 2:
+ error(filename, linenum, 'readability/multiline_string', 5,
+ 'Multi-line string ("...") found. This lint script doesn\'t '
+ 'do well with such strings, and may give bogus warnings. '
+ 'Use C++11 raw strings or concatenation instead.')
+
+
+# (non-threadsafe name, thread-safe alternative, validation pattern)
+#
+# The validation pattern is used to eliminate false positives such as:
+# _rand(); // false positive due to substring match.
+# ->rand(); // some member function rand().
+# ACMRandom rand(seed); // some variable named rand.
+# ISAACRandom rand(); // another variable named rand.
+#
+# Basically we require the return value of these functions to be used
+# in some expression context on the same line by matching on some
+# operator before the function name. This eliminates constructors and
+# member function calls.
+_UNSAFE_FUNC_PREFIX = r'(?:[-+*/=%^&|(<]\s*|>\s+)'
+_THREADING_LIST = (
+ ('asctime(', 'asctime_r(', _UNSAFE_FUNC_PREFIX + r'asctime\([^)]+\)'),
+ ('ctime(', 'ctime_r(', _UNSAFE_FUNC_PREFIX + r'ctime\([^)]+\)'),
+ ('getgrgid(', 'getgrgid_r(', _UNSAFE_FUNC_PREFIX + r'getgrgid\([^)]+\)'),
+ ('getgrnam(', 'getgrnam_r(', _UNSAFE_FUNC_PREFIX + r'getgrnam\([^)]+\)'),
+ ('getlogin(', 'getlogin_r(', _UNSAFE_FUNC_PREFIX + r'getlogin\(\)'),
+ ('getpwnam(', 'getpwnam_r(', _UNSAFE_FUNC_PREFIX + r'getpwnam\([^)]+\)'),
+ ('getpwuid(', 'getpwuid_r(', _UNSAFE_FUNC_PREFIX + r'getpwuid\([^)]+\)'),
+ ('gmtime(', 'gmtime_r(', _UNSAFE_FUNC_PREFIX + r'gmtime\([^)]+\)'),
+ ('localtime(', 'localtime_r(', _UNSAFE_FUNC_PREFIX + r'localtime\([^)]+\)'),
+ ('rand(', 'rand_r(', _UNSAFE_FUNC_PREFIX + r'rand\(\)'),
+ ('strtok(', 'strtok_r(',
+ _UNSAFE_FUNC_PREFIX + r'strtok\([^)]+\)'),
+ ('ttyname(', 'ttyname_r(', _UNSAFE_FUNC_PREFIX + r'ttyname\([^)]+\)'),
+ )
+
+
+def CheckPosixThreading(filename, clean_lines, linenum, error):
+ """Checks for calls to thread-unsafe functions.
+
+ Much code has been originally written without consideration of
+ multi-threading. Also, engineers are relying on their old experience;
+ they have learned posix before threading extensions were added. These
+ tests guide the engineers to use thread-safe functions (when using
+ posix directly).
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+ for single_thread_func, multithread_safe_func, pattern in _THREADING_LIST:
+ # Additional pattern matching check to confirm that this is the
+ # function we are looking for
+ if Search(pattern, line):
+ error(filename, linenum, 'runtime/threadsafe_fn', 2,
+ 'Consider using ' + multithread_safe_func +
+ '...) instead of ' + single_thread_func +
+ '...) for improved thread safety.')
+
+
+def CheckVlogArguments(filename, clean_lines, linenum, error):
+ """Checks that VLOG() is only used for defining a logging level.
+
+ For example, VLOG(2) is correct. VLOG(INFO), VLOG(WARNING), VLOG(ERROR), and
+ VLOG(FATAL) are not.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+ if Search(r'\bVLOG\((INFO|ERROR|WARNING|DFATAL|FATAL)\)', line):
+ error(filename, linenum, 'runtime/vlog', 5,
+ 'VLOG() should be used with numeric verbosity level. '
+ 'Use LOG() if you want symbolic severity levels.')
+
+# Matches invalid increment: *count++, which moves pointer instead of
+# incrementing a value.
+_RE_PATTERN_INVALID_INCREMENT = re.compile(
+ r'^\s*\*\w+(\+\+|--);')
+
+
+def CheckInvalidIncrement(filename, clean_lines, linenum, error):
+ """Checks for invalid increment *count++.
+
+ For example following function:
+ void increment_counter(int* count) {
+ *count++;
+ }
+ is invalid, because it effectively does count++, moving pointer, and should
+ be replaced with ++*count, (*count)++ or *count += 1.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+ if _RE_PATTERN_INVALID_INCREMENT.match(line):
+ error(filename, linenum, 'runtime/invalid_increment', 5,
+ 'Changing pointer instead of value (or unused value of operator*).')
+
+
+def IsMacroDefinition(clean_lines, linenum):
+ if Search(r'^#define', clean_lines[linenum]):
+ return True
+
+ if linenum > 0 and Search(r'\\$', clean_lines[linenum - 1]):
+ return True
+
+ return False
+
+
+def IsForwardClassDeclaration(clean_lines, linenum):
+ return Match(r'^\s*(\btemplate\b)*.*class\s+\w+;\s*$', clean_lines[linenum])
+
+
+class _BlockInfo(object):
+ """Stores information about a generic block of code."""
+
+ def __init__(self, linenum, seen_open_brace):
+ self.starting_linenum = linenum
+ self.seen_open_brace = seen_open_brace
+ self.open_parentheses = 0
+ self.inline_asm = _NO_ASM
+ self.check_namespace_indentation = False
+
+ def CheckBegin(self, filename, clean_lines, linenum, error):
+ """Run checks that applies to text up to the opening brace.
+
+ This is mostly for checking the text after the class identifier
+ and the "{", usually where the base class is specified. For other
+ blocks, there isn't much to check, so we always pass.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ pass
+
+ def CheckEnd(self, filename, clean_lines, linenum, error):
+ """Run checks that applies to text after the closing brace.
+
+ This is mostly used for checking end of namespace comments.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ pass
+
+ def IsBlockInfo(self):
+ """Returns true if this block is a _BlockInfo.
+
+ This is convenient for verifying that an object is an instance of
+ a _BlockInfo, but not an instance of any of the derived classes.
+
+ Returns:
+ True for this class, False for derived classes.
+ """
+ return self.__class__ == _BlockInfo
+
+
+class _ExternCInfo(_BlockInfo):
+ """Stores information about an 'extern "C"' block."""
+
+ def __init__(self, linenum):
+ _BlockInfo.__init__(self, linenum, True)
+
+
+class _ClassInfo(_BlockInfo):
+ """Stores information about a class."""
+
+ def __init__(self, name, class_or_struct, clean_lines, linenum):
+ _BlockInfo.__init__(self, linenum, False)
+ self.name = name
+ self.is_derived = False
+ self.check_namespace_indentation = True
+ if class_or_struct == 'struct':
+ self.access = 'public'
+ self.is_struct = True
+ else:
+ self.access = 'private'
+ self.is_struct = False
+
+ # Remember initial indentation level for this class. Using raw_lines here
+ # instead of elided to account for leading comments.
+ self.class_indent = GetIndentLevel(clean_lines.raw_lines[linenum])
+
+ # Try to find the end of the class. This will be confused by things like:
+ # class A {
+ # } *x = { ...
+ #
+ # But it's still good enough for CheckSectionSpacing.
+ self.last_line = 0
+ depth = 0
+ for i in range(linenum, clean_lines.NumLines()):
+ line = clean_lines.elided[i]
+ depth += line.count('{') - line.count('}')
+ if not depth:
+ self.last_line = i
+ break
+
+ def CheckBegin(self, filename, clean_lines, linenum, error):
+ # Look for a bare ':'
+ if Search('(^|[^:]):($|[^:])', clean_lines.elided[linenum]):
+ self.is_derived = True
+
+ def CheckEnd(self, filename, clean_lines, linenum, error):
+ # If there is a DISALLOW macro, it should appear near the end of
+ # the class.
+ seen_last_thing_in_class = False
+ for i in xrange(linenum - 1, self.starting_linenum, -1):
+ match = Search(
+ r'\b(DISALLOW_COPY_AND_ASSIGN|DISALLOW_IMPLICIT_CONSTRUCTORS)\(' +
+ self.name + r'\)',
+ clean_lines.elided[i])
+ if match:
+ if seen_last_thing_in_class:
+ error(filename, i, 'readability/constructors', 3,
+ match.group(1) + ' should be the last thing in the class')
+ break
+
+ if not Match(r'^\s*$', clean_lines.elided[i]):
+ seen_last_thing_in_class = True
+
+ # Check that closing brace is aligned with beginning of the class.
+ # Only do this if the closing brace is indented by only whitespaces.
+ # This means we will not check single-line class definitions.
+ indent = Match(r'^( *)\}', clean_lines.elided[linenum])
+ if indent and len(indent.group(1)) != self.class_indent:
+ if self.is_struct:
+ parent = 'struct ' + self.name
+ else:
+ parent = 'class ' + self.name
+ error(filename, linenum, 'whitespace/indent', 3,
+ 'Closing brace should be aligned with beginning of %s' % parent)
+
+
+class _NamespaceInfo(_BlockInfo):
+ """Stores information about a namespace."""
+
+ def __init__(self, name, linenum):
+ _BlockInfo.__init__(self, linenum, False)
+ self.name = name or ''
+ self.check_namespace_indentation = True
+
+ def CheckEnd(self, filename, clean_lines, linenum, error):
+ """Check end of namespace comments."""
+ line = clean_lines.raw_lines[linenum]
+
+ # Check how many lines is enclosed in this namespace. Don't issue
+ # warning for missing namespace comments if there aren't enough
+ # lines. However, do apply checks if there is already an end of
+ # namespace comment and it's incorrect.
+ #
+ # TODO(unknown): We always want to check end of namespace comments
+ # if a namespace is large, but sometimes we also want to apply the
+ # check if a short namespace contained nontrivial things (something
+ # other than forward declarations). There is currently no logic on
+ # deciding what these nontrivial things are, so this check is
+ # triggered by namespace size only, which works most of the time.
+ if (linenum - self.starting_linenum < 10
+ and not Match(r'^\s*};*\s*(//|/\*).*\bnamespace\b', line)):
+ return
+
+ # Look for matching comment at end of namespace.
+ #
+ # Note that we accept C style "/* */" comments for terminating
+ # namespaces, so that code that terminate namespaces inside
+ # preprocessor macros can be cpplint clean.
+ #
+ # We also accept stuff like "// end of namespace <name>." with the
+ # period at the end.
+ #
+ # Besides these, we don't accept anything else, otherwise we might
+ # get false negatives when existing comment is a substring of the
+ # expected namespace.
+ if self.name:
+ # Named namespace
+ if not Match((r'^\s*};*\s*(//|/\*).*\bnamespace\s+' +
+ re.escape(self.name) + r'[\*/\.\\\s]*$'),
+ line):
+ error(filename, linenum, 'readability/namespace', 5,
+ 'Namespace should be terminated with "// namespace %s"' %
+ self.name)
+ else:
+ # Anonymous namespace
+ if not Match(r'^\s*};*\s*(//|/\*).*\bnamespace[\*/\.\\\s]*$', line):
+ # If "// namespace anonymous" or "// anonymous namespace (more text)",
+ # mention "// anonymous namespace" as an acceptable form
+ if Match(r'^\s*}.*\b(namespace anonymous|anonymous namespace)\b', line):
+ error(filename, linenum, 'readability/namespace', 5,
+ 'Anonymous namespace should be terminated with "// namespace"'
+ ' or "// anonymous namespace"')
+ else:
+ error(filename, linenum, 'readability/namespace', 5,
+ 'Anonymous namespace should be terminated with "// namespace"')
+
+
+class _PreprocessorInfo(object):
+ """Stores checkpoints of nesting stacks when #if/#else is seen."""
+
+ def __init__(self, stack_before_if):
+ # The entire nesting stack before #if
+ self.stack_before_if = stack_before_if
+
+ # The entire nesting stack up to #else
+ self.stack_before_else = []
+
+ # Whether we have already seen #else or #elif
+ self.seen_else = False
+
+
+class NestingState(object):
+ """Holds states related to parsing braces."""
+
+ def __init__(self):
+ # Stack for tracking all braces. An object is pushed whenever we
+ # see a "{", and popped when we see a "}". Only 3 types of
+ # objects are possible:
+ # - _ClassInfo: a class or struct.
+ # - _NamespaceInfo: a namespace.
+ # - _BlockInfo: some other type of block.
+ self.stack = []
+
+ # Top of the previous stack before each Update().
+ #
+ # Because the nesting_stack is updated at the end of each line, we
+ # had to do some convoluted checks to find out what is the current
+ # scope at the beginning of the line. This check is simplified by
+ # saving the previous top of nesting stack.
+ #
+ # We could save the full stack, but we only need the top. Copying
+ # the full nesting stack would slow down cpplint by ~10%.
+ self.previous_stack_top = []
+
+ # Stack of _PreprocessorInfo objects.
+ self.pp_stack = []
+
+ def SeenOpenBrace(self):
+ """Check if we have seen the opening brace for the innermost block.
+
+ Returns:
+ True if we have seen the opening brace, False if the innermost
+ block is still expecting an opening brace.
+ """
+ return (not self.stack) or self.stack[-1].seen_open_brace
+
+ def InNamespaceBody(self):
+ """Check if we are currently one level inside a namespace body.
+
+ Returns:
+ True if top of the stack is a namespace block, False otherwise.
+ """
+ return self.stack and isinstance(self.stack[-1], _NamespaceInfo)
+
+ def InExternC(self):
+ """Check if we are currently one level inside an 'extern "C"' block.
+
+ Returns:
+ True if top of the stack is an extern block, False otherwise.
+ """
+ return self.stack and isinstance(self.stack[-1], _ExternCInfo)
+
+ def InClassDeclaration(self):
+ """Check if we are currently one level inside a class or struct declaration.
+
+ Returns:
+ True if top of the stack is a class/struct, False otherwise.
+ """
+ return self.stack and isinstance(self.stack[-1], _ClassInfo)
+
+ def InAsmBlock(self):
+ """Check if we are currently one level inside an inline ASM block.
+
+ Returns:
+ True if the top of the stack is a block containing inline ASM.
+ """
+ return self.stack and self.stack[-1].inline_asm != _NO_ASM
+
+ def InTemplateArgumentList(self, clean_lines, linenum, pos):
+ """Check if current position is inside template argument list.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ pos: position just after the suspected template argument.
+ Returns:
+ True if (linenum, pos) is inside template arguments.
+ """
+ while linenum < clean_lines.NumLines():
+ # Find the earliest character that might indicate a template argument
+ line = clean_lines.elided[linenum]
+ match = Match(r'^[^{};=\[\]\.<>]*(.)', line[pos:])
+ if not match:
+ linenum += 1
+ pos = 0
+ continue
+ token = match.group(1)
+ pos += len(match.group(0))
+
+ # These things do not look like template argument list:
+ # class Suspect {
+ # class Suspect x; }
+ if token in ('{', '}', ';'): return False
+
+ # These things look like template argument list:
+ # template <class Suspect>
+ # template <class Suspect = default_value>
+ # template <class Suspect[]>
+ # template <class Suspect...>
+ if token in ('>', '=', '[', ']', '.'): return True
+
+ # Check if token is an unmatched '<'.
+ # If not, move on to the next character.
+ if token != '<':
+ pos += 1
+ if pos >= len(line):
+ linenum += 1
+ pos = 0
+ continue
+
+ # We can't be sure if we just find a single '<', and need to
+ # find the matching '>'.
+ (_, end_line, end_pos) = CloseExpression(clean_lines, linenum, pos - 1)
+ if end_pos < 0:
+ # Not sure if template argument list or syntax error in file
+ return False
+ linenum = end_line
+ pos = end_pos
+ return False
+
+ def UpdatePreprocessor(self, line):
+ """Update preprocessor stack.
+
+ We need to handle preprocessors due to classes like this:
+ #ifdef SWIG
+ struct ResultDetailsPageElementExtensionPoint {
+ #else
+ struct ResultDetailsPageElementExtensionPoint : public Extension {
+ #endif
+
+ We make the following assumptions (good enough for most files):
+ - Preprocessor condition evaluates to true from #if up to first
+ #else/#elif/#endif.
+
+ - Preprocessor condition evaluates to false from #else/#elif up
+ to #endif. We still perform lint checks on these lines, but
+ these do not affect nesting stack.
+
+ Args:
+ line: current line to check.
+ """
+ if Match(r'^\s*#\s*(if|ifdef|ifndef)\b', line):
+ # Beginning of #if block, save the nesting stack here. The saved
+ # stack will allow us to restore the parsing state in the #else case.
+ self.pp_stack.append(_PreprocessorInfo(copy.deepcopy(self.stack)))
+ elif Match(r'^\s*#\s*(else|elif)\b', line):
+ # Beginning of #else block
+ if self.pp_stack:
+ if not self.pp_stack[-1].seen_else:
+ # This is the first #else or #elif block. Remember the
+ # whole nesting stack up to this point. This is what we
+ # keep after the #endif.
+ self.pp_stack[-1].seen_else = True
+ self.pp_stack[-1].stack_before_else = copy.deepcopy(self.stack)
+
+ # Restore the stack to how it was before the #if
+ self.stack = copy.deepcopy(self.pp_stack[-1].stack_before_if)
+ else:
+ # TODO(unknown): unexpected #else, issue warning?
+ pass
+ elif Match(r'^\s*#\s*endif\b', line):
+ # End of #if or #else blocks.
+ if self.pp_stack:
+ # If we saw an #else, we will need to restore the nesting
+ # stack to its former state before the #else, otherwise we
+ # will just continue from where we left off.
+ if self.pp_stack[-1].seen_else:
+ # Here we can just use a shallow copy since we are the last
+ # reference to it.
+ self.stack = self.pp_stack[-1].stack_before_else
+ # Drop the corresponding #if
+ self.pp_stack.pop()
+ else:
+ # TODO(unknown): unexpected #endif, issue warning?
+ pass
+
+ # TODO(unknown): Update() is too long, but we will refactor later.
+ def Update(self, filename, clean_lines, linenum, error):
+ """Update nesting state with current line.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # Remember top of the previous nesting stack.
+ #
+ # The stack is always pushed/popped and not modified in place, so
+ # we can just do a shallow copy instead of copy.deepcopy. Using
+ # deepcopy would slow down cpplint by ~28%.
+ if self.stack:
+ self.previous_stack_top = self.stack[-1]
+ else:
+ self.previous_stack_top = None
+
+ # Update pp_stack
+ self.UpdatePreprocessor(line)
+
+ # Count parentheses. This is to avoid adding struct arguments to
+ # the nesting stack.
+ if self.stack:
+ inner_block = self.stack[-1]
+ depth_change = line.count('(') - line.count(')')
+ inner_block.open_parentheses += depth_change
+
+ # Also check if we are starting or ending an inline assembly block.
+ if inner_block.inline_asm in (_NO_ASM, _END_ASM):
+ if (depth_change != 0 and
+ inner_block.open_parentheses == 1 and
+ _MATCH_ASM.match(line)):
+ # Enter assembly block
+ inner_block.inline_asm = _INSIDE_ASM
+ else:
+ # Not entering assembly block. If previous line was _END_ASM,
+ # we will now shift to _NO_ASM state.
+ inner_block.inline_asm = _NO_ASM
+ elif (inner_block.inline_asm == _INSIDE_ASM and
+ inner_block.open_parentheses == 0):
+ # Exit assembly block
+ inner_block.inline_asm = _END_ASM
+
+ # Consume namespace declaration at the beginning of the line. Do
+ # this in a loop so that we catch same line declarations like this:
+ # namespace proto2 { namespace bridge { class MessageSet; } }
+ while True:
+ # Match start of namespace. The "\b\s*" below catches namespace
+ # declarations even if it weren't followed by a whitespace, this
+ # is so that we don't confuse our namespace checker. The
+ # missing spaces will be flagged by CheckSpacing.
+ namespace_decl_match = Match(r'^\s*namespace\b\s*([:\w]+)?(.*)$', line)
+ if not namespace_decl_match:
+ break
+
+ new_namespace = _NamespaceInfo(namespace_decl_match.group(1), linenum)
+ self.stack.append(new_namespace)
+
+ line = namespace_decl_match.group(2)
+ if line.find('{') != -1:
+ new_namespace.seen_open_brace = True
+ line = line[line.find('{') + 1:]
+
+ # Look for a class declaration in whatever is left of the line
+ # after parsing namespaces. The regexp accounts for decorated classes
+ # such as in:
+ # class LOCKABLE API Object {
+ # };
+ class_decl_match = Match(
+ r'^(\s*(?:template\s*<[\w\s<>,:]*>\s*)?'
+ r'(class|struct)\s+(?:[A-Z_]+\s+)*(\w+(?:::\w+)*))'
+ r'(.*)$', line)
+ if (class_decl_match and
+ (not self.stack or self.stack[-1].open_parentheses == 0)):
+ # We do not want to accept classes that are actually template arguments:
+ # template <class Ignore1,
+ # class Ignore2 = Default<Args>,
+ # template <Args> class Ignore3>
+ # void Function() {};
+ #
+ # To avoid template argument cases, we scan forward and look for
+ # an unmatched '>'. If we see one, assume we are inside a
+ # template argument list.
+ end_declaration = len(class_decl_match.group(1))
+ if not self.InTemplateArgumentList(clean_lines, linenum, end_declaration):
+ self.stack.append(_ClassInfo(
+ class_decl_match.group(3), class_decl_match.group(2),
+ clean_lines, linenum))
+ line = class_decl_match.group(4)
+
+ # If we have not yet seen the opening brace for the innermost block,
+ # run checks here.
+ if not self.SeenOpenBrace():
+ self.stack[-1].CheckBegin(filename, clean_lines, linenum, error)
+
+ # Update access control if we are inside a class/struct
+ if self.stack and isinstance(self.stack[-1], _ClassInfo):
+ classinfo = self.stack[-1]
+ access_match = Match(
+ r'^(.*)\b(public|private|protected|signals)(\s+(?:slots\s*)?)?'
+ r':(?:[^:]|$)',
+ line)
+ if access_match:
+ classinfo.access = access_match.group(2)
+
+ # Check that access keywords are indented +1 space. Skip this
+ # check if the keywords are not preceded by whitespaces.
+ indent = access_match.group(1)
+ if (len(indent) != classinfo.class_indent + 1 and
+ Match(r'^\s*$', indent)):
+ if classinfo.is_struct:
+ parent = 'struct ' + classinfo.name
+ else:
+ parent = 'class ' + classinfo.name
+ slots = ''
+ if access_match.group(3):
+ slots = access_match.group(3)
+ error(filename, linenum, 'whitespace/indent', 3,
+ '%s%s: should be indented +1 space inside %s' % (
+ access_match.group(2), slots, parent))
+
+ # Consume braces or semicolons from what's left of the line
+ while True:
+ # Match first brace, semicolon, or closed parenthesis.
+ matched = Match(r'^[^{;)}]*([{;)}])(.*)$', line)
+ if not matched:
+ break
+
+ token = matched.group(1)
+ if token == '{':
+ # If namespace or class hasn't seen a opening brace yet, mark
+ # namespace/class head as complete. Push a new block onto the
+ # stack otherwise.
+ if not self.SeenOpenBrace():
+ self.stack[-1].seen_open_brace = True
+ elif Match(r'^extern\s*"[^"]*"\s*\{', line):
+ self.stack.append(_ExternCInfo(linenum))
+ else:
+ self.stack.append(_BlockInfo(linenum, True))
+ if _MATCH_ASM.match(line):
+ self.stack[-1].inline_asm = _BLOCK_ASM
+
+ elif token == ';' or token == ')':
+ # If we haven't seen an opening brace yet, but we already saw
+ # a semicolon, this is probably a forward declaration. Pop
+ # the stack for these.
+ #
+ # Similarly, if we haven't seen an opening brace yet, but we
+ # already saw a closing parenthesis, then these are probably
+ # function arguments with extra "class" or "struct" keywords.
+ # Also pop these stack for these.
+ if not self.SeenOpenBrace():
+ self.stack.pop()
+ else: # token == '}'
+ # Perform end of block checks and pop the stack.
+ if self.stack:
+ self.stack[-1].CheckEnd(filename, clean_lines, linenum, error)
+ self.stack.pop()
+ line = matched.group(2)
+
+ def InnermostClass(self):
+ """Get class info on the top of the stack.
+
+ Returns:
+ A _ClassInfo object if we are inside a class, or None otherwise.
+ """
+ for i in range(len(self.stack), 0, -1):
+ classinfo = self.stack[i - 1]
+ if isinstance(classinfo, _ClassInfo):
+ return classinfo
+ return None
+
+ def CheckCompletedBlocks(self, filename, error):
+ """Checks that all classes and namespaces have been completely parsed.
+
+ Call this when all lines in a file have been processed.
+ Args:
+ filename: The name of the current file.
+ error: The function to call with any errors found.
+ """
+ # Note: This test can result in false positives if #ifdef constructs
+ # get in the way of brace matching. See the testBuildClass test in
+ # cpplint_unittest.py for an example of this.
+ for obj in self.stack:
+ if isinstance(obj, _ClassInfo):
+ error(filename, obj.starting_linenum, 'build/class', 5,
+ 'Failed to find complete declaration of class %s' %
+ obj.name)
+ elif isinstance(obj, _NamespaceInfo):
+ error(filename, obj.starting_linenum, 'build/namespaces', 5,
+ 'Failed to find complete declaration of namespace %s' %
+ obj.name)
+
+
+def CheckForNonStandardConstructs(filename, clean_lines, linenum,
+ nesting_state, error):
+ r"""Logs an error if we see certain non-ANSI constructs ignored by gcc-2.
+
+ Complain about several constructs which gcc-2 accepts, but which are
+ not standard C++. Warning about these in lint is one way to ease the
+ transition to new compilers.
+ - put storage class first (e.g. "static const" instead of "const static").
+ - "%lld" instead of %qd" in printf-type functions.
+ - "%1$d" is non-standard in printf-type functions.
+ - "\%" is an undefined character escape sequence.
+ - text after #endif is not allowed.
+ - invalid inner-style forward declaration.
+ - >? and <? operators, and their >?= and <?= cousins.
+
+ Additionally, check for constructor/destructor style violations and reference
+ members, as it is very convenient to do so while checking for
+ gcc-2 compliance.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ error: A callable to which errors are reported, which takes 4 arguments:
+ filename, line number, error level, and message
+ """
+
+ # Remove comments from the line, but leave in strings for now.
+ line = clean_lines.lines[linenum]
+
+ if Search(r'printf\s*\(.*".*%[-+ ]?\d*q', line):
+ error(filename, linenum, 'runtime/printf_format', 3,
+ '%q in format strings is deprecated. Use %ll instead.')
+
+ if Search(r'printf\s*\(.*".*%\d+\$', line):
+ error(filename, linenum, 'runtime/printf_format', 2,
+ '%N$ formats are unconventional. Try rewriting to avoid them.')
+
+ # Remove escaped backslashes before looking for undefined escapes.
+ line = line.replace('\\\\', '')
+
+ if Search(r'("|\').*\\(%|\[|\(|{)', line):
+ error(filename, linenum, 'build/printf_format', 3,
+ '%, [, (, and { are undefined character escapes. Unescape them.')
+
+ # For the rest, work with both comments and strings removed.
+ line = clean_lines.elided[linenum]
+
+ if Search(r'\b(const|volatile|void|char|short|int|long'
+ r'|float|double|signed|unsigned'
+ r'|schar|u?int8|u?int16|u?int32|u?int64)'
+ r'\s+(register|static|extern|typedef)\b',
+ line):
+ error(filename, linenum, 'build/storage_class', 5,
+ 'Storage-class specifier (static, extern, typedef, etc) should be '
+ 'at the beginning of the declaration.')
+
+ if Match(r'\s*#\s*endif\s*[^/\s]+', line):
+ error(filename, linenum, 'build/endif_comment', 5,
+ 'Uncommented text after #endif is non-standard. Use a comment.')
+
+ if Match(r'\s*class\s+(\w+\s*::\s*)+\w+\s*;', line):
+ error(filename, linenum, 'build/forward_decl', 5,
+ 'Inner-style forward declarations are invalid. Remove this line.')
+
+ if Search(r'(\w+|[+-]?\d+(\.\d*)?)\s*(<|>)\?=?\s*(\w+|[+-]?\d+)(\.\d*)?',
+ line):
+ error(filename, linenum, 'build/deprecated', 3,
+ '>? and <? (max and min) operators are non-standard and deprecated.')
+
+ if Search(r'^\s*const\s*string\s*&\s*\w+\s*;', line):
+ # TODO(unknown): Could it be expanded safely to arbitrary references,
+ # without triggering too many false positives? The first
+ # attempt triggered 5 warnings for mostly benign code in the regtest, hence
+ # the restriction.
+ # Here's the original regexp, for the reference:
+ # type_name = r'\w+((\s*::\s*\w+)|(\s*<\s*\w+?\s*>))?'
+ # r'\s*const\s*' + type_name + '\s*&\s*\w+\s*;'
+ error(filename, linenum, 'runtime/member_string_references', 2,
+ 'const string& members are dangerous. It is much better to use '
+ 'alternatives, such as pointers or simple constants.')
+
+ # Everything else in this function operates on class declarations.
+ # Return early if the top of the nesting stack is not a class, or if
+ # the class head is not completed yet.
+ classinfo = nesting_state.InnermostClass()
+ if not classinfo or not classinfo.seen_open_brace:
+ return
+
+ # The class may have been declared with namespace or classname qualifiers.
+ # The constructor and destructor will not have those qualifiers.
+ base_classname = classinfo.name.split('::')[-1]
+
+ # Look for single-argument constructors that aren't marked explicit.
+ # Technically a valid construct, but against style.
+ explicit_constructor_match = Match(
+ r'\s+(?:(?:inline|constexpr)\s+)*(explicit\s+)?'
+ r'(?:(?:inline|constexpr)\s+)*%s\s*'
+ r'\(((?:[^()]|\([^()]*\))*)\)'
+ % re.escape(base_classname),
+ line)
+
+ if explicit_constructor_match:
+ is_marked_explicit = explicit_constructor_match.group(1)
+
+ if not explicit_constructor_match.group(2):
+ constructor_args = []
+ else:
+ constructor_args = explicit_constructor_match.group(2).split(',')
+
+ # collapse arguments so that commas in template parameter lists and function
+ # argument parameter lists don't split arguments in two
+ i = 0
+ while i < len(constructor_args):
+ constructor_arg = constructor_args[i]
+ while (constructor_arg.count('<') > constructor_arg.count('>') or
+ constructor_arg.count('(') > constructor_arg.count(')')):
+ constructor_arg += ',' + constructor_args[i + 1]
+ del constructor_args[i + 1]
+ constructor_args[i] = constructor_arg
+ i += 1
+
+ defaulted_args = [arg for arg in constructor_args if '=' in arg]
+ noarg_constructor = (not constructor_args or # empty arg list
+ # 'void' arg specifier
+ (len(constructor_args) == 1 and
+ constructor_args[0].strip() == 'void'))
+ onearg_constructor = ((len(constructor_args) == 1 and # exactly one arg
+ not noarg_constructor) or
+ # all but at most one arg defaulted
+ (len(constructor_args) >= 1 and
+ not noarg_constructor and
+ len(defaulted_args) >= len(constructor_args) - 1))
+ initializer_list_constructor = bool(
+ onearg_constructor and
+ Search(r'\bstd\s*::\s*initializer_list\b', constructor_args[0]))
+ copy_constructor = bool(
+ onearg_constructor and
+ Match(r'(const\s+)?%s(\s*<[^>]*>)?(\s+const)?\s*(?:<\w+>\s*)?&'
+ % re.escape(base_classname), constructor_args[0].strip()))
+
+ if (not is_marked_explicit and
+ onearg_constructor and
+ not initializer_list_constructor and
+ not copy_constructor):
+ if defaulted_args:
+ error(filename, linenum, 'runtime/explicit', 5,
+ 'Constructors callable with one argument '
+ 'should be marked explicit.')
+ else:
+ error(filename, linenum, 'runtime/explicit', 5,
+ 'Single-parameter constructors should be marked explicit.')
+ elif is_marked_explicit and not onearg_constructor:
+ if noarg_constructor:
+ error(filename, linenum, 'runtime/explicit', 5,
+ 'Zero-parameter constructors should not be marked explicit.')
+
+
+def CheckSpacingForFunctionCall(filename, clean_lines, linenum, error):
+ """Checks for the correctness of various spacing around function calls.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # Since function calls often occur inside if/for/while/switch
+ # expressions - which have their own, more liberal conventions - we
+ # first see if we should be looking inside such an expression for a
+ # function call, to which we can apply more strict standards.
+ fncall = line # if there's no control flow construct, look at whole line
+ for pattern in (r'\bif\s*\((.*)\)\s*{',
+ r'\bfor\s*\((.*)\)\s*{',
+ r'\bwhile\s*\((.*)\)\s*[{;]',
+ r'\bswitch\s*\((.*)\)\s*{'):
+ match = Search(pattern, line)
+ if match:
+ fncall = match.group(1) # look inside the parens for function calls
+ break
+
+ # Except in if/for/while/switch, there should never be space
+ # immediately inside parens (eg "f( 3, 4 )"). We make an exception
+ # for nested parens ( (a+b) + c ). Likewise, there should never be
+ # a space before a ( when it's a function argument. I assume it's a
+ # function argument when the char before the whitespace is legal in
+ # a function name (alnum + _) and we're not starting a macro. Also ignore
+ # pointers and references to arrays and functions coz they're too tricky:
+ # we use a very simple way to recognize these:
+ # " (something)(maybe-something)" or
+ # " (something)(maybe-something," or
+ # " (something)[something]"
+ # Note that we assume the contents of [] to be short enough that
+ # they'll never need to wrap.
+ if ( # Ignore control structures.
+ not Search(r'\b(if|for|while|switch|return|new|delete|catch|sizeof)\b',
+ fncall) and
+ # Ignore pointers/references to functions.
+ not Search(r' \([^)]+\)\([^)]*(\)|,$)', fncall) and
+ # Ignore pointers/references to arrays.
+ not Search(r' \([^)]+\)\[[^\]]+\]', fncall)):
+ if Search(r'\w\s*\(\s(?!\s*\\$)', fncall): # a ( used for a fn call
+ error(filename, linenum, 'whitespace/parens', 4,
+ 'Extra space after ( in function call')
+ elif Search(r'\(\s+(?!(\s*\\)|\()', fncall):
+ error(filename, linenum, 'whitespace/parens', 2,
+ 'Extra space after (')
+ if (Search(r'\w\s+\(', fncall) and
+ not Search(r'_{0,2}asm_{0,2}\s+_{0,2}volatile_{0,2}\s+\(', fncall) and
+ not Search(r'#\s*define|typedef|using\s+\w+\s*=', fncall) and
+ not Search(r'\w\s+\((\w+::)*\*\w+\)\(', fncall) and
+ not Search(r'\bcase\s+\(', fncall)):
+ # TODO(unknown): Space after an operator function seem to be a common
+ # error, silence those for now by restricting them to highest verbosity.
+ if Search(r'\boperator_*\b', line):
+ error(filename, linenum, 'whitespace/parens', 0,
+ 'Extra space before ( in function call')
+ else:
+ error(filename, linenum, 'whitespace/parens', 4,
+ 'Extra space before ( in function call')
+ # If the ) is followed only by a newline or a { + newline, assume it's
+ # part of a control statement (if/while/etc), and don't complain
+ if Search(r'[^)]\s+\)\s*[^{\s]', fncall):
+ # If the closing parenthesis is preceded by only whitespaces,
+ # try to give a more descriptive error message.
+ if Search(r'^\s+\)', fncall):
+ error(filename, linenum, 'whitespace/parens', 2,
+ 'Closing ) should be moved to the previous line')
+ else:
+ error(filename, linenum, 'whitespace/parens', 2,
+ 'Extra space before )')
+
+
+def IsBlankLine(line):
+ """Returns true if the given line is blank.
+
+ We consider a line to be blank if the line is empty or consists of
+ only white spaces.
+
+ Args:
+ line: A line of a string.
+
+ Returns:
+ True, if the given line is blank.
+ """
+ return not line or line.isspace()
+
+
+def CheckForNamespaceIndentation(filename, nesting_state, clean_lines, line,
+ error):
+ is_namespace_indent_item = (
+ len(nesting_state.stack) > 1 and
+ nesting_state.stack[-1].check_namespace_indentation and
+ isinstance(nesting_state.previous_stack_top, _NamespaceInfo) and
+ nesting_state.previous_stack_top == nesting_state.stack[-2])
+
+ if ShouldCheckNamespaceIndentation(nesting_state, is_namespace_indent_item,
+ clean_lines.elided, line):
+ CheckItemIndentationInNamespace(filename, clean_lines.elided,
+ line, error)
+
+
+def CheckForFunctionLengths(filename, clean_lines, linenum,
+ function_state, error):
+ """Reports for long function bodies.
+
+ For an overview why this is done, see:
+ https://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Write_Short_Functions
+
+ Uses a simplistic algorithm assuming other style guidelines
+ (especially spacing) are followed.
+ Only checks unindented functions, so class members are unchecked.
+ Trivial bodies are unchecked, so constructors with huge initializer lists
+ may be missed.
+ Blank/comment lines are not counted so as to avoid encouraging the removal
+ of vertical space and comments just to get through a lint check.
+ NOLINT *on the last line of a function* disables this check.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ function_state: Current function name and lines in body so far.
+ error: The function to call with any errors found.
+ """
+ lines = clean_lines.lines
+ line = lines[linenum]
+ joined_line = ''
+
+ starting_func = False
+ regexp = r'(\w(\w|::|\*|\&|\s)*)\(' # decls * & space::name( ...
+ match_result = Match(regexp, line)
+ if match_result:
+ # If the name is all caps and underscores, figure it's a macro and
+ # ignore it, unless it's TEST or TEST_F.
+ function_name = match_result.group(1).split()[-1]
+ if function_name == 'TEST' or function_name == 'TEST_F' or (
+ not Match(r'[A-Z_]+$', function_name)):
+ starting_func = True
+
+ if starting_func:
+ body_found = False
+ for start_linenum in xrange(linenum, clean_lines.NumLines()):
+ start_line = lines[start_linenum]
+ joined_line += ' ' + start_line.lstrip()
+ if Search(r'(;|})', start_line): # Declarations and trivial functions
+ body_found = True
+ break # ... ignore
+ elif Search(r'{', start_line):
+ body_found = True
+ function = Search(r'((\w|:)*)\(', line).group(1)
+ if Match(r'TEST', function): # Handle TEST... macros
+ parameter_regexp = Search(r'(\(.*\))', joined_line)
+ if parameter_regexp: # Ignore bad syntax
+ function += parameter_regexp.group(1)
+ else:
+ function += '()'
+ function_state.Begin(function)
+ break
+ if not body_found:
+ # No body for the function (or evidence of a non-function) was found.
+ error(filename, linenum, 'readability/fn_size', 5,
+ 'Lint failed to find start of function body.')
+ elif Match(r'^\}\s*$', line): # function end
+ function_state.Check(error, filename, linenum)
+ function_state.End()
+ elif not Match(r'^\s*$', line):
+ function_state.Count() # Count non-blank/non-comment lines.
+
+
+_RE_PATTERN_TODO = re.compile(r'^//(\s*)TODO(\(.+?\))?:?(\s|$)?')
+
+
+def CheckComment(line, filename, linenum, next_line_start, error):
+ """Checks for common mistakes in comments.
+
+ Args:
+ line: The line in question.
+ filename: The name of the current file.
+ linenum: The number of the line to check.
+ next_line_start: The first non-whitespace column of the next line.
+ error: The function to call with any errors found.
+ """
+ commentpos = line.find('//')
+ if commentpos != -1:
+ # Check if the // may be in quotes. If so, ignore it
+ if re.sub(r'\\.', '', line[0:commentpos]).count('"') % 2 == 0:
+ # Allow one space for new scopes, two spaces otherwise:
+ if (not (Match(r'^.*{ *//', line) and next_line_start == commentpos) and
+ ((commentpos >= 1 and
+ line[commentpos-1] not in string.whitespace) or
+ (commentpos >= 2 and
+ line[commentpos-2] not in string.whitespace))):
+ error(filename, linenum, 'whitespace/comments', 2,
+ 'At least two spaces is best between code and comments')
+
+ # Checks for common mistakes in TODO comments.
+ comment = line[commentpos:]
+ match = _RE_PATTERN_TODO.match(comment)
+ if match:
+ # One whitespace is correct; zero whitespace is handled elsewhere.
+ leading_whitespace = match.group(1)
+ if len(leading_whitespace) > 1:
+ error(filename, linenum, 'whitespace/todo', 2,
+ 'Too many spaces before TODO')
+
+ username = match.group(2)
+ if not username:
+ error(filename, linenum, 'readability/todo', 2,
+ 'Missing username in TODO; it should look like '
+ '"// TODO(my_username): Stuff."')
+
+ middle_whitespace = match.group(3)
+ # Comparisons made explicit for correctness -- pylint: disable=g-explicit-bool-comparison
+ if middle_whitespace != ' ' and middle_whitespace != '':
+ error(filename, linenum, 'whitespace/todo', 2,
+ 'TODO(my_username) should be followed by a space')
+
+ # If the comment contains an alphanumeric character, there
+ # should be a space somewhere between it and the // unless
+ # it's a /// or //! Doxygen comment.
+ if (Match(r'//[^ ]*\w', comment) and
+ not Match(r'(///|//\!)(\s+|$)', comment)):
+ error(filename, linenum, 'whitespace/comments', 4,
+ 'Should have a space between // and comment')
+
+
+def CheckSpacing(filename, clean_lines, linenum, nesting_state, error):
+ """Checks for the correctness of various spacing issues in the code.
+
+ Things we check for: spaces around operators, spaces after
+ if/for/while/switch, no spaces around parens in function calls, two
+ spaces between code and comment, don't start a block with a blank
+ line, don't end a function with a blank line, don't add a blank line
+ after public/protected/private, don't have too many blank lines in a row.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ error: The function to call with any errors found.
+ """
+
+ # Don't use "elided" lines here, otherwise we can't check commented lines.
+ # Don't want to use "raw" either, because we don't want to check inside C++11
+ # raw strings,
+ raw = clean_lines.lines_without_raw_strings
+ line = raw[linenum]
+
+ # Before nixing comments, check if the line is blank for no good
+ # reason. This includes the first line after a block is opened, and
+ # blank lines at the end of a function (ie, right before a line like '}'
+ #
+ # Skip all the blank line checks if we are immediately inside a
+ # namespace body. In other words, don't issue blank line warnings
+ # for this block:
+ # namespace {
+ #
+ # }
+ #
+ # A warning about missing end of namespace comments will be issued instead.
+ #
+ # Also skip blank line checks for 'extern "C"' blocks, which are formatted
+ # like namespaces.
+ if (IsBlankLine(line) and
+ not nesting_state.InNamespaceBody() and
+ not nesting_state.InExternC()):
+ elided = clean_lines.elided
+ prev_line = elided[linenum - 1]
+ prevbrace = prev_line.rfind('{')
+ # TODO(unknown): Don't complain if line before blank line, and line after,
+ # both start with alnums and are indented the same amount.
+ # This ignores whitespace at the start of a namespace block
+ # because those are not usually indented.
+ if prevbrace != -1 and prev_line[prevbrace:].find('}') == -1:
+ # OK, we have a blank line at the start of a code block. Before we
+ # complain, we check if it is an exception to the rule: The previous
+ # non-empty line has the parameters of a function header that are indented
+ # 4 spaces (because they did not fit in a 80 column line when placed on
+ # the same line as the function name). We also check for the case where
+ # the previous line is indented 6 spaces, which may happen when the
+ # initializers of a constructor do not fit into a 80 column line.
+ exception = False
+ if Match(r' {6}\w', prev_line): # Initializer list?
+ # We are looking for the opening column of initializer list, which
+ # should be indented 4 spaces to cause 6 space indentation afterwards.
+ search_position = linenum-2
+ while (search_position >= 0
+ and Match(r' {6}\w', elided[search_position])):
+ search_position -= 1
+ exception = (search_position >= 0
+ and elided[search_position][:5] == ' :')
+ else:
+ # Search for the function arguments or an initializer list. We use a
+ # simple heuristic here: If the line is indented 4 spaces; and we have a
+ # closing paren, without the opening paren, followed by an opening brace
+ # or colon (for initializer lists) we assume that it is the last line of
+ # a function header. If we have a colon indented 4 spaces, it is an
+ # initializer list.
+ exception = (Match(r' {4}\w[^\(]*\)\s*(const\s*)?(\{\s*$|:)',
+ prev_line)
+ or Match(r' {4}:', prev_line))
+
+ if not exception:
+ error(filename, linenum, 'whitespace/blank_line', 2,
+ 'Redundant blank line at the start of a code block '
+ 'should be deleted.')
+ # Ignore blank lines at the end of a block in a long if-else
+ # chain, like this:
+ # if (condition1) {
+ # // Something followed by a blank line
+ #
+ # } else if (condition2) {
+ # // Something else
+ # }
+ if linenum + 1 < clean_lines.NumLines():
+ next_line = raw[linenum + 1]
+ if (next_line
+ and Match(r'\s*}', next_line)
+ and next_line.find('} else ') == -1):
+ error(filename, linenum, 'whitespace/blank_line', 3,
+ 'Redundant blank line at the end of a code block '
+ 'should be deleted.')
+
+ matched = Match(r'\s*(public|protected|private):', prev_line)
+ if matched:
+ error(filename, linenum, 'whitespace/blank_line', 3,
+ 'Do not leave a blank line after "%s:"' % matched.group(1))
+
+ # Next, check comments
+ next_line_start = 0
+ if linenum + 1 < clean_lines.NumLines():
+ next_line = raw[linenum + 1]
+ next_line_start = len(next_line) - len(next_line.lstrip())
+ CheckComment(line, filename, linenum, next_line_start, error)
+
+ # get rid of comments and strings
+ line = clean_lines.elided[linenum]
+
+ # You shouldn't have spaces before your brackets, except maybe after
+ # 'delete []', 'return []() {};', or 'auto [abc, ...] = ...;'.
+ if Search(r'\w\s+\[', line) and not Search(r'(?:auto&?|delete|return)\s+\[', line):
+ error(filename, linenum, 'whitespace/braces', 5,
+ 'Extra space before [')
+
+ # In range-based for, we wanted spaces before and after the colon, but
+ # not around "::" tokens that might appear.
+ if (Search(r'for *\(.*[^:]:[^: ]', line) or
+ Search(r'for *\(.*[^: ]:[^:]', line)):
+ error(filename, linenum, 'whitespace/forcolon', 2,
+ 'Missing space around colon in range-based for loop')
+
+
+def CheckOperatorSpacing(filename, clean_lines, linenum, error):
+ """Checks for horizontal spacing around operators.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # Don't try to do spacing checks for operator methods. Do this by
+ # replacing the troublesome characters with something else,
+ # preserving column position for all other characters.
+ #
+ # The replacement is done repeatedly to avoid false positives from
+ # operators that call operators.
+ while True:
+ match = Match(r'^(.*\boperator\b)(\S+)(\s*\(.*)$', line)
+ if match:
+ line = match.group(1) + ('_' * len(match.group(2))) + match.group(3)
+ else:
+ break
+
+ # We allow no-spaces around = within an if: "if ( (a=Foo()) == 0 )".
+ # Otherwise not. Note we only check for non-spaces on *both* sides;
+ # sometimes people put non-spaces on one side when aligning ='s among
+ # many lines (not that this is behavior that I approve of...)
+ if ((Search(r'[\w.]=', line) or
+ Search(r'=[\w.]', line))
+ and not Search(r'\b(if|while|for) ', line)
+ # Operators taken from [lex.operators] in C++11 standard.
+ and not Search(r'(>=|<=|==|!=|&=|\^=|\|=|\+=|\*=|\/=|\%=)', line)
+ and not Search(r'operator=', line)):
+ error(filename, linenum, 'whitespace/operators', 4,
+ 'Missing spaces around =')
+
+ # It's ok not to have spaces around binary operators like + - * /, but if
+ # there's too little whitespace, we get concerned. It's hard to tell,
+ # though, so we punt on this one for now. TODO.
+
+ # You should always have whitespace around binary operators.
+ #
+ # Check <= and >= first to avoid false positives with < and >, then
+ # check non-include lines for spacing around < and >.
+ #
+ # If the operator is followed by a comma, assume it's be used in a
+ # macro context and don't do any checks. This avoids false
+ # positives.
+ #
+ # Note that && is not included here. This is because there are too
+ # many false positives due to RValue references.
+ match = Search(r'[^<>=!\s](==|!=|<=|>=|\|\|)[^<>=!\s,;\)]', line)
+ if match:
+ error(filename, linenum, 'whitespace/operators', 3,
+ 'Missing spaces around %s' % match.group(1))
+ elif not Match(r'#.*include', line):
+ # Look for < that is not surrounded by spaces. This is only
+ # triggered if both sides are missing spaces, even though
+ # technically should should flag if at least one side is missing a
+ # space. This is done to avoid some false positives with shifts.
+ match = Match(r'^(.*[^\s<])<[^\s=<,]', line)
+ if match:
+ (_, _, end_pos) = CloseExpression(
+ clean_lines, linenum, len(match.group(1)))
+ if end_pos <= -1:
+ error(filename, linenum, 'whitespace/operators', 3,
+ 'Missing spaces around <')
+
+ # Look for > that is not surrounded by spaces. Similar to the
+ # above, we only trigger if both sides are missing spaces to avoid
+ # false positives with shifts.
+ match = Match(r'^(.*[^-\s>])>[^\s=>,]', line)
+ if match:
+ (_, _, start_pos) = ReverseCloseExpression(
+ clean_lines, linenum, len(match.group(1)))
+ if start_pos <= -1:
+ error(filename, linenum, 'whitespace/operators', 3,
+ 'Missing spaces around >')
+
+ # We allow no-spaces around << when used like this: 10<<20, but
+ # not otherwise (particularly, not when used as streams)
+ #
+ # We also allow operators following an opening parenthesis, since
+ # those tend to be macros that deal with operators.
+ match = Search(r'(operator|[^\s(<])(?:L|UL|LL|ULL|l|ul|ll|ull)?<<([^\s,=<])', line)
+ if (match and not (match.group(1).isdigit() and match.group(2).isdigit()) and
+ not (match.group(1) == 'operator' and match.group(2) == ';')):
+ error(filename, linenum, 'whitespace/operators', 3,
+ 'Missing spaces around <<')
+
+ # We allow no-spaces around >> for almost anything. This is because
+ # C++11 allows ">>" to close nested templates, which accounts for
+ # most cases when ">>" is not followed by a space.
+ #
+ # We still warn on ">>" followed by alpha character, because that is
+ # likely due to ">>" being used for right shifts, e.g.:
+ # value >> alpha
+ #
+ # When ">>" is used to close templates, the alphanumeric letter that
+ # follows would be part of an identifier, and there should still be
+ # a space separating the template type and the identifier.
+ # type<type<type>> alpha
+ match = Search(r'>>[a-zA-Z_]', line)
+ if match:
+ error(filename, linenum, 'whitespace/operators', 3,
+ 'Missing spaces around >>')
+
+ # There shouldn't be space around unary operators
+ match = Search(r'(!\s|~\s|[\s]--[\s;]|[\s]\+\+[\s;])', line)
+ if match:
+ error(filename, linenum, 'whitespace/operators', 4,
+ 'Extra space for operator %s' % match.group(1))
+
+
+def CheckParenthesisSpacing(filename, clean_lines, linenum, error):
+ """Checks for horizontal spacing around parentheses.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # No spaces after an if, while, switch, or for
+ match = Search(r' (if\(|for\(|while\(|switch\()', line)
+ if match:
+ error(filename, linenum, 'whitespace/parens', 5,
+ 'Missing space before ( in %s' % match.group(1))
+
+ # For if/for/while/switch, the left and right parens should be
+ # consistent about how many spaces are inside the parens, and
+ # there should either be zero or one spaces inside the parens.
+ # We don't want: "if ( foo)" or "if ( foo )".
+ # Exception: "for ( ; foo; bar)" and "for (foo; bar; )" are allowed.
+ match = Search(r'\b(if|for|while|switch)\s*'
+ r'\(([ ]*)(.).*[^ ]+([ ]*)\)\s*{\s*$',
+ line)
+ if match:
+ if len(match.group(2)) != len(match.group(4)):
+ if not (match.group(3) == ';' and
+ len(match.group(2)) == 1 + len(match.group(4)) or
+ not match.group(2) and Search(r'\bfor\s*\(.*; \)', line)):
+ error(filename, linenum, 'whitespace/parens', 5,
+ 'Mismatching spaces inside () in %s' % match.group(1))
+ if len(match.group(2)) not in [0, 1]:
+ error(filename, linenum, 'whitespace/parens', 5,
+ 'Should have zero or one spaces inside ( and ) in %s' %
+ match.group(1))
+
+
+def CheckCommaSpacing(filename, clean_lines, linenum, error):
+ """Checks for horizontal spacing near commas and semicolons.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ raw = clean_lines.lines_without_raw_strings
+ line = clean_lines.elided[linenum]
+
+ # You should always have a space after a comma (either as fn arg or operator)
+ #
+ # This does not apply when the non-space character following the
+ # comma is another comma, since the only time when that happens is
+ # for empty macro arguments.
+ #
+ # We run this check in two passes: first pass on elided lines to
+ # verify that lines contain missing whitespaces, second pass on raw
+ # lines to confirm that those missing whitespaces are not due to
+ # elided comments.
+ if (Search(r',[^,\s]', ReplaceAll(r'\boperator\s*,\s*\(', 'F(', line)) and
+ Search(r',[^,\s]', raw[linenum])):
+ error(filename, linenum, 'whitespace/comma', 3,
+ 'Missing space after ,')
+
+ # You should always have a space after a semicolon
+ # except for few corner cases
+ # TODO(unknown): clarify if 'if (1) { return 1;}' is requires one more
+ # space after ;
+ if Search(r';[^\s};\\)/]', line):
+ error(filename, linenum, 'whitespace/semicolon', 3,
+ 'Missing space after ;')
+
+
+def _IsType(clean_lines, nesting_state, expr):
+ """Check if expression looks like a type name, returns true if so.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ expr: The expression to check.
+ Returns:
+ True, if token looks like a type.
+ """
+ # Keep only the last token in the expression
+ last_word = Match(r'^.*(\b\S+)$', expr)
+ if last_word:
+ token = last_word.group(1)
+ else:
+ token = expr
+
+ # Match native types and stdint types
+ if _TYPES.match(token):
+ return True
+
+ # Try a bit harder to match templated types. Walk up the nesting
+ # stack until we find something that resembles a typename
+ # declaration for what we are looking for.
+ typename_pattern = (r'\b(?:typename|class|struct)\s+' + re.escape(token) +
+ r'\b')
+ block_index = len(nesting_state.stack) - 1
+ while block_index >= 0:
+ if isinstance(nesting_state.stack[block_index], _NamespaceInfo):
+ return False
+
+ # Found where the opening brace is. We want to scan from this
+ # line up to the beginning of the function, minus a few lines.
+ # template <typename Type1, // stop scanning here
+ # ...>
+ # class C
+ # : public ... { // start scanning here
+ last_line = nesting_state.stack[block_index].starting_linenum
+
+ next_block_start = 0
+ if block_index > 0:
+ next_block_start = nesting_state.stack[block_index - 1].starting_linenum
+ first_line = last_line
+ while first_line >= next_block_start:
+ if clean_lines.elided[first_line].find('template') >= 0:
+ break
+ first_line -= 1
+ if first_line < next_block_start:
+ # Didn't find any "template" keyword before reaching the next block,
+ # there are probably no template things to check for this block
+ block_index -= 1
+ continue
+
+ # Look for typename in the specified range
+ for i in xrange(first_line, last_line + 1, 1):
+ if Search(typename_pattern, clean_lines.elided[i]):
+ return True
+ block_index -= 1
+
+ return False
+
+
+def CheckBracesSpacing(filename, clean_lines, linenum, nesting_state, error):
+ """Checks for horizontal spacing near commas.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # Except after an opening paren, or after another opening brace (in case of
+ # an initializer list, for instance), you should have spaces before your
+ # braces when they are delimiting blocks, classes, namespaces etc.
+ # And since you should never have braces at the beginning of a line,
+ # this is an easy test. Except that braces used for initialization don't
+ # follow the same rule; we often don't want spaces before those.
+ match = Match(r'^(.*[^ ({>]){', line)
+
+ if match:
+ # Try a bit harder to check for brace initialization. This
+ # happens in one of the following forms:
+ # Constructor() : initializer_list_{} { ... }
+ # Constructor{}.MemberFunction()
+ # Type variable{};
+ # FunctionCall(type{}, ...);
+ # LastArgument(..., type{});
+ # LOG(INFO) << type{} << " ...";
+ # map_of_type[{...}] = ...;
+ # ternary = expr ? new type{} : nullptr;
+ # OuterTemplate<InnerTemplateConstructor<Type>{}>
+ #
+ # We check for the character following the closing brace, and
+ # silence the warning if it's one of those listed above, i.e.
+ # "{.;,)<>]:".
+ #
+ # To account for nested initializer list, we allow any number of
+ # closing braces up to "{;,)<". We can't simply silence the
+ # warning on first sight of closing brace, because that would
+ # cause false negatives for things that are not initializer lists.
+ # Silence this: But not this:
+ # Outer{ if (...) {
+ # Inner{...} if (...){ // Missing space before {
+ # }; }
+ #
+ # There is a false negative with this approach if people inserted
+ # spurious semicolons, e.g. "if (cond){};", but we will catch the
+ # spurious semicolon with a separate check.
+ leading_text = match.group(1)
+ (endline, endlinenum, endpos) = CloseExpression(
+ clean_lines, linenum, len(match.group(1)))
+ trailing_text = ''
+ if endpos > -1:
+ trailing_text = endline[endpos:]
+ for offset in xrange(endlinenum + 1,
+ min(endlinenum + 3, clean_lines.NumLines() - 1)):
+ trailing_text += clean_lines.elided[offset]
+ # We also suppress warnings for `uint64_t{expression}` etc., as the style
+ # guide recommends brace initialization for integral types to avoid
+ # overflow/truncation.
+ if (not Match(r'^[\s}]*[{.;,)<>\]:]', trailing_text)
+ and not _IsType(clean_lines, nesting_state, leading_text)):
+ error(filename, linenum, 'whitespace/braces', 5,
+ 'Missing space before {')
+
+ # Make sure '} else {' has spaces.
+ if Search(r'}else', line):
+ error(filename, linenum, 'whitespace/braces', 5,
+ 'Missing space before else')
+
+ # You shouldn't have a space before a semicolon at the end of the line.
+ # There's a special case for "for" since the style guide allows space before
+ # the semicolon there.
+ if Search(r':\s*;\s*$', line):
+ error(filename, linenum, 'whitespace/semicolon', 5,
+ 'Semicolon defining empty statement. Use {} instead.')
+ elif Search(r'^\s*;\s*$', line):
+ error(filename, linenum, 'whitespace/semicolon', 5,
+ 'Line contains only semicolon. If this should be an empty statement, '
+ 'use {} instead.')
+ elif (Search(r'\s+;\s*$', line) and
+ not Search(r'\bfor\b', line)):
+ error(filename, linenum, 'whitespace/semicolon', 5,
+ 'Extra space before last semicolon. If this should be an empty '
+ 'statement, use {} instead.')
+
+
+def IsDecltype(clean_lines, linenum, column):
+ """Check if the token ending on (linenum, column) is decltype().
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: the number of the line to check.
+ column: end column of the token to check.
+ Returns:
+ True if this token is decltype() expression, False otherwise.
+ """
+ (text, _, start_col) = ReverseCloseExpression(clean_lines, linenum, column)
+ if start_col < 0:
+ return False
+ if Search(r'\bdecltype\s*$', text[0:start_col]):
+ return True
+ return False
+
+
+def CheckSectionSpacing(filename, clean_lines, class_info, linenum, error):
+ """Checks for additional blank line issues related to sections.
+
+ Currently the only thing checked here is blank line before protected/private.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ class_info: A _ClassInfo objects.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ # Skip checks if the class is small, where small means 25 lines or less.
+ # 25 lines seems like a good cutoff since that's the usual height of
+ # terminals, and any class that can't fit in one screen can't really
+ # be considered "small".
+ #
+ # Also skip checks if we are on the first line. This accounts for
+ # classes that look like
+ # class Foo { public: ... };
+ #
+ # If we didn't find the end of the class, last_line would be zero,
+ # and the check will be skipped by the first condition.
+ if (class_info.last_line - class_info.starting_linenum <= 24 or
+ linenum <= class_info.starting_linenum):
+ return
+
+ matched = Match(r'\s*(public|protected|private):', clean_lines.lines[linenum])
+ if matched:
+ # Issue warning if the line before public/protected/private was
+ # not a blank line, but don't do this if the previous line contains
+ # "class" or "struct". This can happen two ways:
+ # - We are at the beginning of the class.
+ # - We are forward-declaring an inner class that is semantically
+ # private, but needed to be public for implementation reasons.
+ # Also ignores cases where the previous line ends with a backslash as can be
+ # common when defining classes in C macros.
+ prev_line = clean_lines.lines[linenum - 1]
+ if (not IsBlankLine(prev_line) and
+ not Search(r'\b(class|struct)\b', prev_line) and
+ not Search(r'\\$', prev_line)):
+ # Try a bit harder to find the beginning of the class. This is to
+ # account for multi-line base-specifier lists, e.g.:
+ # class Derived
+ # : public Base {
+ end_class_head = class_info.starting_linenum
+ for i in range(class_info.starting_linenum, linenum):
+ if Search(r'\{\s*$', clean_lines.lines[i]):
+ end_class_head = i
+ break
+ if end_class_head < linenum - 1:
+ error(filename, linenum, 'whitespace/blank_line', 3,
+ '"%s:" should be preceded by a blank line' % matched.group(1))
+
+
+def GetPreviousNonBlankLine(clean_lines, linenum):
+ """Return the most recent non-blank line and its line number.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file contents.
+ linenum: The number of the line to check.
+
+ Returns:
+ A tuple with two elements. The first element is the contents of the last
+ non-blank line before the current line, or the empty string if this is the
+ first non-blank line. The second is the line number of that line, or -1
+ if this is the first non-blank line.
+ """
+
+ prevlinenum = linenum - 1
+ while prevlinenum >= 0:
+ prevline = clean_lines.elided[prevlinenum]
+ if not IsBlankLine(prevline): # if not a blank line...
+ return (prevline, prevlinenum)
+ prevlinenum -= 1
+ return ('', -1)
+
+
+def CheckBraces(filename, clean_lines, linenum, error):
+ """Looks for misplaced braces (e.g. at the end of line).
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+
+ line = clean_lines.elided[linenum] # get rid of comments and strings
+
+ if Match(r'\s*{\s*$', line):
+ # We allow an open brace to start a line in the case where someone is using
+ # braces in a block to explicitly create a new scope, which is commonly used
+ # to control the lifetime of stack-allocated variables. Braces are also
+ # used for brace initializers inside function calls. We don't detect this
+ # perfectly: we just don't complain if the last non-whitespace character on
+ # the previous non-blank line is ',', ';', ':', '(', '{', or '}', or if the
+ # previous line starts a preprocessor block. We also allow a brace on the
+ # following line if it is part of an array initialization and would not fit
+ # within the 80 character limit of the preceding line.
+ prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0]
+ if (not Search(r'[,;:}{(]\s*$', prevline) and
+ not Match(r'\s*#', prevline) and
+ not (GetLineWidth(prevline) > _line_length - 2 and '[]' in prevline)):
+ error(filename, linenum, 'whitespace/braces', 4,
+ '{ should almost always be at the end of the previous line')
+
+ # An else clause should be on the same line as the preceding closing brace.
+ if Match(r'\s*else\b\s*(?:if\b|\{|$)', line):
+ prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0]
+ if Match(r'\s*}\s*$', prevline):
+ error(filename, linenum, 'whitespace/newline', 4,
+ 'An else should appear on the same line as the preceding }')
+
+ # If braces come on one side of an else, they should be on both.
+ # However, we have to worry about "else if" that spans multiple lines!
+ if Search(r'else if\s*\(', line): # could be multi-line if
+ brace_on_left = bool(Search(r'}\s*else if\s*\(', line))
+ # find the ( after the if
+ pos = line.find('else if')
+ pos = line.find('(', pos)
+ if pos > 0:
+ (endline, _, endpos) = CloseExpression(clean_lines, linenum, pos)
+ brace_on_right = endline[endpos:].find('{') != -1
+ if brace_on_left != brace_on_right: # must be brace after if
+ error(filename, linenum, 'readability/braces', 5,
+ 'If an else has a brace on one side, it should have it on both')
+ elif Search(r'}\s*else[^{]*$', line) or Match(r'[^}]*else\s*{', line):
+ error(filename, linenum, 'readability/braces', 5,
+ 'If an else has a brace on one side, it should have it on both')
+
+ # Likewise, an else should never have the else clause on the same line
+ if Search(r'\belse [^\s{]', line) and not Search(r'\belse if\b', line):
+ error(filename, linenum, 'whitespace/newline', 4,
+ 'Else clause should never be on same line as else (use 2 lines)')
+
+ # In the same way, a do/while should never be on one line
+ if Match(r'\s*do [^\s{]', line):
+ error(filename, linenum, 'whitespace/newline', 4,
+ 'do/while clauses should not be on a single line')
+
+ # Check single-line if/else bodies. The style guide says 'curly braces are not
+ # required for single-line statements'. We additionally allow multi-line,
+ # single statements, but we reject anything with more than one semicolon in
+ # it. This means that the first semicolon after the if should be at the end of
+ # its line, and the line after that should have an indent level equal to or
+ # lower than the if. We also check for ambiguous if/else nesting without
+ # braces.
+ if_else_match = Search(r'\b(if\s*\(|else\b)', line)
+ if if_else_match and not Match(r'\s*#', line):
+ if_indent = GetIndentLevel(line)
+ endline, endlinenum, endpos = line, linenum, if_else_match.end()
+ if_match = Search(r'\bif\s*\(', line)
+ if if_match:
+ # This could be a multiline if condition, so find the end first.
+ pos = if_match.end() - 1
+ (endline, endlinenum, endpos) = CloseExpression(clean_lines, linenum, pos)
+ # Check for an opening brace, either directly after the if or on the next
+ # line. If found, this isn't a single-statement conditional.
+ if (not Match(r'\s*{', endline[endpos:])
+ and not (Match(r'\s*$', endline[endpos:])
+ and endlinenum < (len(clean_lines.elided) - 1)
+ and Match(r'\s*{', clean_lines.elided[endlinenum + 1]))):
+ while (endlinenum < len(clean_lines.elided)
+ and ';' not in clean_lines.elided[endlinenum][endpos:]):
+ endlinenum += 1
+ endpos = 0
+ if endlinenum < len(clean_lines.elided):
+ endline = clean_lines.elided[endlinenum]
+ # We allow a mix of whitespace and closing braces (e.g. for one-liner
+ # methods) and a single \ after the semicolon (for macros)
+ endpos = endline.find(';')
+ if not Match(r';[\s}]*(\\?)$', endline[endpos:]):
+ # Semicolon isn't the last character, there's something trailing.
+ # Output a warning if the semicolon is not contained inside
+ # a lambda expression.
+ if not Match(r'^[^{};]*\[[^\[\]]*\][^{}]*\{[^{}]*\}\s*\)*[;,]\s*$',
+ endline):
+ error(filename, linenum, 'readability/braces', 4,
+ 'If/else bodies with multiple statements require braces')
+ elif endlinenum < len(clean_lines.elided) - 1:
+ # Make sure the next line is dedented
+ next_line = clean_lines.elided[endlinenum + 1]
+ next_indent = GetIndentLevel(next_line)
+ # With ambiguous nested if statements, this will error out on the
+ # if that *doesn't* match the else, regardless of whether it's the
+ # inner one or outer one.
+ if (if_match and Match(r'\s*else\b', next_line)
+ and next_indent != if_indent):
+ error(filename, linenum, 'readability/braces', 4,
+ 'Else clause should be indented at the same level as if. '
+ 'Ambiguous nested if/else chains require braces.')
+ elif next_indent > if_indent:
+ error(filename, linenum, 'readability/braces', 4,
+ 'If/else bodies with multiple statements require braces')
+
+
+def CheckTrailingSemicolon(filename, clean_lines, linenum, error):
+ """Looks for redundant trailing semicolon.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+
+ line = clean_lines.elided[linenum]
+
+ # Block bodies should not be followed by a semicolon. Due to C++11
+ # brace initialization, there are more places where semicolons are
+ # required than not, so we explicitly list the allowed rules rather
+ # than listing the disallowed ones. These are the places where "};"
+ # should be replaced by just "}":
+ # 1. Some flavor of block following closing parenthesis:
+ # for (;;) {};
+ # while (...) {};
+ # switch (...) {};
+ # Function(...) {};
+ # if (...) {};
+ # if (...) else if (...) {};
+ #
+ # 2. else block:
+ # if (...) else {};
+ #
+ # 3. const member function:
+ # Function(...) const {};
+ #
+ # 4. Block following some statement:
+ # x = 42;
+ # {};
+ #
+ # 5. Block at the beginning of a function:
+ # Function(...) {
+ # {};
+ # }
+ #
+ # Note that naively checking for the preceding "{" will also match
+ # braces inside multi-dimensional arrays, but this is fine since
+ # that expression will not contain semicolons.
+ #
+ # 6. Block following another block:
+ # while (true) {}
+ # {};
+ #
+ # 7. End of namespaces:
+ # namespace {};
+ #
+ # These semicolons seems far more common than other kinds of
+ # redundant semicolons, possibly due to people converting classes
+ # to namespaces. For now we do not warn for this case.
+ #
+ # Try matching case 1 first.
+ match = Match(r'^(.*\)\s*)\{', line)
+ if match:
+ # Matched closing parenthesis (case 1). Check the token before the
+ # matching opening parenthesis, and don't warn if it looks like a
+ # macro. This avoids these false positives:
+ # - macro that defines a base class
+ # - multi-line macro that defines a base class
+ # - macro that defines the whole class-head
+ #
+ # But we still issue warnings for macros that we know are safe to
+ # warn, specifically:
+ # - TEST, TEST_F, TEST_P, MATCHER, MATCHER_P
+ # - TYPED_TEST
+ # - INTERFACE_DEF
+ # - EXCLUSIVE_LOCKS_REQUIRED, SHARED_LOCKS_REQUIRED, LOCKS_EXCLUDED:
+ #
+ # We implement a list of safe macros instead of a list of
+ # unsafe macros, even though the latter appears less frequently in
+ # google code and would have been easier to implement. This is because
+ # the downside for getting the allowed checks wrong means some extra
+ # semicolons, while the downside for getting disallowed checks wrong
+ # would result in compile errors.
+ #
+ # In addition to macros, we also don't want to warn on
+ # - Compound literals
+ # - Lambdas
+ # - alignas specifier with anonymous structs
+ # - decltype
+ closing_brace_pos = match.group(1).rfind(')')
+ opening_parenthesis = ReverseCloseExpression(
+ clean_lines, linenum, closing_brace_pos)
+ if opening_parenthesis[2] > -1:
+ line_prefix = opening_parenthesis[0][0:opening_parenthesis[2]]
+ macro = Search(r'\b([A-Z_][A-Z0-9_]*)\s*$', line_prefix)
+ func = Match(r'^(.*\])\s*$', line_prefix)
+ if ((macro and
+ macro.group(1) not in (
+ 'TEST', 'TEST_F', 'MATCHER', 'MATCHER_P', 'TYPED_TEST',
+ 'EXCLUSIVE_LOCKS_REQUIRED', 'SHARED_LOCKS_REQUIRED',
+ 'LOCKS_EXCLUDED', 'INTERFACE_DEF')) or
+ (func and not Search(r'\boperator\s*\[\s*\]', func.group(1))) or
+ Search(r'\b(?:struct|union)\s+alignas\s*$', line_prefix) or
+ Search(r'\bdecltype$', line_prefix) or
+ Search(r'\s+=\s*$', line_prefix)):
+ match = None
+ if (match and
+ opening_parenthesis[1] > 1 and
+ Search(r'\]\s*$', clean_lines.elided[opening_parenthesis[1] - 1])):
+ # Multi-line lambda-expression
+ match = None
+
+ else:
+ # Try matching cases 2-3.
+ match = Match(r'^(.*(?:else|\)\s*const)\s*)\{', line)
+ if not match:
+ # Try matching cases 4-6. These are always matched on separate lines.
+ #
+ # Note that we can't simply concatenate the previous line to the
+ # current line and do a single match, otherwise we may output
+ # duplicate warnings for the blank line case:
+ # if (cond) {
+ # // blank line
+ # }
+ prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0]
+ if prevline and Search(r'[;{}]\s*$', prevline):
+ match = Match(r'^(\s*)\{', line)
+
+ # Check matching closing brace
+ if match:
+ (endline, endlinenum, endpos) = CloseExpression(
+ clean_lines, linenum, len(match.group(1)))
+ if endpos > -1 and Match(r'^\s*;', endline[endpos:]):
+ # Current {} pair is eligible for semicolon check, and we have found
+ # the redundant semicolon, output warning here.
+ #
+ # Note: because we are scanning forward for opening braces, and
+ # outputting warnings for the matching closing brace, if there are
+ # nested blocks with trailing semicolons, we will get the error
+ # messages in reversed order.
+
+ # We need to check the line forward for NOLINT
+ raw_lines = clean_lines.raw_lines
+ ParseNolintSuppressions(filename, raw_lines[endlinenum-1], endlinenum-1,
+ error)
+ ParseNolintSuppressions(filename, raw_lines[endlinenum], endlinenum,
+ error)
+
+ error(filename, endlinenum, 'readability/braces', 4,
+ "You don't need a ; after a }")
+
+
+def CheckEmptyBlockBody(filename, clean_lines, linenum, error):
+ """Look for empty loop/conditional body with only a single semicolon.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+
+ # Search for loop keywords at the beginning of the line. Because only
+ # whitespaces are allowed before the keywords, this will also ignore most
+ # do-while-loops, since those lines should start with closing brace.
+ #
+ # We also check "if" blocks here, since an empty conditional block
+ # is likely an error.
+ line = clean_lines.elided[linenum]
+ matched = Match(r'\s*(for|while|if)\s*\(', line)
+ if matched:
+ # Find the end of the conditional expression.
+ (end_line, end_linenum, end_pos) = CloseExpression(
+ clean_lines, linenum, line.find('('))
+
+ # Output warning if what follows the condition expression is a semicolon.
+ # No warning for all other cases, including whitespace or newline, since we
+ # have a separate check for semicolons preceded by whitespace.
+ if end_pos >= 0 and Match(r';', end_line[end_pos:]):
+ if matched.group(1) == 'if':
+ error(filename, end_linenum, 'whitespace/empty_conditional_body', 5,
+ 'Empty conditional bodies should use {}')
+ else:
+ error(filename, end_linenum, 'whitespace/empty_loop_body', 5,
+ 'Empty loop bodies should use {} or continue')
+
+ # Check for if statements that have completely empty bodies (no comments)
+ # and no else clauses.
+ if end_pos >= 0 and matched.group(1) == 'if':
+ # Find the position of the opening { for the if statement.
+ # Return without logging an error if it has no brackets.
+ opening_linenum = end_linenum
+ opening_line_fragment = end_line[end_pos:]
+ # Loop until EOF or find anything that's not whitespace or opening {.
+ while not Search(r'^\s*\{', opening_line_fragment):
+ if Search(r'^(?!\s*$)', opening_line_fragment):
+ # Conditional has no brackets.
+ return
+ opening_linenum += 1
+ if opening_linenum == len(clean_lines.elided):
+ # Couldn't find conditional's opening { or any code before EOF.
+ return
+ opening_line_fragment = clean_lines.elided[opening_linenum]
+ # Set opening_line (opening_line_fragment may not be entire opening line).
+ opening_line = clean_lines.elided[opening_linenum]
+
+ # Find the position of the closing }.
+ opening_pos = opening_line_fragment.find('{')
+ if opening_linenum == end_linenum:
+ # We need to make opening_pos relative to the start of the entire line.
+ opening_pos += end_pos
+ (closing_line, closing_linenum, closing_pos) = CloseExpression(
+ clean_lines, opening_linenum, opening_pos)
+ if closing_pos < 0:
+ return
+
+ # Now construct the body of the conditional. This consists of the portion
+ # of the opening line after the {, all lines until the closing line,
+ # and the portion of the closing line before the }.
+ if (clean_lines.raw_lines[opening_linenum] !=
+ CleanseComments(clean_lines.raw_lines[opening_linenum])):
+ # Opening line ends with a comment, so conditional isn't empty.
+ return
+ if closing_linenum > opening_linenum:
+ # Opening line after the {. Ignore comments here since we checked above.
+ body = list(opening_line[opening_pos+1:])
+ # All lines until closing line, excluding closing line, with comments.
+ body.extend(clean_lines.raw_lines[opening_linenum+1:closing_linenum])
+ # Closing line before the }. Won't (and can't) have comments.
+ body.append(clean_lines.elided[closing_linenum][:closing_pos-1])
+ body = '\n'.join(body)
+ else:
+ # If statement has brackets and fits on a single line.
+ body = opening_line[opening_pos+1:closing_pos-1]
+
+ # Check if the body is empty
+ if not _EMPTY_CONDITIONAL_BODY_PATTERN.search(body):
+ return
+ # The body is empty. Now make sure there's not an else clause.
+ current_linenum = closing_linenum
+ current_line_fragment = closing_line[closing_pos:]
+ # Loop until EOF or find anything that's not whitespace or else clause.
+ while Search(r'^\s*$|^(?=\s*else)', current_line_fragment):
+ if Search(r'^(?=\s*else)', current_line_fragment):
+ # Found an else clause, so don't log an error.
+ return
+ current_linenum += 1
+ if current_linenum == len(clean_lines.elided):
+ break
+ current_line_fragment = clean_lines.elided[current_linenum]
+
+ # The body is empty and there's no else clause until EOF or other code.
+ error(filename, end_linenum, 'whitespace/empty_if_body', 4,
+ ('If statement had no body and no else clause'))
+
+
+def FindCheckMacro(line):
+ """Find a replaceable CHECK-like macro.
+
+ Args:
+ line: line to search on.
+ Returns:
+ (macro name, start position), or (None, -1) if no replaceable
+ macro is found.
+ """
+ for macro in _CHECK_MACROS:
+ i = line.find(macro)
+ if i >= 0:
+ # Find opening parenthesis. Do a regular expression match here
+ # to make sure that we are matching the expected CHECK macro, as
+ # opposed to some other macro that happens to contain the CHECK
+ # substring.
+ matched = Match(r'^(.*\b' + macro + r'\s*)\(', line)
+ if not matched:
+ continue
+ return (macro, len(matched.group(1)))
+ return (None, -1)
+
+
+def CheckCheck(filename, clean_lines, linenum, error):
+ """Checks the use of CHECK and EXPECT macros.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+
+ # Decide the set of replacement macros that should be suggested
+ lines = clean_lines.elided
+ (check_macro, start_pos) = FindCheckMacro(lines[linenum])
+ if not check_macro:
+ return
+
+ # Find end of the boolean expression by matching parentheses
+ (last_line, end_line, end_pos) = CloseExpression(
+ clean_lines, linenum, start_pos)
+ if end_pos < 0:
+ return
+
+ # If the check macro is followed by something other than a
+ # semicolon, assume users will log their own custom error messages
+ # and don't suggest any replacements.
+ if not Match(r'\s*;', last_line[end_pos:]):
+ return
+
+ if linenum == end_line:
+ expression = lines[linenum][start_pos + 1:end_pos - 1]
+ else:
+ expression = lines[linenum][start_pos + 1:]
+ for i in xrange(linenum + 1, end_line):
+ expression += lines[i]
+ expression += last_line[0:end_pos - 1]
+
+ # Parse expression so that we can take parentheses into account.
+ # This avoids false positives for inputs like "CHECK((a < 4) == b)",
+ # which is not replaceable by CHECK_LE.
+ lhs = ''
+ rhs = ''
+ operator = None
+ while expression:
+ matched = Match(r'^\s*(<<|<<=|>>|>>=|->\*|->|&&|\|\||'
+ r'==|!=|>=|>|<=|<|\()(.*)$', expression)
+ if matched:
+ token = matched.group(1)
+ if token == '(':
+ # Parenthesized operand
+ expression = matched.group(2)
+ (end, _) = FindEndOfExpressionInLine(expression, 0, ['('])
+ if end < 0:
+ return # Unmatched parenthesis
+ lhs += '(' + expression[0:end]
+ expression = expression[end:]
+ elif token in ('&&', '||'):
+ # Logical and/or operators. This means the expression
+ # contains more than one term, for example:
+ # CHECK(42 < a && a < b);
+ #
+ # These are not replaceable with CHECK_LE, so bail out early.
+ return
+ elif token in ('<<', '<<=', '>>', '>>=', '->*', '->'):
+ # Non-relational operator
+ lhs += token
+ expression = matched.group(2)
+ else:
+ # Relational operator
+ operator = token
+ rhs = matched.group(2)
+ break
+ else:
+ # Unparenthesized operand. Instead of appending to lhs one character
+ # at a time, we do another regular expression match to consume several
+ # characters at once if possible. Trivial benchmark shows that this
+ # is more efficient when the operands are longer than a single
+ # character, which is generally the case.
+ matched = Match(r'^([^-=!<>()&|]+)(.*)$', expression)
+ if not matched:
+ matched = Match(r'^(\s*\S)(.*)$', expression)
+ if not matched:
+ break
+ lhs += matched.group(1)
+ expression = matched.group(2)
+
+ # Only apply checks if we got all parts of the boolean expression
+ if not (lhs and operator and rhs):
+ return
+
+ # Check that rhs do not contain logical operators. We already know
+ # that lhs is fine since the loop above parses out && and ||.
+ if rhs.find('&&') > -1 or rhs.find('||') > -1:
+ return
+
+ # At least one of the operands must be a constant literal. This is
+ # to avoid suggesting replacements for unprintable things like
+ # CHECK(variable != iterator)
+ #
+ # The following pattern matches decimal, hex integers, strings, and
+ # characters (in that order).
+ lhs = lhs.strip()
+ rhs = rhs.strip()
+ match_constant = r'^([-+]?(\d+|0[xX][0-9a-fA-F]+)[lLuU]{0,3}|".*"|\'.*\')$'
+ if Match(match_constant, lhs) or Match(match_constant, rhs):
+ # Note: since we know both lhs and rhs, we can provide a more
+ # descriptive error message like:
+ # Consider using CHECK_EQ(x, 42) instead of CHECK(x == 42)
+ # Instead of:
+ # Consider using CHECK_EQ instead of CHECK(a == b)
+ #
+ # We are still keeping the less descriptive message because if lhs
+ # or rhs gets long, the error message might become unreadable.
+ error(filename, linenum, 'readability/check', 2,
+ 'Consider using %s instead of %s(a %s b)' % (
+ _CHECK_REPLACEMENT[check_macro][operator],
+ check_macro, operator))
+
+
+def CheckAltTokens(filename, clean_lines, linenum, error):
+ """Check alternative keywords being used in boolean expressions.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # Avoid preprocessor lines
+ if Match(r'^\s*#', line):
+ return
+
+ # Last ditch effort to avoid multi-line comments. This will not help
+ # if the comment started before the current line or ended after the
+ # current line, but it catches most of the false positives. At least,
+ # it provides a way to workaround this warning for people who use
+ # multi-line comments in preprocessor macros.
+ #
+ # TODO(unknown): remove this once cpplint has better support for
+ # multi-line comments.
+ if line.find('/*') >= 0 or line.find('*/') >= 0:
+ return
+
+ for match in _ALT_TOKEN_REPLACEMENT_PATTERN.finditer(line):
+ error(filename, linenum, 'readability/alt_tokens', 2,
+ 'Use operator %s instead of %s' % (
+ _ALT_TOKEN_REPLACEMENT[match.group(1)], match.group(1)))
+
+
+def GetLineWidth(line):
+ """Determines the width of the line in column positions.
+
+ Args:
+ line: A string, which may be a Unicode string.
+
+ Returns:
+ The width of the line in column positions, accounting for Unicode
+ combining characters and wide characters.
+ """
+ if isinstance(line, unicode):
+ width = 0
+ for uc in unicodedata.normalize('NFC', line):
+ if unicodedata.east_asian_width(uc) in ('W', 'F'):
+ width += 2
+ elif not unicodedata.combining(uc):
+ # Issue 337
+ # https://mail.python.org/pipermail/python-list/2012-August/628809.html
+ if (sys.version_info.major, sys.version_info.minor) <= (3, 2):
+ # https://github.com/python/cpython/blob/2.7/Include/unicodeobject.h#L81
+ is_wide_build = sysconfig.get_config_var("Py_UNICODE_SIZE") >= 4
+ # https://github.com/python/cpython/blob/2.7/Objects/unicodeobject.c#L564
+ is_low_surrogate = 0xDC00 <= ord(uc) <= 0xDFFF
+ if not is_wide_build and is_low_surrogate:
+ width -= 1
+
+ width += 1
+ return width
+ else:
+ return len(line)
+
+
+def CheckStyle(filename, clean_lines, linenum, file_extension, nesting_state,
+ error):
+ """Checks rules from the 'C++ style rules' section of cppguide.html.
+
+ Most of these rules are hard to test (naming, comment style), but we
+ do what we can. In particular we check for 2-space indents, line lengths,
+ tab usage, spaces inside code, etc.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ file_extension: The extension (without the dot) of the filename.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ error: The function to call with any errors found.
+ """
+
+ # Don't use "elided" lines here, otherwise we can't check commented lines.
+ # Don't want to use "raw" either, because we don't want to check inside C++11
+ # raw strings,
+ raw_lines = clean_lines.lines_without_raw_strings
+ line = raw_lines[linenum]
+ prev = raw_lines[linenum - 1] if linenum > 0 else ''
+
+ if line.find('\t') != -1:
+ error(filename, linenum, 'whitespace/tab', 1,
+ 'Tab found; better to use spaces')
+
+ # One or three blank spaces at the beginning of the line is weird; it's
+ # hard to reconcile that with 2-space indents.
+ # NOTE: here are the conditions rob pike used for his tests. Mine aren't
+ # as sophisticated, but it may be worth becoming so: RLENGTH==initial_spaces
+ # if(RLENGTH > 20) complain = 0;
+ # if(match($0, " +(error|private|public|protected):")) complain = 0;
+ # if(match(prev, "&& *$")) complain = 0;
+ # if(match(prev, "\\|\\| *$")) complain = 0;
+ # if(match(prev, "[\",=><] *$")) complain = 0;
+ # if(match($0, " <<")) complain = 0;
+ # if(match(prev, " +for \\(")) complain = 0;
+ # if(prevodd && match(prevprev, " +for \\(")) complain = 0;
+ scope_or_label_pattern = r'\s*\w+\s*:\s*\\?$'
+ classinfo = nesting_state.InnermostClass()
+ initial_spaces = 0
+ cleansed_line = clean_lines.elided[linenum]
+ while initial_spaces < len(line) and line[initial_spaces] == ' ':
+ initial_spaces += 1
+ # There are certain situations we allow one space, notably for
+ # section labels, and also lines containing multi-line raw strings.
+ # We also don't check for lines that look like continuation lines
+ # (of lines ending in double quotes, commas, equals, or angle brackets)
+ # because the rules for how to indent those are non-trivial.
+ if (not Search(r'[",=><] *$', prev) and
+ (initial_spaces == 1 or initial_spaces == 3) and
+ not Match(scope_or_label_pattern, cleansed_line) and
+ not (clean_lines.raw_lines[linenum] != line and
+ Match(r'^\s*""', line))):
+ error(filename, linenum, 'whitespace/indent', 3,
+ 'Weird number of spaces at line-start. '
+ 'Are you using a 2-space indent?')
+
+ if line and line[-1].isspace():
+ error(filename, linenum, 'whitespace/end_of_line', 4,
+ 'Line ends in whitespace. Consider deleting these extra spaces.')
+
+ # Check if the line is a header guard.
+ is_header_guard = False
+ if IsHeaderExtension(file_extension):
+ cppvar = GetHeaderGuardCPPVariable(filename)
+ if (line.startswith('#ifndef %s' % cppvar) or
+ line.startswith('#define %s' % cppvar) or
+ line.startswith('#endif // %s' % cppvar)):
+ is_header_guard = True
+ # #include lines and header guards can be long, since there's no clean way to
+ # split them.
+ #
+ # URLs can be long too. It's possible to split these, but it makes them
+ # harder to cut&paste.
+ #
+ # The "$Id:...$" comment may also get very long without it being the
+ # developers fault.
+ if (not line.startswith('#include') and not is_header_guard and
+ not Match(r'^\s*//.*http(s?)://\S*$', line) and
+ not Match(r'^\s*//\s*[^\s]*$', line) and
+ not Match(r'^// \$Id:.*#[0-9]+ \$$', line)):
+ line_width = GetLineWidth(line)
+ if line_width > _line_length:
+ error(filename, linenum, 'whitespace/line_length', 2,
+ 'Lines should be <= %i characters long' % _line_length)
+
+ if (cleansed_line.count(';') > 1 and
+ # for loops are allowed two ;'s (and may run over two lines).
+ cleansed_line.find('for') == -1 and
+ (GetPreviousNonBlankLine(clean_lines, linenum)[0].find('for') == -1 or
+ GetPreviousNonBlankLine(clean_lines, linenum)[0].find(';') != -1) and
+ # It's ok to have many commands in a switch case that fits in 1 line
+ not ((cleansed_line.find('case ') != -1 or
+ cleansed_line.find('default:') != -1) and
+ cleansed_line.find('break;') != -1)):
+ error(filename, linenum, 'whitespace/newline', 0,
+ 'More than one command on the same line')
+
+ # Some more style checks
+ CheckBraces(filename, clean_lines, linenum, error)
+ CheckTrailingSemicolon(filename, clean_lines, linenum, error)
+ CheckEmptyBlockBody(filename, clean_lines, linenum, error)
+ CheckSpacing(filename, clean_lines, linenum, nesting_state, error)
+ CheckOperatorSpacing(filename, clean_lines, linenum, error)
+ CheckParenthesisSpacing(filename, clean_lines, linenum, error)
+ CheckCommaSpacing(filename, clean_lines, linenum, error)
+ CheckBracesSpacing(filename, clean_lines, linenum, nesting_state, error)
+ CheckSpacingForFunctionCall(filename, clean_lines, linenum, error)
+ CheckCheck(filename, clean_lines, linenum, error)
+ CheckAltTokens(filename, clean_lines, linenum, error)
+ classinfo = nesting_state.InnermostClass()
+ if classinfo:
+ CheckSectionSpacing(filename, clean_lines, classinfo, linenum, error)
+
+
+_RE_PATTERN_INCLUDE = re.compile(r'^\s*#\s*include\s*([<"])([^>"]*)[>"].*$')
+# Matches the first component of a filename delimited by -s and _s. That is:
+# _RE_FIRST_COMPONENT.match('foo').group(0) == 'foo'
+# _RE_FIRST_COMPONENT.match('foo.cc').group(0) == 'foo'
+# _RE_FIRST_COMPONENT.match('foo-bar_baz.cc').group(0) == 'foo'
+# _RE_FIRST_COMPONENT.match('foo_bar-baz.cc').group(0) == 'foo'
+_RE_FIRST_COMPONENT = re.compile(r'^[^-_.]+')
+
+
+def _DropCommonSuffixes(filename):
+ """Drops common suffixes like _test.cc or -inl.h from filename.
+
+ For example:
+ >>> _DropCommonSuffixes('foo/foo-inl.h')
+ 'foo/foo'
+ >>> _DropCommonSuffixes('foo/bar/foo.cc')
+ 'foo/bar/foo'
+ >>> _DropCommonSuffixes('foo/foo_internal.h')
+ 'foo/foo'
+ >>> _DropCommonSuffixes('foo/foo_unusualinternal.h')
+ 'foo/foo_unusualinternal'
+
+ Args:
+ filename: The input filename.
+
+ Returns:
+ The filename with the common suffix removed.
+ """
+ for suffix in ('test.cc', 'regtest.cc', 'unittest.cc',
+ 'inl.h', 'impl.h', 'internal.h'):
+ if (filename.endswith(suffix) and len(filename) > len(suffix) and
+ filename[-len(suffix) - 1] in ('-', '_')):
+ return filename[:-len(suffix) - 1]
+ return os.path.splitext(filename)[0]
+
+
+def _ClassifyInclude(fileinfo, include, is_system):
+ """Figures out what kind of header 'include' is.
+
+ Args:
+ fileinfo: The current file cpplint is running over. A FileInfo instance.
+ include: The path to a #included file.
+ is_system: True if the #include used <> rather than "".
+
+ Returns:
+ One of the _XXX_HEADER constants.
+
+ For example:
+ >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'stdio.h', True)
+ _C_SYS_HEADER
+ >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'string', True)
+ _CPP_SYS_HEADER
+ >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'foo/foo.h', False)
+ _LIKELY_MY_HEADER
+ >>> _ClassifyInclude(FileInfo('foo/foo_unknown_extension.cc'),
+ ... 'bar/foo_other_ext.h', False)
+ _POSSIBLE_MY_HEADER
+ >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'foo/bar.h', False)
+ _OTHER_HEADER
+ """
+ # This is a list of all standard c++ header files, except
+ # those already checked for above.
+ is_cpp_h = include in _CPP_HEADERS
+
+ if is_system:
+ if is_cpp_h:
+ return _CPP_SYS_HEADER
+ else:
+ return _C_SYS_HEADER
+
+ # If the target file and the include we're checking share a
+ # basename when we drop common extensions, and the include
+ # lives in . , then it's likely to be owned by the target file.
+ target_dir, target_base = (
+ os.path.split(_DropCommonSuffixes(fileinfo.RepositoryName())))
+ include_dir, include_base = os.path.split(_DropCommonSuffixes(include))
+ if target_base == include_base and (
+ include_dir == target_dir or
+ include_dir == os.path.normpath(target_dir + '/../public')):
+ return _LIKELY_MY_HEADER
+
+ # If the target and include share some initial basename
+ # component, it's possible the target is implementing the
+ # include, so it's allowed to be first, but we'll never
+ # complain if it's not there.
+ target_first_component = _RE_FIRST_COMPONENT.match(target_base)
+ include_first_component = _RE_FIRST_COMPONENT.match(include_base)
+ if (target_first_component and include_first_component and
+ target_first_component.group(0) ==
+ include_first_component.group(0)):
+ return _POSSIBLE_MY_HEADER
+
+ return _OTHER_HEADER
+
+
+
+def CheckIncludeLine(filename, clean_lines, linenum, include_state, error):
+ """Check rules that are applicable to #include lines.
+
+ Strings on #include lines are NOT removed from elided line, to make
+ certain tasks easier. However, to prevent false positives, checks
+ applicable to #include lines in CheckLanguage must be put here.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ include_state: An _IncludeState instance in which the headers are inserted.
+ error: The function to call with any errors found.
+ """
+ fileinfo = FileInfo(filename)
+ line = clean_lines.lines[linenum]
+
+ # "include" should use the new style "foo/bar.h" instead of just "bar.h"
+ # Only do this check if the included header follows google naming
+ # conventions. If not, assume that it's a 3rd party API that
+ # requires special include conventions.
+ #
+ # We also make an exception for Lua headers, which follow google
+ # naming convention but not the include convention.
+ match = Match(r'#include\s*"([^/]+\.h)"', line)
+ if match and not _THIRD_PARTY_HEADERS_PATTERN.match(match.group(1)):
+ error(filename, linenum, 'build/include', 4,
+ 'Include the directory when naming .h files')
+
+ # we shouldn't include a file more than once. actually, there are a
+ # handful of instances where doing so is okay, but in general it's
+ # not.
+ match = _RE_PATTERN_INCLUDE.search(line)
+ if match:
+ include = match.group(2)
+ is_system = (match.group(1) == '<')
+ duplicate_line = include_state.FindHeader(include)
+ if duplicate_line >= 0:
+ error(filename, linenum, 'build/include', 4,
+ '"%s" already included at %s:%s' %
+ (include, filename, duplicate_line))
+ elif (include.endswith('.cc') and
+ os.path.dirname(fileinfo.RepositoryName()) != os.path.dirname(include)):
+ error(filename, linenum, 'build/include', 4,
+ 'Do not include .cc files from other packages')
+ elif not _THIRD_PARTY_HEADERS_PATTERN.match(include):
+ include_state.include_list[-1].append((include, linenum))
+
+ # We want to ensure that headers appear in the right order:
+ # 1) for foo.cc, foo.h (preferred location)
+ # 2) c system files
+ # 3) cpp system files
+ # 4) for foo.cc, foo.h (deprecated location)
+ # 5) other google headers
+ #
+ # We classify each include statement as one of those 5 types
+ # using a number of techniques. The include_state object keeps
+ # track of the highest type seen, and complains if we see a
+ # lower type after that.
+ error_message = include_state.CheckNextIncludeOrder(
+ _ClassifyInclude(fileinfo, include, is_system))
+ if error_message:
+ error(filename, linenum, 'build/include_order', 4,
+ '%s. Should be: %s.h, c system, c++ system, other.' %
+ (error_message, fileinfo.BaseName()))
+ canonical_include = include_state.CanonicalizeAlphabeticalOrder(include)
+ if not include_state.IsInAlphabeticalOrder(
+ clean_lines, linenum, canonical_include):
+ error(filename, linenum, 'build/include_alpha', 4,
+ 'Include "%s" not in alphabetical order' % include)
+ include_state.SetLastHeader(canonical_include)
+
+
+
+def _GetTextInside(text, start_pattern):
+ r"""Retrieves all the text between matching open and close parentheses.
+
+ Given a string of lines and a regular expression string, retrieve all the text
+ following the expression and between opening punctuation symbols like
+ (, [, or {, and the matching close-punctuation symbol. This properly nested
+ occurrences of the punctuations, so for the text like
+ printf(a(), b(c()));
+ a call to _GetTextInside(text, r'printf\(') will return 'a(), b(c())'.
+ start_pattern must match string having an open punctuation symbol at the end.
+
+ Args:
+ text: The lines to extract text. Its comments and strings must be elided.
+ It can be single line and can span multiple lines.
+ start_pattern: The regexp string indicating where to start extracting
+ the text.
+ Returns:
+ The extracted text.
+ None if either the opening string or ending punctuation could not be found.
+ """
+ # TODO(unknown): Audit cpplint.py to see what places could be profitably
+ # rewritten to use _GetTextInside (and use inferior regexp matching today).
+
+ # Give opening punctuations to get the matching close-punctuations.
+ matching_punctuation = {'(': ')', '{': '}', '[': ']'}
+ closing_punctuation = set(matching_punctuation.itervalues())
+
+ # Find the position to start extracting text.
+ match = re.search(start_pattern, text, re.M)
+ if not match: # start_pattern not found in text.
+ return None
+ start_position = match.end(0)
+
+ assert start_position > 0, (
+ 'start_pattern must ends with an opening punctuation.')
+ assert text[start_position - 1] in matching_punctuation, (
+ 'start_pattern must ends with an opening punctuation.')
+ # Stack of closing punctuations we expect to have in text after position.
+ punctuation_stack = [matching_punctuation[text[start_position - 1]]]
+ position = start_position
+ while punctuation_stack and position < len(text):
+ if text[position] == punctuation_stack[-1]:
+ punctuation_stack.pop()
+ elif text[position] in closing_punctuation:
+ # A closing punctuation without matching opening punctuations.
+ return None
+ elif text[position] in matching_punctuation:
+ punctuation_stack.append(matching_punctuation[text[position]])
+ position += 1
+ if punctuation_stack:
+ # Opening punctuations left without matching close-punctuations.
+ return None
+ # punctuations match.
+ return text[start_position:position - 1]
+
+
+# Patterns for matching call-by-reference parameters.
+#
+# Supports nested templates up to 2 levels deep using this messy pattern:
+# < (?: < (?: < [^<>]*
+# >
+# | [^<>] )*
+# >
+# | [^<>] )*
+# >
+_RE_PATTERN_IDENT = r'[_a-zA-Z]\w*' # =~ [[:alpha:]][[:alnum:]]*
+_RE_PATTERN_TYPE = (
+ r'(?:const\s+)?(?:typename\s+|class\s+|struct\s+|union\s+|enum\s+)?'
+ r'(?:\w|'
+ r'\s*<(?:<(?:<[^<>]*>|[^<>])*>|[^<>])*>|'
+ r'::)+')
+# A call-by-reference parameter ends with '& identifier'.
+_RE_PATTERN_REF_PARAM = re.compile(
+ r'(' + _RE_PATTERN_TYPE + r'(?:\s*(?:\bconst\b|[*]))*\s*'
+ r'&\s*' + _RE_PATTERN_IDENT + r')\s*(?:=[^,()]+)?[,)]')
+# A call-by-const-reference parameter either ends with 'const& identifier'
+# or looks like 'const type& identifier' when 'type' is atomic.
+_RE_PATTERN_CONST_REF_PARAM = (
+ r'(?:.*\s*\bconst\s*&\s*' + _RE_PATTERN_IDENT +
+ r'|const\s+' + _RE_PATTERN_TYPE + r'\s*&\s*' + _RE_PATTERN_IDENT + r')')
+# Stream types.
+_RE_PATTERN_REF_STREAM_PARAM = (
+ r'(?:.*stream\s*&\s*' + _RE_PATTERN_IDENT + r')')
+
+
+def CheckLanguage(filename, clean_lines, linenum, file_extension,
+ include_state, nesting_state, error):
+ """Checks rules from the 'C++ language rules' section of cppguide.html.
+
+ Some of these rules are hard to test (function overloading, using
+ uint32 inappropriately), but we do the best we can.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ file_extension: The extension (without the dot) of the filename.
+ include_state: An _IncludeState instance in which the headers are inserted.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ error: The function to call with any errors found.
+ """
+ # If the line is empty or consists of entirely a comment, no need to
+ # check it.
+ line = clean_lines.elided[linenum]
+ if not line:
+ return
+
+ match = _RE_PATTERN_INCLUDE.search(line)
+ if match:
+ CheckIncludeLine(filename, clean_lines, linenum, include_state, error)
+ return
+
+ # Reset include state across preprocessor directives. This is meant
+ # to silence warnings for conditional includes.
+ match = Match(r'^\s*#\s*(if|ifdef|ifndef|elif|else|endif)\b', line)
+ if match:
+ include_state.ResetSection(match.group(1))
+
+ # Make Windows paths like Unix.
+ fullname = os.path.abspath(filename).replace('\\', '/')
+
+ # Perform other checks now that we are sure that this is not an include line
+ CheckCasts(filename, clean_lines, linenum, error)
+ CheckGlobalStatic(filename, clean_lines, linenum, error)
+ CheckPrintf(filename, clean_lines, linenum, error)
+
+ if IsHeaderExtension(file_extension):
+ # TODO(unknown): check that 1-arg constructors are explicit.
+ # How to tell it's a constructor?
+ # (handled in CheckForNonStandardConstructs for now)
+ # TODO(unknown): check that classes declare or disable copy/assign
+ # (level 1 error)
+ pass
+
+ # Check if people are using the verboten C basic types. The only exception
+ # we regularly allow is "unsigned short port" for port.
+ if Search(r'\bshort port\b', line):
+ if not Search(r'\bunsigned short port\b', line):
+ error(filename, linenum, 'runtime/int', 4,
+ 'Use "unsigned short" for ports, not "short"')
+ else:
+ match = Search(r'\b(short|long(?! +double)|long long)\b', line)
+ if match:
+ error(filename, linenum, 'runtime/int', 4,
+ 'Use int16/int64/etc, rather than the C type %s' % match.group(1))
+
+ # Check if some verboten operator overloading is going on
+ # TODO(unknown): catch out-of-line unary operator&:
+ # class X {};
+ # int operator&(const X& x) { return 42; } // unary operator&
+ # The trick is it's hard to tell apart from binary operator&:
+ # class Y { int operator&(const Y& x) { return 23; } }; // binary operator&
+ if Search(r'\boperator\s*&\s*\(\s*\)', line):
+ error(filename, linenum, 'runtime/operator', 4,
+ 'Unary operator& is dangerous. Do not use it.')
+
+ # Check for suspicious usage of "if" like
+ # } if (a == b) {
+ if Search(r'\}\s*if\s*\(', line):
+ error(filename, linenum, 'readability/braces', 4,
+ 'Did you mean "else if"? If not, start a new line for "if".')
+
+ # Check for potential format string bugs like printf(foo).
+ # We constrain the pattern not to pick things like DocidForPrintf(foo).
+ # Not perfect but it can catch printf(foo.c_str()) and printf(foo->c_str())
+ # TODO(unknown): Catch the following case. Need to change the calling
+ # convention of the whole function to process multiple line to handle it.
+ # printf(
+ # boy_this_is_a_really_long_variable_that_cannot_fit_on_the_prev_line);
+ printf_args = _GetTextInside(line, r'(?i)\b(string)?printf\s*\(')
+ if printf_args:
+ match = Match(r'([\w.\->()]+)$', printf_args)
+ if match and match.group(1) != '__VA_ARGS__':
+ function_name = re.search(r'\b((?:string)?printf)\s*\(',
+ line, re.I).group(1)
+ error(filename, linenum, 'runtime/printf', 4,
+ 'Potential format string bug. Do %s("%%s", %s) instead.'
+ % (function_name, match.group(1)))
+
+ # Check for potential memset bugs like memset(buf, sizeof(buf), 0).
+ match = Search(r'memset\s*\(([^,]*),\s*([^,]*),\s*0\s*\)', line)
+ if match and not Match(r"^''|-?[0-9]+|0x[0-9A-Fa-f]$", match.group(2)):
+ error(filename, linenum, 'runtime/memset', 4,
+ 'Did you mean "memset(%s, 0, %s)"?'
+ % (match.group(1), match.group(2)))
+
+ if Search(r'\busing namespace\b', line):
+ error(filename, linenum, 'build/namespaces', 5,
+ 'Do not use namespace using-directives. '
+ 'Use using-declarations instead.')
+
+ # Detect variable-length arrays.
+ match = Match(r'\s*(.+::)?(\w+) [a-z]\w*\[(.+)];', line)
+ if (match and match.group(2) != 'return' and match.group(2) != 'delete' and
+ match.group(3).find(']') == -1):
+ # Split the size using space and arithmetic operators as delimiters.
+ # If any of the resulting tokens are not compile time constants then
+ # report the error.
+ tokens = re.split(r'\s|\+|\-|\*|\/|<<|>>]', match.group(3))
+ is_const = True
+ skip_next = False
+ for tok in tokens:
+ if skip_next:
+ skip_next = False
+ continue
+
+ if Search(r'sizeof\(.+\)', tok): continue
+ if Search(r'arraysize\(\w+\)', tok): continue
+
+ tok = tok.lstrip('(')
+ tok = tok.rstrip(')')
+ if not tok: continue
+ if Match(r'\d+', tok): continue
+ if Match(r'0[xX][0-9a-fA-F]+', tok): continue
+ if Match(r'k[A-Z0-9]\w*', tok): continue
+ if Match(r'(.+::)?k[A-Z0-9]\w*', tok): continue
+ if Match(r'(.+::)?[A-Z][A-Z0-9_]*', tok): continue
+ # A catch all for tricky sizeof cases, including 'sizeof expression',
+ # 'sizeof(*type)', 'sizeof(const type)', 'sizeof(struct StructName)'
+ # requires skipping the next token because we split on ' ' and '*'.
+ if tok.startswith('sizeof'):
+ skip_next = True
+ continue
+ is_const = False
+ break
+ if not is_const:
+ error(filename, linenum, 'runtime/arrays', 1,
+ 'Do not use variable-length arrays. Use an appropriately named '
+ "('k' followed by CamelCase) compile-time constant for the size.")
+
+ # Check for use of unnamed namespaces in header files. Registration
+ # macros are typically OK, so we allow use of "namespace {" on lines
+ # that end with backslashes.
+ if (IsHeaderExtension(file_extension)
+ and Search(r'\bnamespace\s*{', line)
+ and line[-1] != '\\'):
+ error(filename, linenum, 'build/namespaces', 4,
+ 'Do not use unnamed namespaces in header files. See '
+ 'https://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Namespaces'
+ ' for more information.')
+
+
+def CheckGlobalStatic(filename, clean_lines, linenum, error):
+ """Check for unsafe global or static objects.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # Match two lines at a time to support multiline declarations
+ if linenum + 1 < clean_lines.NumLines() and not Search(r'[;({]', line):
+ line += clean_lines.elided[linenum + 1].strip()
+
+ # Check for people declaring static/global STL strings at the top level.
+ # This is dangerous because the C++ language does not guarantee that
+ # globals with constructors are initialized before the first access, and
+ # also because globals can be destroyed when some threads are still running.
+ # TODO(unknown): Generalize this to also find static unique_ptr instances.
+ # TODO(unknown): File bugs for clang-tidy to find these.
+ match = Match(
+ r'((?:|static +)(?:|const +))(?::*std::)?string( +const)? +'
+ r'([a-zA-Z0-9_:]+)\b(.*)',
+ line)
+
+ # Remove false positives:
+ # - String pointers (as opposed to values).
+ # string *pointer
+ # const string *pointer
+ # string const *pointer
+ # string *const pointer
+ #
+ # - Functions and template specializations.
+ # string Function<Type>(...
+ # string Class<Type>::Method(...
+ #
+ # - Operators. These are matched separately because operator names
+ # cross non-word boundaries, and trying to match both operators
+ # and functions at the same time would decrease accuracy of
+ # matching identifiers.
+ # string Class::operator*()
+ if (match and
+ not Search(r'\bstring\b(\s+const)?\s*[\*\&]\s*(const\s+)?\w', line) and
+ not Search(r'\boperator\W', line) and
+ not Match(r'\s*(<.*>)?(::[a-zA-Z0-9_]+)*\s*\(([^"]|$)', match.group(4))):
+ if Search(r'\bconst\b', line):
+ error(filename, linenum, 'runtime/string', 4,
+ 'For a static/global string constant, use a C style string '
+ 'instead: "%schar%s %s[]".' %
+ (match.group(1), match.group(2) or '', match.group(3)))
+ else:
+ error(filename, linenum, 'runtime/string', 4,
+ 'Static/global string variables are not permitted.')
+
+ if (Search(r'\b([A-Za-z0-9_]*_)\(\1\)', line) or
+ Search(r'\b([A-Za-z0-9_]*_)\(CHECK_NOTNULL\(\1\)\)', line)):
+ error(filename, linenum, 'runtime/init', 4,
+ 'You seem to be initializing a member variable with itself.')
+
+
+def CheckPrintf(filename, clean_lines, linenum, error):
+ """Check for printf related issues.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # When snprintf is used, the second argument shouldn't be a literal.
+ match = Search(r'snprintf\s*\(([^,]*),\s*([0-9]*)\s*,', line)
+ if match and match.group(2) != '0':
+ # If 2nd arg is zero, snprintf is used to calculate size.
+ error(filename, linenum, 'runtime/printf', 3,
+ 'If you can, use sizeof(%s) instead of %s as the 2nd arg '
+ 'to snprintf.' % (match.group(1), match.group(2)))
+
+ # Check if some verboten C functions are being used.
+ if Search(r'\bsprintf\s*\(', line):
+ error(filename, linenum, 'runtime/printf', 5,
+ 'Never use sprintf. Use snprintf instead.')
+ match = Search(r'\b(strcpy|strcat)\s*\(', line)
+ if match:
+ error(filename, linenum, 'runtime/printf', 4,
+ 'Almost always, snprintf is better than %s' % match.group(1))
+
+
+def IsDerivedFunction(clean_lines, linenum):
+ """Check if current line contains an inherited function.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ Returns:
+ True if current line contains a function with "override"
+ virt-specifier.
+ """
+ # Scan back a few lines for start of current function
+ for i in xrange(linenum, max(-1, linenum - 10), -1):
+ match = Match(r'^([^()]*\w+)\(', clean_lines.elided[i])
+ if match:
+ # Look for "override" after the matching closing parenthesis
+ line, _, closing_paren = CloseExpression(
+ clean_lines, i, len(match.group(1)))
+ return (closing_paren >= 0 and
+ Search(r'\boverride\b', line[closing_paren:]))
+ return False
+
+
+def IsOutOfLineMethodDefinition(clean_lines, linenum):
+ """Check if current line contains an out-of-line method definition.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ Returns:
+ True if current line contains an out-of-line method definition.
+ """
+ # Scan back a few lines for start of current function
+ for i in xrange(linenum, max(-1, linenum - 10), -1):
+ if Match(r'^([^()]*\w+)\(', clean_lines.elided[i]):
+ return Match(r'^[^()]*\w+::\w+\(', clean_lines.elided[i]) is not None
+ return False
+
+
+def IsInitializerList(clean_lines, linenum):
+ """Check if current line is inside constructor initializer list.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ Returns:
+ True if current line appears to be inside constructor initializer
+ list, False otherwise.
+ """
+ for i in xrange(linenum, 1, -1):
+ line = clean_lines.elided[i]
+ if i == linenum:
+ remove_function_body = Match(r'^(.*)\{\s*$', line)
+ if remove_function_body:
+ line = remove_function_body.group(1)
+
+ if Search(r'\s:\s*\w+[({]', line):
+ # A lone colon tend to indicate the start of a constructor
+ # initializer list. It could also be a ternary operator, which
+ # also tend to appear in constructor initializer lists as
+ # opposed to parameter lists.
+ return True
+ if Search(r'\}\s*,\s*$', line):
+ # A closing brace followed by a comma is probably the end of a
+ # brace-initialized member in constructor initializer list.
+ return True
+ if Search(r'[{};]\s*$', line):
+ # Found one of the following:
+ # - A closing brace or semicolon, probably the end of the previous
+ # function.
+ # - An opening brace, probably the start of current class or namespace.
+ #
+ # Current line is probably not inside an initializer list since
+ # we saw one of those things without seeing the starting colon.
+ return False
+
+ # Got to the beginning of the file without seeing the start of
+ # constructor initializer list.
+ return False
+
+
+def CheckForNonConstReference(filename, clean_lines, linenum,
+ nesting_state, error):
+ """Check for non-const references.
+
+ Separate from CheckLanguage since it scans backwards from current
+ line, instead of scanning forward.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ error: The function to call with any errors found.
+ """
+ # Do nothing if there is no '&' on current line.
+ line = clean_lines.elided[linenum]
+ if '&' not in line:
+ return
+
+ # If a function is inherited, current function doesn't have much of
+ # a choice, so any non-const references should not be blamed on
+ # derived function.
+ if IsDerivedFunction(clean_lines, linenum):
+ return
+
+ # Don't warn on out-of-line method definitions, as we would warn on the
+ # in-line declaration, if it isn't marked with 'override'.
+ if IsOutOfLineMethodDefinition(clean_lines, linenum):
+ return
+
+ # Long type names may be broken across multiple lines, usually in one
+ # of these forms:
+ # LongType
+ # ::LongTypeContinued &identifier
+ # LongType::
+ # LongTypeContinued &identifier
+ # LongType<
+ # ...>::LongTypeContinued &identifier
+ #
+ # If we detected a type split across two lines, join the previous
+ # line to current line so that we can match const references
+ # accordingly.
+ #
+ # Note that this only scans back one line, since scanning back
+ # arbitrary number of lines would be expensive. If you have a type
+ # that spans more than 2 lines, please use a typedef.
+ if linenum > 1:
+ previous = None
+ if Match(r'\s*::(?:[\w<>]|::)+\s*&\s*\S', line):
+ # previous_line\n + ::current_line
+ previous = Search(r'\b((?:const\s*)?(?:[\w<>]|::)+[\w<>])\s*$',
+ clean_lines.elided[linenum - 1])
+ elif Match(r'\s*[a-zA-Z_]([\w<>]|::)+\s*&\s*\S', line):
+ # previous_line::\n + current_line
+ previous = Search(r'\b((?:const\s*)?(?:[\w<>]|::)+::)\s*$',
+ clean_lines.elided[linenum - 1])
+ if previous:
+ line = previous.group(1) + line.lstrip()
+ else:
+ # Check for templated parameter that is split across multiple lines
+ endpos = line.rfind('>')
+ if endpos > -1:
+ (_, startline, startpos) = ReverseCloseExpression(
+ clean_lines, linenum, endpos)
+ if startpos > -1 and startline < linenum:
+ # Found the matching < on an earlier line, collect all
+ # pieces up to current line.
+ line = ''
+ for i in xrange(startline, linenum + 1):
+ line += clean_lines.elided[i].strip()
+
+ # Check for non-const references in function parameters. A single '&' may
+ # found in the following places:
+ # inside expression: binary & for bitwise AND
+ # inside expression: unary & for taking the address of something
+ # inside declarators: reference parameter
+ # We will exclude the first two cases by checking that we are not inside a
+ # function body, including one that was just introduced by a trailing '{'.
+ # TODO(unknown): Doesn't account for 'catch(Exception& e)' [rare].
+ if (nesting_state.previous_stack_top and
+ not (isinstance(nesting_state.previous_stack_top, _ClassInfo) or
+ isinstance(nesting_state.previous_stack_top, _NamespaceInfo))):
+ # Not at toplevel, not within a class, and not within a namespace
+ return
+
+ # Avoid initializer lists. We only need to scan back from the
+ # current line for something that starts with ':'.
+ #
+ # We don't need to check the current line, since the '&' would
+ # appear inside the second set of parentheses on the current line as
+ # opposed to the first set.
+ if linenum > 0:
+ for i in xrange(linenum - 1, max(0, linenum - 10), -1):
+ previous_line = clean_lines.elided[i]
+ if not Search(r'[),]\s*$', previous_line):
+ break
+ if Match(r'^\s*:\s+\S', previous_line):
+ return
+
+ # Avoid preprocessors
+ if Search(r'\\\s*$', line):
+ return
+
+ # Avoid constructor initializer lists
+ if IsInitializerList(clean_lines, linenum):
+ return
+
+ # We allow non-const references in a few standard places, like functions
+ # called "swap()" or iostream operators like "<<" or ">>". Do not check
+ # those function parameters.
+ #
+ # We also accept & in static_assert, which looks like a function but
+ # it's actually a declaration expression.
+ allowed_functions = (r'(?:[sS]wap(?:<\w:+>)?|'
+ r'operator\s*[<>][<>]|'
+ r'static_assert|COMPILE_ASSERT'
+ r')\s*\(')
+ if Search(allowed_functions, line):
+ return
+ elif not Search(r'\S+\([^)]*$', line):
+ # Don't see an allowed function on this line. Actually we
+ # didn't see any function name on this line, so this is likely a
+ # multi-line parameter list. Try a bit harder to catch this case.
+ for i in xrange(2):
+ if (linenum > i and
+ Search(allowed_functions, clean_lines.elided[linenum - i - 1])):
+ return
+
+ decls = ReplaceAll(r'{[^}]*}', ' ', line) # exclude function body
+ for parameter in re.findall(_RE_PATTERN_REF_PARAM, decls):
+ if (not Match(_RE_PATTERN_CONST_REF_PARAM, parameter) and
+ not Match(_RE_PATTERN_REF_STREAM_PARAM, parameter)):
+ error(filename, linenum, 'runtime/references', 2,
+ 'Is this a non-const reference? '
+ 'If so, make const or use a pointer: ' +
+ ReplaceAll(' *<', '<', parameter))
+
+
+def CheckCasts(filename, clean_lines, linenum, error):
+ """Various cast related checks.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # Check to see if they're using an conversion function cast.
+ # I just try to capture the most common basic types, though there are more.
+ # Parameterless conversion functions, such as bool(), are allowed as they are
+ # probably a member operator declaration or default constructor.
+ match = Search(
+ r'(\bnew\s+(?:const\s+)?|\S<\s*(?:const\s+)?)?\b'
+ r'(int|float|double|bool|char|int32|uint32|int64|uint64)'
+ r'(\([^)].*)', line)
+ expecting_function = ExpectingFunctionArgs(clean_lines, linenum)
+ if match and not expecting_function:
+ matched_type = match.group(2)
+
+ # matched_new_or_template is used to silence two false positives:
+ # - New operators
+ # - Template arguments with function types
+ #
+ # For template arguments, we match on types immediately following
+ # an opening bracket without any spaces. This is a fast way to
+ # silence the common case where the function type is the first
+ # template argument. False negative with less-than comparison is
+ # avoided because those operators are usually followed by a space.
+ #
+ # function<double(double)> // bracket + no space = false positive
+ # value < double(42) // bracket + space = true positive
+ matched_new_or_template = match.group(1)
+
+ # Avoid arrays by looking for brackets that come after the closing
+ # parenthesis.
+ if Match(r'\([^()]+\)\s*\[', match.group(3)):
+ return
+
+ # Other things to ignore:
+ # - Function pointers
+ # - Casts to pointer types
+ # - Placement new
+ # - Alias declarations
+ matched_funcptr = match.group(3)
+ if (matched_new_or_template is None and
+ not (matched_funcptr and
+ (Match(r'\((?:[^() ]+::\s*\*\s*)?[^() ]+\)\s*\(',
+ matched_funcptr) or
+ matched_funcptr.startswith('(*)'))) and
+ not Match(r'\s*using\s+\S+\s*=\s*' + matched_type, line) and
+ not Search(r'new\(\S+\)\s*' + matched_type, line)):
+ error(filename, linenum, 'readability/casting', 4,
+ 'Using deprecated casting style. '
+ 'Use static_cast<%s>(...) instead' %
+ matched_type)
+
+ if not expecting_function:
+ CheckCStyleCast(filename, clean_lines, linenum, 'static_cast',
+ r'\((int|float|double|bool|char|u?int(16|32|64))\)', error)
+
+ # This doesn't catch all cases. Consider (const char * const)"hello".
+ #
+ # (char *) "foo" should always be a const_cast (reinterpret_cast won't
+ # compile).
+ if CheckCStyleCast(filename, clean_lines, linenum, 'const_cast',
+ r'\((char\s?\*+\s?)\)\s*"', error):
+ pass
+ else:
+ # Check pointer casts for other than string constants
+ CheckCStyleCast(filename, clean_lines, linenum, 'reinterpret_cast',
+ r'\((\w+\s?\*+\s?)\)', error)
+
+ # In addition, we look for people taking the address of a cast. This
+ # is dangerous -- casts can assign to temporaries, so the pointer doesn't
+ # point where you think.
+ #
+ # Some non-identifier character is required before the '&' for the
+ # expression to be recognized as a cast. These are casts:
+ # expression = &static_cast<int*>(temporary());
+ # function(&(int*)(temporary()));
+ #
+ # This is not a cast:
+ # reference_type&(int* function_param);
+ match = Search(
+ r'(?:[^\w]&\(([^)*][^)]*)\)[\w(])|'
+ r'(?:[^\w]&(static|dynamic|down|reinterpret)_cast\b)', line)
+ if match:
+ # Try a better error message when the & is bound to something
+ # dereferenced by the casted pointer, as opposed to the casted
+ # pointer itself.
+ parenthesis_error = False
+ match = Match(r'^(.*&(?:static|dynamic|down|reinterpret)_cast\b)<', line)
+ if match:
+ _, y1, x1 = CloseExpression(clean_lines, linenum, len(match.group(1)))
+ if x1 >= 0 and clean_lines.elided[y1][x1] == '(':
+ _, y2, x2 = CloseExpression(clean_lines, y1, x1)
+ if x2 >= 0:
+ extended_line = clean_lines.elided[y2][x2:]
+ if y2 < clean_lines.NumLines() - 1:
+ extended_line += clean_lines.elided[y2 + 1]
+ if Match(r'\s*(?:->|\[)', extended_line):
+ parenthesis_error = True
+
+ if parenthesis_error:
+ error(filename, linenum, 'readability/casting', 4,
+ ('Are you taking an address of something dereferenced '
+ 'from a cast? Wrapping the dereferenced expression in '
+ 'parentheses will make the binding more obvious'))
+ else:
+ error(filename, linenum, 'runtime/casting', 4,
+ ('Are you taking an address of a cast? '
+ 'This is dangerous: could be a temp var. '
+ 'Take the address before doing the cast, rather than after'))
+
+
+def CheckCStyleCast(filename, clean_lines, linenum, cast_type, pattern, error):
+ """Checks for a C-style cast by looking for the pattern.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ cast_type: The string for the C++ cast to recommend. This is either
+ reinterpret_cast, static_cast, or const_cast, depending.
+ pattern: The regular expression used to find C-style casts.
+ error: The function to call with any errors found.
+
+ Returns:
+ True if an error was emitted.
+ False otherwise.
+ """
+ line = clean_lines.elided[linenum]
+ match = Search(pattern, line)
+ if not match:
+ return False
+
+ # Exclude lines with keywords that tend to look like casts
+ context = line[0:match.start(1) - 1]
+ if Match(r'.*\b(?:sizeof|alignof|alignas|[_A-Z][_A-Z0-9]*)\s*$', context):
+ return False
+
+ # Try expanding current context to see if we one level of
+ # parentheses inside a macro.
+ if linenum > 0:
+ for i in xrange(linenum - 1, max(0, linenum - 5), -1):
+ context = clean_lines.elided[i] + context
+ if Match(r'.*\b[_A-Z][_A-Z0-9]*\s*\((?:\([^()]*\)|[^()])*$', context):
+ return False
+
+ # operator++(int) and operator--(int)
+ if context.endswith(' operator++') or context.endswith(' operator--'):
+ return False
+
+ # A single unnamed argument for a function tends to look like old style cast.
+ # If we see those, don't issue warnings for deprecated casts.
+ remainder = line[match.end(0):]
+ if Match(r'^\s*(?:;|const\b|throw\b|final\b|override\b|[=>{),]|->)',
+ remainder):
+ return False
+
+ # At this point, all that should be left is actual casts.
+ error(filename, linenum, 'readability/casting', 4,
+ 'Using C-style cast. Use %s<%s>(...) instead' %
+ (cast_type, match.group(1)))
+
+ return True
+
+
+def ExpectingFunctionArgs(clean_lines, linenum):
+ """Checks whether where function type arguments are expected.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+
+ Returns:
+ True if the line at 'linenum' is inside something that expects arguments
+ of function types.
+ """
+ line = clean_lines.elided[linenum]
+ return (Match(r'^\s*MOCK_(CONST_)?METHOD\d+(_T)?\(', line) or
+ (linenum >= 2 and
+ (Match(r'^\s*MOCK_(?:CONST_)?METHOD\d+(?:_T)?\((?:\S+,)?\s*$',
+ clean_lines.elided[linenum - 1]) or
+ Match(r'^\s*MOCK_(?:CONST_)?METHOD\d+(?:_T)?\(\s*$',
+ clean_lines.elided[linenum - 2]) or
+ Search(r'\bstd::m?function\s*\<\s*$',
+ clean_lines.elided[linenum - 1]))))
+
+
+_HEADERS_CONTAINING_TEMPLATES = (
+ ('<deque>', ('deque',)),
+ ('<functional>', ('unary_function', 'binary_function',
+ 'plus', 'minus', 'multiplies', 'divides', 'modulus',
+ 'negate',
+ 'equal_to', 'not_equal_to', 'greater', 'less',
+ 'greater_equal', 'less_equal',
+ 'logical_and', 'logical_or', 'logical_not',
+ 'unary_negate', 'not1', 'binary_negate', 'not2',
+ 'bind1st', 'bind2nd',
+ 'pointer_to_unary_function',
+ 'pointer_to_binary_function',
+ 'ptr_fun',
+ 'mem_fun_t', 'mem_fun', 'mem_fun1_t', 'mem_fun1_ref_t',
+ 'mem_fun_ref_t',
+ 'const_mem_fun_t', 'const_mem_fun1_t',
+ 'const_mem_fun_ref_t', 'const_mem_fun1_ref_t',
+ 'mem_fun_ref',
+ )),
+ ('<limits>', ('numeric_limits',)),
+ ('<list>', ('list',)),
+ ('<map>', ('map', 'multimap',)),
+ ('<memory>', ('allocator', 'make_shared', 'make_unique', 'shared_ptr',
+ 'unique_ptr', 'weak_ptr')),
+ ('<queue>', ('queue', 'priority_queue',)),
+ ('<set>', ('set', 'multiset',)),
+ ('<stack>', ('stack',)),
+ ('<string>', ('char_traits', 'basic_string',)),
+ ('<tuple>', ('tuple',)),
+ ('<unordered_map>', ('unordered_map', 'unordered_multimap')),
+ ('<unordered_set>', ('unordered_set', 'unordered_multiset')),
+ ('<utility>', ('pair',)),
+ ('<vector>', ('vector',)),
+
+ # gcc extensions.
+ # Note: std::hash is their hash, ::hash is our hash
+ ('<hash_map>', ('hash_map', 'hash_multimap',)),
+ ('<hash_set>', ('hash_set', 'hash_multiset',)),
+ ('<slist>', ('slist',)),
+ )
+
+_HEADERS_MAYBE_TEMPLATES = (
+ ('<algorithm>', ('copy', 'max', 'min', 'min_element', 'sort',
+ 'transform',
+ )),
+ ('<utility>', ('forward', 'make_pair', 'move', 'swap')),
+ )
+
+_RE_PATTERN_STRING = re.compile(r'\bstring\b')
+
+_re_pattern_headers_maybe_templates = []
+for _header, _templates in _HEADERS_MAYBE_TEMPLATES:
+ for _template in _templates:
+ # Match max<type>(..., ...), max(..., ...), but not foo->max, foo.max or
+ # type::max().
+ _re_pattern_headers_maybe_templates.append(
+ (re.compile(r'[^>.]\b' + _template + r'(<.*?>)?\([^\)]'),
+ _template,
+ _header))
+
+# Other scripts may reach in and modify this pattern.
+_re_pattern_templates = []
+for _header, _templates in _HEADERS_CONTAINING_TEMPLATES:
+ for _template in _templates:
+ _re_pattern_templates.append(
+ (re.compile(r'(\<|\b)' + _template + r'\s*\<'),
+ _template + '<>',
+ _header))
+
+
+def FilesBelongToSameModule(filename_cc, filename_h):
+ """Check if these two filenames belong to the same module.
+
+ The concept of a 'module' here is a as follows:
+ foo.h, foo-inl.h, foo.cc, foo_test.cc and foo_unittest.cc belong to the
+ same 'module' if they are in the same directory.
+ some/path/public/xyzzy and some/path/internal/xyzzy are also considered
+ to belong to the same module here.
+
+ If the filename_cc contains a longer path than the filename_h, for example,
+ '/absolute/path/to/base/sysinfo.cc', and this file would include
+ 'base/sysinfo.h', this function also produces the prefix needed to open the
+ header. This is used by the caller of this function to more robustly open the
+ header file. We don't have access to the real include paths in this context,
+ so we need this guesswork here.
+
+ Known bugs: tools/base/bar.cc and base/bar.h belong to the same module
+ according to this implementation. Because of this, this function gives
+ some false positives. This should be sufficiently rare in practice.
+
+ Args:
+ filename_cc: is the path for the .cc file
+ filename_h: is the path for the header path
+
+ Returns:
+ Tuple with a bool and a string:
+ bool: True if filename_cc and filename_h belong to the same module.
+ string: the additional prefix needed to open the header file.
+ """
+
+ fileinfo = FileInfo(filename_cc)
+ if not fileinfo.IsSource():
+ return (False, '')
+ filename_cc = filename_cc[:-len(fileinfo.Extension())]
+ matched_test_suffix = Search(_TEST_FILE_SUFFIX, fileinfo.BaseName())
+ if matched_test_suffix:
+ filename_cc = filename_cc[:-len(matched_test_suffix.group(1))]
+ filename_cc = filename_cc.replace('/public/', '/')
+ filename_cc = filename_cc.replace('/internal/', '/')
+
+ if not filename_h.endswith('.h'):
+ return (False, '')
+ filename_h = filename_h[:-len('.h')]
+ if filename_h.endswith('-inl'):
+ filename_h = filename_h[:-len('-inl')]
+ filename_h = filename_h.replace('/public/', '/')
+ filename_h = filename_h.replace('/internal/', '/')
+
+ files_belong_to_same_module = filename_cc.endswith(filename_h)
+ common_path = ''
+ if files_belong_to_same_module:
+ common_path = filename_cc[:-len(filename_h)]
+ return files_belong_to_same_module, common_path
+
+
+def UpdateIncludeState(filename, include_dict, io=codecs):
+ """Fill up the include_dict with new includes found from the file.
+
+ Args:
+ filename: the name of the header to read.
+ include_dict: a dictionary in which the headers are inserted.
+ io: The io factory to use to read the file. Provided for testability.
+
+ Returns:
+ True if a header was successfully added. False otherwise.
+ """
+ headerfile = None
+ try:
+ headerfile = io.open(filename, 'r', 'utf8', 'replace')
+ except IOError:
+ return False
+ linenum = 0
+ for line in headerfile:
+ linenum += 1
+ clean_line = CleanseComments(line)
+ match = _RE_PATTERN_INCLUDE.search(clean_line)
+ if match:
+ include = match.group(2)
+ include_dict.setdefault(include, linenum)
+ return True
+
+
+def CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error,
+ io=codecs):
+ """Reports for missing stl includes.
+
+ This function will output warnings to make sure you are including the headers
+ necessary for the stl containers and functions that you use. We only give one
+ reason to include a header. For example, if you use both equal_to<> and
+ less<> in a .h file, only one (the latter in the file) of these will be
+ reported as a reason to include the <functional>.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ include_state: An _IncludeState instance.
+ error: The function to call with any errors found.
+ io: The IO factory to use to read the header file. Provided for unittest
+ injection.
+ """
+ required = {} # A map of header name to linenumber and the template entity.
+ # Example of required: { '<functional>': (1219, 'less<>') }
+
+ for linenum in xrange(clean_lines.NumLines()):
+ line = clean_lines.elided[linenum]
+ if not line or line[0] == '#':
+ continue
+
+ # String is special -- it is a non-templatized type in STL.
+ matched = _RE_PATTERN_STRING.search(line)
+ if matched:
+ # Don't warn about strings in non-STL namespaces:
+ # (We check only the first match per line; good enough.)
+ prefix = line[:matched.start()]
+ if prefix.endswith('std::') or not prefix.endswith('::'):
+ required['<string>'] = (linenum, 'string')
+
+ for pattern, template, header in _re_pattern_headers_maybe_templates:
+ if pattern.search(line):
+ required[header] = (linenum, template)
+
+ # The following function is just a speed up, no semantics are changed.
+ if not '<' in line: # Reduces the cpu time usage by skipping lines.
+ continue
+
+ for pattern, template, header in _re_pattern_templates:
+ matched = pattern.search(line)
+ if matched:
+ # Don't warn about IWYU in non-STL namespaces:
+ # (We check only the first match per line; good enough.)
+ prefix = line[:matched.start()]
+ if prefix.endswith('std::') or not prefix.endswith('::'):
+ required[header] = (linenum, template)
+
+ # The policy is that if you #include something in foo.h you don't need to
+ # include it again in foo.cc. Here, we will look at possible includes.
+ # Let's flatten the include_state include_list and copy it into a dictionary.
+ include_dict = dict([item for sublist in include_state.include_list
+ for item in sublist])
+
+ # Did we find the header for this file (if any) and successfully load it?
+ header_found = False
+
+ # Use the absolute path so that matching works properly.
+ abs_filename = FileInfo(filename).FullName()
+
+ # For Emacs's flymake.
+ # If cpplint is invoked from Emacs's flymake, a temporary file is generated
+ # by flymake and that file name might end with '_flymake.cc'. In that case,
+ # restore original file name here so that the corresponding header file can be
+ # found.
+ # e.g. If the file name is 'foo_flymake.cc', we should search for 'foo.h'
+ # instead of 'foo_flymake.h'
+ abs_filename = re.sub(r'_flymake\.cc$', '.cc', abs_filename)
+
+ # include_dict is modified during iteration, so we iterate over a copy of
+ # the keys.
+ header_keys = include_dict.keys()
+ for header in header_keys:
+ (same_module, common_path) = FilesBelongToSameModule(abs_filename, header)
+ fullpath = common_path + header
+ if same_module and UpdateIncludeState(fullpath, include_dict, io):
+ header_found = True
+
+ # If we can't find the header file for a .cc, assume it's because we don't
+ # know where to look. In that case we'll give up as we're not sure they
+ # didn't include it in the .h file.
+ # TODO(unknown): Do a better job of finding .h files so we are confident that
+ # not having the .h file means there isn't one.
+ if filename.endswith('.cc') and not header_found:
+ return
+
+ # All the lines have been processed, report the errors found.
+ for required_header_unstripped in required:
+ template = required[required_header_unstripped][1]
+ if required_header_unstripped.strip('<>"') not in include_dict:
+ error(filename, required[required_header_unstripped][0],
+ 'build/include_what_you_use', 4,
+ 'Add #include ' + required_header_unstripped + ' for ' + template)
+
+
+_RE_PATTERN_EXPLICIT_MAKEPAIR = re.compile(r'\bmake_pair\s*<')
+
+
+def CheckMakePairUsesDeduction(filename, clean_lines, linenum, error):
+ """Check that make_pair's template arguments are deduced.
+
+ G++ 4.6 in C++11 mode fails badly if make_pair's template arguments are
+ specified explicitly, and such use isn't intended in any case.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+ match = _RE_PATTERN_EXPLICIT_MAKEPAIR.search(line)
+ if match:
+ error(filename, linenum, 'build/explicit_make_pair',
+ 4, # 4 = high confidence
+ 'For C++11-compatibility, omit template arguments from make_pair'
+ ' OR use pair directly OR if appropriate, construct a pair directly')
+
+
+def CheckRedundantVirtual(filename, clean_lines, linenum, error):
+ """Check if line contains a redundant "virtual" function-specifier.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ # Look for "virtual" on current line.
+ line = clean_lines.elided[linenum]
+ virtual = Match(r'^(.*)(\bvirtual\b)(.*)$', line)
+ if not virtual: return
+
+ # Ignore "virtual" keywords that are near access-specifiers. These
+ # are only used in class base-specifier and do not apply to member
+ # functions.
+ if (Search(r'\b(public|protected|private)\s+$', virtual.group(1)) or
+ Match(r'^\s+(public|protected|private)\b', virtual.group(3))):
+ return
+
+ # Ignore the "virtual" keyword from virtual base classes. Usually
+ # there is a column on the same line in these cases (virtual base
+ # classes are rare in google3 because multiple inheritance is rare).
+ if Match(r'^.*[^:]:[^:].*$', line): return
+
+ # Look for the next opening parenthesis. This is the start of the
+ # parameter list (possibly on the next line shortly after virtual).
+ # TODO(unknown): doesn't work if there are virtual functions with
+ # decltype() or other things that use parentheses, but csearch suggests
+ # that this is rare.
+ end_col = -1
+ end_line = -1
+ start_col = len(virtual.group(2))
+ for start_line in xrange(linenum, min(linenum + 3, clean_lines.NumLines())):
+ line = clean_lines.elided[start_line][start_col:]
+ parameter_list = Match(r'^([^(]*)\(', line)
+ if parameter_list:
+ # Match parentheses to find the end of the parameter list
+ (_, end_line, end_col) = CloseExpression(
+ clean_lines, start_line, start_col + len(parameter_list.group(1)))
+ break
+ start_col = 0
+
+ if end_col < 0:
+ return # Couldn't find end of parameter list, give up
+
+ # Look for "override" or "final" after the parameter list
+ # (possibly on the next few lines).
+ for i in xrange(end_line, min(end_line + 3, clean_lines.NumLines())):
+ line = clean_lines.elided[i][end_col:]
+ match = Search(r'\b(override|final)\b', line)
+ if match:
+ error(filename, linenum, 'readability/inheritance', 4,
+ ('"virtual" is redundant since function is '
+ 'already declared as "%s"' % match.group(1)))
+
+ # Set end_col to check whole lines after we are done with the
+ # first line.
+ end_col = 0
+ if Search(r'[^\w]\s*$', line):
+ break
+
+
+def CheckRedundantOverrideOrFinal(filename, clean_lines, linenum, error):
+ """Check if line contains a redundant "override" or "final" virt-specifier.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ # Look for closing parenthesis nearby. We need one to confirm where
+ # the declarator ends and where the virt-specifier starts to avoid
+ # false positives.
+ line = clean_lines.elided[linenum]
+ declarator_end = line.rfind(')')
+ if declarator_end >= 0:
+ fragment = line[declarator_end:]
+ else:
+ if linenum > 1 and clean_lines.elided[linenum - 1].rfind(')') >= 0:
+ fragment = line
+ else:
+ return
+
+ # Check that at most one of "override" or "final" is present, not both
+ if Search(r'\boverride\b', fragment) and Search(r'\bfinal\b', fragment):
+ error(filename, linenum, 'readability/inheritance', 4,
+ ('"override" is redundant since function is '
+ 'already declared as "final"'))
+
+
+
+
+# Returns true if we are at a new block, and it is directly
+# inside of a namespace.
+def IsBlockInNameSpace(nesting_state, is_forward_declaration):
+ """Checks that the new block is directly in a namespace.
+
+ Args:
+ nesting_state: The _NestingState object that contains info about our state.
+ is_forward_declaration: If the class is a forward declared class.
+ Returns:
+ Whether or not the new block is directly in a namespace.
+ """
+ if is_forward_declaration:
+ if len(nesting_state.stack) >= 1 and (
+ isinstance(nesting_state.stack[-1], _NamespaceInfo)):
+ return True
+ else:
+ return False
+
+ return (len(nesting_state.stack) > 1 and
+ nesting_state.stack[-1].check_namespace_indentation and
+ isinstance(nesting_state.stack[-2], _NamespaceInfo))
+
+
+def ShouldCheckNamespaceIndentation(nesting_state, is_namespace_indent_item,
+ raw_lines_no_comments, linenum):
+ """This method determines if we should apply our namespace indentation check.
+
+ Args:
+ nesting_state: The current nesting state.
+ is_namespace_indent_item: If we just put a new class on the stack, True.
+ If the top of the stack is not a class, or we did not recently
+ add the class, False.
+ raw_lines_no_comments: The lines without the comments.
+ linenum: The current line number we are processing.
+
+ Returns:
+ True if we should apply our namespace indentation check. Currently, it
+ only works for classes and namespaces inside of a namespace.
+ """
+
+ is_forward_declaration = IsForwardClassDeclaration(raw_lines_no_comments,
+ linenum)
+
+ if not (is_namespace_indent_item or is_forward_declaration):
+ return False
+
+ # If we are in a macro, we do not want to check the namespace indentation.
+ if IsMacroDefinition(raw_lines_no_comments, linenum):
+ return False
+
+ return IsBlockInNameSpace(nesting_state, is_forward_declaration)
+
+
+# Call this method if the line is directly inside of a namespace.
+# If the line above is blank (excluding comments) or the start of
+# an inner namespace, it cannot be indented.
+def CheckItemIndentationInNamespace(filename, raw_lines_no_comments, linenum,
+ error):
+ line = raw_lines_no_comments[linenum]
+ if Match(r'^\s+', line):
+ error(filename, linenum, 'runtime/indentation_namespace', 4,
+ 'Do not indent within a namespace')
+
+
+def ProcessLine(filename, file_extension, clean_lines, line,
+ include_state, function_state, nesting_state, error,
+ extra_check_functions=[]):
+ """Processes a single line in the file.
+
+ Args:
+ filename: Filename of the file that is being processed.
+ file_extension: The extension (dot not included) of the file.
+ clean_lines: An array of strings, each representing a line of the file,
+ with comments stripped.
+ line: Number of line being processed.
+ include_state: An _IncludeState instance in which the headers are inserted.
+ function_state: A _FunctionState instance which counts function lines, etc.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ error: A callable to which errors are reported, which takes 4 arguments:
+ filename, line number, error level, and message
+ extra_check_functions: An array of additional check functions that will be
+ run on each source line. Each function takes 4
+ arguments: filename, clean_lines, line, error
+ """
+ raw_lines = clean_lines.raw_lines
+ ParseNolintSuppressions(filename, raw_lines[line], line, error)
+ nesting_state.Update(filename, clean_lines, line, error)
+ CheckForNamespaceIndentation(filename, nesting_state, clean_lines, line,
+ error)
+ if nesting_state.InAsmBlock(): return
+ CheckForFunctionLengths(filename, clean_lines, line, function_state, error)
+ CheckForMultilineCommentsAndStrings(filename, clean_lines, line, error)
+ CheckStyle(filename, clean_lines, line, file_extension, nesting_state, error)
+ CheckLanguage(filename, clean_lines, line, file_extension, include_state,
+ nesting_state, error)
+ CheckForNonConstReference(filename, clean_lines, line, nesting_state, error)
+ CheckForNonStandardConstructs(filename, clean_lines, line,
+ nesting_state, error)
+ CheckVlogArguments(filename, clean_lines, line, error)
+ CheckPosixThreading(filename, clean_lines, line, error)
+ CheckInvalidIncrement(filename, clean_lines, line, error)
+ CheckMakePairUsesDeduction(filename, clean_lines, line, error)
+ CheckRedundantVirtual(filename, clean_lines, line, error)
+ CheckRedundantOverrideOrFinal(filename, clean_lines, line, error)
+ for check_fn in extra_check_functions:
+ check_fn(filename, clean_lines, line, error)
+
+def FlagCxx11Features(filename, clean_lines, linenum, error):
+ """Flag those c++11 features that we only allow in certain places.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ include = Match(r'\s*#\s*include\s+[<"]([^<"]+)[">]', line)
+
+ # Flag unapproved C++ TR1 headers.
+ if include and include.group(1).startswith('tr1/'):
+ error(filename, linenum, 'build/c++tr1', 5,
+ ('C++ TR1 headers such as <%s> are unapproved.') % include.group(1))
+
+ # Flag unapproved C++11 headers.
+ if include and include.group(1) in ('cfenv',
+ 'condition_variable',
+ 'fenv.h',
+ 'future',
+ 'mutex',
+ 'thread',
+ 'chrono',
+ 'ratio',
+ 'regex',
+ 'system_error',
+ ):
+ error(filename, linenum, 'build/c++11', 5,
+ ('<%s> is an unapproved C++11 header.') % include.group(1))
+
+ # The only place where we need to worry about C++11 keywords and library
+ # features in preprocessor directives is in macro definitions.
+ if Match(r'\s*#', line) and not Match(r'\s*#\s*define\b', line): return
+
+ # These are classes and free functions. The classes are always
+ # mentioned as std::*, but we only catch the free functions if
+ # they're not found by ADL. They're alphabetical by header.
+ for top_name in (
+ # type_traits
+ 'alignment_of',
+ 'aligned_union',
+ ):
+ if Search(r'\bstd::%s\b' % top_name, line):
+ error(filename, linenum, 'build/c++11', 5,
+ ('std::%s is an unapproved C++11 class or function. Send c-style '
+ 'an example of where it would make your code more readable, and '
+ 'they may let you use it.') % top_name)
+
+
+def FlagCxx14Features(filename, clean_lines, linenum, error):
+ """Flag those C++14 features that we restrict.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ include = Match(r'\s*#\s*include\s+[<"]([^<"]+)[">]', line)
+
+ # Flag unapproved C++14 headers.
+ if include and include.group(1) in ('scoped_allocator', 'shared_mutex'):
+ error(filename, linenum, 'build/c++14', 5,
+ ('<%s> is an unapproved C++14 header.') % include.group(1))
+
+
+def ProcessFileData(filename, file_extension, lines, error,
+ extra_check_functions=[]):
+ """Performs lint checks and reports any errors to the given error function.
+
+ Args:
+ filename: Filename of the file that is being processed.
+ file_extension: The extension (dot not included) of the file.
+ lines: An array of strings, each representing a line of the file, with the
+ last element being empty if the file is terminated with a newline.
+ error: A callable to which errors are reported, which takes 4 arguments:
+ filename, line number, error level, and message
+ extra_check_functions: An array of additional check functions that will be
+ run on each source line. Each function takes 4
+ arguments: filename, clean_lines, line, error
+ """
+ lines = (['// marker so line numbers and indices both start at 1'] + lines +
+ ['// marker so line numbers end in a known way'])
+
+ include_state = _IncludeState()
+ function_state = _FunctionState()
+ nesting_state = NestingState()
+
+ ResetNolintSuppressions()
+
+ CheckForCopyright(filename, lines, error)
+ ProcessGlobalSuppresions(lines)
+ RemoveMultiLineComments(filename, lines, error)
+ clean_lines = CleansedLines(lines)
+
+ if IsHeaderExtension(file_extension):
+ CheckForHeaderGuard(filename, clean_lines, error)
+
+ for line in xrange(clean_lines.NumLines()):
+ ProcessLine(filename, file_extension, clean_lines, line,
+ include_state, function_state, nesting_state, error,
+ extra_check_functions)
+ FlagCxx11Features(filename, clean_lines, line, error)
+ nesting_state.CheckCompletedBlocks(filename, error)
+
+ CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error)
+
+ # Check that the .cc file has included its header if it exists.
+ if _IsSourceExtension(file_extension):
+ CheckHeaderFileIncluded(filename, include_state, error)
+
+ # We check here rather than inside ProcessLine so that we see raw
+ # lines rather than "cleaned" lines.
+ CheckForBadCharacters(filename, lines, error)
+
+ CheckForNewlineAtEOF(filename, lines, error)
+
+def ProcessConfigOverrides(filename):
+ """ Loads the configuration files and processes the config overrides.
+
+ Args:
+ filename: The name of the file being processed by the linter.
+
+ Returns:
+ False if the current |filename| should not be processed further.
+ """
+
+ abs_filename = os.path.abspath(filename)
+ cfg_filters = []
+ keep_looking = True
+ while keep_looking:
+ abs_path, base_name = os.path.split(abs_filename)
+ if not base_name:
+ break # Reached the root directory.
+
+ cfg_file = os.path.join(abs_path, "CPPLINT.cfg")
+ abs_filename = abs_path
+ if not os.path.isfile(cfg_file):
+ continue
+
+ try:
+ with open(cfg_file) as file_handle:
+ for line in file_handle:
+ line, _, _ = line.partition('#') # Remove comments.
+ if not line.strip():
+ continue
+
+ name, _, val = line.partition('=')
+ name = name.strip()
+ val = val.strip()
+ if name == 'set noparent':
+ keep_looking = False
+ elif name == 'filter':
+ cfg_filters.append(val)
+ elif name == 'exclude_files':
+ # When matching exclude_files pattern, use the base_name of
+ # the current file name or the directory name we are processing.
+ # For example, if we are checking for lint errors in /foo/bar/baz.cc
+ # and we found the .cfg file at /foo/CPPLINT.cfg, then the config
+ # file's "exclude_files" filter is meant to be checked against "bar"
+ # and not "baz" nor "bar/baz.cc".
+ if base_name:
+ pattern = re.compile(val)
+ if pattern.match(base_name):
+ if _cpplint_state.quiet:
+ # Suppress "Ignoring file" warning when using --quiet.
+ return False
+ sys.stderr.write('Ignoring "%s": file excluded by "%s". '
+ 'File path component "%s" matches '
+ 'pattern "%s"\n' %
+ (filename, cfg_file, base_name, val))
+ return False
+ elif name == 'linelength':
+ global _line_length
+ try:
+ _line_length = int(val)
+ except ValueError:
+ sys.stderr.write('Line length must be numeric.')
+ elif name == 'root':
+ global _root
+ # root directories are specified relative to CPPLINT.cfg dir.
+ _root = os.path.join(os.path.dirname(cfg_file), val)
+ elif name == 'headers':
+ ProcessHppHeadersOption(val)
+ else:
+ sys.stderr.write(
+ 'Invalid configuration option (%s) in file %s\n' %
+ (name, cfg_file))
+
+ except IOError:
+ sys.stderr.write(
+ "Skipping config file '%s': Can't open for reading\n" % cfg_file)
+ keep_looking = False
+
+ # Apply all the accumulated filters in reverse order (top-level directory
+ # config options having the least priority).
+ for filter in reversed(cfg_filters):
+ _AddFilters(filter)
+
+ return True
+
+
+def ProcessFile(filename, vlevel, extra_check_functions=[]):
+ """Does google-lint on a single file.
+
+ Args:
+ filename: The name of the file to parse.
+
+ vlevel: The level of errors to report. Every error of confidence
+ >= verbose_level will be reported. 0 is a good default.
+
+ extra_check_functions: An array of additional check functions that will be
+ run on each source line. Each function takes 4
+ arguments: filename, clean_lines, line, error
+ """
+
+ _SetVerboseLevel(vlevel)
+ _BackupFilters()
+ old_errors = _cpplint_state.error_count
+
+ if not ProcessConfigOverrides(filename):
+ _RestoreFilters()
+ return
+
+ lf_lines = []
+ crlf_lines = []
+ try:
+ # Support the UNIX convention of using "-" for stdin. Note that
+ # we are not opening the file with universal newline support
+ # (which codecs doesn't support anyway), so the resulting lines do
+ # contain trailing '\r' characters if we are reading a file that
+ # has CRLF endings.
+ # If after the split a trailing '\r' is present, it is removed
+ # below.
+ if filename == '-':
+ lines = codecs.StreamReaderWriter(sys.stdin,
+ codecs.getreader('utf8'),
+ codecs.getwriter('utf8'),
+ 'replace').read().split('\n')
+ else:
+ lines = codecs.open(filename, 'r', 'utf8', 'replace').read().split('\n')
+
+ # Remove trailing '\r'.
+ # The -1 accounts for the extra trailing blank line we get from split()
+ for linenum in range(len(lines) - 1):
+ if lines[linenum].endswith('\r'):
+ lines[linenum] = lines[linenum].rstrip('\r')
+ crlf_lines.append(linenum + 1)
+ else:
+ lf_lines.append(linenum + 1)
+
+ except IOError:
+ sys.stderr.write(
+ "Skipping input '%s': Can't open for reading\n" % filename)
+ _RestoreFilters()
+ return
+
+ # Note, if no dot is found, this will give the entire filename as the ext.
+ file_extension = filename[filename.rfind('.') + 1:]
+
+ # When reading from stdin, the extension is unknown, so no cpplint tests
+ # should rely on the extension.
+ if filename != '-' and file_extension not in _valid_extensions:
+ sys.stderr.write('Ignoring %s; not a valid file name '
+ '(%s)\n' % (filename, ', '.join(_valid_extensions)))
+ else:
+ ProcessFileData(filename, file_extension, lines, Error,
+ extra_check_functions)
+
+ # If end-of-line sequences are a mix of LF and CR-LF, issue
+ # warnings on the lines with CR.
+ #
+ # Don't issue any warnings if all lines are uniformly LF or CR-LF,
+ # since critique can handle these just fine, and the style guide
+ # doesn't dictate a particular end of line sequence.
+ #
+ # We can't depend on os.linesep to determine what the desired
+ # end-of-line sequence should be, since that will return the
+ # server-side end-of-line sequence.
+ if lf_lines and crlf_lines:
+ # Warn on every line with CR. An alternative approach might be to
+ # check whether the file is mostly CRLF or just LF, and warn on the
+ # minority, we bias toward LF here since most tools prefer LF.
+ for linenum in crlf_lines:
+ Error(filename, linenum, 'whitespace/newline', 1,
+ 'Unexpected \\r (^M) found; better to use only \\n')
+
+ # Suppress printing anything if --quiet was passed unless the error
+ # count has increased after processing this file.
+ if not _cpplint_state.quiet or old_errors != _cpplint_state.error_count:
+ sys.stdout.write('Done processing %s\n' % filename)
+ _RestoreFilters()
+
+
+def PrintUsage(message):
+ """Prints a brief usage string and exits, optionally with an error message.
+
+ Args:
+ message: The optional error message.
+ """
+ sys.stderr.write(_USAGE)
+ if message:
+ sys.exit('\nFATAL ERROR: ' + message)
+ else:
+ sys.exit(1)
+
+
+def PrintCategories():
+ """Prints a list of all the error-categories used by error messages.
+
+ These are the categories used to filter messages via --filter.
+ """
+ sys.stderr.write(''.join(' %s\n' % cat for cat in _ERROR_CATEGORIES))
+ sys.exit(0)
+
+
+def ParseArguments(args):
+ """Parses the command line arguments.
+
+ This may set the output format and verbosity level as side-effects.
+
+ Args:
+ args: The command line arguments:
+
+ Returns:
+ The list of filenames to lint.
+ """
+ try:
+ (opts, filenames) = getopt.getopt(args, '', ['help', 'output=', 'verbose=',
+ 'counting=',
+ 'filter=',
+ 'root=',
+ 'linelength=',
+ 'extensions=',
+ 'headers=',
+ 'quiet'])
+ except getopt.GetoptError:
+ PrintUsage('Invalid arguments.')
+
+ verbosity = _VerboseLevel()
+ output_format = _OutputFormat()
+ filters = ''
+ quiet = _Quiet()
+ counting_style = ''
+
+ for (opt, val) in opts:
+ if opt == '--help':
+ PrintUsage(None)
+ elif opt == '--output':
+ if val not in ('emacs', 'vs7', 'eclipse'):
+ PrintUsage('The only allowed output formats are emacs, vs7 and eclipse.')
+ output_format = val
+ elif opt == '--quiet':
+ quiet = True
+ elif opt == '--verbose':
+ verbosity = int(val)
+ elif opt == '--filter':
+ filters = val
+ if not filters:
+ PrintCategories()
+ elif opt == '--counting':
+ if val not in ('total', 'toplevel', 'detailed'):
+ PrintUsage('Valid counting options are total, toplevel, and detailed')
+ counting_style = val
+ elif opt == '--root':
+ global _root
+ _root = val
+ elif opt == '--linelength':
+ global _line_length
+ try:
+ _line_length = int(val)
+ except ValueError:
+ PrintUsage('Line length must be digits.')
+ elif opt == '--extensions':
+ global _valid_extensions
+ try:
+ _valid_extensions = set(val.split(','))
+ except ValueError:
+ PrintUsage('Extensions must be comma separated list.')
+ elif opt == '--headers':
+ ProcessHppHeadersOption(val)
+
+ if not filenames:
+ PrintUsage('No files were specified.')
+
+ _SetOutputFormat(output_format)
+ _SetQuiet(quiet)
+ _SetVerboseLevel(verbosity)
+ _SetFilters(filters)
+ _SetCountingStyle(counting_style)
+
+ return filenames
+
+
+def main():
+ filenames = ParseArguments(sys.argv[1:])
+
+ # Change stderr to write with replacement characters so we don't die
+ # if we try to print something containing non-ASCII characters.
+ sys.stderr = codecs.StreamReaderWriter(sys.stderr,
+ codecs.getreader('utf8'),
+ codecs.getwriter('utf8'),
+ 'replace')
+
+ _cpplint_state.ResetErrorCounts()
+ for filename in filenames:
+ ProcessFile(filename, _cpplint_state.verbose_level)
+ # If --quiet is passed, suppress printing error count unless there are errors.
+ if not _cpplint_state.quiet or _cpplint_state.error_count > 0:
+ _cpplint_state.PrintErrorCounts()
+
+ sys.exit(_cpplint_state.error_count > 0)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/third_party/aom/tools/diff.py b/third_party/aom/tools/diff.py
new file mode 100644
index 0000000000..7bb6b7fcb4
--- /dev/null
+++ b/third_party/aom/tools/diff.py
@@ -0,0 +1,132 @@
+#!/usr/bin/env python3
+##
+## Copyright (c) 2016, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 2 Clause License and
+## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+## was not distributed with this source code in the LICENSE file, you can
+## obtain it at www.aomedia.org/license/software. If the Alliance for Open
+## Media Patent License 1.0 was not distributed with this source code in the
+## PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+##
+"""Classes for representing diff pieces."""
+
+__author__ = "jkoleszar@google.com"
+
+import re
+
+
+class DiffLines(object):
+ """A container for one half of a diff."""
+
+ def __init__(self, filename, offset, length):
+ self.filename = filename
+ self.offset = offset
+ self.length = length
+ self.lines = []
+ self.delta_line_nums = []
+
+ def Append(self, line):
+ l = len(self.lines)
+ if line[0] != " ":
+ self.delta_line_nums.append(self.offset + l)
+ self.lines.append(line[1:])
+ assert l+1 <= self.length
+
+ def Complete(self):
+ return len(self.lines) == self.length
+
+ def __contains__(self, item):
+ return item >= self.offset and item <= self.offset + self.length - 1
+
+
+class DiffHunk(object):
+ """A container for one diff hunk, consisting of two DiffLines."""
+
+ def __init__(self, header, file_a, file_b, start_a, len_a, start_b, len_b):
+ self.header = header
+ self.left = DiffLines(file_a, start_a, len_a)
+ self.right = DiffLines(file_b, start_b, len_b)
+ self.lines = []
+
+ def Append(self, line):
+ """Adds a line to the DiffHunk and its DiffLines children."""
+ if line[0] == "-":
+ self.left.Append(line)
+ elif line[0] == "+":
+ self.right.Append(line)
+ elif line[0] == " ":
+ self.left.Append(line)
+ self.right.Append(line)
+ elif line[0] == "\\":
+ # Ignore newline messages from git diff.
+ pass
+ else:
+ assert False, ("Unrecognized character at start of diff line "
+ "%r" % line[0])
+ self.lines.append(line)
+
+ def Complete(self):
+ return self.left.Complete() and self.right.Complete()
+
+ def __repr__(self):
+ return "DiffHunk(%s, %s, len %d)" % (
+ self.left.filename, self.right.filename,
+ max(self.left.length, self.right.length))
+
+
+def ParseDiffHunks(stream):
+ """Walk a file-like object, yielding DiffHunks as they're parsed."""
+
+ file_regex = re.compile(r"(\+\+\+|---) (\S+)")
+ range_regex = re.compile(r"@@ -(\d+)(,(\d+))? \+(\d+)(,(\d+))?")
+ hunk = None
+ while True:
+ line = stream.readline()
+ if not line:
+ break
+
+ if hunk is None:
+ # Parse file names
+ diff_file = file_regex.match(line)
+ if diff_file:
+ if line.startswith("---"):
+ a_line = line
+ a = diff_file.group(2)
+ continue
+ if line.startswith("+++"):
+ b_line = line
+ b = diff_file.group(2)
+ continue
+
+ # Parse offset/lengths
+ diffrange = range_regex.match(line)
+ if diffrange:
+ if diffrange.group(2):
+ start_a = int(diffrange.group(1))
+ len_a = int(diffrange.group(3))
+ else:
+ start_a = 1
+ len_a = int(diffrange.group(1))
+
+ if diffrange.group(5):
+ start_b = int(diffrange.group(4))
+ len_b = int(diffrange.group(6))
+ else:
+ start_b = 1
+ len_b = int(diffrange.group(4))
+
+ header = [a_line, b_line, line]
+ hunk = DiffHunk(header, a, b, start_a, len_a, start_b, len_b)
+ else:
+ # Add the current line to the hunk
+ hunk.Append(line)
+
+ # See if the whole hunk has been parsed. If so, yield it and prepare
+ # for the next hunk.
+ if hunk.Complete():
+ yield hunk
+ hunk = None
+
+ # Partial hunks are a parse error
+ assert hunk is None
diff --git a/third_party/aom/tools/dump_obu.cc b/third_party/aom/tools/dump_obu.cc
new file mode 100644
index 0000000000..b9ff985c44
--- /dev/null
+++ b/third_party/aom/tools/dump_obu.cc
@@ -0,0 +1,168 @@
+/*
+ * Copyright (c) 2017, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <stdlib.h>
+#include <string.h>
+
+#include <memory>
+#include <string>
+
+#include "config/aom_config.h"
+
+#include "common/ivfdec.h"
+#include "common/obudec.h"
+#include "common/tools_common.h"
+#include "common/webmdec.h"
+#include "tools/obu_parser.h"
+
+namespace {
+
+const size_t kInitialBufferSize = 100 * 1024;
+
+struct InputContext {
+ InputContext() = default;
+ ~InputContext() { free(unit_buffer); }
+
+ void Init() {
+ memset(avx_ctx, 0, sizeof(*avx_ctx));
+ memset(obu_ctx, 0, sizeof(*obu_ctx));
+ obu_ctx->avx_ctx = avx_ctx;
+#if CONFIG_WEBM_IO
+ memset(webm_ctx, 0, sizeof(*webm_ctx));
+#endif
+ }
+
+ AvxInputContext *avx_ctx = nullptr;
+ ObuDecInputContext *obu_ctx = nullptr;
+#if CONFIG_WEBM_IO
+ WebmInputContext *webm_ctx = nullptr;
+#endif
+ uint8_t *unit_buffer = nullptr;
+ size_t unit_buffer_size = 0;
+};
+
+void PrintUsage() {
+ printf("Libaom OBU dump.\nUsage: dump_obu <input_file>\n");
+}
+
+VideoFileType GetFileType(InputContext *ctx) {
+ // TODO(https://crbug.com/aomedia/1706): webm type does not support reading
+ // from stdin yet, and file_is_webm is not using the detect buffer when
+ // determining the type. Therefore it should only be checked when using a file
+ // and needs to be checked prior to other types.
+#if CONFIG_WEBM_IO
+ if (file_is_webm(ctx->webm_ctx, ctx->avx_ctx)) return FILE_TYPE_WEBM;
+#endif
+ if (file_is_ivf(ctx->avx_ctx)) return FILE_TYPE_IVF;
+ if (file_is_obu(ctx->obu_ctx)) return FILE_TYPE_OBU;
+ return FILE_TYPE_RAW;
+}
+
+bool ReadTemporalUnit(InputContext *ctx, size_t *unit_size) {
+ const VideoFileType file_type = ctx->avx_ctx->file_type;
+ switch (file_type) {
+ case FILE_TYPE_IVF: {
+ if (ivf_read_frame(ctx->avx_ctx, &ctx->unit_buffer, unit_size,
+ &ctx->unit_buffer_size, NULL)) {
+ return false;
+ }
+ break;
+ }
+ case FILE_TYPE_OBU: {
+ if (obudec_read_temporal_unit(ctx->obu_ctx, &ctx->unit_buffer, unit_size,
+ &ctx->unit_buffer_size)) {
+ return false;
+ }
+ break;
+ }
+#if CONFIG_WEBM_IO
+ case FILE_TYPE_WEBM: {
+ if (webm_read_frame(ctx->webm_ctx, &ctx->unit_buffer, unit_size,
+ &ctx->unit_buffer_size)) {
+ return false;
+ }
+ break;
+ }
+#endif
+ default:
+ // TODO(tomfinegan): Abuse FILE_TYPE_RAW for AV1/OBU elementary streams?
+ fprintf(stderr, "Error: Unsupported file type.\n");
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace
+
+int main(int argc, const char *argv[]) {
+ // TODO(tomfinegan): Could do with some params for verbosity.
+ if (argc < 2) {
+ PrintUsage();
+ return EXIT_SUCCESS;
+ }
+
+ const std::string filename = argv[1];
+
+ using FilePtr = std::unique_ptr<FILE, decltype(&fclose)>;
+ FilePtr input_file(fopen(filename.c_str(), "rb"), &fclose);
+ if (input_file.get() == nullptr) {
+ input_file.release();
+ fprintf(stderr, "Error: Cannot open input file.\n");
+ return EXIT_FAILURE;
+ }
+
+ AvxInputContext avx_ctx;
+ InputContext input_ctx;
+ input_ctx.avx_ctx = &avx_ctx;
+ ObuDecInputContext obu_ctx;
+ input_ctx.obu_ctx = &obu_ctx;
+#if CONFIG_WEBM_IO
+ WebmInputContext webm_ctx;
+ input_ctx.webm_ctx = &webm_ctx;
+#endif
+
+ input_ctx.Init();
+ avx_ctx.file = input_file.get();
+ avx_ctx.file_type = GetFileType(&input_ctx);
+
+ // Note: the reader utilities will realloc the buffer using realloc() etc.
+ // Can't have nice things like unique_ptr wrappers with that type of
+ // behavior underneath the function calls.
+ input_ctx.unit_buffer =
+ reinterpret_cast<uint8_t *>(calloc(kInitialBufferSize, 1));
+ if (!input_ctx.unit_buffer) {
+ fprintf(stderr, "Error: No memory, can't alloc input buffer.\n");
+ return EXIT_FAILURE;
+ }
+ input_ctx.unit_buffer_size = kInitialBufferSize;
+
+ size_t unit_size = 0;
+ int unit_number = 0;
+ int64_t obu_overhead_bytes_total = 0;
+ while (ReadTemporalUnit(&input_ctx, &unit_size)) {
+ printf("Temporal unit %d\n", unit_number);
+
+ int obu_overhead_current_unit = 0;
+ if (!aom_tools::DumpObu(input_ctx.unit_buffer, static_cast<int>(unit_size),
+ &obu_overhead_current_unit)) {
+ fprintf(stderr, "Error: Temporal Unit parse failed on unit number %d.\n",
+ unit_number);
+ return EXIT_FAILURE;
+ }
+ printf(" OBU overhead: %d\n", obu_overhead_current_unit);
+ ++unit_number;
+ obu_overhead_bytes_total += obu_overhead_current_unit;
+ }
+
+ printf("File total OBU overhead: %" PRId64 "\n", obu_overhead_bytes_total);
+ return EXIT_SUCCESS;
+}
diff --git a/third_party/aom/tools/frame_size_variation_analyzer.py b/third_party/aom/tools/frame_size_variation_analyzer.py
new file mode 100644
index 0000000000..5c02319df1
--- /dev/null
+++ b/third_party/aom/tools/frame_size_variation_analyzer.py
@@ -0,0 +1,74 @@
+# RTC frame size variation analyzer
+# Usage:
+# 1. Config with "-DCONFIG_OUTPUT_FRAME_SIZE=1".
+# 2. Build aomenc. Encode a file, and generate output file: frame_sizes.csv
+# 3. Run: python ./frame_size.py frame_sizes.csv target-bitrate fps
+# Where target-bitrate: Bitrate (kbps), and fps is frame per second.
+# Example: python ../aom/tools/frame_size_variation_analyzer.py frame_sizes.csv
+# 1000 30
+
+import numpy as np
+import csv
+import sys
+import matplotlib.pyplot as plt
+
+# return the moving average
+def moving_average(x, w):
+ return np.convolve(x, np.ones(w), 'valid') / w
+
+def frame_size_analysis(filename, target_br, fps):
+ tbr = target_br * 1000 / fps
+
+ with open(filename, 'r') as infile:
+ raw_data = list(csv.reader(infile, delimiter=','))
+
+ data = np.array(raw_data).astype(float)
+ fsize = data[:, 0].astype(float) # frame size
+ qindex = data[:, 1].astype(float) # qindex
+
+ # Frame bit rate mismatch
+ mismatch = np.absolute(fsize - np.full(fsize.size, tbr))
+
+ # Count how many frames are more than 2.5x of frame target bit rate.
+ tbr_thr = tbr * 2.5
+ cnt = 0
+ idx = np.arange(fsize.size)
+ for i in idx:
+ if fsize[i] > tbr_thr:
+ cnt = cnt + 1
+
+ # Use the 15-frame moving window
+ win = 15
+ avg_fsize = moving_average(fsize, win)
+ win_mismatch = np.absolute(avg_fsize - np.full(avg_fsize.size, tbr))
+
+ print('[Target frame rate (bit)]:', "%.2f"%tbr)
+ print('[Average frame rate (bit)]:', "%.2f"%np.average(fsize))
+ print('[Frame rate standard deviation]:', "%.2f"%np.std(fsize))
+ print('[Max/min frame rate (bit)]:', "%.2f"%np.max(fsize), '/', "%.2f"%np.min(fsize))
+ print('[Average frame rate mismatch (bit)]:', "%.2f"%np.average(mismatch))
+ print('[Number of frames (frame rate > 2.5x of target frame rate)]:', cnt)
+ print(' Moving window size:', win)
+ print('[Moving average frame rate mismatch (bit)]:', "%.2f"%np.average(win_mismatch))
+ print('------------------------------')
+
+ figure, axis = plt.subplots(2)
+ x = np.arange(fsize.size)
+ axis[0].plot(x, fsize, color='blue')
+ axis[0].set_title("frame sizes")
+ axis[1].plot(x, qindex, color='blue')
+ axis[1].set_title("frame qindex")
+ plt.tight_layout()
+
+ # Save the plot
+ plotname = filename + '.png'
+ plt.savefig(plotname)
+ plt.show()
+
+if __name__ == '__main__':
+ if (len(sys.argv) < 4):
+ print(sys.argv[0], 'input_file, target_bitrate, fps')
+ sys.exit()
+ target_br = int(sys.argv[2])
+ fps = int(sys.argv[3])
+ frame_size_analysis(sys.argv[1], target_br, fps)
diff --git a/third_party/aom/tools/gen_authors.sh b/third_party/aom/tools/gen_authors.sh
new file mode 100755
index 0000000000..5def8bc898
--- /dev/null
+++ b/third_party/aom/tools/gen_authors.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+# Add organization names manually.
+
+cat <<EOF
+# This file is automatically generated from the git commit history
+# by tools/gen_authors.sh.
+
+$(git log --pretty=format:"%aN <%aE>" | sort | uniq | grep -v "corp.google\|clang-format")
+EOF
diff --git a/third_party/aom/tools/gen_constrained_tokenset.py b/third_party/aom/tools/gen_constrained_tokenset.py
new file mode 100755
index 0000000000..f5b0816dbf
--- /dev/null
+++ b/third_party/aom/tools/gen_constrained_tokenset.py
@@ -0,0 +1,120 @@
+#!/usr/bin/env python3
+##
+## Copyright (c) 2016, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 2 Clause License and
+## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+## was not distributed with this source code in the LICENSE file, you can
+## obtain it at www.aomedia.org/license/software. If the Alliance for Open
+## Media Patent License 1.0 was not distributed with this source code in the
+## PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+##
+"""Generate the probability model for the constrained token set.
+
+Model obtained from a 2-sided zero-centered distribution derived
+from a Pareto distribution. The cdf of the distribution is:
+cdf(x) = 0.5 + 0.5 * sgn(x) * [1 - {alpha/(alpha + |x|)} ^ beta]
+
+For a given beta and a given probability of the 1-node, the alpha
+is first solved, and then the {alpha, beta} pair is used to generate
+the probabilities for the rest of the nodes.
+"""
+
+import heapq
+import sys
+import numpy as np
+import scipy.optimize
+import scipy.stats
+
+
+def cdf_spareto(x, xm, beta):
+ p = 1 - (xm / (np.abs(x) + xm))**beta
+ p = 0.5 + 0.5 * np.sign(x) * p
+ return p
+
+
+def get_spareto(p, beta):
+ cdf = cdf_spareto
+
+ def func(x):
+ return ((cdf(1.5, x, beta) - cdf(0.5, x, beta)) /
+ (1 - cdf(0.5, x, beta)) - p)**2
+
+ alpha = scipy.optimize.fminbound(func, 1e-12, 10000, xtol=1e-12)
+ parray = np.zeros(11)
+ parray[0] = 2 * (cdf(0.5, alpha, beta) - 0.5)
+ parray[1] = (2 * (cdf(1.5, alpha, beta) - cdf(0.5, alpha, beta)))
+ parray[2] = (2 * (cdf(2.5, alpha, beta) - cdf(1.5, alpha, beta)))
+ parray[3] = (2 * (cdf(3.5, alpha, beta) - cdf(2.5, alpha, beta)))
+ parray[4] = (2 * (cdf(4.5, alpha, beta) - cdf(3.5, alpha, beta)))
+ parray[5] = (2 * (cdf(6.5, alpha, beta) - cdf(4.5, alpha, beta)))
+ parray[6] = (2 * (cdf(10.5, alpha, beta) - cdf(6.5, alpha, beta)))
+ parray[7] = (2 * (cdf(18.5, alpha, beta) - cdf(10.5, alpha, beta)))
+ parray[8] = (2 * (cdf(34.5, alpha, beta) - cdf(18.5, alpha, beta)))
+ parray[9] = (2 * (cdf(66.5, alpha, beta) - cdf(34.5, alpha, beta)))
+ parray[10] = 2 * (1. - cdf(66.5, alpha, beta))
+ return parray
+
+
+def quantize_probs(p, save_first_bin, bits):
+ """Quantize probability precisely.
+
+ Quantize probabilities minimizing dH (Kullback-Leibler divergence)
+ approximated by: sum (p_i-q_i)^2/p_i.
+ References:
+ https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
+ https://github.com/JarekDuda/AsymmetricNumeralSystemsToolkit
+ """
+ num_sym = p.size
+ p = np.clip(p, 1e-16, 1)
+ L = 2**bits
+ pL = p * L
+ ip = 1. / p # inverse probability
+ q = np.clip(np.round(pL), 1, L + 1 - num_sym)
+ quant_err = (pL - q)**2 * ip
+ sgn = np.sign(L - q.sum()) # direction of correction
+ if sgn != 0: # correction is needed
+ v = [] # heap of adjustment results (adjustment err, index) of each symbol
+ for i in range(1 if save_first_bin else 0, num_sym):
+ q_adj = q[i] + sgn
+ if q_adj > 0 and q_adj < L:
+ adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i]
+ heapq.heappush(v, (adj_err, i))
+ while q.sum() != L:
+ # apply lowest error adjustment
+ (adj_err, i) = heapq.heappop(v)
+ quant_err[i] += adj_err
+ q[i] += sgn
+ # calculate the cost of adjusting this symbol again
+ q_adj = q[i] + sgn
+ if q_adj > 0 and q_adj < L:
+ adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i]
+ heapq.heappush(v, (adj_err, i))
+ return q
+
+
+def get_quantized_spareto(p, beta, bits, first_token):
+ parray = get_spareto(p, beta)
+ parray = parray[1:] / (1 - parray[0])
+ # CONFIG_NEW_TOKENSET
+ if first_token > 1:
+ parray = parray[1:] / (1 - parray[0])
+ qarray = quantize_probs(parray, first_token == 1, bits)
+ return qarray.astype(np.int)
+
+
+def main(bits=15, first_token=1):
+ beta = 8
+ for q in range(1, 256):
+ parray = get_quantized_spareto(q / 256., beta, bits, first_token)
+ assert parray.sum() == 2**bits
+ print('{', ', '.join('%d' % i for i in parray), '},')
+
+
+if __name__ == '__main__':
+ if len(sys.argv) > 2:
+ main(int(sys.argv[1]), int(sys.argv[2]))
+ elif len(sys.argv) > 1:
+ main(int(sys.argv[1]))
+ else:
+ main()
diff --git a/third_party/aom/tools/gop_bitrate/analyze_data.py b/third_party/aom/tools/gop_bitrate/analyze_data.py
new file mode 100644
index 0000000000..4e006b9220
--- /dev/null
+++ b/third_party/aom/tools/gop_bitrate/analyze_data.py
@@ -0,0 +1,18 @@
+with open('experiment.txt', 'r') as file:
+ lines = file.readlines()
+ curr_filename = ''
+ keyframe = 0
+ actual_value = 0
+ estimate_value = 0
+ print('filename, estimated value (b), actual value (b)')
+ for line in lines:
+ if line.startswith('input:'):
+ curr_filename = line[13:].strip()
+ if line.startswith('estimated'):
+ estimate_value = float(line[19:].strip())
+ if line.startswith('frame:'):
+ actual_value += float(line[line.find('size')+6:line.find('total')-2])
+ if line.startswith('****'):
+ print(f'{curr_filename}, {estimate_value}, {actual_value}')
+ estimate_value = 0
+ actual_value = 0
diff --git a/third_party/aom/tools/gop_bitrate/encode_all_script.sh b/third_party/aom/tools/gop_bitrate/encode_all_script.sh
new file mode 100755
index 0000000000..0689b33138
--- /dev/null
+++ b/third_party/aom/tools/gop_bitrate/encode_all_script.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+#INPUT=media/cheer_sif.y4m
+OUTPUT=test.webm
+LIMIT=17
+CPU_USED=3
+CQ_LEVEL=36
+
+for input in media/*
+do
+ echo "****" >> experiment.txt
+ echo "input: $input" >> experiment.txt
+ ./aomenc --limit=$LIMIT --codec=av1 --cpu-used=$CPU_USED --end-usage=q --cq-level=$CQ_LEVEL --psnr --threads=0 --profile=0 --lag-in-frames=35 --min-q=0 --max-q=63 --auto-alt-ref=1 --passes=2 --kf-max-dist=160 --kf-min-dist=0 --drop-frame=0 --static-thresh=0 --minsection-pct=0 --maxsection-pct=2000 --arnr-maxframes=7 --arnr-strength=5 --sharpness=0 --undershoot-pct=100 --overshoot-pct=100 --frame-parallel=0 --tile-columns=0 -o $OUTPUT $input >> experiment.txt
+done
diff --git a/third_party/aom/tools/gop_bitrate/python/bitrate_accuracy.py b/third_party/aom/tools/gop_bitrate/python/bitrate_accuracy.py
new file mode 100644
index 0000000000..2a5da6a794
--- /dev/null
+++ b/third_party/aom/tools/gop_bitrate/python/bitrate_accuracy.py
@@ -0,0 +1,185 @@
+import numpy as np
+
+# Model A only.
+# Uses least squares regression to find the solution
+# when there is one unknown variable.
+def lstsq_solution(A, B):
+ A_inv = np.linalg.pinv(A)
+ x = np.matmul(A_inv, B)
+ return x[0][0]
+
+# Model B only.
+# Uses the pseudoinverse matrix to find the solution
+# when there are two unknown variables.
+def pinv_solution(A, mv, B):
+ new_A = np.concatenate((A, mv), axis=1)
+ new_A_inv = np.linalg.pinv(new_A)
+ new_x = np.matmul(new_A_inv, B)
+ print("pinv solution:", new_x[0][0], new_x[1][0])
+ return (new_x[0][0], new_x[1][0])
+
+# Model A only.
+# Finds the coefficient to multiply A by to minimize
+# the percentage error between A and B.
+def minimize_percentage_error_model_a(A, B):
+ R = np.divide(A, B)
+ num = 0
+ den = 0
+ best_x = 0
+ best_error = 100
+ for r_i in R:
+ num += r_i
+ den += r_i**2
+ if den == 0:
+ return 0
+ return (num/den)[0]
+
+# Model B only.
+# Finds the coefficients to multiply to the frame bitrate
+# and the motion vector bitrate to minimize the percent error.
+def minimize_percentage_error_model_b(r_e, r_m, r_f):
+ r_ef = np.divide(r_e, r_f)
+ r_mf = np.divide(r_m, r_f)
+ sum_ef = np.sum(r_ef)
+ sum_ef_sq = np.sum(np.square(r_ef))
+ sum_mf = np.sum(r_mf)
+ sum_mf_sq = np.sum(np.square(r_mf))
+ sum_ef_mf = np.sum(np.multiply(r_ef, r_mf))
+ # Divides x by y. If y is zero, returns 0.
+ divide = lambda x, y : 0 if y == 0 else x / y
+ # Set up and solve the matrix equation
+ A = np.array([[1, divide(sum_ef_mf, sum_ef_sq)],[divide(sum_ef_mf, sum_mf_sq), 1]])
+ B = np.array([divide(sum_ef, sum_ef_sq), divide(sum_mf, sum_mf_sq)])
+ A_inv = np.linalg.pinv(A)
+ x = np.matmul(A_inv, B)
+ return x
+
+# Model A only.
+# Calculates the least squares error between A and B
+# using coefficients in X.
+def average_lstsq_error(A, B, x):
+ error = 0
+ n = 0
+ for i, a in enumerate(A):
+ a = a[0]
+ b = B[i][0]
+ if b == 0:
+ continue
+ n += 1
+ error += (b - x*a)**2
+ if n == 0:
+ return None
+ error /= n
+ return error
+
+# Model A only.
+# Calculates the average percentage error between A and B.
+def average_percent_error_model_a(A, B, x):
+ error = 0
+ n = 0
+ for i, a in enumerate(A):
+ a = a[0]
+ b = B[i][0]
+ if b == 0:
+ continue
+ n += 1
+ error_i = (abs(x*a-b)/b)*100
+ error += error_i
+ error /= n
+ return error
+
+# Model B only.
+# Calculates the average percentage error between A and B.
+def average_percent_error_model_b(A, M, B, x):
+ error = 0
+ for i, a in enumerate(A):
+ a = a[0]
+ mv = M[i]
+ b = B[i][0]
+ if b == 0:
+ continue
+ estimate = x[0]*a
+ estimate += x[1]*mv
+ error += abs(estimate - b) / b
+ error *= 100
+ error /= A.shape[0]
+ return error
+
+def average_squared_error_model_a(A, B, x):
+ error = 0
+ n = 0
+ for i, a in enumerate(A):
+ a = a[0]
+ b = B[i][0]
+ if b == 0:
+ continue
+ n += 1
+ error_i = (1 - x*(a/b))**2
+ error += error_i
+ error /= n
+ error = error**0.5
+ return error * 100
+
+def average_squared_error_model_b(A, M, B, x):
+ error = 0
+ n = 0
+ for i, a in enumerate(A):
+ a = a[0]
+ b = B[i][0]
+ mv = M[i]
+ if b == 0:
+ continue
+ n += 1
+ error_i = 1 - ((x[0]*a + x[1]*mv)/b)
+ error_i = error_i**2
+ error += error_i
+ error /= n
+ error = error**0.5
+ return error * 100
+
+# Traverses the data and prints out one value for
+# each update type.
+def print_solutions(file_path):
+ data = np.genfromtxt(file_path, delimiter="\t")
+ prev_update = 0
+ split_list_indices = list()
+ for i, val in enumerate(data):
+ if prev_update != val[3]:
+ split_list_indices.append(i)
+ prev_update = val[3]
+ split = np.split(data, split_list_indices)
+ for array in split:
+ A, mv, B, update = np.hsplit(array, 4)
+ z = np.where(B == 0)[0]
+ r_e = np.delete(A, z, axis=0)
+ r_m = np.delete(mv, z, axis=0)
+ r_f = np.delete(B, z, axis=0)
+ A = r_e
+ mv = r_m
+ B = r_f
+ all_zeros = not A.any()
+ if all_zeros:
+ continue
+ print("update type:", update[0][0])
+ x_ls = lstsq_solution(A, B)
+ x_a = minimize_percentage_error_model_a(A, B)
+ x_b = minimize_percentage_error_model_b(A, mv, B)
+ percent_error_a = average_percent_error_model_a(A, B, x_a)
+ percent_error_b = average_percent_error_model_b(A, mv, B, x_b)[0]
+ baseline_percent_error_a = average_percent_error_model_a(A, B, 1)
+ baseline_percent_error_b = average_percent_error_model_b(A, mv, B, [1, 1])[0]
+
+ squared_error_a = average_squared_error_model_a(A, B, x_a)
+ squared_error_b = average_squared_error_model_b(A, mv, B, x_b)[0]
+ baseline_squared_error_a = average_squared_error_model_a(A, B, 1)
+ baseline_squared_error_b = average_squared_error_model_b(A, mv, B, [1, 1])[0]
+
+ print("model,\tframe_coeff,\tmv_coeff,\terror,\tbaseline_error")
+ print("Model A %_error,\t" + str(x_a) + ",\t" + str(0) + ",\t" + str(percent_error_a) + ",\t" + str(baseline_percent_error_a))
+ print("Model A sq_error,\t" + str(x_a) + ",\t" + str(0) + ",\t" + str(squared_error_a) + ",\t" + str(baseline_squared_error_a))
+ print("Model B %_error,\t" + str(x_b[0]) + ",\t" + str(x_b[1]) + ",\t" + str(percent_error_b) + ",\t" + str(baseline_percent_error_b))
+ print("Model B sq_error,\t" + str(x_b[0]) + ",\t" + str(x_b[1]) + ",\t" + str(squared_error_b) + ",\t" + str(baseline_squared_error_b))
+ print()
+
+if __name__ == "__main__":
+ print_solutions("data2/all_lowres_target_lt600_data.txt")
diff --git a/third_party/aom/tools/inspect-cli.js b/third_party/aom/tools/inspect-cli.js
new file mode 100644
index 0000000000..a14c08111a
--- /dev/null
+++ b/third_party/aom/tools/inspect-cli.js
@@ -0,0 +1,39 @@
+/**
+ * This tool lets you test if the compiled Javascript decoder is functioning properly. You'll
+ * need to download a SpiderMonkey js-shell to run this script.
+ * https://archive.mozilla.org/pub/firefox/nightly/latest-mozilla-central/
+ *
+ * Example:
+ * js-shell inspect-cli.js video.ivf
+ */
+load("inspect.js");
+var buffer = read(scriptArgs[0], "binary");
+var Module = {
+ noExitRuntime: true,
+ noInitialRun: true,
+ preInit: [],
+ preRun: [],
+ postRun: [function () {
+ printErr(`Loaded Javascript Decoder OK`);
+ }],
+ memoryInitializerPrefixURL: "bin/",
+ arguments: ['input.ivf', 'output.raw'],
+ on_frame_decoded_json: function (jsonString) {
+ let json = JSON.parse("[" + Module.UTF8ToString(jsonString) + "null]");
+ json.forEach(frame => {
+ if (frame) {
+ print(frame.frame);
+ }
+ });
+ }
+};
+DecoderModule(Module);
+Module.FS.writeFile("/tmp/input.ivf", buffer, { encoding: "binary" });
+Module._open_file();
+Module._set_layers(0xFFFFFFFF); // Set this to zero if you want to benchmark decoding.
+while(true) {
+ printErr("Decoding Frame ...");
+ if (Module._read_frame()) {
+ break;
+ }
+}
diff --git a/third_party/aom/tools/inspect-post.js b/third_party/aom/tools/inspect-post.js
new file mode 100644
index 0000000000..31c40bb82c
--- /dev/null
+++ b/third_party/aom/tools/inspect-post.js
@@ -0,0 +1 @@
+Module["FS"] = FS;
diff --git a/third_party/aom/tools/intersect-diffs.py b/third_party/aom/tools/intersect-diffs.py
new file mode 100755
index 0000000000..960183675d
--- /dev/null
+++ b/third_party/aom/tools/intersect-diffs.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+##
+## Copyright (c) 2016, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 2 Clause License and
+## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+## was not distributed with this source code in the LICENSE file, you can
+## obtain it at www.aomedia.org/license/software. If the Alliance for Open
+## Media Patent License 1.0 was not distributed with this source code in the
+## PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+##
+"""Calculates the "intersection" of two unified diffs.
+
+Given two diffs, A and B, it finds all hunks in B that had non-context lines
+in A and prints them to stdout. This is useful to determine the hunks in B that
+are relevant to A. The resulting file can be applied with patch(1) on top of A.
+"""
+
+__author__ = "jkoleszar@google.com"
+
+import sys
+
+import diff
+
+
+def FormatDiffHunks(hunks):
+ """Re-serialize a list of DiffHunks."""
+ r = []
+ last_header = None
+ for hunk in hunks:
+ this_header = hunk.header[0:2]
+ if last_header != this_header:
+ r.extend(hunk.header)
+ last_header = this_header
+ else:
+ r.extend(hunk.header[2])
+ r.extend(hunk.lines)
+ r.append("\n")
+ return "".join(r)
+
+
+def ZipHunks(rhs_hunks, lhs_hunks):
+ """Join two hunk lists on filename."""
+ for rhs_hunk in rhs_hunks:
+ rhs_file = rhs_hunk.right.filename.split("/")[1:]
+
+ for lhs_hunk in lhs_hunks:
+ lhs_file = lhs_hunk.left.filename.split("/")[1:]
+ if lhs_file != rhs_file:
+ continue
+ yield (rhs_hunk, lhs_hunk)
+
+
+def main():
+ old_hunks = [x for x in diff.ParseDiffHunks(open(sys.argv[1], "r"))]
+ new_hunks = [x for x in diff.ParseDiffHunks(open(sys.argv[2], "r"))]
+ out_hunks = []
+
+ # Join the right hand side of the older diff with the left hand side of the
+ # newer diff.
+ for old_hunk, new_hunk in ZipHunks(old_hunks, new_hunks):
+ if new_hunk in out_hunks:
+ continue
+ old_lines = old_hunk.right
+ new_lines = new_hunk.left
+
+ # Determine if this hunk overlaps any non-context line from the other
+ for i in old_lines.delta_line_nums:
+ if i in new_lines:
+ out_hunks.append(new_hunk)
+ break
+
+ if out_hunks:
+ print(FormatDiffHunks(out_hunks))
+ sys.exit(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/third_party/aom/tools/lint-hunks.py b/third_party/aom/tools/lint-hunks.py
new file mode 100755
index 0000000000..8b3af972fc
--- /dev/null
+++ b/third_party/aom/tools/lint-hunks.py
@@ -0,0 +1,150 @@
+#!/usr/bin/env python3
+##
+## Copyright (c) 2016, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 2 Clause License and
+## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+## was not distributed with this source code in the LICENSE file, you can
+## obtain it at www.aomedia.org/license/software. If the Alliance for Open
+## Media Patent License 1.0 was not distributed with this source code in the
+## PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+##
+"""Performs style checking on each diff hunk."""
+import getopt
+import os
+import io
+import subprocess
+import sys
+
+import diff
+
+
+SHORT_OPTIONS = "h"
+LONG_OPTIONS = ["help"]
+
+TOPLEVEL_CMD = ["git", "rev-parse", "--show-toplevel"]
+DIFF_CMD = ["git", "diff"]
+DIFF_INDEX_CMD = ["git", "diff-index", "-u", "HEAD", "--"]
+SHOW_CMD = ["git", "show"]
+CPPLINT_FILTERS = ["-readability/casting"]
+
+
+class Usage(Exception):
+ pass
+
+
+class SubprocessException(Exception):
+ def __init__(self, args):
+ msg = "Failed to execute '%s'"%(" ".join(args))
+ super(SubprocessException, self).__init__(msg)
+
+
+class Subprocess(subprocess.Popen):
+ """Adds the notion of an expected returncode to Popen."""
+
+ def __init__(self, args, expected_returncode=0, **kwargs):
+ self._args = args
+ self._expected_returncode = expected_returncode
+ super(Subprocess, self).__init__(args, **kwargs)
+
+ def communicate(self, *args, **kwargs):
+ result = super(Subprocess, self).communicate(*args, **kwargs)
+ if self._expected_returncode is not None:
+ try:
+ ok = self.returncode in self._expected_returncode
+ except TypeError:
+ ok = self.returncode == self._expected_returncode
+ if not ok:
+ raise SubprocessException(self._args)
+ return result
+
+
+def main(argv=None):
+ if argv is None:
+ argv = sys.argv
+ try:
+ try:
+ opts, args = getopt.getopt(argv[1:], SHORT_OPTIONS, LONG_OPTIONS)
+ except getopt.error as msg:
+ raise Usage(msg)
+
+ # process options
+ for o, _ in opts:
+ if o in ("-h", "--help"):
+ print(__doc__)
+ sys.exit(0)
+
+ if args and len(args) > 1:
+ print(__doc__)
+ sys.exit(0)
+
+ # Find the fully qualified path to the root of the tree
+ tl = Subprocess(TOPLEVEL_CMD, stdout=subprocess.PIPE, text=True)
+ tl = tl.communicate()[0].strip()
+
+ # See if we're working on the index or not.
+ if args:
+ diff_cmd = DIFF_CMD + [args[0] + "^!"]
+ else:
+ diff_cmd = DIFF_INDEX_CMD
+
+ # Build the command line to execute cpplint
+ cpplint_cmd = [os.path.join(tl, "tools", "cpplint.py"),
+ "--filter=" + ",".join(CPPLINT_FILTERS),
+ "-"]
+
+ # Get a list of all affected lines
+ file_affected_line_map = {}
+ p = Subprocess(diff_cmd, stdout=subprocess.PIPE, text=True)
+ stdout = p.communicate()[0]
+ for hunk in diff.ParseDiffHunks(io.StringIO(stdout)):
+ filename = hunk.right.filename[2:]
+ if filename not in file_affected_line_map:
+ file_affected_line_map[filename] = set()
+ file_affected_line_map[filename].update(hunk.right.delta_line_nums)
+
+ # Run each affected file through cpplint
+ lint_failed = False
+ for filename, affected_lines in file_affected_line_map.items():
+ if filename.split(".")[-1] not in ("c", "h", "cc"):
+ continue
+ if filename.startswith("third_party"):
+ continue
+
+ if args:
+ # File contents come from git
+ show_cmd = SHOW_CMD + [args[0] + ":" + filename]
+ show = Subprocess(show_cmd, stdout=subprocess.PIPE, text=True)
+ lint = Subprocess(cpplint_cmd, expected_returncode=(0, 1),
+ stdin=show.stdout, stderr=subprocess.PIPE,
+ text=True)
+ lint_out = lint.communicate()[1]
+ else:
+ # File contents come from the working tree
+ lint = Subprocess(cpplint_cmd, expected_returncode=(0, 1),
+ stdin=subprocess.PIPE, stderr=subprocess.PIPE,
+ text=True)
+ stdin = open(os.path.join(tl, filename)).read()
+ lint_out = lint.communicate(stdin)[1]
+
+ for line in lint_out.split("\n"):
+ fields = line.split(":")
+ if fields[0] != "-":
+ continue
+ warning_line_num = int(fields[1])
+ if warning_line_num in affected_lines:
+ print("%s:%d:%s"%(filename, warning_line_num,
+ ":".join(fields[2:])))
+ lint_failed = True
+
+ # Set exit code if any relevant lint errors seen
+ if lint_failed:
+ return 1
+
+ except Usage as err:
+ print(err, file=sys.stderr)
+ print("for help use --help", file=sys.stderr)
+ return 2
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/third_party/aom/tools/obu_parser.cc b/third_party/aom/tools/obu_parser.cc
new file mode 100644
index 0000000000..5716b46218
--- /dev/null
+++ b/third_party/aom/tools/obu_parser.cc
@@ -0,0 +1,190 @@
+/*
+ * Copyright (c) 2017, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+#include <string.h>
+
+#include <cstdio>
+#include <string>
+
+#include "aom/aom_codec.h"
+#include "aom/aom_integer.h"
+#include "aom_ports/mem_ops.h"
+#include "av1/common/obu_util.h"
+#include "tools/obu_parser.h"
+
+namespace aom_tools {
+
+// Basic OBU syntax
+// 8 bits: Header
+// 7
+// forbidden bit
+// 6,5,4,3
+// type bits
+// 2
+// extension flag bit
+// 1
+// has size field bit
+// 0
+// reserved bit
+const uint32_t kObuForbiddenBitMask = 0x1;
+const uint32_t kObuForbiddenBitShift = 7;
+const uint32_t kObuTypeBitsMask = 0xF;
+const uint32_t kObuTypeBitsShift = 3;
+const uint32_t kObuExtensionFlagBitMask = 0x1;
+const uint32_t kObuExtensionFlagBitShift = 2;
+const uint32_t kObuHasSizeFieldBitMask = 0x1;
+const uint32_t kObuHasSizeFieldBitShift = 1;
+
+// When extension flag bit is set:
+// 8 bits: extension header
+// 7,6,5
+// temporal ID
+// 4,3
+// spatial ID
+// 2,1,0
+// reserved bits
+const uint32_t kObuExtTemporalIdBitsMask = 0x7;
+const uint32_t kObuExtTemporalIdBitsShift = 5;
+const uint32_t kObuExtSpatialIdBitsMask = 0x3;
+const uint32_t kObuExtSpatialIdBitsShift = 3;
+
+bool ValidObuType(int obu_type) {
+ switch (obu_type) {
+ case OBU_SEQUENCE_HEADER:
+ case OBU_TEMPORAL_DELIMITER:
+ case OBU_FRAME_HEADER:
+ case OBU_TILE_GROUP:
+ case OBU_METADATA:
+ case OBU_FRAME:
+ case OBU_REDUNDANT_FRAME_HEADER:
+ case OBU_TILE_LIST:
+ case OBU_PADDING: return true;
+ }
+ return false;
+}
+
+bool ParseObuHeader(uint8_t obu_header_byte, ObuHeader *obu_header) {
+ const int forbidden_bit =
+ (obu_header_byte >> kObuForbiddenBitShift) & kObuForbiddenBitMask;
+ if (forbidden_bit) {
+ fprintf(stderr, "Invalid OBU, forbidden bit set.\n");
+ return false;
+ }
+
+ obu_header->type = static_cast<OBU_TYPE>(
+ (obu_header_byte >> kObuTypeBitsShift) & kObuTypeBitsMask);
+ if (!ValidObuType(obu_header->type)) {
+ fprintf(stderr, "Invalid OBU type: %d.\n", obu_header->type);
+ return false;
+ }
+
+ obu_header->has_extension =
+ (obu_header_byte >> kObuExtensionFlagBitShift) & kObuExtensionFlagBitMask;
+ obu_header->has_size_field =
+ (obu_header_byte >> kObuHasSizeFieldBitShift) & kObuHasSizeFieldBitMask;
+ return true;
+}
+
+bool ParseObuExtensionHeader(uint8_t ext_header_byte, ObuHeader *obu_header) {
+ obu_header->temporal_layer_id =
+ (ext_header_byte >> kObuExtTemporalIdBitsShift) &
+ kObuExtTemporalIdBitsMask;
+ obu_header->spatial_layer_id =
+ (ext_header_byte >> kObuExtSpatialIdBitsShift) & kObuExtSpatialIdBitsMask;
+
+ return true;
+}
+
+void PrintObuHeader(const ObuHeader *header) {
+ printf(
+ " OBU type: %s\n"
+ " extension: %s\n",
+ aom_obu_type_to_string(static_cast<OBU_TYPE>(header->type)),
+ header->has_extension ? "yes" : "no");
+ if (header->has_extension) {
+ printf(
+ " temporal_id: %d\n"
+ " spatial_id: %d\n",
+ header->temporal_layer_id, header->spatial_layer_id);
+ }
+}
+
+bool DumpObu(const uint8_t *data, int length, int *obu_overhead_bytes) {
+ const int kObuHeaderSizeBytes = 1;
+ const int kMinimumBytesRequired = 1 + kObuHeaderSizeBytes;
+ int consumed = 0;
+ int obu_overhead = 0;
+ ObuHeader obu_header;
+ while (consumed < length) {
+ const int remaining = length - consumed;
+ if (remaining < kMinimumBytesRequired) {
+ fprintf(stderr,
+ "OBU parse error. Did not consume all data, %d bytes remain.\n",
+ remaining);
+ return false;
+ }
+
+ int obu_header_size = 0;
+
+ memset(&obu_header, 0, sizeof(obu_header));
+ const uint8_t obu_header_byte = *(data + consumed);
+ if (!ParseObuHeader(obu_header_byte, &obu_header)) {
+ fprintf(stderr, "OBU parsing failed at offset %d.\n", consumed);
+ return false;
+ }
+
+ ++obu_overhead;
+ ++obu_header_size;
+
+ if (obu_header.has_extension) {
+ const uint8_t obu_ext_header_byte =
+ *(data + consumed + kObuHeaderSizeBytes);
+ if (!ParseObuExtensionHeader(obu_ext_header_byte, &obu_header)) {
+ fprintf(stderr, "OBU extension parsing failed at offset %d.\n",
+ consumed + kObuHeaderSizeBytes);
+ return false;
+ }
+
+ ++obu_overhead;
+ ++obu_header_size;
+ }
+
+ PrintObuHeader(&obu_header);
+
+ uint64_t obu_size = 0;
+ size_t length_field_size = 0;
+ if (aom_uleb_decode(data + consumed + obu_header_size,
+ remaining - obu_header_size, &obu_size,
+ &length_field_size) != 0) {
+ fprintf(stderr, "OBU size parsing failed at offset %d.\n",
+ consumed + obu_header_size);
+ return false;
+ }
+ int current_obu_length = static_cast<int>(obu_size);
+ if (obu_header_size + static_cast<int>(length_field_size) +
+ current_obu_length >
+ remaining) {
+ fprintf(stderr, "OBU parsing failed: not enough OBU data.\n");
+ return false;
+ }
+ consumed += obu_header_size + static_cast<int>(length_field_size) +
+ current_obu_length;
+ printf(" length: %d\n",
+ static_cast<int>(obu_header_size + length_field_size +
+ current_obu_length));
+ }
+
+ if (obu_overhead_bytes != nullptr) *obu_overhead_bytes = obu_overhead;
+ printf(" TU size: %d\n", consumed);
+
+ return true;
+}
+
+} // namespace aom_tools
diff --git a/third_party/aom/tools/obu_parser.h b/third_party/aom/tools/obu_parser.h
new file mode 100644
index 0000000000..1d7d2d794b
--- /dev/null
+++ b/third_party/aom/tools/obu_parser.h
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2017, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#ifndef AOM_TOOLS_OBU_PARSER_H_
+#define AOM_TOOLS_OBU_PARSER_H_
+
+#include <cstdint>
+
+namespace aom_tools {
+
+// Print information obtained from OBU(s) in data until data is exhausted or an
+// error occurs. Returns true when all data is consumed successfully, and
+// optionally reports OBU storage overhead via obu_overhead_bytes when the
+// pointer is non-null.
+bool DumpObu(const uint8_t *data, int length, int *obu_overhead_bytes);
+
+} // namespace aom_tools
+
+#endif // AOM_TOOLS_OBU_PARSER_H_
diff --git a/third_party/aom/tools/ratectrl_log_analyzer/analyze_ratectrl_log.py b/third_party/aom/tools/ratectrl_log_analyzer/analyze_ratectrl_log.py
new file mode 100644
index 0000000000..9afb78cbf5
--- /dev/null
+++ b/third_party/aom/tools/ratectrl_log_analyzer/analyze_ratectrl_log.py
@@ -0,0 +1,154 @@
+#!/usr/bin/python3
+##
+## Copyright (c) 2022, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 2 Clause License and
+## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+## was not distributed with this source code in the LICENSE file, you can
+## obtain it at www.aomedia.org/license/software. If the Alliance for Open
+## Media Patent License 1.0 was not distributed with this source code in the
+## PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+##
+""" Analyze the log generated by experimental flag CONFIG_RATECTRL_LOG."""
+
+import matplotlib.pyplot as plt
+import os
+
+
+def get_file_basename(filename):
+ return filename.split(".")[0]
+
+
+def parse_log(log_file):
+ data_list = []
+ with open(log_file) as fp:
+ for line in fp:
+ dic = {}
+ word_ls = line.split()
+ i = 0
+ while i < len(word_ls):
+ dic[word_ls[i]] = float(word_ls[i + 1])
+ i += 2
+ data_list.append(dic)
+ fp.close()
+ return data_list
+
+
+def extract_data(data_list, name):
+ arr = []
+ for data in data_list:
+ arr.append(data[name])
+ return arr
+
+
+def visualize_q_indices(exp_summary, exp_list, fig_path=None):
+ for exp in exp_list:
+ data = parse_log(exp["log"])
+ q_indices = extract_data(data, "q")
+ plt.title(exp_summary)
+ plt.xlabel("frame_coding_idx")
+ plt.ylabel("q_index")
+ plt.plot(q_indices, marker=".", label=exp["label"])
+ plt.legend()
+ if fig_path:
+ plt.savefig(fig_path)
+ else:
+ plt.show()
+ plt.clf()
+
+
+def get_rc_type_from_exp_type(exp_type):
+ if exp_type == "Q_3P":
+ return "q"
+ return "vbr"
+
+
+def test_video(exe_name, input, exp_type, level, log=None, limit=150):
+ basic_cmd = ("--test-decode=warn --threads=0 --profile=0 --min-q=0 --max-q=63"
+ " --auto-alt-ref=1 --kf-max-dist=160 --kf-min-dist=0 "
+ "--drop-frame=0 --static-thresh=0 --minsection-pct=0 "
+ "--maxsection-pct=2000 --arnr-maxframes=7 --arnr-strength=5 "
+ "--sharpness=0 --undershoot-pct=100 --overshoot-pct=100 "
+ "--frame-parallel=0 --tile-columns=0 --cpu-used=3 "
+ "--lag-in-frames=48 --psnr")
+ rc_type = get_rc_type_from_exp_type(exp_type)
+ rc_cmd = "--end-usage=" + rc_type
+ level_cmd = ""
+ if rc_type == "q":
+ level_cmd += "--cq-level=" + str(level)
+ elif rc_type == "vbr":
+ level_cmd += "--target-bitrate=" + str(level)
+ limit_cmd = "--limit=" + str(limit)
+ passes_cmd = "--passes=3 --second-pass-log=second_pass_log"
+ output_cmd = "-o test.webm"
+ input_cmd = "~/data/" + input
+ log_cmd = ""
+ if log != None:
+ log_cmd = ">" + log
+ cmd_ls = [
+ exe_name, basic_cmd, rc_cmd, level_cmd, limit_cmd, passes_cmd, output_cmd,
+ input_cmd, log_cmd
+ ]
+ cmd = " ".join(cmd_ls)
+ os.system(cmd)
+
+
+def gen_ratectrl_log(test_case):
+ exe = test_case["exe"]
+ video = test_case["video"]
+ exp_type = test_case["exp_type"]
+ level = test_case["level"]
+ log = test_case["log"]
+ test_video(exe, video, exp_type, level, log=log, limit=150)
+ return log
+
+
+def gen_test_case(exp_type, dataset, videoname, level, log_dir=None):
+ test_case = {}
+ exe = "./aomenc_bl"
+ if exp_type == "BA_3P":
+ exe = "./aomenc_ba"
+ test_case["exe"] = exe
+
+ video = os.path.join(dataset, videoname)
+ test_case["video"] = video
+ test_case["exp_type"] = exp_type
+ test_case["level"] = level
+
+ video_basename = get_file_basename(videoname)
+ log = ".".join([dataset, video_basename, exp_type, str(level)])
+ if log_dir != None:
+ log = os.path.join(log_dir, log)
+ test_case["log"] = log
+ return test_case
+
+
+def run_ratectrl_exp(exp_config):
+ fp = open(exp_config)
+ log_dir = "./lowres_rc_log"
+ fig_dir = "./lowres_rc_fig"
+ dataset = "lowres"
+ for line in fp:
+ word_ls = line.split()
+ dataset = word_ls[0]
+ videoname = word_ls[1]
+ exp_type_ls = ["VBR_3P", "BA_3P"]
+ level_ls = [int(v) for v in word_ls[2:4]]
+ exp_ls = []
+ for i in range(len(exp_type_ls)):
+ exp_type = exp_type_ls[i]
+ test_case = gen_test_case(exp_type, dataset, videoname, level_ls[i],
+ log_dir)
+ log = gen_ratectrl_log(test_case)
+ exp = {}
+ exp["log"] = log
+ exp["label"] = exp_type
+ exp_ls.append(exp)
+ video_basename = get_file_basename(videoname)
+ fig_path = os.path.join(fig_dir, video_basename + ".png")
+ visualize_q_indices(video_basename, exp_ls, fig_path)
+ fp.close()
+
+
+if __name__ == "__main__":
+ run_ratectrl_exp("exp_rc_config")
diff --git a/third_party/aom/tools/txfm_analyzer/txfm_gen_code.cc b/third_party/aom/tools/txfm_analyzer/txfm_gen_code.cc
new file mode 100644
index 0000000000..7c5400b91a
--- /dev/null
+++ b/third_party/aom/tools/txfm_analyzer/txfm_gen_code.cc
@@ -0,0 +1,580 @@
+/*
+ * Copyright (c) 2018, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <math.h>
+#include <float.h>
+#include <string.h>
+
+#include "tools/txfm_analyzer/txfm_graph.h"
+
+typedef enum CODE_TYPE {
+ CODE_TYPE_C,
+ CODE_TYPE_SSE2,
+ CODE_TYPE_SSE4_1
+} CODE_TYPE;
+
+int get_cos_idx(double value, int mod) {
+ return round(acos(fabs(value)) / PI * mod);
+}
+
+char *cos_text_arr(double value, int mod, char *text, int size) {
+ int num = get_cos_idx(value, mod);
+ if (value < 0) {
+ snprintf(text, size, "-cospi[%2d]", num);
+ } else {
+ snprintf(text, size, " cospi[%2d]", num);
+ }
+
+ if (num == 0)
+ printf("v: %f -> %d/%d v==-1 is %d\n", value, num, mod, value == -1);
+
+ return text;
+}
+
+char *cos_text_sse2(double w0, double w1, int mod, char *text, int size) {
+ int idx0 = get_cos_idx(w0, mod);
+ int idx1 = get_cos_idx(w1, mod);
+ char p[] = "p";
+ char n[] = "m";
+ char *sgn0 = w0 < 0 ? n : p;
+ char *sgn1 = w1 < 0 ? n : p;
+ snprintf(text, size, "cospi_%s%02d_%s%02d", sgn0, idx0, sgn1, idx1);
+ return text;
+}
+
+char *cos_text_sse4_1(double w, int mod, char *text, int size) {
+ int idx = get_cos_idx(w, mod);
+ char p[] = "p";
+ char n[] = "m";
+ char *sgn = w < 0 ? n : p;
+ snprintf(text, size, "cospi_%s%02d", sgn, idx);
+ return text;
+}
+
+void node_to_code_c(Node *node, const char *buf0, const char *buf1) {
+ int cnt = 0;
+ for (int i = 0; i < 2; i++) {
+ if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0) cnt++;
+ }
+ if (cnt == 2) {
+ int cnt2 = 0;
+ printf(" %s[%d] =", buf1, node->nodeIdx);
+ for (int i = 0; i < 2; i++) {
+ if (fabs(node->inWeight[i]) == 1) {
+ cnt2++;
+ }
+ }
+ if (cnt2 == 2) {
+ printf(" apply_value(");
+ }
+ int cnt1 = 0;
+ for (int i = 0; i < 2; i++) {
+ if (node->inWeight[i] == 1) {
+ if (cnt1 > 0)
+ printf(" + %s[%d]", buf0, node->inNodeIdx[i]);
+ else
+ printf(" %s[%d]", buf0, node->inNodeIdx[i]);
+ cnt1++;
+ } else if (node->inWeight[i] == -1) {
+ if (cnt1 > 0)
+ printf(" - %s[%d]", buf0, node->inNodeIdx[i]);
+ else
+ printf("-%s[%d]", buf0, node->inNodeIdx[i]);
+ cnt1++;
+ }
+ }
+ if (cnt2 == 2) {
+ printf(", stage_range[stage])");
+ }
+ printf(";\n");
+ } else {
+ char w0[100];
+ char w1[100];
+ printf(
+ " %s[%d] = half_btf(%s, %s[%d], %s, %s[%d], "
+ "cos_bit);\n",
+ buf1, node->nodeIdx, cos_text_arr(node->inWeight[0], COS_MOD, w0, 100),
+ buf0, node->inNodeIdx[0],
+ cos_text_arr(node->inWeight[1], COS_MOD, w1, 100), buf0,
+ node->inNodeIdx[1]);
+ }
+}
+
+void gen_code_c(Node *node, int stage_num, int node_num, TYPE_TXFM type) {
+ char *fun_name = new char[100];
+ get_fun_name(fun_name, 100, type, node_num);
+
+ printf("\n");
+ printf(
+ "void av1_%s(const int32_t *input, int32_t *output, int8_t cos_bit, "
+ "const int8_t* stage_range) "
+ "{\n",
+ fun_name);
+ printf(" assert(output != input);\n");
+ printf(" const int32_t size = %d;\n", node_num);
+ printf(" const int32_t *cospi = cospi_arr(cos_bit);\n");
+ printf("\n");
+
+ printf(" int32_t stage = 0;\n");
+ printf(" int32_t *bf0, *bf1;\n");
+ printf(" int32_t step[%d];\n", node_num);
+
+ const char *buf0 = "bf0";
+ const char *buf1 = "bf1";
+ const char *input = "input";
+
+ int si = 0;
+ printf("\n");
+ printf(" // stage %d;\n", si);
+ printf(" apply_range(stage, input, %s, size, stage_range[stage]);\n", input);
+
+ si = 1;
+ printf("\n");
+ printf(" // stage %d;\n", si);
+ printf(" stage++;\n");
+ if (si % 2 == (stage_num - 1) % 2) {
+ printf(" %s = output;\n", buf1);
+ } else {
+ printf(" %s = step;\n", buf1);
+ }
+
+ for (int ni = 0; ni < node_num; ni++) {
+ int idx = get_idx(si, ni, node_num);
+ node_to_code_c(node + idx, input, buf1);
+ }
+
+ printf(" range_check_buf(stage, input, bf1, size, stage_range[stage]);\n");
+
+ for (int si = 2; si < stage_num; si++) {
+ printf("\n");
+ printf(" // stage %d\n", si);
+ printf(" stage++;\n");
+ if (si % 2 == (stage_num - 1) % 2) {
+ printf(" %s = step;\n", buf0);
+ printf(" %s = output;\n", buf1);
+ } else {
+ printf(" %s = output;\n", buf0);
+ printf(" %s = step;\n", buf1);
+ }
+
+ // computation code
+ for (int ni = 0; ni < node_num; ni++) {
+ int idx = get_idx(si, ni, node_num);
+ node_to_code_c(node + idx, buf0, buf1);
+ }
+
+ if (si != stage_num - 1) {
+ printf(
+ " range_check_buf(stage, input, bf1, size, stage_range[stage]);\n");
+ }
+ }
+ printf(" apply_range(stage, input, output, size, stage_range[stage]);\n");
+ printf("}\n");
+}
+
+void single_node_to_code_sse2(Node *node, const char *buf0, const char *buf1) {
+ printf(" %s[%2d] =", buf1, node->nodeIdx);
+ if (node->inWeight[0] == 1 && node->inWeight[1] == 1) {
+ printf(" _mm_adds_epi16(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0,
+ node->inNodeIdx[1]);
+ } else if (node->inWeight[0] == 1 && node->inWeight[1] == -1) {
+ printf(" _mm_subs_epi16(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0,
+ node->inNodeIdx[1]);
+ } else if (node->inWeight[0] == -1 && node->inWeight[1] == 1) {
+ printf(" _mm_subs_epi16(%s[%d], %s[%d])", buf0, node->inNodeIdx[1], buf0,
+ node->inNodeIdx[0]);
+ } else if (node->inWeight[0] == 1 && node->inWeight[1] == 0) {
+ printf(" %s[%d]", buf0, node->inNodeIdx[0]);
+ } else if (node->inWeight[0] == 0 && node->inWeight[1] == 1) {
+ printf(" %s[%d]", buf0, node->inNodeIdx[1]);
+ } else if (node->inWeight[0] == -1 && node->inWeight[1] == 0) {
+ printf(" _mm_subs_epi16(__zero, %s[%d])", buf0, node->inNodeIdx[0]);
+ } else if (node->inWeight[0] == 0 && node->inWeight[1] == -1) {
+ printf(" _mm_subs_epi16(__zero, %s[%d])", buf0, node->inNodeIdx[1]);
+ }
+ printf(";\n");
+}
+
+void pair_node_to_code_sse2(Node *node, Node *partnerNode, const char *buf0,
+ const char *buf1) {
+ char temp0[100];
+ char temp1[100];
+ // btf_16_sse2_type0(w0, w1, in0, in1, out0, out1)
+ if (node->inNodeIdx[0] != partnerNode->inNodeIdx[0])
+ printf(" btf_16_sse2(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d]);\n",
+ cos_text_sse2(node->inWeight[0], node->inWeight[1], COS_MOD, temp0,
+ 100),
+ cos_text_sse2(partnerNode->inWeight[1], partnerNode->inWeight[0],
+ COS_MOD, temp1, 100),
+ buf0, node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1,
+ node->nodeIdx, buf1, partnerNode->nodeIdx);
+ else
+ printf(" btf_16_sse2(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d]);\n",
+ cos_text_sse2(node->inWeight[0], node->inWeight[1], COS_MOD, temp0,
+ 100),
+ cos_text_sse2(partnerNode->inWeight[0], partnerNode->inWeight[1],
+ COS_MOD, temp1, 100),
+ buf0, node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1,
+ node->nodeIdx, buf1, partnerNode->nodeIdx);
+}
+
+Node *get_partner_node(Node *node) {
+ int diff = node->inNode[1]->nodeIdx - node->nodeIdx;
+ return node + diff;
+}
+
+void node_to_code_sse2(Node *node, const char *buf0, const char *buf1) {
+ int cnt = 0;
+ int cnt1 = 0;
+ if (node->visited == 0) {
+ node->visited = 1;
+ for (int i = 0; i < 2; i++) {
+ if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0) cnt++;
+ if (fabs(node->inWeight[i]) == 1) cnt1++;
+ }
+ if (cnt == 2) {
+ if (cnt1 == 2) {
+ // has a partner
+ Node *partnerNode = get_partner_node(node);
+ partnerNode->visited = 1;
+ single_node_to_code_sse2(node, buf0, buf1);
+ single_node_to_code_sse2(partnerNode, buf0, buf1);
+ } else {
+ single_node_to_code_sse2(node, buf0, buf1);
+ }
+ } else {
+ Node *partnerNode = get_partner_node(node);
+ partnerNode->visited = 1;
+ pair_node_to_code_sse2(node, partnerNode, buf0, buf1);
+ }
+ }
+}
+
+void gen_cospi_list_sse2(Node *node, int stage_num, int node_num) {
+ int visited[65][65][2][2];
+ memset(visited, 0, sizeof(visited));
+ char text[100];
+ char text1[100];
+ char text2[100];
+ int size = 100;
+ printf("\n");
+ for (int si = 1; si < stage_num; si++) {
+ for (int ni = 0; ni < node_num; ni++) {
+ int idx = get_idx(si, ni, node_num);
+ int cnt = 0;
+ Node *node0 = node + idx;
+ if (node0->visited == 0) {
+ node0->visited = 1;
+ for (int i = 0; i < 2; i++) {
+ if (fabs(node0->inWeight[i]) == 1 || fabs(node0->inWeight[i]) == 0)
+ cnt++;
+ }
+ if (cnt != 2) {
+ {
+ double w0 = node0->inWeight[0];
+ double w1 = node0->inWeight[1];
+ int idx0 = get_cos_idx(w0, COS_MOD);
+ int idx1 = get_cos_idx(w1, COS_MOD);
+ int sgn0 = w0 < 0 ? 1 : 0;
+ int sgn1 = w1 < 0 ? 1 : 0;
+
+ if (!visited[idx0][idx1][sgn0][sgn1]) {
+ visited[idx0][idx1][sgn0][sgn1] = 1;
+ printf(" __m128i %s = pair_set_epi16(%s, %s);\n",
+ cos_text_sse2(w0, w1, COS_MOD, text, size),
+ cos_text_arr(w0, COS_MOD, text1, size),
+ cos_text_arr(w1, COS_MOD, text2, size));
+ }
+ }
+ Node *node1 = get_partner_node(node0);
+ node1->visited = 1;
+ if (node1->inNode[0]->nodeIdx != node0->inNode[0]->nodeIdx) {
+ double w0 = node1->inWeight[0];
+ double w1 = node1->inWeight[1];
+ int idx0 = get_cos_idx(w0, COS_MOD);
+ int idx1 = get_cos_idx(w1, COS_MOD);
+ int sgn0 = w0 < 0 ? 1 : 0;
+ int sgn1 = w1 < 0 ? 1 : 0;
+
+ if (!visited[idx1][idx0][sgn1][sgn0]) {
+ visited[idx1][idx0][sgn1][sgn0] = 1;
+ printf(" __m128i %s = pair_set_epi16(%s, %s);\n",
+ cos_text_sse2(w1, w0, COS_MOD, text, size),
+ cos_text_arr(w1, COS_MOD, text1, size),
+ cos_text_arr(w0, COS_MOD, text2, size));
+ }
+ } else {
+ double w0 = node1->inWeight[0];
+ double w1 = node1->inWeight[1];
+ int idx0 = get_cos_idx(w0, COS_MOD);
+ int idx1 = get_cos_idx(w1, COS_MOD);
+ int sgn0 = w0 < 0 ? 1 : 0;
+ int sgn1 = w1 < 0 ? 1 : 0;
+
+ if (!visited[idx0][idx1][sgn0][sgn1]) {
+ visited[idx0][idx1][sgn0][sgn1] = 1;
+ printf(" __m128i %s = pair_set_epi16(%s, %s);\n",
+ cos_text_sse2(w0, w1, COS_MOD, text, size),
+ cos_text_arr(w0, COS_MOD, text1, size),
+ cos_text_arr(w1, COS_MOD, text2, size));
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+void gen_code_sse2(Node *node, int stage_num, int node_num, TYPE_TXFM type) {
+ char *fun_name = new char[100];
+ get_fun_name(fun_name, 100, type, node_num);
+
+ printf("\n");
+ printf(
+ "void %s_sse2(const __m128i *input, __m128i *output, int8_t cos_bit) "
+ "{\n",
+ fun_name);
+
+ printf(" const int32_t* cospi = cospi_arr(cos_bit);\n");
+ printf(" const __m128i __zero = _mm_setzero_si128();\n");
+ printf(" const __m128i __rounding = _mm_set1_epi32(1 << (cos_bit - 1));\n");
+
+ graph_reset_visited(node, stage_num, node_num);
+ gen_cospi_list_sse2(node, stage_num, node_num);
+ graph_reset_visited(node, stage_num, node_num);
+ for (int si = 1; si < stage_num; si++) {
+ char in[100];
+ char out[100];
+ printf("\n");
+ printf(" // stage %d\n", si);
+ if (si == 1)
+ snprintf(in, 100, "%s", "input");
+ else
+ snprintf(in, 100, "x%d", si - 1);
+ if (si == stage_num - 1) {
+ snprintf(out, 100, "%s", "output");
+ } else {
+ snprintf(out, 100, "x%d", si);
+ printf(" __m128i %s[%d];\n", out, node_num);
+ }
+ // computation code
+ for (int ni = 0; ni < node_num; ni++) {
+ int idx = get_idx(si, ni, node_num);
+ node_to_code_sse2(node + idx, in, out);
+ }
+ }
+
+ printf("}\n");
+}
+void gen_cospi_list_sse4_1(Node *node, int stage_num, int node_num) {
+ int visited[65][2];
+ memset(visited, 0, sizeof(visited));
+ char text[100];
+ char text1[100];
+ int size = 100;
+ printf("\n");
+ for (int si = 1; si < stage_num; si++) {
+ for (int ni = 0; ni < node_num; ni++) {
+ int idx = get_idx(si, ni, node_num);
+ Node *node0 = node + idx;
+ if (node0->visited == 0) {
+ int cnt = 0;
+ node0->visited = 1;
+ for (int i = 0; i < 2; i++) {
+ if (fabs(node0->inWeight[i]) == 1 || fabs(node0->inWeight[i]) == 0)
+ cnt++;
+ }
+ if (cnt != 2) {
+ for (int i = 0; i < 2; i++) {
+ if (fabs(node0->inWeight[i]) != 1 &&
+ fabs(node0->inWeight[i]) != 0) {
+ double w = node0->inWeight[i];
+ int idx = get_cos_idx(w, COS_MOD);
+ int sgn = w < 0 ? 1 : 0;
+
+ if (!visited[idx][sgn]) {
+ visited[idx][sgn] = 1;
+ printf(" __m128i %s = _mm_set1_epi32(%s);\n",
+ cos_text_sse4_1(w, COS_MOD, text, size),
+ cos_text_arr(w, COS_MOD, text1, size));
+ }
+ }
+ }
+ Node *node1 = get_partner_node(node0);
+ node1->visited = 1;
+ }
+ }
+ }
+ }
+}
+
+void single_node_to_code_sse4_1(Node *node, const char *buf0,
+ const char *buf1) {
+ printf(" %s[%2d] =", buf1, node->nodeIdx);
+ if (node->inWeight[0] == 1 && node->inWeight[1] == 1) {
+ printf(" _mm_add_epi32(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0,
+ node->inNodeIdx[1]);
+ } else if (node->inWeight[0] == 1 && node->inWeight[1] == -1) {
+ printf(" _mm_sub_epi32(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0,
+ node->inNodeIdx[1]);
+ } else if (node->inWeight[0] == -1 && node->inWeight[1] == 1) {
+ printf(" _mm_sub_epi32(%s[%d], %s[%d])", buf0, node->inNodeIdx[1], buf0,
+ node->inNodeIdx[0]);
+ } else if (node->inWeight[0] == 1 && node->inWeight[1] == 0) {
+ printf(" %s[%d]", buf0, node->inNodeIdx[0]);
+ } else if (node->inWeight[0] == 0 && node->inWeight[1] == 1) {
+ printf(" %s[%d]", buf0, node->inNodeIdx[1]);
+ } else if (node->inWeight[0] == -1 && node->inWeight[1] == 0) {
+ printf(" _mm_sub_epi32(__zero, %s[%d])", buf0, node->inNodeIdx[0]);
+ } else if (node->inWeight[0] == 0 && node->inWeight[1] == -1) {
+ printf(" _mm_sub_epi32(__zero, %s[%d])", buf0, node->inNodeIdx[1]);
+ }
+ printf(";\n");
+}
+
+void pair_node_to_code_sse4_1(Node *node, Node *partnerNode, const char *buf0,
+ const char *buf1) {
+ char temp0[100];
+ char temp1[100];
+ if (node->inWeight[0] * partnerNode->inWeight[0] < 0) {
+ /* type0
+ * cos sin
+ * sin -cos
+ */
+ // btf_32_sse2_type0(w0, w1, in0, in1, out0, out1)
+ // out0 = w0*in0 + w1*in1
+ // out1 = -w0*in1 + w1*in0
+ printf(
+ " btf_32_type0_sse4_1_new(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d], "
+ "__rounding, cos_bit);\n",
+ cos_text_sse4_1(node->inWeight[0], COS_MOD, temp0, 100),
+ cos_text_sse4_1(node->inWeight[1], COS_MOD, temp1, 100), buf0,
+ node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1, node->nodeIdx, buf1,
+ partnerNode->nodeIdx);
+ } else {
+ /* type1
+ * cos sin
+ * -sin cos
+ */
+ // btf_32_sse2_type1(w0, w1, in0, in1, out0, out1)
+ // out0 = w0*in0 + w1*in1
+ // out1 = w0*in1 - w1*in0
+ printf(
+ " btf_32_type1_sse4_1_new(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d], "
+ "__rounding, cos_bit);\n",
+ cos_text_sse4_1(node->inWeight[0], COS_MOD, temp0, 100),
+ cos_text_sse4_1(node->inWeight[1], COS_MOD, temp1, 100), buf0,
+ node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1, node->nodeIdx, buf1,
+ partnerNode->nodeIdx);
+ }
+}
+
+void node_to_code_sse4_1(Node *node, const char *buf0, const char *buf1) {
+ int cnt = 0;
+ int cnt1 = 0;
+ if (node->visited == 0) {
+ node->visited = 1;
+ for (int i = 0; i < 2; i++) {
+ if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0) cnt++;
+ if (fabs(node->inWeight[i]) == 1) cnt1++;
+ }
+ if (cnt == 2) {
+ if (cnt1 == 2) {
+ // has a partner
+ Node *partnerNode = get_partner_node(node);
+ partnerNode->visited = 1;
+ single_node_to_code_sse4_1(node, buf0, buf1);
+ single_node_to_code_sse4_1(partnerNode, buf0, buf1);
+ } else {
+ single_node_to_code_sse2(node, buf0, buf1);
+ }
+ } else {
+ Node *partnerNode = get_partner_node(node);
+ partnerNode->visited = 1;
+ pair_node_to_code_sse4_1(node, partnerNode, buf0, buf1);
+ }
+ }
+}
+
+void gen_code_sse4_1(Node *node, int stage_num, int node_num, TYPE_TXFM type) {
+ char *fun_name = new char[100];
+ get_fun_name(fun_name, 100, type, node_num);
+
+ printf("\n");
+ printf(
+ "void %s_sse4_1(const __m128i *input, __m128i *output, int8_t cos_bit) "
+ "{\n",
+ fun_name);
+
+ printf(" const int32_t* cospi = cospi_arr(cos_bit);\n");
+ printf(" const __m128i __zero = _mm_setzero_si128();\n");
+ printf(" const __m128i __rounding = _mm_set1_epi32(1 << (cos_bit - 1));\n");
+
+ graph_reset_visited(node, stage_num, node_num);
+ gen_cospi_list_sse4_1(node, stage_num, node_num);
+ graph_reset_visited(node, stage_num, node_num);
+ for (int si = 1; si < stage_num; si++) {
+ char in[100];
+ char out[100];
+ printf("\n");
+ printf(" // stage %d\n", si);
+ if (si == 1)
+ snprintf(in, 100, "%s", "input");
+ else
+ snprintf(in, 100, "x%d", si - 1);
+ if (si == stage_num - 1) {
+ snprintf(out, 100, "%s", "output");
+ } else {
+ snprintf(out, 100, "x%d", si);
+ printf(" __m128i %s[%d];\n", out, node_num);
+ }
+ // computation code
+ for (int ni = 0; ni < node_num; ni++) {
+ int idx = get_idx(si, ni, node_num);
+ node_to_code_sse4_1(node + idx, in, out);
+ }
+ }
+
+ printf("}\n");
+}
+
+void gen_hybrid_code(CODE_TYPE code_type, TYPE_TXFM txfm_type, int node_num) {
+ int stage_num = get_hybrid_stage_num(txfm_type, node_num);
+
+ Node *node = new Node[node_num * stage_num];
+ init_graph(node, stage_num, node_num);
+
+ gen_hybrid_graph_1d(node, stage_num, node_num, 0, 0, node_num, txfm_type);
+
+ switch (code_type) {
+ case CODE_TYPE_C: gen_code_c(node, stage_num, node_num, txfm_type); break;
+ case CODE_TYPE_SSE2:
+ gen_code_sse2(node, stage_num, node_num, txfm_type);
+ break;
+ case CODE_TYPE_SSE4_1:
+ gen_code_sse4_1(node, stage_num, node_num, txfm_type);
+ break;
+ }
+
+ delete[] node;
+}
+
+int main(int argc, char **argv) {
+ CODE_TYPE code_type = CODE_TYPE_SSE4_1;
+ for (int txfm_type = TYPE_DCT; txfm_type < TYPE_LAST; txfm_type++) {
+ for (int node_num = 4; node_num <= 64; node_num *= 2) {
+ gen_hybrid_code(code_type, (TYPE_TXFM)txfm_type, node_num);
+ }
+ }
+ return 0;
+}
diff --git a/third_party/aom/tools/txfm_analyzer/txfm_graph.cc b/third_party/aom/tools/txfm_analyzer/txfm_graph.cc
new file mode 100644
index 0000000000..a249061008
--- /dev/null
+++ b/third_party/aom/tools/txfm_analyzer/txfm_graph.cc
@@ -0,0 +1,943 @@
+/*
+ * Copyright (c) 2018, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include "tools/txfm_analyzer/txfm_graph.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <math.h>
+
+typedef struct Node Node;
+
+void get_fun_name(char *str_fun_name, int str_buf_size, const TYPE_TXFM type,
+ const int txfm_size) {
+ if (type == TYPE_DCT)
+ snprintf(str_fun_name, str_buf_size, "fdct%d_new", txfm_size);
+ else if (type == TYPE_ADST)
+ snprintf(str_fun_name, str_buf_size, "fadst%d_new", txfm_size);
+ else if (type == TYPE_IDCT)
+ snprintf(str_fun_name, str_buf_size, "idct%d_new", txfm_size);
+ else if (type == TYPE_IADST)
+ snprintf(str_fun_name, str_buf_size, "iadst%d_new", txfm_size);
+}
+
+void get_txfm_type_name(char *str_fun_name, int str_buf_size,
+ const TYPE_TXFM type, const int txfm_size) {
+ if (type == TYPE_DCT)
+ snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_DCT%d", txfm_size);
+ else if (type == TYPE_ADST)
+ snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_ADST%d", txfm_size);
+ else if (type == TYPE_IDCT)
+ snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_DCT%d", txfm_size);
+ else if (type == TYPE_IADST)
+ snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_ADST%d", txfm_size);
+}
+
+void get_hybrid_2d_type_name(char *buf, int buf_size, const TYPE_TXFM type0,
+ const TYPE_TXFM type1, const int txfm_size0,
+ const int txfm_size1) {
+ if (type0 == TYPE_DCT && type1 == TYPE_DCT)
+ snprintf(buf, buf_size, "_dct_dct_%dx%d", txfm_size1, txfm_size0);
+ else if (type0 == TYPE_DCT && type1 == TYPE_ADST)
+ snprintf(buf, buf_size, "_dct_adst_%dx%d", txfm_size1, txfm_size0);
+ else if (type0 == TYPE_ADST && type1 == TYPE_ADST)
+ snprintf(buf, buf_size, "_adst_adst_%dx%d", txfm_size1, txfm_size0);
+ else if (type0 == TYPE_ADST && type1 == TYPE_DCT)
+ snprintf(buf, buf_size, "_adst_dct_%dx%d", txfm_size1, txfm_size0);
+}
+
+TYPE_TXFM get_inv_type(TYPE_TXFM type) {
+ if (type == TYPE_DCT)
+ return TYPE_IDCT;
+ else if (type == TYPE_ADST)
+ return TYPE_IADST;
+ else if (type == TYPE_IDCT)
+ return TYPE_DCT;
+ else if (type == TYPE_IADST)
+ return TYPE_ADST;
+ else
+ return TYPE_LAST;
+}
+
+void reference_dct_1d(double *in, double *out, int size) {
+ const double kInvSqrt2 = 0.707106781186547524400844362104;
+ for (int k = 0; k < size; k++) {
+ out[k] = 0; // initialize out[k]
+ for (int n = 0; n < size; n++) {
+ out[k] += in[n] * cos(PI * (2 * n + 1) * k / (2 * size));
+ }
+ if (k == 0) out[k] = out[k] * kInvSqrt2;
+ }
+}
+
+void reference_dct_2d(double *in, double *out, int size) {
+ double *tempOut = new double[size * size];
+ // dct each row: in -> out
+ for (int r = 0; r < size; r++) {
+ reference_dct_1d(in + r * size, out + r * size, size);
+ }
+
+ for (int r = 0; r < size; r++) {
+ // out ->tempOut
+ for (int c = 0; c < size; c++) {
+ tempOut[r * size + c] = out[c * size + r];
+ }
+ }
+ for (int r = 0; r < size; r++) {
+ reference_dct_1d(tempOut + r * size, out + r * size, size);
+ }
+ delete[] tempOut;
+}
+
+void reference_adst_1d(double *in, double *out, int size) {
+ for (int k = 0; k < size; k++) {
+ out[k] = 0; // initialize out[k]
+ for (int n = 0; n < size; n++) {
+ out[k] += in[n] * sin(PI * (2 * n + 1) * (2 * k + 1) / (4 * size));
+ }
+ }
+}
+
+void reference_hybrid_2d(double *in, double *out, int size, int type0,
+ int type1) {
+ double *tempOut = new double[size * size];
+ // dct each row: in -> out
+ for (int r = 0; r < size; r++) {
+ if (type0 == TYPE_DCT)
+ reference_dct_1d(in + r * size, out + r * size, size);
+ else
+ reference_adst_1d(in + r * size, out + r * size, size);
+ }
+
+ for (int r = 0; r < size; r++) {
+ // out ->tempOut
+ for (int c = 0; c < size; c++) {
+ tempOut[r * size + c] = out[c * size + r];
+ }
+ }
+ for (int r = 0; r < size; r++) {
+ if (type1 == TYPE_DCT)
+ reference_dct_1d(tempOut + r * size, out + r * size, size);
+ else
+ reference_adst_1d(tempOut + r * size, out + r * size, size);
+ }
+ delete[] tempOut;
+}
+
+void reference_hybrid_2d_new(double *in, double *out, int size0, int size1,
+ int type0, int type1) {
+ double *tempOut = new double[size0 * size1];
+ // dct each row: in -> out
+ for (int r = 0; r < size1; r++) {
+ if (type0 == TYPE_DCT)
+ reference_dct_1d(in + r * size0, out + r * size0, size0);
+ else
+ reference_adst_1d(in + r * size0, out + r * size0, size0);
+ }
+
+ for (int r = 0; r < size1; r++) {
+ // out ->tempOut
+ for (int c = 0; c < size0; c++) {
+ tempOut[c * size1 + r] = out[r * size0 + c];
+ }
+ }
+ for (int r = 0; r < size0; r++) {
+ if (type1 == TYPE_DCT)
+ reference_dct_1d(tempOut + r * size1, out + r * size1, size1);
+ else
+ reference_adst_1d(tempOut + r * size1, out + r * size1, size1);
+ }
+ delete[] tempOut;
+}
+
+unsigned int get_max_bit(unsigned int x) {
+ int max_bit = -1;
+ while (x) {
+ x = x >> 1;
+ max_bit++;
+ }
+ return max_bit;
+}
+
+unsigned int bitwise_reverse(unsigned int x, int max_bit) {
+ x = ((x >> 16) & 0x0000ffff) | ((x & 0x0000ffff) << 16);
+ x = ((x >> 8) & 0x00ff00ff) | ((x & 0x00ff00ff) << 8);
+ x = ((x >> 4) & 0x0f0f0f0f) | ((x & 0x0f0f0f0f) << 4);
+ x = ((x >> 2) & 0x33333333) | ((x & 0x33333333) << 2);
+ x = ((x >> 1) & 0x55555555) | ((x & 0x55555555) << 1);
+ x = x >> (31 - max_bit);
+ return x;
+}
+
+int get_idx(int ri, int ci, int cSize) { return ri * cSize + ci; }
+
+void add_node(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int in, double w) {
+ int outIdx = get_idx(stage_idx, node_idx, node_num);
+ int inIdx = get_idx(stage_idx - 1, in, node_num);
+ int idx = node[outIdx].inNodeNum;
+ if (idx < 2) {
+ node[outIdx].inNode[idx] = &node[inIdx];
+ node[outIdx].inNodeIdx[idx] = in;
+ node[outIdx].inWeight[idx] = w;
+ idx++;
+ node[outIdx].inNodeNum = idx;
+ } else {
+ printf("Error: inNode is full");
+ }
+}
+
+void connect_node(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int in0, double w0, int in1, double w1) {
+ int outIdx = get_idx(stage_idx, node_idx, node_num);
+ int inIdx0 = get_idx(stage_idx - 1, in0, node_num);
+ int inIdx1 = get_idx(stage_idx - 1, in1, node_num);
+
+ int idx = 0;
+ // if(w0 != 0) {
+ node[outIdx].inNode[idx] = &node[inIdx0];
+ node[outIdx].inNodeIdx[idx] = in0;
+ node[outIdx].inWeight[idx] = w0;
+ idx++;
+ //}
+
+ // if(w1 != 0) {
+ node[outIdx].inNode[idx] = &node[inIdx1];
+ node[outIdx].inNodeIdx[idx] = in1;
+ node[outIdx].inWeight[idx] = w1;
+ idx++;
+ //}
+
+ node[outIdx].inNodeNum = idx;
+}
+
+void propagate(Node *node, int stage_num, int node_num, int stage_idx) {
+ for (int ni = 0; ni < node_num; ni++) {
+ int outIdx = get_idx(stage_idx, ni, node_num);
+ node[outIdx].value = 0;
+ for (int k = 0; k < node[outIdx].inNodeNum; k++) {
+ node[outIdx].value +=
+ node[outIdx].inNode[k]->value * node[outIdx].inWeight[k];
+ }
+ }
+}
+
+int64_t round_shift(int64_t value, int bit) {
+ if (bit > 0) {
+ if (value < 0) {
+ return -round_shift(-value, bit);
+ } else {
+ return (value + (1 << (bit - 1))) >> bit;
+ }
+ } else {
+ return value << (-bit);
+ }
+}
+
+void round_shift_array(int32_t *arr, int size, int bit) {
+ if (bit == 0) {
+ return;
+ } else {
+ for (int i = 0; i < size; i++) {
+ arr[i] = round_shift(arr[i], bit);
+ }
+ }
+}
+
+void graph_reset_visited(Node *node, int stage_num, int node_num) {
+ for (int si = 0; si < stage_num; si++) {
+ for (int ni = 0; ni < node_num; ni++) {
+ int idx = get_idx(si, ni, node_num);
+ node[idx].visited = 0;
+ }
+ }
+}
+
+void estimate_value(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int estimate_bit) {
+ if (stage_idx > 0) {
+ int outIdx = get_idx(stage_idx, node_idx, node_num);
+ int64_t out = 0;
+ node[outIdx].value = 0;
+ for (int k = 0; k < node[outIdx].inNodeNum; k++) {
+ int64_t w = round(node[outIdx].inWeight[k] * (1 << estimate_bit));
+ int64_t v = round(node[outIdx].inNode[k]->value);
+ out += v * w;
+ }
+ node[outIdx].value = round_shift(out, estimate_bit);
+ }
+}
+
+void amplify_value(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int amplify_bit) {
+ int outIdx = get_idx(stage_idx, node_idx, node_num);
+ node[outIdx].value = round_shift(round(node[outIdx].value), -amplify_bit);
+}
+
+void propagate_estimate_amlify(Node *node, int stage_num, int node_num,
+ int stage_idx, int amplify_bit,
+ int estimate_bit) {
+ for (int ni = 0; ni < node_num; ni++) {
+ estimate_value(node, stage_num, node_num, stage_idx, ni, estimate_bit);
+ amplify_value(node, stage_num, node_num, stage_idx, ni, amplify_bit);
+ }
+}
+
+void init_graph(Node *node, int stage_num, int node_num) {
+ for (int si = 0; si < stage_num; si++) {
+ for (int ni = 0; ni < node_num; ni++) {
+ int outIdx = get_idx(si, ni, node_num);
+ node[outIdx].stageIdx = si;
+ node[outIdx].nodeIdx = ni;
+ node[outIdx].value = 0;
+ node[outIdx].inNodeNum = 0;
+ if (si >= 1) {
+ connect_node(node, stage_num, node_num, si, ni, ni, 1, ni, 0);
+ }
+ }
+ }
+}
+
+void gen_B_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int N, int star) {
+ for (int i = 0; i < N / 2; i++) {
+ int out = node_idx + i;
+ int in1 = node_idx + N - 1 - i;
+ if (star == 1) {
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, out, -1, in1,
+ 1);
+ } else {
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, in1,
+ 1);
+ }
+ }
+ for (int i = N / 2; i < N; i++) {
+ int out = node_idx + i;
+ int in1 = node_idx + N - 1 - i;
+ if (star == 1) {
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, in1,
+ 1);
+ } else {
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, out, -1, in1,
+ 1);
+ }
+ }
+}
+
+void gen_P_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int N) {
+ int max_bit = get_max_bit(N - 1);
+ for (int i = 0; i < N; i++) {
+ int out = node_idx + bitwise_reverse(i, max_bit);
+ int in = node_idx + i;
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
+ }
+}
+
+void gen_type1_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int N) {
+ int max_bit = get_max_bit(N);
+ for (int ni = 0; ni < N / 2; ni++) {
+ int ai = bitwise_reverse(N + ni, max_bit);
+ int out = node_idx + ni;
+ int in1 = node_idx + N - ni - 1;
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, out,
+ sin(PI * ai / (2 * 2 * N)), in1, cos(PI * ai / (2 * 2 * N)));
+ }
+ for (int ni = N / 2; ni < N; ni++) {
+ int ai = bitwise_reverse(N + ni, max_bit);
+ int out = node_idx + ni;
+ int in1 = node_idx + N - ni - 1;
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, out,
+ cos(PI * ai / (2 * 2 * N)), in1, -sin(PI * ai / (2 * 2 * N)));
+ }
+}
+
+void gen_type2_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int N) {
+ for (int ni = 0; ni < N / 4; ni++) {
+ int out = node_idx + ni;
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out, 0);
+ }
+
+ for (int ni = N / 4; ni < N / 2; ni++) {
+ int out = node_idx + ni;
+ int in1 = node_idx + N - ni - 1;
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, out,
+ -cos(PI / 4), in1, cos(-PI / 4));
+ }
+
+ for (int ni = N / 2; ni < N * 3 / 4; ni++) {
+ int out = node_idx + ni;
+ int in1 = node_idx + N - ni - 1;
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, out,
+ cos(-PI / 4), in1, cos(PI / 4));
+ }
+
+ for (int ni = N * 3 / 4; ni < N; ni++) {
+ int out = node_idx + ni;
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out, 0);
+ }
+}
+
+void gen_type3_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int idx, int N) {
+ // TODO(angiebird): Simplify and clarify this function
+
+ int i = 2 * N / (1 << (idx / 2));
+ int max_bit =
+ get_max_bit(i / 2) - 1; // the max_bit counts on i/2 instead of N here
+ int N_over_i = 2 << (idx / 2);
+
+ for (int nj = 0; nj < N / 2; nj += N_over_i) {
+ int j = nj / (N_over_i);
+ int kj = bitwise_reverse(i / 4 + j, max_bit);
+ // printf("kj = %d\n", kj);
+
+ // I_N/2i --- 0
+ int offset = nj;
+ for (int ni = 0; ni < N_over_i / 4; ni++) {
+ int out = node_idx + offset + ni;
+ int in = out;
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
+ }
+
+ // -C_Kj/i --- S_Kj/i
+ offset += N_over_i / 4;
+ for (int ni = 0; ni < N_over_i / 4; ni++) {
+ int out = node_idx + offset + ni;
+ int in0 = out;
+ double w0 = -cos(kj * PI / i);
+ int in1 = N - (offset + ni) - 1 + node_idx;
+ double w1 = sin(kj * PI / i);
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1,
+ w1);
+ }
+
+ // S_kj/i --- -C_Kj/i
+ offset += N_over_i / 4;
+ for (int ni = 0; ni < N_over_i / 4; ni++) {
+ int out = node_idx + offset + ni;
+ int in0 = out;
+ double w0 = -sin(kj * PI / i);
+ int in1 = N - (offset + ni) - 1 + node_idx;
+ double w1 = -cos(kj * PI / i);
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1,
+ w1);
+ }
+
+ // I_N/2i --- 0
+ offset += N_over_i / 4;
+ for (int ni = 0; ni < N_over_i / 4; ni++) {
+ int out = node_idx + offset + ni;
+ int in = out;
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
+ }
+ }
+
+ for (int nj = N / 2; nj < N; nj += N_over_i) {
+ int j = nj / N_over_i;
+ int kj = bitwise_reverse(i / 4 + j, max_bit);
+
+ // I_N/2i --- 0
+ int offset = nj;
+ for (int ni = 0; ni < N_over_i / 4; ni++) {
+ int out = node_idx + offset + ni;
+ int in = out;
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
+ }
+
+ // C_kj/i --- -S_Kj/i
+ offset += N_over_i / 4;
+ for (int ni = 0; ni < N_over_i / 4; ni++) {
+ int out = node_idx + offset + ni;
+ int in0 = out;
+ double w0 = cos(kj * PI / i);
+ int in1 = N - (offset + ni) - 1 + node_idx;
+ double w1 = -sin(kj * PI / i);
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1,
+ w1);
+ }
+
+ // S_kj/i --- C_Kj/i
+ offset += N_over_i / 4;
+ for (int ni = 0; ni < N_over_i / 4; ni++) {
+ int out = node_idx + offset + ni;
+ int in0 = out;
+ double w0 = sin(kj * PI / i);
+ int in1 = N - (offset + ni) - 1 + node_idx;
+ double w1 = cos(kj * PI / i);
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1,
+ w1);
+ }
+
+ // I_N/2i --- 0
+ offset += N_over_i / 4;
+ for (int ni = 0; ni < N_over_i / 4; ni++) {
+ int out = node_idx + offset + ni;
+ int in = out;
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
+ }
+ }
+}
+
+void gen_type4_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int idx, int N) {
+ int B_size = 1 << ((idx + 1) / 2);
+ for (int ni = 0; ni < N; ni += B_size) {
+ gen_B_graph(node, stage_num, node_num, stage_idx, node_idx + ni, B_size,
+ (ni / B_size) % 2);
+ }
+}
+
+void gen_R_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int N) {
+ int max_idx = 2 * (get_max_bit(N) + 1) - 3;
+ for (int idx = 0; idx < max_idx; idx++) {
+ int s = stage_idx + max_idx - idx - 1;
+ if (idx == 0) {
+ // type 1
+ gen_type1_graph(node, stage_num, node_num, s, node_idx, N);
+ } else if (idx == max_idx - 1) {
+ // type 2
+ gen_type2_graph(node, stage_num, node_num, s, node_idx, N);
+ } else if ((idx + 1) % 2 == 0) {
+ // type 4
+ gen_type4_graph(node, stage_num, node_num, s, node_idx, idx, N);
+ } else if ((idx + 1) % 2 == 1) {
+ // type 3
+ gen_type3_graph(node, stage_num, node_num, s, node_idx, idx, N);
+ } else {
+ printf("check gen_R_graph()\n");
+ }
+ }
+}
+
+void gen_DCT_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int N) {
+ if (N > 2) {
+ gen_B_graph(node, stage_num, node_num, stage_idx, node_idx, N, 0);
+ gen_DCT_graph(node, stage_num, node_num, stage_idx + 1, node_idx, N / 2);
+ gen_R_graph(node, stage_num, node_num, stage_idx + 1, node_idx + N / 2,
+ N / 2);
+ } else {
+ // generate dct_2
+ connect_node(node, stage_num, node_num, stage_idx + 1, node_idx, node_idx,
+ cos(PI / 4), node_idx + 1, cos(PI / 4));
+ connect_node(node, stage_num, node_num, stage_idx + 1, node_idx + 1,
+ node_idx + 1, -cos(PI / 4), node_idx, cos(PI / 4));
+ }
+}
+
+int get_dct_stage_num(int size) { return 2 * get_max_bit(size); }
+
+void gen_DCT_graph_1d(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int dct_node_num) {
+ gen_DCT_graph(node, stage_num, node_num, stage_idx, node_idx, dct_node_num);
+ int dct_stage_num = get_dct_stage_num(dct_node_num);
+ gen_P_graph(node, stage_num, node_num, stage_idx + dct_stage_num - 2,
+ node_idx, dct_node_num);
+}
+
+void gen_adst_B_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_idx) {
+ int size = 1 << (adst_idx + 1);
+ for (int ni = 0; ni < size / 2; ni++) {
+ int nOut = node_idx + ni;
+ int nIn = nOut + size / 2;
+ connect_node(node, stage_num, node_num, stage_idx + 1, nOut, nOut, 1, nIn,
+ 1);
+ // printf("nOut: %d nIn: %d\n", nOut, nIn);
+ }
+ for (int ni = size / 2; ni < size; ni++) {
+ int nOut = node_idx + ni;
+ int nIn = nOut - size / 2;
+ connect_node(node, stage_num, node_num, stage_idx + 1, nOut, nOut, -1, nIn,
+ 1);
+ // printf("ndctOut: %d nIn: %d\n", nOut, nIn);
+ }
+}
+
+void gen_adst_U_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_idx, int adst_node_num) {
+ int size = 1 << (adst_idx + 1);
+ for (int ni = 0; ni < adst_node_num; ni += size) {
+ gen_adst_B_graph(node, stage_num, node_num, stage_idx, node_idx + ni,
+ adst_idx);
+ }
+}
+
+void gen_adst_T_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, double freq) {
+ connect_node(node, stage_num, node_num, stage_idx + 1, node_idx, node_idx,
+ cos(freq * PI), node_idx + 1, sin(freq * PI));
+ connect_node(node, stage_num, node_num, stage_idx + 1, node_idx + 1,
+ node_idx + 1, -cos(freq * PI), node_idx, sin(freq * PI));
+}
+
+void gen_adst_E_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_idx) {
+ int size = 1 << (adst_idx);
+ for (int i = 0; i < size / 2; i++) {
+ int ni = i * 2;
+ double fi = (1 + 4 * i) * 1.0 / (1 << (adst_idx + 1));
+ gen_adst_T_graph(node, stage_num, node_num, stage_idx, node_idx + ni, fi);
+ }
+}
+
+void gen_adst_V_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_idx, int adst_node_num) {
+ int size = 1 << (adst_idx);
+ for (int i = 0; i < adst_node_num / size; i++) {
+ if (i % 2 == 1) {
+ int ni = i * size;
+ gen_adst_E_graph(node, stage_num, node_num, stage_idx, node_idx + ni,
+ adst_idx);
+ }
+ }
+}
+void gen_adst_VJ_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_node_num) {
+ for (int i = 0; i < adst_node_num / 2; i++) {
+ int ni = i * 2;
+ double fi = (1 + 4 * i) * 1.0 / (4 * adst_node_num);
+ gen_adst_T_graph(node, stage_num, node_num, stage_idx, node_idx + ni, fi);
+ }
+}
+void gen_adst_Q_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_node_num) {
+ // reverse order when idx is 1, 3, 5, 7 ...
+ // example of adst_node_num = 8:
+ // 0 1 2 3 4 5 6 7
+ // --> 0 7 2 5 4 3 6 1
+ for (int ni = 0; ni < adst_node_num; ni++) {
+ if (ni % 2 == 0) {
+ int out = node_idx + ni;
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out,
+ 0);
+ } else {
+ int out = node_idx + ni;
+ int in = node_idx + adst_node_num - ni;
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
+ }
+ }
+}
+void gen_adst_Ibar_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_node_num) {
+ // reverse order
+ // 0 1 2 3 --> 3 2 1 0
+ for (int ni = 0; ni < adst_node_num; ni++) {
+ int out = node_idx + ni;
+ int in = node_idx + adst_node_num - ni - 1;
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
+ }
+}
+
+int get_Q_out2in(int adst_node_num, int out) {
+ int in;
+ if (out % 2 == 0) {
+ in = out;
+ } else {
+ in = adst_node_num - out;
+ }
+ return in;
+}
+
+int get_Ibar_out2in(int adst_node_num, int out) {
+ return adst_node_num - out - 1;
+}
+
+void gen_adst_IbarQ_graph(Node *node, int stage_num, int node_num,
+ int stage_idx, int node_idx, int adst_node_num) {
+ // in -> Ibar -> Q -> out
+ for (int ni = 0; ni < adst_node_num; ni++) {
+ int out = node_idx + ni;
+ int in = node_idx +
+ get_Ibar_out2in(adst_node_num, get_Q_out2in(adst_node_num, ni));
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
+ }
+}
+
+void gen_adst_D_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_node_num) {
+ // reverse order
+ for (int ni = 0; ni < adst_node_num; ni++) {
+ int out = node_idx + ni;
+ int in = out;
+ if (ni % 2 == 0) {
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
+ } else {
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, in, -1, in,
+ 0);
+ }
+ }
+}
+
+int get_hadamard_idx(int x, int adst_node_num) {
+ int max_bit = get_max_bit(adst_node_num - 1);
+ x = bitwise_reverse(x, max_bit);
+
+ // gray code
+ int c = x & 1;
+ int p = x & 1;
+ int y = c;
+
+ for (int i = 1; i <= max_bit; i++) {
+ p = c;
+ c = (x >> i) & 1;
+ y += (c ^ p) << i;
+ }
+ return y;
+}
+
+void gen_adst_Ht_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_node_num) {
+ for (int ni = 0; ni < adst_node_num; ni++) {
+ int out = node_idx + ni;
+ int in = node_idx + get_hadamard_idx(ni, adst_node_num);
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
+ }
+}
+
+void gen_adst_HtD_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_node_num) {
+ for (int ni = 0; ni < adst_node_num; ni++) {
+ int out = node_idx + ni;
+ int in = node_idx + get_hadamard_idx(ni, adst_node_num);
+ double inW;
+ if (ni % 2 == 0)
+ inW = 1;
+ else
+ inW = -1;
+ connect_node(node, stage_num, node_num, stage_idx + 1, out, in, inW, in, 0);
+ }
+}
+
+int get_adst_stage_num(int adst_node_num) {
+ return 2 * get_max_bit(adst_node_num) + 2;
+}
+
+int gen_iadst_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_node_num) {
+ int max_bit = get_max_bit(adst_node_num);
+ int si = 0;
+ gen_adst_IbarQ_graph(node, stage_num, node_num, stage_idx + si, node_idx,
+ adst_node_num);
+ si++;
+ gen_adst_VJ_graph(node, stage_num, node_num, stage_idx + si, node_idx,
+ adst_node_num);
+ si++;
+ for (int adst_idx = max_bit - 1; adst_idx >= 1; adst_idx--) {
+ gen_adst_U_graph(node, stage_num, node_num, stage_idx + si, node_idx,
+ adst_idx, adst_node_num);
+ si++;
+ gen_adst_V_graph(node, stage_num, node_num, stage_idx + si, node_idx,
+ adst_idx, adst_node_num);
+ si++;
+ }
+ gen_adst_HtD_graph(node, stage_num, node_num, stage_idx + si, node_idx,
+ adst_node_num);
+ si++;
+ return si + 1;
+}
+
+int gen_adst_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_node_num) {
+ int hybrid_stage_num = get_hybrid_stage_num(TYPE_ADST, adst_node_num);
+ // generate a adst tempNode
+ Node *tempNode = new Node[hybrid_stage_num * adst_node_num];
+ init_graph(tempNode, hybrid_stage_num, adst_node_num);
+ int si = gen_iadst_graph(tempNode, hybrid_stage_num, adst_node_num, 0, 0,
+ adst_node_num);
+
+ // tempNode's inverse graph to node[stage_idx][node_idx]
+ gen_inv_graph(tempNode, hybrid_stage_num, adst_node_num, node, stage_num,
+ node_num, stage_idx, node_idx);
+ delete[] tempNode;
+ return si;
+}
+
+void connect_layer_2d(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int dct_node_num) {
+ for (int first = 0; first < dct_node_num; first++) {
+ for (int second = 0; second < dct_node_num; second++) {
+ // int sIn = stage_idx;
+ int sOut = stage_idx + 1;
+ int nIn = node_idx + first * dct_node_num + second;
+ int nOut = node_idx + second * dct_node_num + first;
+
+ // printf("sIn: %d nIn: %d sOut: %d nOut: %d\n", sIn, nIn, sOut, nOut);
+
+ connect_node(node, stage_num, node_num, sOut, nOut, nIn, 1, nIn, 0);
+ }
+ }
+}
+
+void connect_layer_2d_new(Node *node, int stage_num, int node_num,
+ int stage_idx, int node_idx, int dct_node_num0,
+ int dct_node_num1) {
+ for (int i = 0; i < dct_node_num1; i++) {
+ for (int j = 0; j < dct_node_num0; j++) {
+ // int sIn = stage_idx;
+ int sOut = stage_idx + 1;
+ int nIn = node_idx + i * dct_node_num0 + j;
+ int nOut = node_idx + j * dct_node_num1 + i;
+
+ // printf("sIn: %d nIn: %d sOut: %d nOut: %d\n", sIn, nIn, sOut, nOut);
+
+ connect_node(node, stage_num, node_num, sOut, nOut, nIn, 1, nIn, 0);
+ }
+ }
+}
+
+void gen_DCT_graph_2d(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int dct_node_num) {
+ int dct_stage_num = get_dct_stage_num(dct_node_num);
+ // put 2 layers of dct_node_num DCTs on the graph
+ for (int ni = 0; ni < dct_node_num; ni++) {
+ gen_DCT_graph_1d(node, stage_num, node_num, stage_idx,
+ node_idx + ni * dct_node_num, dct_node_num);
+ gen_DCT_graph_1d(node, stage_num, node_num, stage_idx + dct_stage_num,
+ node_idx + ni * dct_node_num, dct_node_num);
+ }
+ // connect first layer and second layer
+ connect_layer_2d(node, stage_num, node_num, stage_idx + dct_stage_num - 1,
+ node_idx, dct_node_num);
+}
+
+int get_hybrid_stage_num(int type, int hybrid_node_num) {
+ if (type == TYPE_DCT || type == TYPE_IDCT) {
+ return get_dct_stage_num(hybrid_node_num);
+ } else if (type == TYPE_ADST || type == TYPE_IADST) {
+ return get_adst_stage_num(hybrid_node_num);
+ }
+ return 0;
+}
+
+int get_hybrid_2d_stage_num(int type0, int type1, int hybrid_node_num) {
+ int stage_num = 0;
+ stage_num += get_hybrid_stage_num(type0, hybrid_node_num);
+ stage_num += get_hybrid_stage_num(type1, hybrid_node_num);
+ return stage_num;
+}
+
+int get_hybrid_2d_stage_num_new(int type0, int type1, int hybrid_node_num0,
+ int hybrid_node_num1) {
+ int stage_num = 0;
+ stage_num += get_hybrid_stage_num(type0, hybrid_node_num0);
+ stage_num += get_hybrid_stage_num(type1, hybrid_node_num1);
+ return stage_num;
+}
+
+int get_hybrid_amplify_factor(int type, int hybrid_node_num) {
+ return get_max_bit(hybrid_node_num) - 1;
+}
+
+void gen_hybrid_graph_1d(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int hybrid_node_num, int type) {
+ if (type == TYPE_DCT) {
+ gen_DCT_graph_1d(node, stage_num, node_num, stage_idx, node_idx,
+ hybrid_node_num);
+ } else if (type == TYPE_ADST) {
+ gen_adst_graph(node, stage_num, node_num, stage_idx, node_idx,
+ hybrid_node_num);
+ } else if (type == TYPE_IDCT) {
+ int hybrid_stage_num = get_hybrid_stage_num(type, hybrid_node_num);
+ // generate a dct tempNode
+ Node *tempNode = new Node[hybrid_stage_num * hybrid_node_num];
+ init_graph(tempNode, hybrid_stage_num, hybrid_node_num);
+ gen_DCT_graph_1d(tempNode, hybrid_stage_num, hybrid_node_num, 0, 0,
+ hybrid_node_num);
+
+ // tempNode's inverse graph to node[stage_idx][node_idx]
+ gen_inv_graph(tempNode, hybrid_stage_num, hybrid_node_num, node, stage_num,
+ node_num, stage_idx, node_idx);
+ delete[] tempNode;
+ } else if (type == TYPE_IADST) {
+ int hybrid_stage_num = get_hybrid_stage_num(type, hybrid_node_num);
+ // generate a adst tempNode
+ Node *tempNode = new Node[hybrid_stage_num * hybrid_node_num];
+ init_graph(tempNode, hybrid_stage_num, hybrid_node_num);
+ gen_adst_graph(tempNode, hybrid_stage_num, hybrid_node_num, 0, 0,
+ hybrid_node_num);
+
+ // tempNode's inverse graph to node[stage_idx][node_idx]
+ gen_inv_graph(tempNode, hybrid_stage_num, hybrid_node_num, node, stage_num,
+ node_num, stage_idx, node_idx);
+ delete[] tempNode;
+ }
+}
+
+void gen_hybrid_graph_2d(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int hybrid_node_num, int type0,
+ int type1) {
+ int hybrid_stage_num = get_hybrid_stage_num(type0, hybrid_node_num);
+
+ for (int ni = 0; ni < hybrid_node_num; ni++) {
+ gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx,
+ node_idx + ni * hybrid_node_num, hybrid_node_num,
+ type0);
+ gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx + hybrid_stage_num,
+ node_idx + ni * hybrid_node_num, hybrid_node_num,
+ type1);
+ }
+
+ // connect first layer and second layer
+ connect_layer_2d(node, stage_num, node_num, stage_idx + hybrid_stage_num - 1,
+ node_idx, hybrid_node_num);
+}
+
+void gen_hybrid_graph_2d_new(Node *node, int stage_num, int node_num,
+ int stage_idx, int node_idx, int hybrid_node_num0,
+ int hybrid_node_num1, int type0, int type1) {
+ int hybrid_stage_num0 = get_hybrid_stage_num(type0, hybrid_node_num0);
+
+ for (int ni = 0; ni < hybrid_node_num1; ni++) {
+ gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx,
+ node_idx + ni * hybrid_node_num0, hybrid_node_num0,
+ type0);
+ }
+ for (int ni = 0; ni < hybrid_node_num0; ni++) {
+ gen_hybrid_graph_1d(
+ node, stage_num, node_num, stage_idx + hybrid_stage_num0,
+ node_idx + ni * hybrid_node_num1, hybrid_node_num1, type1);
+ }
+
+ // connect first layer and second layer
+ connect_layer_2d_new(node, stage_num, node_num,
+ stage_idx + hybrid_stage_num0 - 1, node_idx,
+ hybrid_node_num0, hybrid_node_num1);
+}
+
+void gen_inv_graph(Node *node, int stage_num, int node_num, Node *invNode,
+ int inv_stage_num, int inv_node_num, int inv_stage_idx,
+ int inv_node_idx) {
+ // clean up inNodeNum in invNode because of add_node
+ for (int si = 1 + inv_stage_idx; si < inv_stage_idx + stage_num; si++) {
+ for (int ni = inv_node_idx; ni < inv_node_idx + node_num; ni++) {
+ int idx = get_idx(si, ni, inv_node_num);
+ invNode[idx].inNodeNum = 0;
+ }
+ }
+ // generate inverse graph of node on invNode
+ for (int si = 1; si < stage_num; si++) {
+ for (int ni = 0; ni < node_num; ni++) {
+ int invSi = stage_num - si;
+ int idx = get_idx(si, ni, node_num);
+ for (int k = 0; k < node[idx].inNodeNum; k++) {
+ int invNi = node[idx].inNodeIdx[k];
+ add_node(invNode, inv_stage_num, inv_node_num, invSi + inv_stage_idx,
+ invNi + inv_node_idx, ni + inv_node_idx,
+ node[idx].inWeight[k]);
+ }
+ }
+ }
+}
diff --git a/third_party/aom/tools/txfm_analyzer/txfm_graph.h b/third_party/aom/tools/txfm_analyzer/txfm_graph.h
new file mode 100644
index 0000000000..8dc36146dd
--- /dev/null
+++ b/third_party/aom/tools/txfm_analyzer/txfm_graph.h
@@ -0,0 +1,160 @@
+/*
+ * Copyright (c) 2018, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#ifndef AOM_TOOLS_TXFM_ANALYZER_TXFM_GRAPH_H_
+#define AOM_TOOLS_TXFM_ANALYZER_TXFM_GRAPH_H_
+
+struct Node {
+ Node *inNode[2];
+ int inNodeNum;
+ int inNodeIdx[2];
+ double inWeight[2];
+ double value;
+ int nodeIdx;
+ int stageIdx;
+ int visited;
+};
+
+#define STAGENUM (10)
+#define NODENUM (32)
+#define COS_MOD (128)
+
+typedef enum {
+ TYPE_DCT = 0,
+ TYPE_ADST,
+ TYPE_IDCT,
+ TYPE_IADST,
+ TYPE_LAST
+} TYPE_TXFM;
+
+TYPE_TXFM get_inv_type(TYPE_TXFM type);
+void get_fun_name(char *str_fun_name, int str_buf_size, const TYPE_TXFM type,
+ const int txfm_size);
+
+void get_txfm_type_name(char *str_fun_name, int str_buf_size,
+ const TYPE_TXFM type, const int txfm_size);
+void get_hybrid_2d_type_name(char *buf, int buf_size, const TYPE_TXFM type0,
+ const TYPE_TXFM type1, const int txfm_size0,
+ const int txfm_size1);
+unsigned int get_max_bit(unsigned int x);
+unsigned int bitwise_reverse(unsigned int x, int max_bit);
+int get_idx(int ri, int ci, int cSize);
+
+int get_dct_stage_num(int size);
+void reference_dct_1d(double *in, double *out, int size);
+void reference_dct_2d(double *in, double *out, int size);
+void connect_node(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int in0, double w0, int in1, double w1);
+void propagate(Node *node, int stage_num, int node_num, int stage);
+void init_graph(Node *node, int stage_num, int node_num);
+void graph_reset_visited(Node *node, int stage_num, int node_num);
+void gen_B_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int N, int star);
+void gen_P_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int N);
+
+void gen_type1_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int N);
+void gen_type2_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int N);
+void gen_type3_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int idx, int N);
+void gen_type4_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int idx, int N);
+
+void gen_R_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int N);
+
+void gen_DCT_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int N);
+
+void gen_DCT_graph_1d(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int dct_node_num);
+void connect_layer_2d(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int dct_node_num);
+
+void gen_DCT_graph_2d(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int dct_node_num);
+
+void gen_adst_B_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_idx);
+
+void gen_adst_U_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_idx, int adst_node_num);
+void gen_adst_T_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, double freq);
+
+void gen_adst_E_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_idx);
+
+void gen_adst_V_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_idx, int adst_node_num);
+
+void gen_adst_VJ_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_node_num);
+void gen_adst_Q_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_node_num);
+void gen_adst_Ibar_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_node_num);
+
+void gen_adst_D_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_node_num);
+
+int get_hadamard_idx(int x, int adst_node_num);
+void gen_adst_Ht_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_node_num);
+
+int gen_adst_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_node_num);
+int gen_iadst_graph(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int adst_node_num);
+void reference_adst_1d(double *in, double *out, int size);
+
+int get_adst_stage_num(int adst_node_num);
+int get_hybrid_stage_num(int type, int hybrid_node_num);
+int get_hybrid_2d_stage_num(int type0, int type1, int hybrid_node_num);
+int get_hybrid_2d_stage_num_new(int type0, int type1, int hybrid_node_num0,
+ int hybrid_node_num1);
+int get_hybrid_amplify_factor(int type, int hybrid_node_num);
+void gen_hybrid_graph_1d(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int hybrid_node_num, int type);
+void gen_hybrid_graph_2d(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int hybrid_node_num, int type0,
+ int type1);
+void gen_hybrid_graph_2d_new(Node *node, int stage_num, int node_num,
+ int stage_idx, int node_idx, int hybrid_node_num0,
+ int hybrid_node_num1, int type0, int type1);
+
+void reference_hybrid_2d(double *in, double *out, int size, int type0,
+ int type1);
+
+void reference_hybrid_2d_new(double *in, double *out, int size0, int size1,
+ int type0, int type1);
+void reference_adst_dct_2d(double *in, double *out, int size);
+
+void gen_code(Node *node, int stage_num, int node_num, TYPE_TXFM type);
+
+void gen_inv_graph(Node *node, int stage_num, int node_num, Node *invNode,
+ int inv_stage_num, int inv_node_num, int inv_stage_idx,
+ int inv_node_idx);
+
+TYPE_TXFM hybrid_char_to_int(char ctype);
+
+int64_t round_shift(int64_t value, int bit);
+void round_shift_array(int32_t *arr, int size, int bit);
+void estimate_value(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int estimate_bit);
+void amplify_value(Node *node, int stage_num, int node_num, int stage_idx,
+ int node_idx, int estimate_bit);
+void propagate_estimate_amlify(Node *node, int stage_num, int node_num,
+ int stage_idx, int amplify_bit,
+ int estimate_bit);
+#endif // AOM_TOOLS_TXFM_ANALYZER_TXFM_GRAPH_H_
diff --git a/third_party/aom/tools/wrap-commit-msg.py b/third_party/aom/tools/wrap-commit-msg.py
new file mode 100755
index 0000000000..c51ed093d3
--- /dev/null
+++ b/third_party/aom/tools/wrap-commit-msg.py
@@ -0,0 +1,72 @@
+#!/usr/bin/env python3
+##
+## Copyright (c) 2016, Alliance for Open Media. All rights reserved
+##
+## This source code is subject to the terms of the BSD 2 Clause License and
+## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+## was not distributed with this source code in the LICENSE file, you can
+## obtain it at www.aomedia.org/license/software. If the Alliance for Open
+## Media Patent License 1.0 was not distributed with this source code in the
+## PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+##
+"""Wraps paragraphs of text, preserving manual formatting
+
+This is like fold(1), but has the special convention of not modifying lines
+that start with whitespace. This allows you to intersperse blocks with
+special formatting, like code blocks, with written prose. The prose will
+be wordwrapped, and the manual formatting will be preserved.
+
+ * This won't handle the case of a bulleted (or ordered) list specially, so
+ manual wrapping must be done.
+
+Occasionally it's useful to put something with explicit formatting that
+doesn't look at all like a block of text inline.
+
+ indicator = has_leading_whitespace(line);
+ if (indicator)
+ preserve_formatting(line);
+
+The intent is that this docstring would make it through the transform
+and still be legible and presented as it is in the source. If additional
+cases are handled, update this doc to describe the effect.
+"""
+
+__author__ = "jkoleszar@google.com"
+import textwrap
+import sys
+
+def wrap(text):
+ if text:
+ return textwrap.fill(text, break_long_words=False) + '\n'
+ return ""
+
+
+def main(fileobj):
+ text = ""
+ output = ""
+ while True:
+ line = fileobj.readline()
+ if not line:
+ break
+
+ if line.lstrip() == line:
+ text += line
+ else:
+ output += wrap(text)
+ text=""
+ output += line
+ output += wrap(text)
+
+ # Replace the file or write to stdout.
+ if fileobj == sys.stdin:
+ fileobj = sys.stdout
+ else:
+ fileobj.seek(0)
+ fileobj.truncate(0)
+ fileobj.write(output)
+
+if __name__ == "__main__":
+ if len(sys.argv) > 1:
+ main(open(sys.argv[1], "r+"))
+ else:
+ main(sys.stdin)