import sqlparse from collections import namedtuple from sqlparse.sql import IdentifierList, Identifier, Function from sqlparse.tokens import Keyword, DML, Punctuation TableReference = namedtuple( "TableReference", ["schema", "name", "alias", "is_function"] ) TableReference.ref = property( lambda self: self.alias or ( self.name if self.name.islower() or self.name[0] == '"' else '"' + self.name + '"' ) ) # This code is borrowed from sqlparse example script. # def is_subselect(parsed): if not parsed.is_group: return False for item in parsed.tokens: if item.ttype is DML and item.value.upper() in ( "SELECT", "INSERT", "UPDATE", "CREATE", "DELETE", ): return True return False def _identifier_is_function(identifier): return any(isinstance(t, Function) for t in identifier.tokens) def extract_from_part(parsed, stop_at_punctuation=True): tbl_prefix_seen = False for item in parsed.tokens: if tbl_prefix_seen: if is_subselect(item): yield from extract_from_part(item, stop_at_punctuation) elif stop_at_punctuation and item.ttype is Punctuation: return # An incomplete nested select won't be recognized correctly as a # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes # the second FROM to trigger this elif condition resulting in a # `return`. So we need to ignore the keyword if the keyword # FROM. # Also 'SELECT * FROM abc JOIN def' will trigger this elif # condition. So we need to ignore the keyword JOIN and its variants # INNER JOIN, FULL OUTER JOIN, etc. elif ( item.ttype is Keyword and (not item.value.upper() == "FROM") and (not item.value.upper().endswith("JOIN")) ): tbl_prefix_seen = False else: yield item elif item.ttype is Keyword or item.ttype is Keyword.DML: item_val = item.value.upper() if item_val in ( "COPY", "FROM", "INTO", "UPDATE", "TABLE", ) or item_val.endswith("JOIN"): tbl_prefix_seen = True # 'SELECT a, FROM abc' will detect FROM as part of the column list. # So this check here is necessary. elif isinstance(item, IdentifierList): for identifier in item.get_identifiers(): if identifier.ttype is Keyword and identifier.value.upper() == "FROM": tbl_prefix_seen = True break def extract_table_identifiers(token_stream, allow_functions=True): """yields tuples of TableReference namedtuples""" # We need to do some massaging of the names because postgres is case- # insensitive and '"Foo"' is not the same table as 'Foo' (while 'foo' is) def parse_identifier(item): name = item.get_real_name() schema_name = item.get_parent_name() alias = item.get_alias() if not name: schema_name = None name = item.get_name() alias = alias or name schema_quoted = schema_name and item.value[0] == '"' if schema_name and not schema_quoted: schema_name = schema_name.lower() quote_count = item.value.count('"') name_quoted = quote_count > 2 or (quote_count and not schema_quoted) alias_quoted = alias and item.value[-1] == '"' if alias_quoted or name_quoted and not alias and name.islower(): alias = '"' + (alias or name) + '"' if name and not name_quoted and not name.islower(): if not alias: alias = name name = name.lower() return schema_name, name, alias try: for item in token_stream: if isinstance(item, IdentifierList): for identifier in item.get_identifiers(): # Sometimes Keywords (such as FROM ) are classified as # identifiers which don't have the get_real_name() method. try: schema_name = identifier.get_parent_name() real_name = identifier.get_real_name() is_function = allow_functions and _identifier_is_function( identifier ) except AttributeError: continue if real_name: yield TableReference( schema_name, real_name, identifier.get_alias(), is_function ) elif isinstance(item, Identifier): schema_name, real_name, alias = parse_identifier(item) is_function = allow_functions and _identifier_is_function(item) yield TableReference(schema_name, real_name, alias, is_function) elif isinstance(item, Function): schema_name, real_name, alias = parse_identifier(item) yield TableReference(None, real_name, alias, allow_functions) except StopIteration: return # extract_tables is inspired from examples in the sqlparse lib. def extract_tables(sql): """Extract the table names from an SQL statement. Returns a list of TableReference namedtuples """ parsed = sqlparse.parse(sql) if not parsed: return () # INSERT statements must stop looking for tables at the sign of first # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) # abc is the table name, but if we don't stop at the first lparen, then # we'll identify abc, col1 and col2 as table names. insert_stmt = parsed[0].token_first().value.lower() == "insert" stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) # Kludge: sqlparse mistakenly identifies insert statements as # function calls due to the parenthesized column list, e.g. interprets # "insert into foo (bar, baz)" as a function call to foo with arguments # (bar, baz). So don't allow any identifiers in insert statements # to have is_function=True identifiers = extract_table_identifiers(stream, allow_functions=not insert_stmt) # In the case 'sche.', we get an empty TableReference; remove that return tuple(i for i in identifiers if i.name)