summaryrefslogtreecommitdiffstats
path: root/third_party/aom/tools/auto_refactor
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--third_party/aom/tools/auto_refactor/auto_refactor.py919
-rw-r--r--third_party/aom/tools/auto_refactor/av1_preprocess.py113
-rw-r--r--third_party/aom/tools/auto_refactor/c_files/decl_status_code.c31
-rw-r--r--third_party/aom/tools/auto_refactor/c_files/func_in_out.c208
-rw-r--r--third_party/aom/tools/auto_refactor/c_files/global_variable.c27
-rw-r--r--third_party/aom/tools/auto_refactor/c_files/parse_lvalue.c46
-rw-r--r--third_party/aom/tools/auto_refactor/c_files/simple_code.c64
-rw-r--r--third_party/aom/tools/auto_refactor/c_files/struct_code.c49
-rw-r--r--third_party/aom/tools/auto_refactor/test_auto_refactor.py675
9 files changed, 2132 insertions, 0 deletions
diff --git a/third_party/aom/tools/auto_refactor/auto_refactor.py b/third_party/aom/tools/auto_refactor/auto_refactor.py
new file mode 100644
index 0000000000..dd0d4415f9
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/auto_refactor.py
@@ -0,0 +1,919 @@
+# Copyright (c) 2021, Alliance for Open Media. All rights reserved
+#
+# This source code is subject to the terms of the BSD 2 Clause License and
+# the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+# was not distributed with this source code in the LICENSE file, you can
+# obtain it at www.aomedia.org/license/software. If the Alliance for Open
+# Media Patent License 1.0 was not distributed with this source code in the
+# PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+#
+
+from __future__ import print_function
+import sys
+import os
+import operator
+from pycparser import c_parser, c_ast, parse_file
+from math import *
+
+from inspect import currentframe, getframeinfo
+from collections import deque
+
+
+def debug_print(frameinfo):
+ print('******** ERROR:', frameinfo.filename, frameinfo.lineno, '********')
+
+
+class StructItem():
+
+ def __init__(self,
+ typedef_name=None,
+ struct_name=None,
+ struct_node=None,
+ is_union=False):
+ self.typedef_name = typedef_name
+ self.struct_name = struct_name
+ self.struct_node = struct_node
+ self.is_union = is_union
+ self.child_decl_map = None
+
+ def __str__(self):
+ return str(self.typedef_name) + ' ' + str(self.struct_name) + ' ' + str(
+ self.is_union)
+
+ def compute_child_decl_map(self, struct_info):
+ self.child_decl_map = {}
+ if self.struct_node != None and self.struct_node.decls != None:
+ for decl_node in self.struct_node.decls:
+ if decl_node.name == None:
+ for sub_decl_node in decl_node.type.decls:
+ sub_decl_status = parse_decl_node(struct_info, sub_decl_node)
+ self.child_decl_map[sub_decl_node.name] = sub_decl_status
+ else:
+ decl_status = parse_decl_node(struct_info, decl_node)
+ self.child_decl_map[decl_status.name] = decl_status
+
+ def get_child_decl_status(self, decl_name):
+ if self.child_decl_map == None:
+ debug_print(getframeinfo(currentframe()))
+ print('child_decl_map is None')
+ return None
+ if decl_name not in self.child_decl_map:
+ debug_print(getframeinfo(currentframe()))
+ print(decl_name, 'does not exist ')
+ return None
+ return self.child_decl_map[decl_name]
+
+
+class StructInfo():
+
+ def __init__(self):
+ self.struct_name_dic = {}
+ self.typedef_name_dic = {}
+ self.enum_value_dic = {} # enum value -> enum_node
+ self.enum_name_dic = {} # enum name -> enum_node
+ self.struct_item_list = []
+
+ def get_struct_by_typedef_name(self, typedef_name):
+ if typedef_name in self.typedef_name_dic:
+ return self.typedef_name_dic[typedef_name]
+ else:
+ return None
+
+ def get_struct_by_struct_name(self, struct_name):
+ if struct_name in self.struct_name_dic:
+ return self.struct_name_dic[struct_name]
+ else:
+ debug_print(getframeinfo(currentframe()))
+ print('Cant find', struct_name)
+ return None
+
+ def update_struct_item_list(self):
+ # Collect all struct_items from struct_name_dic and typedef_name_dic
+ # Compute child_decl_map for each struct item.
+ for struct_name in self.struct_name_dic.keys():
+ struct_item = self.struct_name_dic[struct_name]
+ struct_item.compute_child_decl_map(self)
+ self.struct_item_list.append(struct_item)
+
+ for typedef_name in self.typedef_name_dic.keys():
+ struct_item = self.typedef_name_dic[typedef_name]
+ if struct_item.struct_name not in self.struct_name_dic:
+ struct_item.compute_child_decl_map(self)
+ self.struct_item_list.append(struct_item)
+
+ def update_enum(self, enum_node):
+ if enum_node.name != None:
+ self.enum_name_dic[enum_node.name] = enum_node
+
+ if enum_node.values != None:
+ enumerator_list = enum_node.values.enumerators
+ for enumerator in enumerator_list:
+ self.enum_value_dic[enumerator.name] = enum_node
+
+ def update(self,
+ typedef_name=None,
+ struct_name=None,
+ struct_node=None,
+ is_union=False):
+ """T: typedef_name S: struct_name N: struct_node
+
+ T S N
+ case 1: o o o
+ typedef struct P {
+ int u;
+ } K;
+ T S N
+ case 2: o o x
+ typedef struct P K;
+
+ T S N
+ case 3: x o o
+ struct P {
+ int u;
+ };
+
+ T S N
+ case 4: o x o
+ typedef struct {
+ int u;
+ } K;
+ """
+ struct_item = None
+
+ # Check whether struct_name or typedef_name is already in the dictionary
+ if struct_name in self.struct_name_dic:
+ struct_item = self.struct_name_dic[struct_name]
+
+ if typedef_name in self.typedef_name_dic:
+ struct_item = self.typedef_name_dic[typedef_name]
+
+ if struct_item == None:
+ struct_item = StructItem(typedef_name, struct_name, struct_node, is_union)
+
+ if struct_node.decls != None:
+ struct_item.struct_node = struct_node
+
+ if struct_name != None:
+ self.struct_name_dic[struct_name] = struct_item
+
+ if typedef_name != None:
+ self.typedef_name_dic[typedef_name] = struct_item
+
+
+class StructDefVisitor(c_ast.NodeVisitor):
+
+ def __init__(self):
+ self.struct_info = StructInfo()
+
+ def visit_Struct(self, node):
+ if node.decls != None:
+ self.struct_info.update(None, node.name, node)
+ self.generic_visit(node)
+
+ def visit_Union(self, node):
+ if node.decls != None:
+ self.struct_info.update(None, node.name, node, True)
+ self.generic_visit(node)
+
+ def visit_Enum(self, node):
+ self.struct_info.update_enum(node)
+ self.generic_visit(node)
+
+ def visit_Typedef(self, node):
+ if node.type.__class__.__name__ == 'TypeDecl':
+ typedecl = node.type
+ if typedecl.type.__class__.__name__ == 'Struct':
+ struct_node = typedecl.type
+ typedef_name = node.name
+ struct_name = struct_node.name
+ self.struct_info.update(typedef_name, struct_name, struct_node)
+ elif typedecl.type.__class__.__name__ == 'Union':
+ union_node = typedecl.type
+ typedef_name = node.name
+ union_name = union_node.name
+ self.struct_info.update(typedef_name, union_name, union_node, True)
+ # TODO(angiebird): Do we need to deal with enum here?
+ self.generic_visit(node)
+
+
+def build_struct_info(ast):
+ v = StructDefVisitor()
+ v.visit(ast)
+ struct_info = v.struct_info
+ struct_info.update_struct_item_list()
+ return v.struct_info
+
+
+class DeclStatus():
+
+ def __init__(self, name, struct_item=None, is_ptr_decl=False):
+ self.name = name
+ self.struct_item = struct_item
+ self.is_ptr_decl = is_ptr_decl
+
+ def get_child_decl_status(self, decl_name):
+ if self.struct_item != None:
+ return self.struct_item.get_child_decl_status(decl_name)
+ else:
+ #TODO(angiebird): 2. Investigage the situation when a struct's definition can't be found.
+ return None
+
+ def __str__(self):
+ return str(self.struct_item) + ' ' + str(self.name) + ' ' + str(
+ self.is_ptr_decl)
+
+
+def peel_ptr_decl(decl_type_node):
+ """ Remove PtrDecl and ArrayDecl layer """
+ is_ptr_decl = False
+ peeled_decl_type_node = decl_type_node
+ while peeled_decl_type_node.__class__.__name__ == 'PtrDecl' or peeled_decl_type_node.__class__.__name__ == 'ArrayDecl':
+ is_ptr_decl = True
+ peeled_decl_type_node = peeled_decl_type_node.type
+ return is_ptr_decl, peeled_decl_type_node
+
+
+def parse_peeled_decl_type_node(struct_info, node):
+ struct_item = None
+ if node.__class__.__name__ == 'TypeDecl':
+ if node.type.__class__.__name__ == 'IdentifierType':
+ identifier_type_node = node.type
+ typedef_name = identifier_type_node.names[0]
+ struct_item = struct_info.get_struct_by_typedef_name(typedef_name)
+ elif node.type.__class__.__name__ == 'Struct':
+ struct_node = node.type
+ if struct_node.name != None:
+ struct_item = struct_info.get_struct_by_struct_name(struct_node.name)
+ else:
+ struct_item = StructItem(None, None, struct_node, False)
+ struct_item.compute_child_decl_map(struct_info)
+ elif node.type.__class__.__name__ == 'Union':
+ # TODO(angiebird): Special treatment for Union?
+ struct_node = node.type
+ if struct_node.name != None:
+ struct_item = struct_info.get_struct_by_struct_name(struct_node.name)
+ else:
+ struct_item = StructItem(None, None, struct_node, True)
+ struct_item.compute_child_decl_map(struct_info)
+ elif node.type.__class__.__name__ == 'Enum':
+ # TODO(angiebird): Special treatment for Union?
+ struct_node = node.type
+ struct_item = None
+ else:
+ print('Unrecognized peeled_decl_type_node.type',
+ node.type.__class__.__name__)
+ else:
+ # debug_print(getframeinfo(currentframe()))
+ # print(node.__class__.__name__)
+ #TODO(angiebird): Do we need to take care of this part?
+ pass
+
+ return struct_item
+
+
+def parse_decl_node(struct_info, decl_node):
+ # struct_item is None if this decl_node is not a struct_item
+ decl_node_name = decl_node.name
+ decl_type_node = decl_node.type
+ is_ptr_decl, peeled_decl_type_node = peel_ptr_decl(decl_type_node)
+ struct_item = parse_peeled_decl_type_node(struct_info, peeled_decl_type_node)
+ return DeclStatus(decl_node_name, struct_item, is_ptr_decl)
+
+
+def get_lvalue_lead(lvalue_node):
+ """return '&' or '*' of lvalue if available"""
+ if lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '&':
+ return '&'
+ elif lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '*':
+ return '*'
+ return None
+
+
+def parse_lvalue(lvalue_node):
+ """get id_chain from lvalue"""
+ id_chain = parse_lvalue_recursive(lvalue_node, [])
+ return id_chain
+
+
+def parse_lvalue_recursive(lvalue_node, id_chain):
+ """cpi->rd->u -> (cpi->rd)->u"""
+ if lvalue_node.__class__.__name__ == 'ID':
+ id_chain.append(lvalue_node.name)
+ id_chain.reverse()
+ return id_chain
+ elif lvalue_node.__class__.__name__ == 'StructRef':
+ id_chain.append(lvalue_node.field.name)
+ return parse_lvalue_recursive(lvalue_node.name, id_chain)
+ elif lvalue_node.__class__.__name__ == 'ArrayRef':
+ return parse_lvalue_recursive(lvalue_node.name, id_chain)
+ elif lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '&':
+ return parse_lvalue_recursive(lvalue_node.expr, id_chain)
+ elif lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '*':
+ return parse_lvalue_recursive(lvalue_node.expr, id_chain)
+ else:
+ return None
+
+
+class FuncDefVisitor(c_ast.NodeVisitor):
+ func_dictionary = {}
+
+ def visit_FuncDef(self, node):
+ func_name = node.decl.name
+ self.func_dictionary[func_name] = node
+
+
+def build_func_dictionary(ast):
+ v = FuncDefVisitor()
+ v.visit(ast)
+ return v.func_dictionary
+
+
+def get_func_start_coord(func_node):
+ return func_node.coord
+
+
+def find_end_node(node):
+ node_list = []
+ for c in node:
+ node_list.append(c)
+ if len(node_list) == 0:
+ return node
+ else:
+ return find_end_node(node_list[-1])
+
+
+def get_func_end_coord(func_node):
+ return find_end_node(func_node).coord
+
+
+def get_func_size(func_node):
+ start_coord = get_func_start_coord(func_node)
+ end_coord = get_func_end_coord(func_node)
+ if start_coord.file == end_coord.file:
+ return end_coord.line - start_coord.line + 1
+ else:
+ return None
+
+
+def save_object(obj, filename):
+ with open(filename, 'wb') as obj_fp:
+ pickle.dump(obj, obj_fp, protocol=-1)
+
+
+def load_object(filename):
+ obj = None
+ with open(filename, 'rb') as obj_fp:
+ obj = pickle.load(obj_fp)
+ return obj
+
+
+def get_av1_ast(gen_ast=False):
+ # TODO(angiebird): Generalize this path
+ c_filename = './av1_pp.c'
+ print('generate ast')
+ ast = parse_file(c_filename)
+ #save_object(ast, ast_file)
+ print('finished generate ast')
+ return ast
+
+
+def get_func_param_id_map(func_def_node):
+ param_id_map = {}
+ func_decl = func_def_node.decl.type
+ param_list = func_decl.args.params
+ for decl in param_list:
+ param_id_map[decl.name] = decl
+ return param_id_map
+
+
+class IDTreeStack():
+
+ def __init__(self, global_id_tree):
+ self.stack = deque()
+ self.global_id_tree = global_id_tree
+
+ def add_link_node(self, node, link_id_chain):
+ link_node = self.add_id_node(link_id_chain)
+ node.link_node = link_node
+ node.link_id_chain = link_id_chain
+
+ def push_id_tree(self, id_tree=None):
+ if id_tree == None:
+ id_tree = IDStatusNode()
+ self.stack.append(id_tree)
+ return id_tree
+
+ def pop_id_tree(self):
+ return self.stack.pop()
+
+ def add_id_seed_node(self, id_seed, decl_status):
+ return self.stack[-1].add_child(id_seed, decl_status)
+
+ def get_id_seed_node(self, id_seed):
+ idx = len(self.stack) - 1
+ while idx >= 0:
+ id_node = self.stack[idx].get_child(id_seed)
+ if id_node != None:
+ return id_node
+ idx -= 1
+
+ id_node = self.global_id_tree.get_child(id_seed)
+ if id_node != None:
+ return id_node
+ return None
+
+ def add_id_node(self, id_chain):
+ id_seed = id_chain[0]
+ id_seed_node = self.get_id_seed_node(id_seed)
+ if id_seed_node == None:
+ return None
+ if len(id_chain) == 1:
+ return id_seed_node
+ return id_seed_node.add_descendant(id_chain[1:])
+
+ def get_id_node(self, id_chain):
+ id_seed = id_chain[0]
+ id_seed_node = self.get_id_seed_node(id_seed)
+ if id_seed_node == None:
+ return None
+ if len(id_chain) == 1:
+ return id_seed_node
+ return id_seed_node.get_descendant(id_chain[1:])
+
+ def top(self):
+ return self.stack[-1]
+
+
+class IDStatusNode():
+
+ def __init__(self, name=None, root=None):
+ if root is None:
+ self.root = self
+ else:
+ self.root = root
+
+ self.name = name
+
+ self.parent = None
+ self.children = {}
+
+ self.assign = False
+ self.last_assign_coord = None
+ self.refer = False
+ self.last_refer_coord = None
+
+ self.decl_status = None
+
+ self.link_id_chain = None
+ self.link_node = None
+
+ self.visit = False
+
+ def set_link_id_chain(self, link_id_chain):
+ self.set_assign(False)
+ self.link_id_chain = link_id_chain
+ self.link_node = self.root.get_descendant(link_id_chain)
+
+ def set_link_node(self, link_node):
+ self.set_assign(False)
+ self.link_id_chain = ['*']
+ self.link_node = link_node
+
+ def get_link_id_chain(self):
+ return self.link_id_chain
+
+ def get_concrete_node(self):
+ if self.visit == True:
+ # return None when there is a loop
+ return None
+ self.visit = True
+ if self.link_node == None:
+ self.visit = False
+ return self
+ else:
+ concrete_node = self.link_node.get_concrete_node()
+ self.visit = False
+ if concrete_node == None:
+ return self
+ return concrete_node
+
+ def set_assign(self, assign, coord=None):
+ concrete_node = self.get_concrete_node()
+ concrete_node.assign = assign
+ concrete_node.last_assign_coord = coord
+
+ def get_assign(self):
+ concrete_node = self.get_concrete_node()
+ return concrete_node.assign
+
+ def set_refer(self, refer, coord=None):
+ concrete_node = self.get_concrete_node()
+ concrete_node.refer = refer
+ concrete_node.last_refer_coord = coord
+
+ def get_refer(self):
+ concrete_node = self.get_concrete_node()
+ return concrete_node.refer
+
+ def set_parent(self, parent):
+ concrete_node = self.get_concrete_node()
+ concrete_node.parent = parent
+
+ def add_child(self, name, decl_status=None):
+ concrete_node = self.get_concrete_node()
+ if name not in concrete_node.children:
+ child_id_node = IDStatusNode(name, concrete_node.root)
+ concrete_node.children[name] = child_id_node
+ if decl_status == None:
+ # Check if the child decl_status can be inferred from its parent's
+ # decl_status
+ if self.decl_status != None:
+ decl_status = self.decl_status.get_child_decl_status(name)
+ child_id_node.set_decl_status(decl_status)
+ return concrete_node.children[name]
+
+ def get_child(self, name):
+ concrete_node = self.get_concrete_node()
+ if name in concrete_node.children:
+ return concrete_node.children[name]
+ else:
+ return None
+
+ def add_descendant(self, id_chain):
+ current_node = self.get_concrete_node()
+ for name in id_chain:
+ current_node.add_child(name)
+ parent_node = current_node
+ current_node = current_node.get_child(name)
+ current_node.set_parent(parent_node)
+ return current_node
+
+ def get_descendant(self, id_chain):
+ current_node = self.get_concrete_node()
+ for name in id_chain:
+ current_node = current_node.get_child(name)
+ if current_node == None:
+ return None
+ return current_node
+
+ def get_children(self):
+ current_node = self.get_concrete_node()
+ return current_node.children
+
+ def set_decl_status(self, decl_status):
+ current_node = self.get_concrete_node()
+ current_node.decl_status = decl_status
+
+ def get_decl_status(self):
+ current_node = self.get_concrete_node()
+ return current_node.decl_status
+
+ def __str__(self):
+ if self.link_id_chain is None:
+ return str(self.name) + ' a: ' + str(int(self.assign)) + ' r: ' + str(
+ int(self.refer))
+ else:
+ return str(self.name) + ' -> ' + ' '.join(self.link_id_chain)
+
+ def collect_assign_refer_status(self,
+ id_chain=None,
+ assign_ls=None,
+ refer_ls=None):
+ if id_chain == None:
+ id_chain = []
+ if assign_ls == None:
+ assign_ls = []
+ if refer_ls == None:
+ refer_ls = []
+ id_chain.append(self.name)
+ if self.assign:
+ info_str = ' '.join([
+ ' '.join(id_chain[1:]), 'a:',
+ str(int(self.assign)), 'r:',
+ str(int(self.refer)),
+ str(self.last_assign_coord)
+ ])
+ assign_ls.append(info_str)
+ if self.refer:
+ info_str = ' '.join([
+ ' '.join(id_chain[1:]), 'a:',
+ str(int(self.assign)), 'r:',
+ str(int(self.refer)),
+ str(self.last_refer_coord)
+ ])
+ refer_ls.append(info_str)
+ for c in self.children:
+ self.children[c].collect_assign_refer_status(id_chain, assign_ls,
+ refer_ls)
+ id_chain.pop()
+ return assign_ls, refer_ls
+
+ def show(self):
+ assign_ls, refer_ls = self.collect_assign_refer_status()
+ print('---- assign ----')
+ for item in assign_ls:
+ print(item)
+ print('---- refer ----')
+ for item in refer_ls:
+ print(item)
+
+
+class FuncInOutVisitor(c_ast.NodeVisitor):
+
+ def __init__(self,
+ func_def_node,
+ struct_info,
+ func_dictionary,
+ keep_body_id_tree=True,
+ call_param_map=None,
+ global_id_tree=None,
+ func_history=None,
+ unknown=None):
+ self.func_dictionary = func_dictionary
+ self.struct_info = struct_info
+ self.param_id_map = get_func_param_id_map(func_def_node)
+ self.parent_node = None
+ self.global_id_tree = global_id_tree
+ self.body_id_tree = None
+ self.keep_body_id_tree = keep_body_id_tree
+ if func_history == None:
+ self.func_history = {}
+ else:
+ self.func_history = func_history
+
+ if unknown == None:
+ self.unknown = []
+ else:
+ self.unknown = unknown
+
+ self.id_tree_stack = IDTreeStack(global_id_tree)
+ self.id_tree_stack.push_id_tree()
+
+ #TODO move this part into a function
+ for param in self.param_id_map:
+ decl_node = self.param_id_map[param]
+ decl_status = parse_decl_node(self.struct_info, decl_node)
+ descendant = self.id_tree_stack.add_id_seed_node(decl_status.name,
+ decl_status)
+ if call_param_map is not None and param in call_param_map:
+ # This is a function call.
+ # Map the input parameter to the caller's nodes
+ # TODO(angiebird): Can we use add_link_node here?
+ descendant.set_link_node(call_param_map[param])
+
+ def get_id_tree_stack(self):
+ return self.id_tree_stack
+
+ def generic_visit(self, node):
+ prev_parent = self.parent_node
+ self.parent_node = node
+ for c in node:
+ self.visit(c)
+ self.parent_node = prev_parent
+
+ # TODO rename
+ def add_new_id_tree(self, node):
+ self.id_tree_stack.push_id_tree()
+ self.generic_visit(node)
+ id_tree = self.id_tree_stack.pop_id_tree()
+ if self.parent_node == None and self.keep_body_id_tree == True:
+ # this is function body
+ self.body_id_tree = id_tree
+
+ def visit_For(self, node):
+ self.add_new_id_tree(node)
+
+ def visit_Compound(self, node):
+ self.add_new_id_tree(node)
+
+ def visit_Decl(self, node):
+ if node.type.__class__.__name__ != 'FuncDecl':
+ decl_status = parse_decl_node(self.struct_info, node)
+ descendant = self.id_tree_stack.add_id_seed_node(decl_status.name,
+ decl_status)
+ if node.init is not None:
+ init_id_chain = self.process_lvalue(node.init)
+ if init_id_chain != None:
+ if decl_status.struct_item is None:
+ init_descendant = self.id_tree_stack.add_id_node(init_id_chain)
+ if init_descendant != None:
+ init_descendant.set_refer(True, node.coord)
+ else:
+ self.unknown.append(node)
+ descendant.set_assign(True, node.coord)
+ else:
+ self.id_tree_stack.add_link_node(descendant, init_id_chain)
+ else:
+ self.unknown.append(node)
+ else:
+ descendant.set_assign(True, node.coord)
+ self.generic_visit(node)
+
+ def is_lvalue(self, node):
+ if self.parent_node is None:
+ # TODO(angiebird): Do every lvalue has parent_node != None?
+ return False
+ if self.parent_node.__class__.__name__ == 'StructRef':
+ return False
+ if self.parent_node.__class__.__name__ == 'ArrayRef' and node == self.parent_node.name:
+ # if node == self.parent_node.subscript, the node could be lvalue
+ return False
+ if self.parent_node.__class__.__name__ == 'UnaryOp' and self.parent_node.op == '&':
+ return False
+ if self.parent_node.__class__.__name__ == 'UnaryOp' and self.parent_node.op == '*':
+ return False
+ return True
+
+ def process_lvalue(self, node):
+ id_chain = parse_lvalue(node)
+ if id_chain == None:
+ return id_chain
+ elif id_chain[0] in self.struct_info.enum_value_dic:
+ return None
+ else:
+ return id_chain
+
+ def process_possible_lvalue(self, node):
+ if self.is_lvalue(node):
+ id_chain = self.process_lvalue(node)
+ lead_char = get_lvalue_lead(node)
+ # make sure the id is not an enum value
+ if id_chain == None:
+ self.unknown.append(node)
+ return
+ descendant = self.id_tree_stack.add_id_node(id_chain)
+ if descendant == None:
+ self.unknown.append(node)
+ return
+ decl_status = descendant.get_decl_status()
+ if decl_status == None:
+ descendant.set_assign(True, node.coord)
+ descendant.set_refer(True, node.coord)
+ self.unknown.append(node)
+ return
+ if self.parent_node.__class__.__name__ == 'Assignment':
+ if node is self.parent_node.lvalue:
+ if decl_status.struct_item != None:
+ if len(id_chain) > 1:
+ descendant.set_assign(True, node.coord)
+ elif len(id_chain) == 1:
+ if lead_char == '*':
+ descendant.set_assign(True, node.coord)
+ else:
+ right_id_chain = self.process_lvalue(self.parent_node.rvalue)
+ if right_id_chain != None:
+ self.id_tree_stack.add_link_node(descendant, right_id_chain)
+ else:
+ #TODO(angiebird): 1.Find a better way to deal with this case.
+ descendant.set_assign(True, node.coord)
+ else:
+ debug_print(getframeinfo(currentframe()))
+ else:
+ descendant.set_assign(True, node.coord)
+ elif node is self.parent_node.rvalue:
+ if decl_status.struct_item is None:
+ descendant.set_refer(True, node.coord)
+ if lead_char == '&':
+ descendant.set_assign(True, node.coord)
+ else:
+ left_id_chain = self.process_lvalue(self.parent_node.lvalue)
+ left_lead_char = get_lvalue_lead(self.parent_node.lvalue)
+ if left_id_chain != None:
+ if len(left_id_chain) > 1:
+ descendant.set_refer(True, node.coord)
+ elif len(left_id_chain) == 1:
+ if left_lead_char == '*':
+ descendant.set_refer(True, node.coord)
+ else:
+ #TODO(angiebird): Check whether the other node is linked to this node.
+ pass
+ else:
+ self.unknown.append(self.parent_node.lvalue)
+ debug_print(getframeinfo(currentframe()))
+ else:
+ self.unknown.append(self.parent_node.lvalue)
+ debug_print(getframeinfo(currentframe()))
+ else:
+ debug_print(getframeinfo(currentframe()))
+ elif self.parent_node.__class__.__name__ == 'UnaryOp':
+ # TODO(angiebird): Consider +=, *=, -=, /= etc
+ if self.parent_node.op == '--' or self.parent_node.op == '++' or\
+ self.parent_node.op == 'p--' or self.parent_node.op == 'p++':
+ descendant.set_assign(True, node.coord)
+ descendant.set_refer(True, node.coord)
+ else:
+ descendant.set_refer(True, node.coord)
+ elif self.parent_node.__class__.__name__ == 'Decl':
+ #The logic is at visit_Decl
+ pass
+ elif self.parent_node.__class__.__name__ == 'ExprList':
+ #The logic is at visit_FuncCall
+ pass
+ else:
+ descendant.set_refer(True, node.coord)
+
+ def visit_ID(self, node):
+ # If the parent is a FuncCall, this ID is a function name.
+ if self.parent_node.__class__.__name__ != 'FuncCall':
+ self.process_possible_lvalue(node)
+ self.generic_visit(node)
+
+ def visit_StructRef(self, node):
+ self.process_possible_lvalue(node)
+ self.generic_visit(node)
+
+ def visit_ArrayRef(self, node):
+ self.process_possible_lvalue(node)
+ self.generic_visit(node)
+
+ def visit_UnaryOp(self, node):
+ if node.op == '&' or node.op == '*':
+ self.process_possible_lvalue(node)
+ self.generic_visit(node)
+
+ def visit_FuncCall(self, node):
+ if node.name.__class__.__name__ == 'ID':
+ if node.name.name in self.func_dictionary:
+ if node.name.name not in self.func_history:
+ self.func_history[node.name.name] = True
+ func_def_node = self.func_dictionary[node.name.name]
+ call_param_map = self.process_func_call(node, func_def_node)
+
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary, False,
+ call_param_map, self.global_id_tree,
+ self.func_history, self.unknown)
+ visitor.visit(func_def_node.body)
+ else:
+ self.unknown.append(node)
+ self.generic_visit(node)
+
+ def process_func_call(self, func_call_node, func_def_node):
+ # set up a refer/assign for func parameters
+ # return call_param_map
+ call_param_ls = func_call_node.args.exprs
+ call_param_map = {}
+
+ func_decl = func_def_node.decl.type
+ decl_param_ls = func_decl.args.params
+ for param_node, decl_node in zip(call_param_ls, decl_param_ls):
+ id_chain = self.process_lvalue(param_node)
+ if id_chain != None:
+ descendant = self.id_tree_stack.add_id_node(id_chain)
+ if descendant == None:
+ self.unknown.append(param_node)
+ else:
+ decl_status = descendant.get_decl_status()
+ if decl_status != None:
+ if decl_status.struct_item == None:
+ if decl_status.is_ptr_decl == True:
+ descendant.set_assign(True, param_node.coord)
+ descendant.set_refer(True, param_node.coord)
+ else:
+ descendant.set_refer(True, param_node.coord)
+ else:
+ call_param_map[decl_node.name] = descendant
+ else:
+ self.unknown.append(param_node)
+ else:
+ self.unknown.append(param_node)
+ return call_param_map
+
+
+def build_global_id_tree(ast, struct_info):
+ global_id_tree = IDStatusNode()
+ for node in ast.ext:
+ if node.__class__.__name__ == 'Decl':
+ # id tree is for tracking assign/refer status
+ # we don't care about function id because they can't be changed
+ if node.type.__class__.__name__ != 'FuncDecl':
+ decl_status = parse_decl_node(struct_info, node)
+ descendant = global_id_tree.add_child(decl_status.name, decl_status)
+ return global_id_tree
+
+
+class FuncAnalyzer():
+
+ def __init__(self):
+ self.ast = get_av1_ast()
+ self.struct_info = build_struct_info(self.ast)
+ self.func_dictionary = build_func_dictionary(self.ast)
+ self.global_id_tree = build_global_id_tree(self.ast, self.struct_info)
+
+ def analyze(self, func_name):
+ if func_name in self.func_dictionary:
+ func_def_node = self.func_dictionary[func_name]
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary, True, None,
+ self.global_id_tree)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ root.top().show()
+ else:
+ print(func_name, "doesn't exist")
+
+
+if __name__ == '__main__':
+ fa = FuncAnalyzer()
+ fa.analyze('tpl_get_satd_cost')
+ pass
diff --git a/third_party/aom/tools/auto_refactor/av1_preprocess.py b/third_party/aom/tools/auto_refactor/av1_preprocess.py
new file mode 100644
index 0000000000..ea76912cf1
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/av1_preprocess.py
@@ -0,0 +1,113 @@
+# Copyright (c) 2021, Alliance for Open Media. All rights reserved
+#
+# This source code is subject to the terms of the BSD 2 Clause License and
+# the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+# was not distributed with this source code in the LICENSE file, you can
+# obtain it at www.aomedia.org/license/software. If the Alliance for Open
+# Media Patent License 1.0 was not distributed with this source code in the
+# PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+#
+
+import os
+import sys
+
+
+def is_code_file(filename):
+ return filename.endswith(".c") or filename.endswith(".h")
+
+
+def is_simd_file(filename):
+ simd_keywords = [
+ "avx2", "sse2", "sse3", "ssse3", "sse4", "dspr2", "neon", "msa", "simd",
+ "x86"
+ ]
+ for keyword in simd_keywords:
+ if filename.find(keyword) >= 0:
+ return True
+ return False
+
+
+def get_code_file_list(path, exclude_file_set):
+ code_file_list = []
+ for cur_dir, sub_dir, file_list in os.walk(path):
+ for filename in file_list:
+ if is_code_file(filename) and not is_simd_file(
+ filename) and filename not in exclude_file_set:
+ file_path = os.path.join(cur_dir, filename)
+ code_file_list.append(file_path)
+ return code_file_list
+
+
+def av1_exclude_file_set():
+ exclude_file_set = {
+ "cfl_ppc.c",
+ "ppc_cpudetect.c",
+ }
+ return exclude_file_set
+
+
+def get_av1_pp_command(fake_header_dir, code_file_list):
+ pre_command = "gcc -w -nostdinc -E -I./ -I../ -I" + fake_header_dir + (" "
+ "-D'ATTRIBUTE_PACKED='"
+ " "
+ "-D'__attribute__(x)='"
+ " "
+ "-D'__inline__='"
+ " "
+ "-D'float_t=float'"
+ " "
+ "-D'DECLARE_ALIGNED(n,"
+ " typ,"
+ " "
+ "val)=typ"
+ " val'"
+ " "
+ "-D'volatile='"
+ " "
+ "-D'AV1_K_MEANS_DIM=2'"
+ " "
+ "-D'INLINE='"
+ " "
+ "-D'AOM_INLINE='"
+ " "
+ "-D'AOM_FORCE_INLINE='"
+ " "
+ "-D'inline='"
+ )
+ return pre_command + " " + " ".join(code_file_list)
+
+
+def modify_av1_rtcd(build_dir):
+ av1_rtcd = os.path.join(build_dir, "config/av1_rtcd.h")
+ fp = open(av1_rtcd)
+ string = fp.read()
+ fp.close()
+ new_string = string.replace("#ifdef RTCD_C", "#if 0")
+ fp = open(av1_rtcd, "w")
+ fp.write(new_string)
+ fp.close()
+
+
+def preprocess_av1(aom_dir, build_dir, fake_header_dir):
+ cur_dir = os.getcwd()
+ output = os.path.join(cur_dir, "av1_pp.c")
+ path_list = [
+ os.path.join(aom_dir, "av1/encoder"),
+ os.path.join(aom_dir, "av1/common")
+ ]
+ code_file_list = []
+ for path in path_list:
+ path = os.path.realpath(path)
+ code_file_list.extend(get_code_file_list(path, av1_exclude_file_set()))
+ modify_av1_rtcd(build_dir)
+ cmd = get_av1_pp_command(fake_header_dir, code_file_list) + " >" + output
+ os.chdir(build_dir)
+ os.system(cmd)
+ os.chdir(cur_dir)
+
+
+if __name__ == "__main__":
+ aom_dir = sys.argv[1]
+ build_dir = sys.argv[2]
+ fake_header_dir = sys.argv[3]
+ preprocess_av1(aom_dir, build_dir, fake_header_dir)
diff --git a/third_party/aom/tools/auto_refactor/c_files/decl_status_code.c b/third_party/aom/tools/auto_refactor/c_files/decl_status_code.c
new file mode 100644
index 0000000000..a444553bb1
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/c_files/decl_status_code.c
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+typedef struct S1 {
+ int x;
+} T1;
+
+int parse_decl_node_2(void) { int arr[3]; }
+
+int parse_decl_node_3(void) { int *a; }
+
+int parse_decl_node_4(void) { T1 t1[3]; }
+
+int parse_decl_node_5(void) { T1 *t2[3]; }
+
+int parse_decl_node_6(void) { T1 t3[3][3]; }
+
+int main(void) {
+ int a;
+ T1 t1;
+ struct S1 s1;
+ T1 *t2;
+}
diff --git a/third_party/aom/tools/auto_refactor/c_files/func_in_out.c b/third_party/aom/tools/auto_refactor/c_files/func_in_out.c
new file mode 100644
index 0000000000..7f37bbae7e
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/c_files/func_in_out.c
@@ -0,0 +1,208 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+typedef struct XD {
+ int u;
+ int v;
+} XD;
+
+typedef struct RD {
+ XD *xd;
+ int u;
+ int v;
+} RD;
+
+typedef struct VP9_COMP {
+ int y;
+ RD *rd;
+ RD rd2;
+ int arr[3];
+ union {
+ int z;
+ };
+ struct {
+ int w;
+ };
+} VP9_COMP;
+
+int sub_func(VP9_COMP *cpi, int b) {
+ int d;
+ cpi->y += 1;
+ cpi->y -= b;
+ d = cpi->y * 2;
+ return d;
+}
+
+int func_id_forrest_show(VP9_COMP *cpi, int b) {
+ int c = 2;
+ int x = cpi->y + c * 2 + 1;
+ int y;
+ RD *rd = cpi->rd;
+ y = cpi->rd->u;
+ return x + y;
+}
+
+int func_link_id_chain_1(VP9_COMP *cpi) {
+ RD *rd = cpi->rd;
+ rd->u = 0;
+}
+
+int func_link_id_chain_2(VP9_COMP *cpi) {
+ RD *rd = cpi->rd;
+ XD *xd = rd->xd;
+ xd->u = 0;
+}
+
+int func_assign_refer_status_1(VP9_COMP *cpi) { RD *rd = cpi->rd; }
+
+int func_assign_refer_status_2(VP9_COMP *cpi) {
+ RD *rd2;
+ rd2 = cpi->rd;
+}
+
+int func_assign_refer_status_3(VP9_COMP *cpi) {
+ int a;
+ a = cpi->y;
+}
+
+int func_assign_refer_status_4(VP9_COMP *cpi) {
+ int *b;
+ b = &cpi->y;
+}
+
+int func_assign_refer_status_5(VP9_COMP *cpi) {
+ RD *rd5;
+ rd5 = &cpi->rd2;
+}
+
+int func_assign_refer_status_6(VP9_COMP *cpi, VP9_COMP *cpi2) {
+ cpi->rd = cpi2->rd;
+}
+
+int func_assign_refer_status_7(VP9_COMP *cpi, VP9_COMP *cpi2) {
+ cpi->arr[3] = 0;
+}
+
+int func_assign_refer_status_8(VP9_COMP *cpi, VP9_COMP *cpi2) {
+ int x = cpi->arr[3];
+}
+
+int func_assign_refer_status_9(VP9_COMP *cpi) {
+ {
+ RD *rd = cpi->rd;
+ { rd->u = 0; }
+ }
+}
+
+int func_assign_refer_status_10(VP9_COMP *cpi) { cpi->arr[cpi->rd->u] = 0; }
+
+int func_assign_refer_status_11(VP9_COMP *cpi) {
+ RD *rd11 = &cpi->rd2;
+ rd11->v = 1;
+}
+
+int func_assign_refer_status_12(VP9_COMP *cpi, VP9_COMP *cpi2) {
+ *cpi->rd = *cpi2->rd;
+}
+
+int func_assign_refer_status_13(VP9_COMP *cpi) {
+ cpi->z = 0;
+ cpi->w = 0;
+}
+
+int func(VP9_COMP *cpi, int x) {
+ int a;
+ cpi->y = 4;
+ a = 3 + cpi->y;
+ a = a * x;
+ cpi->y *= 4;
+ RD *ref_rd = cpi->rd;
+ ref_rd->u = 0;
+ cpi->rd2.v = 1;
+ cpi->rd->v = 1;
+ RD *ref_rd2 = &cpi->rd2;
+ RD **ref_rd3 = &(&cpi->rd2);
+ int b = sub_func(cpi, a);
+ cpi->rd->v++;
+ return b;
+}
+
+int func_sub_call_1(VP9_COMP *cpi2, int x) { cpi2->y = 4; }
+
+int func_call_1(VP9_COMP *cpi, int y) { func_sub_call_1(cpi, y); }
+
+int func_sub_call_2(VP9_COMP *cpi2, RD *rd, int x) { rd->u = 0; }
+
+int func_call_2(VP9_COMP *cpi, int y) { func_sub_call_2(cpi, &cpi->rd, y); }
+
+int func_sub_call_3(VP9_COMP *cpi2, int x) {}
+
+int func_call_3(VP9_COMP *cpi, int y) { func_sub_call_3(cpi, ++cpi->y); }
+
+int func_sub_sub_call_4(VP9_COMP *cpi3, XD *xd) {
+ cpi3->rd.u = 0;
+ xd->u = 0;
+}
+
+int func_sub_call_4(VP9_COMP *cpi2, RD *rd) {
+ func_sub_sub_call_4(cpi2, rd->xd);
+}
+
+int func_call_4(VP9_COMP *cpi, int y) { func_sub_call_4(cpi, &cpi->rd); }
+
+int func_sub_call_5(VP9_COMP *cpi) {
+ cpi->y = 2;
+ func_call_5(cpi);
+}
+
+int func_call_5(VP9_COMP *cpi) { func_sub_call_5(cpi); }
+
+int func_compound_1(VP9_COMP *cpi) {
+ for (int i = 0; i < 10; ++i) {
+ cpi->y++;
+ }
+}
+
+int func_compound_2(VP9_COMP *cpi) {
+ for (int i = 0; i < cpi->y; ++i) {
+ cpi->rd->u = i;
+ }
+}
+
+int func_compound_3(VP9_COMP *cpi) {
+ int i = 3;
+ while (i > 0) {
+ cpi->rd->u = i;
+ i--;
+ }
+}
+
+int func_compound_4(VP9_COMP *cpi) {
+ while (cpi->y-- >= 0) {
+ }
+}
+
+int func_compound_5(VP9_COMP *cpi) {
+ do {
+ } while (cpi->y-- >= 0);
+}
+
+int func_compound_6(VP9_COMP *cpi) {
+ for (int i = 0; i < 10; ++i) cpi->y--;
+}
+
+int main(void) {
+ int x;
+ VP9_COMP cpi;
+ RD rd;
+ cpi->rd = rd;
+ func(&cpi, x);
+}
diff --git a/third_party/aom/tools/auto_refactor/c_files/global_variable.c b/third_party/aom/tools/auto_refactor/c_files/global_variable.c
new file mode 100644
index 0000000000..26d5385e97
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/c_files/global_variable.c
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+extern const int global_a[13];
+
+const int global_b = 0;
+
+typedef struct S1 {
+ int x;
+} T1;
+
+struct S3 {
+ int x;
+} s3;
+
+int func_global_1(int *a) {
+ *a = global_a[3];
+ return 0;
+}
diff --git a/third_party/aom/tools/auto_refactor/c_files/parse_lvalue.c b/third_party/aom/tools/auto_refactor/c_files/parse_lvalue.c
new file mode 100644
index 0000000000..fa44d72381
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/c_files/parse_lvalue.c
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+typedef struct RD {
+ int u;
+ int v;
+ int arr[3];
+} RD;
+
+typedef struct VP9_COMP {
+ int y;
+ RD *rd;
+ RD rd2;
+ RD rd3[2];
+} VP9_COMP;
+
+int parse_lvalue_2(VP9_COMP *cpi) { RD *rd2 = &cpi->rd2; }
+
+int func(VP9_COMP *cpi, int x) {
+ cpi->rd->u = 0;
+
+ int y;
+ y = 0;
+
+ cpi->rd2.v = 0;
+
+ cpi->rd->arr[2] = 0;
+
+ cpi->rd3[1]->arr[2] = 0;
+
+ return 0;
+}
+
+int main(void) {
+ int x = 0;
+ VP9_COMP cpi;
+ func(&cpi, x);
+}
diff --git a/third_party/aom/tools/auto_refactor/c_files/simple_code.c b/third_party/aom/tools/auto_refactor/c_files/simple_code.c
new file mode 100644
index 0000000000..902cd1d826
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/c_files/simple_code.c
@@ -0,0 +1,64 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+typedef struct S {
+ int x;
+ int y;
+ int z;
+} S;
+
+typedef struct T {
+ S s;
+} T;
+
+int d(S *s) {
+ ++s->x;
+ s->x--;
+ s->y = s->y + 1;
+ int *c = &s->x;
+ S ss;
+ ss.x = 1;
+ ss.x += 2;
+ ss.z *= 2;
+ return 0;
+}
+int b(S *s) {
+ d(s);
+ return 0;
+}
+int c(int x) {
+ if (x) {
+ c(x - 1);
+ } else {
+ S s;
+ d(&s);
+ }
+ return 0;
+}
+int a(S *s) {
+ b(s);
+ c(1);
+ return 0;
+}
+int e(void) {
+ c(0);
+ return 0;
+}
+int main(void) {
+ int p = 3;
+ S s;
+ s.x = p + 1;
+ s.y = 2;
+ s.z = 3;
+ a(&s);
+ T t;
+ t.s.x = 3;
+}
diff --git a/third_party/aom/tools/auto_refactor/c_files/struct_code.c b/third_party/aom/tools/auto_refactor/c_files/struct_code.c
new file mode 100644
index 0000000000..7f24d41075
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/c_files/struct_code.c
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2021, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+typedef struct S1 {
+ int x;
+} T1;
+
+struct S3 {
+ int x;
+};
+
+typedef struct {
+ int x;
+ struct S3 s3;
+} T4;
+
+typedef union U5 {
+ int x;
+ double y;
+} T5;
+
+typedef struct S6 {
+ struct {
+ int x;
+ };
+ union {
+ int y;
+ int z;
+ };
+} T6;
+
+typedef struct S7 {
+ struct {
+ int x;
+ } y;
+ union {
+ int w;
+ } z;
+} T7;
+
+int main(void) {}
diff --git a/third_party/aom/tools/auto_refactor/test_auto_refactor.py b/third_party/aom/tools/auto_refactor/test_auto_refactor.py
new file mode 100644
index 0000000000..6b1e269efa
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/test_auto_refactor.py
@@ -0,0 +1,675 @@
+#!/usr/bin/env python
+# Copyright (c) 2021, Alliance for Open Media. All rights reserved
+#
+# This source code is subject to the terms of the BSD 2 Clause License and
+# the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+# was not distributed with this source code in the LICENSE file, you can
+# obtain it at www.aomedia.org/license/software. If the Alliance for Open
+# Media Patent License 1.0 was not distributed with this source code in the
+# PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+#
+
+import pprint
+import re
+import os, sys
+import io
+import unittest as googletest
+
+sys.path[0:0] = ['.', '..']
+
+from pycparser import c_parser, parse_file
+from pycparser.c_ast import *
+from pycparser.c_parser import CParser, Coord, ParseError
+
+from auto_refactor import *
+
+
+def get_c_file_path(filename):
+ return os.path.join('c_files', filename)
+
+
+class TestStructInfo(googletest.TestCase):
+
+ def setUp(self):
+ filename = get_c_file_path('struct_code.c')
+ self.ast = parse_file(filename)
+
+ def test_build_struct_info(self):
+ struct_info = build_struct_info(self.ast)
+ typedef_name_dic = struct_info.typedef_name_dic
+ self.assertEqual('T1' in typedef_name_dic, True)
+ self.assertEqual('T4' in typedef_name_dic, True)
+ self.assertEqual('T5' in typedef_name_dic, True)
+
+ struct_name_dic = struct_info.struct_name_dic
+ struct_name = 'S1'
+ self.assertEqual(struct_name in struct_name_dic, True)
+ struct_item = struct_name_dic[struct_name]
+ self.assertEqual(struct_item.is_union, False)
+
+ struct_name = 'S3'
+ self.assertEqual(struct_name in struct_name_dic, True)
+ struct_item = struct_name_dic[struct_name]
+ self.assertEqual(struct_item.is_union, False)
+
+ struct_name = 'U5'
+ self.assertEqual(struct_name in struct_name_dic, True)
+ struct_item = struct_name_dic[struct_name]
+ self.assertEqual(struct_item.is_union, True)
+
+ self.assertEqual(len(struct_info.struct_item_list), 6)
+
+ def test_get_child_decl_status(self):
+ struct_info = build_struct_info(self.ast)
+ struct_item = struct_info.typedef_name_dic['T4']
+
+ decl_status = struct_item.child_decl_map['x']
+ self.assertEqual(decl_status.struct_item, None)
+ self.assertEqual(decl_status.is_ptr_decl, False)
+
+ decl_status = struct_item.child_decl_map['s3']
+ self.assertEqual(decl_status.struct_item.struct_name, 'S3')
+ self.assertEqual(decl_status.is_ptr_decl, False)
+
+ struct_item = struct_info.typedef_name_dic['T6']
+ decl_status = struct_item.child_decl_map['x']
+ self.assertEqual(decl_status.struct_item, None)
+ self.assertEqual(decl_status.is_ptr_decl, False)
+
+ decl_status = struct_item.child_decl_map['y']
+ self.assertEqual(decl_status.struct_item, None)
+ self.assertEqual(decl_status.is_ptr_decl, False)
+
+ decl_status = struct_item.child_decl_map['z']
+ self.assertEqual(decl_status.struct_item, None)
+ self.assertEqual(decl_status.is_ptr_decl, False)
+
+ struct_item = struct_info.typedef_name_dic['T7']
+ decl_status = struct_item.child_decl_map['y']
+ self.assertEqual('x' in decl_status.struct_item.child_decl_map, True)
+
+ struct_item = struct_info.typedef_name_dic['T7']
+ decl_status = struct_item.child_decl_map['z']
+ self.assertEqual('w' in decl_status.struct_item.child_decl_map, True)
+
+
+class TestParseLvalue(googletest.TestCase):
+
+ def setUp(self):
+ filename = get_c_file_path('parse_lvalue.c')
+ self.ast = parse_file(filename)
+ self.func_dictionary = build_func_dictionary(self.ast)
+
+ def test_parse_lvalue(self):
+ func_node = self.func_dictionary['func']
+ func_body_items = func_node.body.block_items
+ id_list = parse_lvalue(func_body_items[0].lvalue)
+ ref_id_list = ['cpi', 'rd', 'u']
+ self.assertEqual(id_list, ref_id_list)
+
+ id_list = parse_lvalue(func_body_items[2].lvalue)
+ ref_id_list = ['y']
+ self.assertEqual(id_list, ref_id_list)
+
+ id_list = parse_lvalue(func_body_items[3].lvalue)
+ ref_id_list = ['cpi', 'rd2', 'v']
+ self.assertEqual(id_list, ref_id_list)
+
+ id_list = parse_lvalue(func_body_items[4].lvalue)
+ ref_id_list = ['cpi', 'rd', 'arr']
+ self.assertEqual(id_list, ref_id_list)
+
+ id_list = parse_lvalue(func_body_items[5].lvalue)
+ ref_id_list = ['cpi', 'rd3', 'arr']
+ self.assertEqual(id_list, ref_id_list)
+
+ def test_parse_lvalue_2(self):
+ func_node = self.func_dictionary['parse_lvalue_2']
+ func_body_items = func_node.body.block_items
+ id_list = parse_lvalue(func_body_items[0].init)
+ ref_id_list = ['cpi', 'rd2']
+ self.assertEqual(id_list, ref_id_list)
+
+
+class TestIDStatusNode(googletest.TestCase):
+
+ def test_add_descendant(self):
+ root = IDStatusNode('root')
+ id_chain1 = ['cpi', 'rd', 'u']
+ id_chain2 = ['cpi', 'rd', 'v']
+ root.add_descendant(id_chain1)
+ root.add_descendant(id_chain2)
+
+ ref_children_list1 = ['cpi']
+ children_list1 = list(root.children.keys())
+ self.assertEqual(children_list1, ref_children_list1)
+
+ ref_children_list2 = ['rd']
+ children_list2 = list(root.children['cpi'].children.keys())
+ self.assertEqual(children_list2, ref_children_list2)
+
+ ref_children_list3 = ['u', 'v']
+ children_list3 = list(root.children['cpi'].children['rd'].children.keys())
+ self.assertEqual(children_list3, ref_children_list3)
+
+ def test_get_descendant(self):
+ root = IDStatusNode('root')
+ id_chain1 = ['cpi', 'rd', 'u']
+ id_chain2 = ['cpi', 'rd', 'v']
+ ref_descendant_1 = root.add_descendant(id_chain1)
+ ref_descendant_2 = root.add_descendant(id_chain2)
+
+ descendant_1 = root.get_descendant(id_chain1)
+ self.assertEqual(descendant_1 is ref_descendant_1, True)
+
+ descendant_2 = root.get_descendant(id_chain2)
+ self.assertEqual(descendant_2 is ref_descendant_2, True)
+
+ id_chain3 = ['cpi', 'rd', 'h']
+ descendant_3 = root.get_descendant(id_chain3)
+ self.assertEqual(descendant_3, None)
+
+
+class TestFuncInOut(googletest.TestCase):
+
+ def setUp(self):
+ c_filename = get_c_file_path('func_in_out.c')
+ self.ast = parse_file(c_filename)
+ self.func_dictionary = build_func_dictionary(self.ast)
+ self.struct_info = build_struct_info(self.ast)
+
+ def test_get_func_param_id_map(self):
+ func_def_node = self.func_dictionary['func']
+ param_id_map = get_func_param_id_map(func_def_node)
+ ref_param_id_map_keys = ['cpi', 'x']
+ self.assertEqual(list(param_id_map.keys()), ref_param_id_map_keys)
+
+ def test_assign_refer_status_1(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_1']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ body_id_tree = visitor.body_id_tree
+
+ id_chain = ['rd']
+ descendant = body_id_tree.get_descendant(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), False)
+ ref_link_id_chain = ['cpi', 'rd']
+ self.assertEqual(ref_link_id_chain, descendant.get_link_id_chain())
+
+ id_chain = ['cpi', 'rd']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), False)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ def test_assign_refer_status_2(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_2']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ body_id_tree = visitor.body_id_tree
+
+ id_chain = ['rd2']
+ descendant = body_id_tree.get_descendant(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), False)
+
+ ref_link_id_chain = ['cpi', 'rd']
+ self.assertEqual(ref_link_id_chain, descendant.get_link_id_chain())
+
+ id_chain = ['cpi', 'rd']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), False)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ def test_assign_refer_status_3(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_3']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ body_id_tree = visitor.body_id_tree
+
+ id_chain = ['a']
+ descendant = body_id_tree.get_descendant(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), True)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ def test_assign_refer_status_4(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_4']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ body_id_tree = visitor.body_id_tree
+
+ id_chain = ['b']
+ descendant = body_id_tree.get_descendant(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), True)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ def test_assign_refer_status_5(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_5']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ body_id_tree = visitor.body_id_tree
+
+ id_chain = ['rd5']
+ descendant = body_id_tree.get_descendant(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), False)
+
+ id_chain = ['cpi', 'rd2']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), False)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ def test_assign_refer_status_6(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_6']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+
+ id_chain = ['cpi', 'rd']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ id_chain = ['cpi2', 'rd']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), True)
+ self.assertEqual(None, descendant.get_link_id_chain())
+
+ def test_assign_refer_status_7(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_7']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'arr']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_assign_refer_status_8(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_8']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'arr']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), True)
+
+ def test_assign_refer_status_9(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_9']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'rd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_assign_refer_status_10(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_10']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'rd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), True)
+
+ id_chain = ['cpi', 'arr']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_assign_refer_status_11(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_11']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'rd2', 'v']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_assign_refer_status_12(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_12']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'rd']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ id_chain = ['cpi2', 'rd']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), True)
+
+ def test_assign_refer_status_13(self):
+ func_def_node = self.func_dictionary['func_assign_refer_status_13']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'z']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ id_chain = ['cpi', 'w']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_id_status_forrest_1(self):
+ func_def_node = self.func_dictionary['func']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack().top()
+ children_names = set(root.get_children().keys())
+ ref_children_names = set(['cpi', 'x'])
+ self.assertEqual(children_names, ref_children_names)
+
+ root = visitor.body_id_tree
+ children_names = set(root.get_children().keys())
+ ref_children_names = set(['a', 'ref_rd', 'ref_rd2', 'ref_rd3', 'b'])
+ self.assertEqual(children_names, ref_children_names)
+
+ def test_id_status_forrest_show(self):
+ func_def_node = self.func_dictionary['func_id_forrest_show']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ visitor.get_id_tree_stack().top().show()
+
+ def test_id_status_forrest_2(self):
+ func_def_node = self.func_dictionary['func_id_forrest_show']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack().top()
+ self.assertEqual(root, root.root)
+
+ id_chain = ['cpi', 'rd']
+ descendant = root.get_descendant(id_chain)
+ self.assertEqual(root, descendant.root)
+
+ id_chain = ['b']
+ descendant = root.get_descendant(id_chain)
+ self.assertEqual(root, descendant.root)
+
+ def test_link_id_chain_1(self):
+ func_def_node = self.func_dictionary['func_link_id_chain_1']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'rd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+
+ def test_link_id_chain_2(self):
+ func_def_node = self.func_dictionary['func_link_id_chain_2']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'rd', 'xd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+
+ def test_func_call_1(self):
+ func_def_node = self.func_dictionary['func_call_1']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ id_chain = ['y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), True)
+
+ def test_func_call_2(self):
+ func_def_node = self.func_dictionary['func_call_2']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'rd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ id_chain = ['cpi', 'rd']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_func_call_3(self):
+ func_def_node = self.func_dictionary['func_call_3']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), True)
+
+ def test_func_call_4(self):
+ func_def_node = self.func_dictionary['func_call_4']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+
+ id_chain = ['cpi', 'rd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ id_chain = ['cpi', 'rd', 'xd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_func_call_5(self):
+ func_def_node = self.func_dictionary['func_call_5']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_func_compound_1(self):
+ func_def_node = self.func_dictionary['func_compound_1']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), True)
+
+ def test_func_compound_2(self):
+ func_def_node = self.func_dictionary['func_compound_2']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), False)
+ self.assertEqual(descendant.get_refer(), True)
+
+ id_chain = ['cpi', 'rd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_func_compound_3(self):
+ func_def_node = self.func_dictionary['func_compound_3']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+
+ id_chain = ['cpi', 'rd', 'u']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), False)
+
+ def test_func_compound_4(self):
+ func_def_node = self.func_dictionary['func_compound_4']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), True)
+
+ def test_func_compound_5(self):
+ func_def_node = self.func_dictionary['func_compound_5']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), True)
+
+ def test_func_compound_6(self):
+ func_def_node = self.func_dictionary['func_compound_6']
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ id_chain = ['cpi', 'y']
+ descendant = root.get_id_node(id_chain)
+ self.assertEqual(descendant.get_assign(), True)
+ self.assertEqual(descendant.get_refer(), True)
+
+
+class TestDeclStatus(googletest.TestCase):
+
+ def setUp(self):
+ filename = get_c_file_path('decl_status_code.c')
+ self.ast = parse_file(filename)
+ self.func_dictionary = build_func_dictionary(self.ast)
+ self.struct_info = build_struct_info(self.ast)
+
+ def test_parse_decl_node(self):
+ func_def_node = self.func_dictionary['main']
+ decl_list = func_def_node.body.block_items
+ decl_status = parse_decl_node(self.struct_info, decl_list[0])
+ self.assertEqual(decl_status.name, 'a')
+ self.assertEqual(decl_status.is_ptr_decl, False)
+
+ decl_status = parse_decl_node(self.struct_info, decl_list[1])
+ self.assertEqual(decl_status.name, 't1')
+ self.assertEqual(decl_status.is_ptr_decl, False)
+
+ decl_status = parse_decl_node(self.struct_info, decl_list[2])
+ self.assertEqual(decl_status.name, 's1')
+ self.assertEqual(decl_status.is_ptr_decl, False)
+
+ decl_status = parse_decl_node(self.struct_info, decl_list[3])
+ self.assertEqual(decl_status.name, 't2')
+ self.assertEqual(decl_status.is_ptr_decl, True)
+
+ def test_parse_decl_node_2(self):
+ func_def_node = self.func_dictionary['parse_decl_node_2']
+ decl_list = func_def_node.body.block_items
+ decl_status = parse_decl_node(self.struct_info, decl_list[0])
+ self.assertEqual(decl_status.name, 'arr')
+ self.assertEqual(decl_status.is_ptr_decl, True)
+ self.assertEqual(decl_status.struct_item, None)
+
+ def test_parse_decl_node_3(self):
+ func_def_node = self.func_dictionary['parse_decl_node_3']
+ decl_list = func_def_node.body.block_items
+ decl_status = parse_decl_node(self.struct_info, decl_list[0])
+ self.assertEqual(decl_status.name, 'a')
+ self.assertEqual(decl_status.is_ptr_decl, True)
+ self.assertEqual(decl_status.struct_item, None)
+
+ def test_parse_decl_node_4(self):
+ func_def_node = self.func_dictionary['parse_decl_node_4']
+ decl_list = func_def_node.body.block_items
+ decl_status = parse_decl_node(self.struct_info, decl_list[0])
+ self.assertEqual(decl_status.name, 't1')
+ self.assertEqual(decl_status.is_ptr_decl, True)
+ self.assertEqual(decl_status.struct_item.typedef_name, 'T1')
+ self.assertEqual(decl_status.struct_item.struct_name, 'S1')
+
+ def test_parse_decl_node_5(self):
+ func_def_node = self.func_dictionary['parse_decl_node_5']
+ decl_list = func_def_node.body.block_items
+ decl_status = parse_decl_node(self.struct_info, decl_list[0])
+ self.assertEqual(decl_status.name, 't2')
+ self.assertEqual(decl_status.is_ptr_decl, True)
+ self.assertEqual(decl_status.struct_item.typedef_name, 'T1')
+ self.assertEqual(decl_status.struct_item.struct_name, 'S1')
+
+ def test_parse_decl_node_6(self):
+ func_def_node = self.func_dictionary['parse_decl_node_6']
+ decl_list = func_def_node.body.block_items
+ decl_status = parse_decl_node(self.struct_info, decl_list[0])
+ self.assertEqual(decl_status.name, 't3')
+ self.assertEqual(decl_status.is_ptr_decl, True)
+ self.assertEqual(decl_status.struct_item.typedef_name, 'T1')
+ self.assertEqual(decl_status.struct_item.struct_name, 'S1')
+
+
+if __name__ == '__main__':
+ googletest.main()