summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/qualify_columns.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/qualify_columns.py')
-rw-r--r--sqlglot/optimizer/qualify_columns.py422
1 files changed, 422 insertions, 0 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
new file mode 100644
index 0000000..394f49e
--- /dev/null
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -0,0 +1,422 @@
+import itertools
+
+from sqlglot import alias, exp
+from sqlglot.errors import OptimizeError
+from sqlglot.optimizer.schema import ensure_schema
+from sqlglot.optimizer.scope import traverse_scope
+
+SKIP_QUALIFY = (exp.Unnest, exp.Lateral)
+
+
+def qualify_columns(expression, schema):
+ """
+ Rewrite sqlglot AST to have fully qualified columns.
+
+ Example:
+ >>> import sqlglot
+ >>> schema = {"tbl": {"col": "INT"}}
+ >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
+ >>> qualify_columns(expression, schema).sql()
+ 'SELECT tbl.col AS col FROM tbl'
+
+ Args:
+ expression (sqlglot.Expression): expression to qualify
+ schema (dict|sqlglot.optimizer.Schema): Database schema
+ Returns:
+ sqlglot.Expression: qualified expression
+ """
+ schema = ensure_schema(schema)
+
+ for scope in traverse_scope(expression):
+ resolver = _Resolver(scope, schema)
+ _pop_table_column_aliases(scope.ctes)
+ _pop_table_column_aliases(scope.derived_tables)
+ _expand_using(scope, resolver)
+ _expand_group_by(scope, resolver)
+ _expand_order_by(scope)
+ _qualify_columns(scope, resolver)
+ if not isinstance(scope.expression, SKIP_QUALIFY):
+ _expand_stars(scope, resolver)
+ _qualify_outputs(scope)
+ _check_unknown_tables(scope)
+
+ return expression
+
+
+def _pop_table_column_aliases(derived_tables):
+ """
+ Remove table column aliases.
+
+ (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
+ """
+ for derived_table in derived_tables:
+ if isinstance(derived_table, SKIP_QUALIFY):
+ continue
+ table_alias = derived_table.args.get("alias")
+ if table_alias:
+ table_alias.args.pop("columns", None)
+
+
+def _expand_using(scope, resolver):
+ joins = list(scope.expression.find_all(exp.Join))
+ names = {join.this.alias for join in joins}
+ ordered = [key for key in scope.selected_sources if key not in names]
+
+ # Mapping of automatically joined column names to source names
+ column_tables = {}
+
+ for join in joins:
+ using = join.args.get("using")
+
+ if not using:
+ continue
+
+ join_table = join.this.alias_or_name
+
+ columns = {}
+
+ for k in scope.selected_sources:
+ if k in ordered:
+ for column in resolver.get_source_columns(k):
+ if column not in columns:
+ columns[column] = k
+
+ ordered.append(join_table)
+ join_columns = resolver.get_source_columns(join_table)
+ conditions = []
+
+ for identifier in using:
+ identifier = identifier.name
+ table = columns.get(identifier)
+
+ if not table or identifier not in join_columns:
+ raise OptimizeError(f"Cannot automatically join: {identifier}")
+
+ conditions.append(
+ exp.condition(
+ exp.EQ(
+ this=exp.column(identifier, table=table),
+ expression=exp.column(identifier, table=join_table),
+ )
+ )
+ )
+
+ tables = column_tables.setdefault(identifier, [])
+ if table not in tables:
+ tables.append(table)
+ if join_table not in tables:
+ tables.append(join_table)
+
+ join.args.pop("using")
+ join.set("on", exp.and_(*conditions))
+
+ if column_tables:
+ for column in scope.columns:
+ if not column.table and column.name in column_tables:
+ tables = column_tables[column.name]
+ coalesce = [exp.column(column.name, table=table) for table in tables]
+ replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
+
+ # Ensure selects keep their output name
+ if isinstance(column.parent, exp.Select):
+ replacement = exp.alias_(replacement, alias=column.name)
+
+ scope.replace(column, replacement)
+
+
+def _expand_group_by(scope, resolver):
+ group = scope.expression.args.get("group")
+ if not group:
+ return
+
+ # Replace references to select aliases
+ def transform(node, *_):
+ if isinstance(node, exp.Column) and not node.table:
+ table = resolver.get_table(node.name)
+
+ # Source columns get priority over select aliases
+ if table:
+ node.set("table", exp.to_identifier(table))
+ return node
+
+ selects = {s.alias_or_name: s for s in scope.selects}
+
+ select = selects.get(node.name)
+ if select:
+ scope.clear_cache()
+ if isinstance(select, exp.Alias):
+ select = select.this
+ return select.copy()
+
+ return node
+
+ group.transform(transform, copy=False)
+ group.set("expressions", _expand_positional_references(scope, group.expressions))
+ scope.expression.set("group", group)
+
+
+def _expand_order_by(scope):
+ order = scope.expression.args.get("order")
+ if not order:
+ return
+
+ ordereds = order.expressions
+ for ordered, new_expression in zip(
+ ordereds,
+ _expand_positional_references(scope, (o.this for o in ordereds)),
+ ):
+ ordered.set("this", new_expression)
+
+
+def _expand_positional_references(scope, expressions):
+ new_nodes = []
+ for node in expressions:
+ if node.is_int:
+ try:
+ select = scope.selects[int(node.name) - 1]
+ except IndexError:
+ raise OptimizeError(f"Unknown output column: {node.name}")
+ if isinstance(select, exp.Alias):
+ select = select.this
+ new_nodes.append(select.copy())
+ scope.clear_cache()
+ else:
+ new_nodes.append(node)
+
+ return new_nodes
+
+
+def _qualify_columns(scope, resolver):
+ """Disambiguate columns, ensuring each column specifies a source"""
+ for column in scope.columns:
+ column_table = column.table
+ column_name = column.name
+
+ if (
+ column_table
+ and column_table in scope.sources
+ and column_name not in resolver.get_source_columns(column_table)
+ ):
+ raise OptimizeError(f"Unknown column: {column_name}")
+
+ if not column_table:
+ column_table = resolver.get_table(column_name)
+
+ if not scope.is_subquery and not scope.is_unnest:
+ if column_name not in resolver.all_columns:
+ raise OptimizeError(f"Unknown column: {column_name}")
+
+ if column_table is None:
+ raise OptimizeError(f"Ambiguous column: {column_name}")
+
+ # column_table can be a '' because bigquery unnest has no table alias
+ if column_table:
+ column.set("table", exp.to_identifier(column_table))
+
+
+def _expand_stars(scope, resolver):
+ """Expand stars to lists of column selections"""
+
+ new_selections = []
+ except_columns = {}
+ replace_columns = {}
+
+ for expression in scope.selects:
+ if isinstance(expression, exp.Star):
+ tables = list(scope.selected_sources)
+ _add_except_columns(expression, tables, except_columns)
+ _add_replace_columns(expression, tables, replace_columns)
+ elif isinstance(expression, exp.Column) and isinstance(
+ expression.this, exp.Star
+ ):
+ tables = [expression.table]
+ _add_except_columns(expression.this, tables, except_columns)
+ _add_replace_columns(expression.this, tables, replace_columns)
+ else:
+ new_selections.append(expression)
+ continue
+
+ for table in tables:
+ if table not in scope.sources:
+ raise OptimizeError(f"Unknown table: {table}")
+ columns = resolver.get_source_columns(table)
+ table_id = id(table)
+ for name in columns:
+ if name not in except_columns.get(table_id, set()):
+ alias_ = replace_columns.get(table_id, {}).get(name, name)
+ column = exp.column(name, table)
+ new_selections.append(
+ alias(column, alias_) if alias_ != name else column
+ )
+
+ scope.expression.set("expressions", new_selections)
+
+
+def _add_except_columns(expression, tables, except_columns):
+ except_ = expression.args.get("except")
+
+ if not except_:
+ return
+
+ columns = {e.name for e in except_}
+
+ for table in tables:
+ except_columns[id(table)] = columns
+
+
+def _add_replace_columns(expression, tables, replace_columns):
+ replace = expression.args.get("replace")
+
+ if not replace:
+ return
+
+ columns = {e.this.name: e.alias for e in replace}
+
+ for table in tables:
+ replace_columns[id(table)] = columns
+
+
+def _qualify_outputs(scope):
+ """Ensure all output columns are aliased"""
+ new_selections = []
+
+ for i, (selection, aliased_column) in enumerate(
+ itertools.zip_longest(scope.selects, scope.outer_column_list)
+ ):
+ if isinstance(selection, exp.Column):
+ # convoluted setter because a simple selection.replace(alias) would require a copy
+ alias_ = alias(exp.column(""), alias=selection.name)
+ alias_.set("this", selection)
+ selection = alias_
+ elif not isinstance(selection, exp.Alias):
+ alias_ = alias(exp.column(""), f"_col_{i}")
+ alias_.set("this", selection)
+ selection = alias_
+
+ if aliased_column:
+ selection.set("alias", exp.to_identifier(aliased_column))
+
+ new_selections.append(selection)
+
+ scope.expression.set("expressions", new_selections)
+
+
+def _check_unknown_tables(scope):
+ if (
+ scope.external_columns
+ and not scope.is_unnest
+ and not scope.is_correlated_subquery
+ ):
+ raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
+
+
+class _Resolver:
+ """
+ Helper for resolving columns.
+
+ This is a class so we can lazily load some things and easily share them across functions.
+ """
+
+ def __init__(self, scope, schema):
+ self.scope = scope
+ self.schema = schema
+ self._source_columns = None
+ self._unambiguous_columns = None
+ self._all_columns = None
+
+ def get_table(self, column_name):
+ """
+ Get the table for a column name.
+
+ Args:
+ column_name (str)
+ Returns:
+ (str) table name
+ """
+ if self._unambiguous_columns is None:
+ self._unambiguous_columns = self._get_unambiguous_columns(
+ self._get_all_source_columns()
+ )
+ return self._unambiguous_columns.get(column_name)
+
+ @property
+ def all_columns(self):
+ """All available columns of all sources in this scope"""
+ if self._all_columns is None:
+ self._all_columns = set(
+ column
+ for columns in self._get_all_source_columns().values()
+ for column in columns
+ )
+ return self._all_columns
+
+ def get_source_columns(self, name):
+ """Resolve the source columns for a given source `name`"""
+ if name not in self.scope.sources:
+ raise OptimizeError(f"Unknown table: {name}")
+
+ source = self.scope.sources[name]
+
+ # If referencing a table, return the columns from the schema
+ if isinstance(source, exp.Table):
+ try:
+ return self.schema.column_names(source)
+ except Exception as e:
+ raise OptimizeError(str(e)) from e
+
+ # Otherwise, if referencing another scope, return that scope's named selects
+ return source.expression.named_selects
+
+ def _get_all_source_columns(self):
+ if self._source_columns is None:
+ self._source_columns = {
+ k: self.get_source_columns(k) for k in self.scope.selected_sources
+ }
+ return self._source_columns
+
+ def _get_unambiguous_columns(self, source_columns):
+ """
+ Find all the unambiguous columns in sources.
+
+ Args:
+ source_columns (dict): Mapping of names to source columns
+ Returns:
+ dict: Mapping of column name to source name
+ """
+ if not source_columns:
+ return {}
+
+ source_columns = list(source_columns.items())
+
+ first_table, first_columns = source_columns[0]
+ unambiguous_columns = {
+ col: first_table for col in self._find_unique_columns(first_columns)
+ }
+ all_columns = set(unambiguous_columns)
+
+ for table, columns in source_columns[1:]:
+ unique = self._find_unique_columns(columns)
+ ambiguous = set(all_columns).intersection(unique)
+ all_columns.update(columns)
+ for column in ambiguous:
+ unambiguous_columns.pop(column, None)
+ for column in unique.difference(ambiguous):
+ unambiguous_columns[column] = table
+
+ return unambiguous_columns
+
+ @staticmethod
+ def _find_unique_columns(columns):
+ """
+ Find the unique columns in a list of columns.
+
+ Example:
+ >>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"]))
+ ['a', 'c']
+
+ This is necessary because duplicate column names are ambiguous.
+ """
+ counts = {}
+ for column in columns:
+ counts[column] = counts.get(column, 0) + 1
+ return {column for column, count in counts.items() if count == 1}