summaryrefslogtreecommitdiffstats
path: root/third_party/aom/tools/auto_refactor/auto_refactor.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--third_party/aom/tools/auto_refactor/auto_refactor.py919
1 files changed, 919 insertions, 0 deletions
diff --git a/third_party/aom/tools/auto_refactor/auto_refactor.py b/third_party/aom/tools/auto_refactor/auto_refactor.py
new file mode 100644
index 0000000000..dd0d4415f9
--- /dev/null
+++ b/third_party/aom/tools/auto_refactor/auto_refactor.py
@@ -0,0 +1,919 @@
+# Copyright (c) 2021, Alliance for Open Media. All rights reserved
+#
+# This source code is subject to the terms of the BSD 2 Clause License and
+# the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+# was not distributed with this source code in the LICENSE file, you can
+# obtain it at www.aomedia.org/license/software. If the Alliance for Open
+# Media Patent License 1.0 was not distributed with this source code in the
+# PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+#
+
+from __future__ import print_function
+import sys
+import os
+import operator
+from pycparser import c_parser, c_ast, parse_file
+from math import *
+
+from inspect import currentframe, getframeinfo
+from collections import deque
+
+
+def debug_print(frameinfo):
+ print('******** ERROR:', frameinfo.filename, frameinfo.lineno, '********')
+
+
+class StructItem():
+
+ def __init__(self,
+ typedef_name=None,
+ struct_name=None,
+ struct_node=None,
+ is_union=False):
+ self.typedef_name = typedef_name
+ self.struct_name = struct_name
+ self.struct_node = struct_node
+ self.is_union = is_union
+ self.child_decl_map = None
+
+ def __str__(self):
+ return str(self.typedef_name) + ' ' + str(self.struct_name) + ' ' + str(
+ self.is_union)
+
+ def compute_child_decl_map(self, struct_info):
+ self.child_decl_map = {}
+ if self.struct_node != None and self.struct_node.decls != None:
+ for decl_node in self.struct_node.decls:
+ if decl_node.name == None:
+ for sub_decl_node in decl_node.type.decls:
+ sub_decl_status = parse_decl_node(struct_info, sub_decl_node)
+ self.child_decl_map[sub_decl_node.name] = sub_decl_status
+ else:
+ decl_status = parse_decl_node(struct_info, decl_node)
+ self.child_decl_map[decl_status.name] = decl_status
+
+ def get_child_decl_status(self, decl_name):
+ if self.child_decl_map == None:
+ debug_print(getframeinfo(currentframe()))
+ print('child_decl_map is None')
+ return None
+ if decl_name not in self.child_decl_map:
+ debug_print(getframeinfo(currentframe()))
+ print(decl_name, 'does not exist ')
+ return None
+ return self.child_decl_map[decl_name]
+
+
+class StructInfo():
+
+ def __init__(self):
+ self.struct_name_dic = {}
+ self.typedef_name_dic = {}
+ self.enum_value_dic = {} # enum value -> enum_node
+ self.enum_name_dic = {} # enum name -> enum_node
+ self.struct_item_list = []
+
+ def get_struct_by_typedef_name(self, typedef_name):
+ if typedef_name in self.typedef_name_dic:
+ return self.typedef_name_dic[typedef_name]
+ else:
+ return None
+
+ def get_struct_by_struct_name(self, struct_name):
+ if struct_name in self.struct_name_dic:
+ return self.struct_name_dic[struct_name]
+ else:
+ debug_print(getframeinfo(currentframe()))
+ print('Cant find', struct_name)
+ return None
+
+ def update_struct_item_list(self):
+ # Collect all struct_items from struct_name_dic and typedef_name_dic
+ # Compute child_decl_map for each struct item.
+ for struct_name in self.struct_name_dic.keys():
+ struct_item = self.struct_name_dic[struct_name]
+ struct_item.compute_child_decl_map(self)
+ self.struct_item_list.append(struct_item)
+
+ for typedef_name in self.typedef_name_dic.keys():
+ struct_item = self.typedef_name_dic[typedef_name]
+ if struct_item.struct_name not in self.struct_name_dic:
+ struct_item.compute_child_decl_map(self)
+ self.struct_item_list.append(struct_item)
+
+ def update_enum(self, enum_node):
+ if enum_node.name != None:
+ self.enum_name_dic[enum_node.name] = enum_node
+
+ if enum_node.values != None:
+ enumerator_list = enum_node.values.enumerators
+ for enumerator in enumerator_list:
+ self.enum_value_dic[enumerator.name] = enum_node
+
+ def update(self,
+ typedef_name=None,
+ struct_name=None,
+ struct_node=None,
+ is_union=False):
+ """T: typedef_name S: struct_name N: struct_node
+
+ T S N
+ case 1: o o o
+ typedef struct P {
+ int u;
+ } K;
+ T S N
+ case 2: o o x
+ typedef struct P K;
+
+ T S N
+ case 3: x o o
+ struct P {
+ int u;
+ };
+
+ T S N
+ case 4: o x o
+ typedef struct {
+ int u;
+ } K;
+ """
+ struct_item = None
+
+ # Check whether struct_name or typedef_name is already in the dictionary
+ if struct_name in self.struct_name_dic:
+ struct_item = self.struct_name_dic[struct_name]
+
+ if typedef_name in self.typedef_name_dic:
+ struct_item = self.typedef_name_dic[typedef_name]
+
+ if struct_item == None:
+ struct_item = StructItem(typedef_name, struct_name, struct_node, is_union)
+
+ if struct_node.decls != None:
+ struct_item.struct_node = struct_node
+
+ if struct_name != None:
+ self.struct_name_dic[struct_name] = struct_item
+
+ if typedef_name != None:
+ self.typedef_name_dic[typedef_name] = struct_item
+
+
+class StructDefVisitor(c_ast.NodeVisitor):
+
+ def __init__(self):
+ self.struct_info = StructInfo()
+
+ def visit_Struct(self, node):
+ if node.decls != None:
+ self.struct_info.update(None, node.name, node)
+ self.generic_visit(node)
+
+ def visit_Union(self, node):
+ if node.decls != None:
+ self.struct_info.update(None, node.name, node, True)
+ self.generic_visit(node)
+
+ def visit_Enum(self, node):
+ self.struct_info.update_enum(node)
+ self.generic_visit(node)
+
+ def visit_Typedef(self, node):
+ if node.type.__class__.__name__ == 'TypeDecl':
+ typedecl = node.type
+ if typedecl.type.__class__.__name__ == 'Struct':
+ struct_node = typedecl.type
+ typedef_name = node.name
+ struct_name = struct_node.name
+ self.struct_info.update(typedef_name, struct_name, struct_node)
+ elif typedecl.type.__class__.__name__ == 'Union':
+ union_node = typedecl.type
+ typedef_name = node.name
+ union_name = union_node.name
+ self.struct_info.update(typedef_name, union_name, union_node, True)
+ # TODO(angiebird): Do we need to deal with enum here?
+ self.generic_visit(node)
+
+
+def build_struct_info(ast):
+ v = StructDefVisitor()
+ v.visit(ast)
+ struct_info = v.struct_info
+ struct_info.update_struct_item_list()
+ return v.struct_info
+
+
+class DeclStatus():
+
+ def __init__(self, name, struct_item=None, is_ptr_decl=False):
+ self.name = name
+ self.struct_item = struct_item
+ self.is_ptr_decl = is_ptr_decl
+
+ def get_child_decl_status(self, decl_name):
+ if self.struct_item != None:
+ return self.struct_item.get_child_decl_status(decl_name)
+ else:
+ #TODO(angiebird): 2. Investigage the situation when a struct's definition can't be found.
+ return None
+
+ def __str__(self):
+ return str(self.struct_item) + ' ' + str(self.name) + ' ' + str(
+ self.is_ptr_decl)
+
+
+def peel_ptr_decl(decl_type_node):
+ """ Remove PtrDecl and ArrayDecl layer """
+ is_ptr_decl = False
+ peeled_decl_type_node = decl_type_node
+ while peeled_decl_type_node.__class__.__name__ == 'PtrDecl' or peeled_decl_type_node.__class__.__name__ == 'ArrayDecl':
+ is_ptr_decl = True
+ peeled_decl_type_node = peeled_decl_type_node.type
+ return is_ptr_decl, peeled_decl_type_node
+
+
+def parse_peeled_decl_type_node(struct_info, node):
+ struct_item = None
+ if node.__class__.__name__ == 'TypeDecl':
+ if node.type.__class__.__name__ == 'IdentifierType':
+ identifier_type_node = node.type
+ typedef_name = identifier_type_node.names[0]
+ struct_item = struct_info.get_struct_by_typedef_name(typedef_name)
+ elif node.type.__class__.__name__ == 'Struct':
+ struct_node = node.type
+ if struct_node.name != None:
+ struct_item = struct_info.get_struct_by_struct_name(struct_node.name)
+ else:
+ struct_item = StructItem(None, None, struct_node, False)
+ struct_item.compute_child_decl_map(struct_info)
+ elif node.type.__class__.__name__ == 'Union':
+ # TODO(angiebird): Special treatment for Union?
+ struct_node = node.type
+ if struct_node.name != None:
+ struct_item = struct_info.get_struct_by_struct_name(struct_node.name)
+ else:
+ struct_item = StructItem(None, None, struct_node, True)
+ struct_item.compute_child_decl_map(struct_info)
+ elif node.type.__class__.__name__ == 'Enum':
+ # TODO(angiebird): Special treatment for Union?
+ struct_node = node.type
+ struct_item = None
+ else:
+ print('Unrecognized peeled_decl_type_node.type',
+ node.type.__class__.__name__)
+ else:
+ # debug_print(getframeinfo(currentframe()))
+ # print(node.__class__.__name__)
+ #TODO(angiebird): Do we need to take care of this part?
+ pass
+
+ return struct_item
+
+
+def parse_decl_node(struct_info, decl_node):
+ # struct_item is None if this decl_node is not a struct_item
+ decl_node_name = decl_node.name
+ decl_type_node = decl_node.type
+ is_ptr_decl, peeled_decl_type_node = peel_ptr_decl(decl_type_node)
+ struct_item = parse_peeled_decl_type_node(struct_info, peeled_decl_type_node)
+ return DeclStatus(decl_node_name, struct_item, is_ptr_decl)
+
+
+def get_lvalue_lead(lvalue_node):
+ """return '&' or '*' of lvalue if available"""
+ if lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '&':
+ return '&'
+ elif lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '*':
+ return '*'
+ return None
+
+
+def parse_lvalue(lvalue_node):
+ """get id_chain from lvalue"""
+ id_chain = parse_lvalue_recursive(lvalue_node, [])
+ return id_chain
+
+
+def parse_lvalue_recursive(lvalue_node, id_chain):
+ """cpi->rd->u -> (cpi->rd)->u"""
+ if lvalue_node.__class__.__name__ == 'ID':
+ id_chain.append(lvalue_node.name)
+ id_chain.reverse()
+ return id_chain
+ elif lvalue_node.__class__.__name__ == 'StructRef':
+ id_chain.append(lvalue_node.field.name)
+ return parse_lvalue_recursive(lvalue_node.name, id_chain)
+ elif lvalue_node.__class__.__name__ == 'ArrayRef':
+ return parse_lvalue_recursive(lvalue_node.name, id_chain)
+ elif lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '&':
+ return parse_lvalue_recursive(lvalue_node.expr, id_chain)
+ elif lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '*':
+ return parse_lvalue_recursive(lvalue_node.expr, id_chain)
+ else:
+ return None
+
+
+class FuncDefVisitor(c_ast.NodeVisitor):
+ func_dictionary = {}
+
+ def visit_FuncDef(self, node):
+ func_name = node.decl.name
+ self.func_dictionary[func_name] = node
+
+
+def build_func_dictionary(ast):
+ v = FuncDefVisitor()
+ v.visit(ast)
+ return v.func_dictionary
+
+
+def get_func_start_coord(func_node):
+ return func_node.coord
+
+
+def find_end_node(node):
+ node_list = []
+ for c in node:
+ node_list.append(c)
+ if len(node_list) == 0:
+ return node
+ else:
+ return find_end_node(node_list[-1])
+
+
+def get_func_end_coord(func_node):
+ return find_end_node(func_node).coord
+
+
+def get_func_size(func_node):
+ start_coord = get_func_start_coord(func_node)
+ end_coord = get_func_end_coord(func_node)
+ if start_coord.file == end_coord.file:
+ return end_coord.line - start_coord.line + 1
+ else:
+ return None
+
+
+def save_object(obj, filename):
+ with open(filename, 'wb') as obj_fp:
+ pickle.dump(obj, obj_fp, protocol=-1)
+
+
+def load_object(filename):
+ obj = None
+ with open(filename, 'rb') as obj_fp:
+ obj = pickle.load(obj_fp)
+ return obj
+
+
+def get_av1_ast(gen_ast=False):
+ # TODO(angiebird): Generalize this path
+ c_filename = './av1_pp.c'
+ print('generate ast')
+ ast = parse_file(c_filename)
+ #save_object(ast, ast_file)
+ print('finished generate ast')
+ return ast
+
+
+def get_func_param_id_map(func_def_node):
+ param_id_map = {}
+ func_decl = func_def_node.decl.type
+ param_list = func_decl.args.params
+ for decl in param_list:
+ param_id_map[decl.name] = decl
+ return param_id_map
+
+
+class IDTreeStack():
+
+ def __init__(self, global_id_tree):
+ self.stack = deque()
+ self.global_id_tree = global_id_tree
+
+ def add_link_node(self, node, link_id_chain):
+ link_node = self.add_id_node(link_id_chain)
+ node.link_node = link_node
+ node.link_id_chain = link_id_chain
+
+ def push_id_tree(self, id_tree=None):
+ if id_tree == None:
+ id_tree = IDStatusNode()
+ self.stack.append(id_tree)
+ return id_tree
+
+ def pop_id_tree(self):
+ return self.stack.pop()
+
+ def add_id_seed_node(self, id_seed, decl_status):
+ return self.stack[-1].add_child(id_seed, decl_status)
+
+ def get_id_seed_node(self, id_seed):
+ idx = len(self.stack) - 1
+ while idx >= 0:
+ id_node = self.stack[idx].get_child(id_seed)
+ if id_node != None:
+ return id_node
+ idx -= 1
+
+ id_node = self.global_id_tree.get_child(id_seed)
+ if id_node != None:
+ return id_node
+ return None
+
+ def add_id_node(self, id_chain):
+ id_seed = id_chain[0]
+ id_seed_node = self.get_id_seed_node(id_seed)
+ if id_seed_node == None:
+ return None
+ if len(id_chain) == 1:
+ return id_seed_node
+ return id_seed_node.add_descendant(id_chain[1:])
+
+ def get_id_node(self, id_chain):
+ id_seed = id_chain[0]
+ id_seed_node = self.get_id_seed_node(id_seed)
+ if id_seed_node == None:
+ return None
+ if len(id_chain) == 1:
+ return id_seed_node
+ return id_seed_node.get_descendant(id_chain[1:])
+
+ def top(self):
+ return self.stack[-1]
+
+
+class IDStatusNode():
+
+ def __init__(self, name=None, root=None):
+ if root is None:
+ self.root = self
+ else:
+ self.root = root
+
+ self.name = name
+
+ self.parent = None
+ self.children = {}
+
+ self.assign = False
+ self.last_assign_coord = None
+ self.refer = False
+ self.last_refer_coord = None
+
+ self.decl_status = None
+
+ self.link_id_chain = None
+ self.link_node = None
+
+ self.visit = False
+
+ def set_link_id_chain(self, link_id_chain):
+ self.set_assign(False)
+ self.link_id_chain = link_id_chain
+ self.link_node = self.root.get_descendant(link_id_chain)
+
+ def set_link_node(self, link_node):
+ self.set_assign(False)
+ self.link_id_chain = ['*']
+ self.link_node = link_node
+
+ def get_link_id_chain(self):
+ return self.link_id_chain
+
+ def get_concrete_node(self):
+ if self.visit == True:
+ # return None when there is a loop
+ return None
+ self.visit = True
+ if self.link_node == None:
+ self.visit = False
+ return self
+ else:
+ concrete_node = self.link_node.get_concrete_node()
+ self.visit = False
+ if concrete_node == None:
+ return self
+ return concrete_node
+
+ def set_assign(self, assign, coord=None):
+ concrete_node = self.get_concrete_node()
+ concrete_node.assign = assign
+ concrete_node.last_assign_coord = coord
+
+ def get_assign(self):
+ concrete_node = self.get_concrete_node()
+ return concrete_node.assign
+
+ def set_refer(self, refer, coord=None):
+ concrete_node = self.get_concrete_node()
+ concrete_node.refer = refer
+ concrete_node.last_refer_coord = coord
+
+ def get_refer(self):
+ concrete_node = self.get_concrete_node()
+ return concrete_node.refer
+
+ def set_parent(self, parent):
+ concrete_node = self.get_concrete_node()
+ concrete_node.parent = parent
+
+ def add_child(self, name, decl_status=None):
+ concrete_node = self.get_concrete_node()
+ if name not in concrete_node.children:
+ child_id_node = IDStatusNode(name, concrete_node.root)
+ concrete_node.children[name] = child_id_node
+ if decl_status == None:
+ # Check if the child decl_status can be inferred from its parent's
+ # decl_status
+ if self.decl_status != None:
+ decl_status = self.decl_status.get_child_decl_status(name)
+ child_id_node.set_decl_status(decl_status)
+ return concrete_node.children[name]
+
+ def get_child(self, name):
+ concrete_node = self.get_concrete_node()
+ if name in concrete_node.children:
+ return concrete_node.children[name]
+ else:
+ return None
+
+ def add_descendant(self, id_chain):
+ current_node = self.get_concrete_node()
+ for name in id_chain:
+ current_node.add_child(name)
+ parent_node = current_node
+ current_node = current_node.get_child(name)
+ current_node.set_parent(parent_node)
+ return current_node
+
+ def get_descendant(self, id_chain):
+ current_node = self.get_concrete_node()
+ for name in id_chain:
+ current_node = current_node.get_child(name)
+ if current_node == None:
+ return None
+ return current_node
+
+ def get_children(self):
+ current_node = self.get_concrete_node()
+ return current_node.children
+
+ def set_decl_status(self, decl_status):
+ current_node = self.get_concrete_node()
+ current_node.decl_status = decl_status
+
+ def get_decl_status(self):
+ current_node = self.get_concrete_node()
+ return current_node.decl_status
+
+ def __str__(self):
+ if self.link_id_chain is None:
+ return str(self.name) + ' a: ' + str(int(self.assign)) + ' r: ' + str(
+ int(self.refer))
+ else:
+ return str(self.name) + ' -> ' + ' '.join(self.link_id_chain)
+
+ def collect_assign_refer_status(self,
+ id_chain=None,
+ assign_ls=None,
+ refer_ls=None):
+ if id_chain == None:
+ id_chain = []
+ if assign_ls == None:
+ assign_ls = []
+ if refer_ls == None:
+ refer_ls = []
+ id_chain.append(self.name)
+ if self.assign:
+ info_str = ' '.join([
+ ' '.join(id_chain[1:]), 'a:',
+ str(int(self.assign)), 'r:',
+ str(int(self.refer)),
+ str(self.last_assign_coord)
+ ])
+ assign_ls.append(info_str)
+ if self.refer:
+ info_str = ' '.join([
+ ' '.join(id_chain[1:]), 'a:',
+ str(int(self.assign)), 'r:',
+ str(int(self.refer)),
+ str(self.last_refer_coord)
+ ])
+ refer_ls.append(info_str)
+ for c in self.children:
+ self.children[c].collect_assign_refer_status(id_chain, assign_ls,
+ refer_ls)
+ id_chain.pop()
+ return assign_ls, refer_ls
+
+ def show(self):
+ assign_ls, refer_ls = self.collect_assign_refer_status()
+ print('---- assign ----')
+ for item in assign_ls:
+ print(item)
+ print('---- refer ----')
+ for item in refer_ls:
+ print(item)
+
+
+class FuncInOutVisitor(c_ast.NodeVisitor):
+
+ def __init__(self,
+ func_def_node,
+ struct_info,
+ func_dictionary,
+ keep_body_id_tree=True,
+ call_param_map=None,
+ global_id_tree=None,
+ func_history=None,
+ unknown=None):
+ self.func_dictionary = func_dictionary
+ self.struct_info = struct_info
+ self.param_id_map = get_func_param_id_map(func_def_node)
+ self.parent_node = None
+ self.global_id_tree = global_id_tree
+ self.body_id_tree = None
+ self.keep_body_id_tree = keep_body_id_tree
+ if func_history == None:
+ self.func_history = {}
+ else:
+ self.func_history = func_history
+
+ if unknown == None:
+ self.unknown = []
+ else:
+ self.unknown = unknown
+
+ self.id_tree_stack = IDTreeStack(global_id_tree)
+ self.id_tree_stack.push_id_tree()
+
+ #TODO move this part into a function
+ for param in self.param_id_map:
+ decl_node = self.param_id_map[param]
+ decl_status = parse_decl_node(self.struct_info, decl_node)
+ descendant = self.id_tree_stack.add_id_seed_node(decl_status.name,
+ decl_status)
+ if call_param_map is not None and param in call_param_map:
+ # This is a function call.
+ # Map the input parameter to the caller's nodes
+ # TODO(angiebird): Can we use add_link_node here?
+ descendant.set_link_node(call_param_map[param])
+
+ def get_id_tree_stack(self):
+ return self.id_tree_stack
+
+ def generic_visit(self, node):
+ prev_parent = self.parent_node
+ self.parent_node = node
+ for c in node:
+ self.visit(c)
+ self.parent_node = prev_parent
+
+ # TODO rename
+ def add_new_id_tree(self, node):
+ self.id_tree_stack.push_id_tree()
+ self.generic_visit(node)
+ id_tree = self.id_tree_stack.pop_id_tree()
+ if self.parent_node == None and self.keep_body_id_tree == True:
+ # this is function body
+ self.body_id_tree = id_tree
+
+ def visit_For(self, node):
+ self.add_new_id_tree(node)
+
+ def visit_Compound(self, node):
+ self.add_new_id_tree(node)
+
+ def visit_Decl(self, node):
+ if node.type.__class__.__name__ != 'FuncDecl':
+ decl_status = parse_decl_node(self.struct_info, node)
+ descendant = self.id_tree_stack.add_id_seed_node(decl_status.name,
+ decl_status)
+ if node.init is not None:
+ init_id_chain = self.process_lvalue(node.init)
+ if init_id_chain != None:
+ if decl_status.struct_item is None:
+ init_descendant = self.id_tree_stack.add_id_node(init_id_chain)
+ if init_descendant != None:
+ init_descendant.set_refer(True, node.coord)
+ else:
+ self.unknown.append(node)
+ descendant.set_assign(True, node.coord)
+ else:
+ self.id_tree_stack.add_link_node(descendant, init_id_chain)
+ else:
+ self.unknown.append(node)
+ else:
+ descendant.set_assign(True, node.coord)
+ self.generic_visit(node)
+
+ def is_lvalue(self, node):
+ if self.parent_node is None:
+ # TODO(angiebird): Do every lvalue has parent_node != None?
+ return False
+ if self.parent_node.__class__.__name__ == 'StructRef':
+ return False
+ if self.parent_node.__class__.__name__ == 'ArrayRef' and node == self.parent_node.name:
+ # if node == self.parent_node.subscript, the node could be lvalue
+ return False
+ if self.parent_node.__class__.__name__ == 'UnaryOp' and self.parent_node.op == '&':
+ return False
+ if self.parent_node.__class__.__name__ == 'UnaryOp' and self.parent_node.op == '*':
+ return False
+ return True
+
+ def process_lvalue(self, node):
+ id_chain = parse_lvalue(node)
+ if id_chain == None:
+ return id_chain
+ elif id_chain[0] in self.struct_info.enum_value_dic:
+ return None
+ else:
+ return id_chain
+
+ def process_possible_lvalue(self, node):
+ if self.is_lvalue(node):
+ id_chain = self.process_lvalue(node)
+ lead_char = get_lvalue_lead(node)
+ # make sure the id is not an enum value
+ if id_chain == None:
+ self.unknown.append(node)
+ return
+ descendant = self.id_tree_stack.add_id_node(id_chain)
+ if descendant == None:
+ self.unknown.append(node)
+ return
+ decl_status = descendant.get_decl_status()
+ if decl_status == None:
+ descendant.set_assign(True, node.coord)
+ descendant.set_refer(True, node.coord)
+ self.unknown.append(node)
+ return
+ if self.parent_node.__class__.__name__ == 'Assignment':
+ if node is self.parent_node.lvalue:
+ if decl_status.struct_item != None:
+ if len(id_chain) > 1:
+ descendant.set_assign(True, node.coord)
+ elif len(id_chain) == 1:
+ if lead_char == '*':
+ descendant.set_assign(True, node.coord)
+ else:
+ right_id_chain = self.process_lvalue(self.parent_node.rvalue)
+ if right_id_chain != None:
+ self.id_tree_stack.add_link_node(descendant, right_id_chain)
+ else:
+ #TODO(angiebird): 1.Find a better way to deal with this case.
+ descendant.set_assign(True, node.coord)
+ else:
+ debug_print(getframeinfo(currentframe()))
+ else:
+ descendant.set_assign(True, node.coord)
+ elif node is self.parent_node.rvalue:
+ if decl_status.struct_item is None:
+ descendant.set_refer(True, node.coord)
+ if lead_char == '&':
+ descendant.set_assign(True, node.coord)
+ else:
+ left_id_chain = self.process_lvalue(self.parent_node.lvalue)
+ left_lead_char = get_lvalue_lead(self.parent_node.lvalue)
+ if left_id_chain != None:
+ if len(left_id_chain) > 1:
+ descendant.set_refer(True, node.coord)
+ elif len(left_id_chain) == 1:
+ if left_lead_char == '*':
+ descendant.set_refer(True, node.coord)
+ else:
+ #TODO(angiebird): Check whether the other node is linked to this node.
+ pass
+ else:
+ self.unknown.append(self.parent_node.lvalue)
+ debug_print(getframeinfo(currentframe()))
+ else:
+ self.unknown.append(self.parent_node.lvalue)
+ debug_print(getframeinfo(currentframe()))
+ else:
+ debug_print(getframeinfo(currentframe()))
+ elif self.parent_node.__class__.__name__ == 'UnaryOp':
+ # TODO(angiebird): Consider +=, *=, -=, /= etc
+ if self.parent_node.op == '--' or self.parent_node.op == '++' or\
+ self.parent_node.op == 'p--' or self.parent_node.op == 'p++':
+ descendant.set_assign(True, node.coord)
+ descendant.set_refer(True, node.coord)
+ else:
+ descendant.set_refer(True, node.coord)
+ elif self.parent_node.__class__.__name__ == 'Decl':
+ #The logic is at visit_Decl
+ pass
+ elif self.parent_node.__class__.__name__ == 'ExprList':
+ #The logic is at visit_FuncCall
+ pass
+ else:
+ descendant.set_refer(True, node.coord)
+
+ def visit_ID(self, node):
+ # If the parent is a FuncCall, this ID is a function name.
+ if self.parent_node.__class__.__name__ != 'FuncCall':
+ self.process_possible_lvalue(node)
+ self.generic_visit(node)
+
+ def visit_StructRef(self, node):
+ self.process_possible_lvalue(node)
+ self.generic_visit(node)
+
+ def visit_ArrayRef(self, node):
+ self.process_possible_lvalue(node)
+ self.generic_visit(node)
+
+ def visit_UnaryOp(self, node):
+ if node.op == '&' or node.op == '*':
+ self.process_possible_lvalue(node)
+ self.generic_visit(node)
+
+ def visit_FuncCall(self, node):
+ if node.name.__class__.__name__ == 'ID':
+ if node.name.name in self.func_dictionary:
+ if node.name.name not in self.func_history:
+ self.func_history[node.name.name] = True
+ func_def_node = self.func_dictionary[node.name.name]
+ call_param_map = self.process_func_call(node, func_def_node)
+
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary, False,
+ call_param_map, self.global_id_tree,
+ self.func_history, self.unknown)
+ visitor.visit(func_def_node.body)
+ else:
+ self.unknown.append(node)
+ self.generic_visit(node)
+
+ def process_func_call(self, func_call_node, func_def_node):
+ # set up a refer/assign for func parameters
+ # return call_param_map
+ call_param_ls = func_call_node.args.exprs
+ call_param_map = {}
+
+ func_decl = func_def_node.decl.type
+ decl_param_ls = func_decl.args.params
+ for param_node, decl_node in zip(call_param_ls, decl_param_ls):
+ id_chain = self.process_lvalue(param_node)
+ if id_chain != None:
+ descendant = self.id_tree_stack.add_id_node(id_chain)
+ if descendant == None:
+ self.unknown.append(param_node)
+ else:
+ decl_status = descendant.get_decl_status()
+ if decl_status != None:
+ if decl_status.struct_item == None:
+ if decl_status.is_ptr_decl == True:
+ descendant.set_assign(True, param_node.coord)
+ descendant.set_refer(True, param_node.coord)
+ else:
+ descendant.set_refer(True, param_node.coord)
+ else:
+ call_param_map[decl_node.name] = descendant
+ else:
+ self.unknown.append(param_node)
+ else:
+ self.unknown.append(param_node)
+ return call_param_map
+
+
+def build_global_id_tree(ast, struct_info):
+ global_id_tree = IDStatusNode()
+ for node in ast.ext:
+ if node.__class__.__name__ == 'Decl':
+ # id tree is for tracking assign/refer status
+ # we don't care about function id because they can't be changed
+ if node.type.__class__.__name__ != 'FuncDecl':
+ decl_status = parse_decl_node(struct_info, node)
+ descendant = global_id_tree.add_child(decl_status.name, decl_status)
+ return global_id_tree
+
+
+class FuncAnalyzer():
+
+ def __init__(self):
+ self.ast = get_av1_ast()
+ self.struct_info = build_struct_info(self.ast)
+ self.func_dictionary = build_func_dictionary(self.ast)
+ self.global_id_tree = build_global_id_tree(self.ast, self.struct_info)
+
+ def analyze(self, func_name):
+ if func_name in self.func_dictionary:
+ func_def_node = self.func_dictionary[func_name]
+ visitor = FuncInOutVisitor(func_def_node, self.struct_info,
+ self.func_dictionary, True, None,
+ self.global_id_tree)
+ visitor.visit(func_def_node.body)
+ root = visitor.get_id_tree_stack()
+ root.top().show()
+ else:
+ print(func_name, "doesn't exist")
+
+
+if __name__ == '__main__':
+ fa = FuncAnalyzer()
+ fa.analyze('tpl_get_satd_cost')
+ pass