summaryrefslogtreecommitdiffstats
path: root/build/clang-plugin/OverrideBaseCallChecker.cpp
blob: 600d4313354e32cfbbb9710f1f7d0236fd26c34f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

#include "OverrideBaseCallChecker.h"
#include "CustomMatchers.h"

void OverrideBaseCallChecker::registerMatchers(MatchFinder *AstMatcher) {
  AstMatcher->addMatcher(cxxRecordDecl(hasBaseClasses()).bind("class"), this);
}

bool OverrideBaseCallChecker::isRequiredBaseMethod(
    const CXXMethodDecl *Method) {
  return hasCustomAttribute<moz_required_base_method>(Method);
}

void OverrideBaseCallChecker::evaluateExpression(
    const Stmt *StmtExpr, std::list<const CXXMethodDecl *> &MethodList) {
  // Continue while we have methods in our list
  if (!MethodList.size()) {
    return;
  }

  if (auto MemberFuncCall = dyn_cast<CXXMemberCallExpr>(StmtExpr)) {
    if (auto Method =
            dyn_cast<CXXMethodDecl>(MemberFuncCall->getDirectCallee())) {
      findBaseMethodCall(Method, MethodList);
    }
  }

  for (auto S : StmtExpr->children()) {
    if (S) {
      evaluateExpression(S, MethodList);
    }
  }
}

void OverrideBaseCallChecker::getRequiredBaseMethod(
    const CXXMethodDecl *Method,
    std::list<const CXXMethodDecl *> &MethodsList) {

  if (isRequiredBaseMethod(Method)) {
    MethodsList.push_back(Method);
  } else {
    // Loop through all it's base methods.
    for (auto BaseMethod = Method->begin_overridden_methods();
         BaseMethod != Method->end_overridden_methods(); BaseMethod++) {
      getRequiredBaseMethod(*BaseMethod, MethodsList);
    }
  }
}

void OverrideBaseCallChecker::findBaseMethodCall(
    const CXXMethodDecl *Method,
    std::list<const CXXMethodDecl *> &MethodsList) {

  MethodsList.remove(Method);
  // Loop also through all it's base methods;
  for (auto BaseMethod = Method->begin_overridden_methods();
       BaseMethod != Method->end_overridden_methods(); BaseMethod++) {
    findBaseMethodCall(*BaseMethod, MethodsList);
  }
}

void OverrideBaseCallChecker::check(const MatchFinder::MatchResult &Result) {
  const char *Error =
      "Method %0 must be called in all overrides, but is not called in "
      "this override defined for class %1";
  const CXXRecordDecl *Decl = Result.Nodes.getNodeAs<CXXRecordDecl>("class");

  // Loop through the methods and look for the ones that are overridden.
  for (auto Method : Decl->methods()) {
    // If this method doesn't override other methods or it doesn't have a body,
    // continue to the next declaration.
    if (!Method->size_overridden_methods() || !Method->hasBody()) {
      continue;
    }

    // Preferred the usage of list instead of vector in order to avoid
    // calling erase-remove when deleting items
    std::list<const CXXMethodDecl *> MethodsList;
    // For each overridden method push it to a list if it meets our
    // criteria
    for (auto BaseMethod = Method->begin_overridden_methods();
         BaseMethod != Method->end_overridden_methods(); BaseMethod++) {
      getRequiredBaseMethod(*BaseMethod, MethodsList);
    }

    // If no method has been found then no annotation was used
    // so checking is not needed
    if (!MethodsList.size()) {
      continue;
    }

    // Loop through the body of our method and search for calls to
    // base methods
    evaluateExpression(Method->getBody(), MethodsList);

    // If list is not empty pop up errors
    for (auto BaseMethod : MethodsList) {
      std::string QualName;
      raw_string_ostream OS(QualName);
      BaseMethod->printQualifiedName(OS);

      diag(Method->getLocation(), Error, DiagnosticIDs::Error)
          << OS.str() << Decl->getName();
    }
  }
}