From 28cc22419e32a65fea2d1678400265b8cabc3aff Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 15 Sep 2022 18:46:17 +0200 Subject: Adding upstream version 6.0.4. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/qualify_columns.py | 422 +++++++++++++++++++++++++++++++++++ 1 file changed, 422 insertions(+) create mode 100644 sqlglot/optimizer/qualify_columns.py (limited to 'sqlglot/optimizer/qualify_columns.py') diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py new file mode 100644 index 0000000..394f49e --- /dev/null +++ b/sqlglot/optimizer/qualify_columns.py @@ -0,0 +1,422 @@ +import itertools + +from sqlglot import alias, exp +from sqlglot.errors import OptimizeError +from sqlglot.optimizer.schema import ensure_schema +from sqlglot.optimizer.scope import traverse_scope + +SKIP_QUALIFY = (exp.Unnest, exp.Lateral) + + +def qualify_columns(expression, schema): + """ + Rewrite sqlglot AST to have fully qualified columns. + + Example: + >>> import sqlglot + >>> schema = {"tbl": {"col": "INT"}} + >>> expression = sqlglot.parse_one("SELECT col FROM tbl") + >>> qualify_columns(expression, schema).sql() + 'SELECT tbl.col AS col FROM tbl' + + Args: + expression (sqlglot.Expression): expression to qualify + schema (dict|sqlglot.optimizer.Schema): Database schema + Returns: + sqlglot.Expression: qualified expression + """ + schema = ensure_schema(schema) + + for scope in traverse_scope(expression): + resolver = _Resolver(scope, schema) + _pop_table_column_aliases(scope.ctes) + _pop_table_column_aliases(scope.derived_tables) + _expand_using(scope, resolver) + _expand_group_by(scope, resolver) + _expand_order_by(scope) + _qualify_columns(scope, resolver) + if not isinstance(scope.expression, SKIP_QUALIFY): + _expand_stars(scope, resolver) + _qualify_outputs(scope) + _check_unknown_tables(scope) + + return expression + + +def _pop_table_column_aliases(derived_tables): + """ + Remove table column aliases. + + (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) + """ + for derived_table in derived_tables: + if isinstance(derived_table, SKIP_QUALIFY): + continue + table_alias = derived_table.args.get("alias") + if table_alias: + table_alias.args.pop("columns", None) + + +def _expand_using(scope, resolver): + joins = list(scope.expression.find_all(exp.Join)) + names = {join.this.alias for join in joins} + ordered = [key for key in scope.selected_sources if key not in names] + + # Mapping of automatically joined column names to source names + column_tables = {} + + for join in joins: + using = join.args.get("using") + + if not using: + continue + + join_table = join.this.alias_or_name + + columns = {} + + for k in scope.selected_sources: + if k in ordered: + for column in resolver.get_source_columns(k): + if column not in columns: + columns[column] = k + + ordered.append(join_table) + join_columns = resolver.get_source_columns(join_table) + conditions = [] + + for identifier in using: + identifier = identifier.name + table = columns.get(identifier) + + if not table or identifier not in join_columns: + raise OptimizeError(f"Cannot automatically join: {identifier}") + + conditions.append( + exp.condition( + exp.EQ( + this=exp.column(identifier, table=table), + expression=exp.column(identifier, table=join_table), + ) + ) + ) + + tables = column_tables.setdefault(identifier, []) + if table not in tables: + tables.append(table) + if join_table not in tables: + tables.append(join_table) + + join.args.pop("using") + join.set("on", exp.and_(*conditions)) + + if column_tables: + for column in scope.columns: + if not column.table and column.name in column_tables: + tables = column_tables[column.name] + coalesce = [exp.column(column.name, table=table) for table in tables] + replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) + + # Ensure selects keep their output name + if isinstance(column.parent, exp.Select): + replacement = exp.alias_(replacement, alias=column.name) + + scope.replace(column, replacement) + + +def _expand_group_by(scope, resolver): + group = scope.expression.args.get("group") + if not group: + return + + # Replace references to select aliases + def transform(node, *_): + if isinstance(node, exp.Column) and not node.table: + table = resolver.get_table(node.name) + + # Source columns get priority over select aliases + if table: + node.set("table", exp.to_identifier(table)) + return node + + selects = {s.alias_or_name: s for s in scope.selects} + + select = selects.get(node.name) + if select: + scope.clear_cache() + if isinstance(select, exp.Alias): + select = select.this + return select.copy() + + return node + + group.transform(transform, copy=False) + group.set("expressions", _expand_positional_references(scope, group.expressions)) + scope.expression.set("group", group) + + +def _expand_order_by(scope): + order = scope.expression.args.get("order") + if not order: + return + + ordereds = order.expressions + for ordered, new_expression in zip( + ordereds, + _expand_positional_references(scope, (o.this for o in ordereds)), + ): + ordered.set("this", new_expression) + + +def _expand_positional_references(scope, expressions): + new_nodes = [] + for node in expressions: + if node.is_int: + try: + select = scope.selects[int(node.name) - 1] + except IndexError: + raise OptimizeError(f"Unknown output column: {node.name}") + if isinstance(select, exp.Alias): + select = select.this + new_nodes.append(select.copy()) + scope.clear_cache() + else: + new_nodes.append(node) + + return new_nodes + + +def _qualify_columns(scope, resolver): + """Disambiguate columns, ensuring each column specifies a source""" + for column in scope.columns: + column_table = column.table + column_name = column.name + + if ( + column_table + and column_table in scope.sources + and column_name not in resolver.get_source_columns(column_table) + ): + raise OptimizeError(f"Unknown column: {column_name}") + + if not column_table: + column_table = resolver.get_table(column_name) + + if not scope.is_subquery and not scope.is_unnest: + if column_name not in resolver.all_columns: + raise OptimizeError(f"Unknown column: {column_name}") + + if column_table is None: + raise OptimizeError(f"Ambiguous column: {column_name}") + + # column_table can be a '' because bigquery unnest has no table alias + if column_table: + column.set("table", exp.to_identifier(column_table)) + + +def _expand_stars(scope, resolver): + """Expand stars to lists of column selections""" + + new_selections = [] + except_columns = {} + replace_columns = {} + + for expression in scope.selects: + if isinstance(expression, exp.Star): + tables = list(scope.selected_sources) + _add_except_columns(expression, tables, except_columns) + _add_replace_columns(expression, tables, replace_columns) + elif isinstance(expression, exp.Column) and isinstance( + expression.this, exp.Star + ): + tables = [expression.table] + _add_except_columns(expression.this, tables, except_columns) + _add_replace_columns(expression.this, tables, replace_columns) + else: + new_selections.append(expression) + continue + + for table in tables: + if table not in scope.sources: + raise OptimizeError(f"Unknown table: {table}") + columns = resolver.get_source_columns(table) + table_id = id(table) + for name in columns: + if name not in except_columns.get(table_id, set()): + alias_ = replace_columns.get(table_id, {}).get(name, name) + column = exp.column(name, table) + new_selections.append( + alias(column, alias_) if alias_ != name else column + ) + + scope.expression.set("expressions", new_selections) + + +def _add_except_columns(expression, tables, except_columns): + except_ = expression.args.get("except") + + if not except_: + return + + columns = {e.name for e in except_} + + for table in tables: + except_columns[id(table)] = columns + + +def _add_replace_columns(expression, tables, replace_columns): + replace = expression.args.get("replace") + + if not replace: + return + + columns = {e.this.name: e.alias for e in replace} + + for table in tables: + replace_columns[id(table)] = columns + + +def _qualify_outputs(scope): + """Ensure all output columns are aliased""" + new_selections = [] + + for i, (selection, aliased_column) in enumerate( + itertools.zip_longest(scope.selects, scope.outer_column_list) + ): + if isinstance(selection, exp.Column): + # convoluted setter because a simple selection.replace(alias) would require a copy + alias_ = alias(exp.column(""), alias=selection.name) + alias_.set("this", selection) + selection = alias_ + elif not isinstance(selection, exp.Alias): + alias_ = alias(exp.column(""), f"_col_{i}") + alias_.set("this", selection) + selection = alias_ + + if aliased_column: + selection.set("alias", exp.to_identifier(aliased_column)) + + new_selections.append(selection) + + scope.expression.set("expressions", new_selections) + + +def _check_unknown_tables(scope): + if ( + scope.external_columns + and not scope.is_unnest + and not scope.is_correlated_subquery + ): + raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}") + + +class _Resolver: + """ + Helper for resolving columns. + + This is a class so we can lazily load some things and easily share them across functions. + """ + + def __init__(self, scope, schema): + self.scope = scope + self.schema = schema + self._source_columns = None + self._unambiguous_columns = None + self._all_columns = None + + def get_table(self, column_name): + """ + Get the table for a column name. + + Args: + column_name (str) + Returns: + (str) table name + """ + if self._unambiguous_columns is None: + self._unambiguous_columns = self._get_unambiguous_columns( + self._get_all_source_columns() + ) + return self._unambiguous_columns.get(column_name) + + @property + def all_columns(self): + """All available columns of all sources in this scope""" + if self._all_columns is None: + self._all_columns = set( + column + for columns in self._get_all_source_columns().values() + for column in columns + ) + return self._all_columns + + def get_source_columns(self, name): + """Resolve the source columns for a given source `name`""" + if name not in self.scope.sources: + raise OptimizeError(f"Unknown table: {name}") + + source = self.scope.sources[name] + + # If referencing a table, return the columns from the schema + if isinstance(source, exp.Table): + try: + return self.schema.column_names(source) + except Exception as e: + raise OptimizeError(str(e)) from e + + # Otherwise, if referencing another scope, return that scope's named selects + return source.expression.named_selects + + def _get_all_source_columns(self): + if self._source_columns is None: + self._source_columns = { + k: self.get_source_columns(k) for k in self.scope.selected_sources + } + return self._source_columns + + def _get_unambiguous_columns(self, source_columns): + """ + Find all the unambiguous columns in sources. + + Args: + source_columns (dict): Mapping of names to source columns + Returns: + dict: Mapping of column name to source name + """ + if not source_columns: + return {} + + source_columns = list(source_columns.items()) + + first_table, first_columns = source_columns[0] + unambiguous_columns = { + col: first_table for col in self._find_unique_columns(first_columns) + } + all_columns = set(unambiguous_columns) + + for table, columns in source_columns[1:]: + unique = self._find_unique_columns(columns) + ambiguous = set(all_columns).intersection(unique) + all_columns.update(columns) + for column in ambiguous: + unambiguous_columns.pop(column, None) + for column in unique.difference(ambiguous): + unambiguous_columns[column] = table + + return unambiguous_columns + + @staticmethod + def _find_unique_columns(columns): + """ + Find the unique columns in a list of columns. + + Example: + >>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"])) + ['a', 'c'] + + This is necessary because duplicate column names are ambiguous. + """ + counts = {} + for column in columns: + counts[column] = counts.get(column, 0) + 1 + return {column for column, count in counts.items() if count == 1} -- cgit v1.2.3