diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-04-03 07:31:50 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-04-03 07:31:50 +0000 |
commit | 1fb60a37d31eacbac62ddafde51b829522925694 (patch) | |
tree | 5c04a33630f7a2cd4cff248e965053f97ec3e4ac /sqlglot/optimizer/qualify_columns.py | |
parent | Adding upstream version 11.4.1. (diff) | |
download | sqlglot-1fb60a37d31eacbac62ddafde51b829522925694.tar.xz sqlglot-1fb60a37d31eacbac62ddafde51b829522925694.zip |
Adding upstream version 11.4.5.upstream/11.4.5
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/qualify_columns.py')
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 88 |
1 files changed, 63 insertions, 25 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 66b3170..5e40cf3 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -30,11 +30,12 @@ def qualify_columns(expression, schema): resolver = Resolver(scope, schema) _pop_table_column_aliases(scope.ctes) _pop_table_column_aliases(scope.derived_tables) - _expand_using(scope, resolver) + using_column_tables = _expand_using(scope, resolver) _qualify_columns(scope, resolver) if not isinstance(scope.expression, exp.UDTF): - _expand_stars(scope, resolver) + _expand_stars(scope, resolver, using_column_tables) _qualify_outputs(scope) + _expand_alias_refs(scope, resolver) _expand_group_by(scope, resolver) _expand_order_by(scope) @@ -69,11 +70,11 @@ def _pop_table_column_aliases(derived_tables): def _expand_using(scope, resolver): - joins = list(scope.expression.find_all(exp.Join)) + joins = list(scope.find_all(exp.Join)) names = {join.this.alias for join in joins} ordered = [key for key in scope.selected_sources if key not in names] - # Mapping of automatically joined column names to source names + # Mapping of automatically joined column names to an ordered set of source names (dict). column_tables = {} for join in joins: @@ -112,11 +113,12 @@ def _expand_using(scope, resolver): ) ) - tables = column_tables.setdefault(identifier, []) + # Set all values in the dict to None, because we only care about the key ordering + tables = column_tables.setdefault(identifier, {}) if table not in tables: - tables.append(table) + tables[table] = None if join_table not in tables: - tables.append(join_table) + tables[join_table] = None join.args.pop("using") join.set("on", exp.and_(*conditions)) @@ -134,11 +136,11 @@ def _expand_using(scope, resolver): scope.replace(column, replacement) + return column_tables -def _expand_group_by(scope, resolver): - group = scope.expression.args.get("group") - if not group: - return + +def _expand_alias_refs(scope, resolver): + selects = {} # Replace references to select aliases def transform(node, *_): @@ -150,9 +152,11 @@ def _expand_group_by(scope, resolver): node.set("table", table) return node - selects = {s.alias_or_name: s for s in scope.selects} - + if not selects: + for s in scope.selects: + selects[s.alias_or_name] = s select = selects.get(node.name) + if select: scope.clear_cache() if isinstance(select, exp.Alias): @@ -161,7 +165,21 @@ def _expand_group_by(scope, resolver): return node - group.transform(transform, copy=False) + for select in scope.expression.selects: + select.transform(transform, copy=False) + + for modifier in ("where", "group"): + part = scope.expression.args.get(modifier) + + if part: + part.transform(transform, copy=False) + + +def _expand_group_by(scope, resolver): + group = scope.expression.args.get("group") + if not group: + return + group.set("expressions", _expand_positional_references(scope, group.expressions)) scope.expression.set("group", group) @@ -231,18 +249,24 @@ def _qualify_columns(scope, resolver): column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) columns_missing_from_scope = [] + # Determine whether each reference in the order by clause is to a column or an alias. - for ordered in scope.find_all(exp.Ordered): - for column in ordered.find_all(exp.Column): - if ( - not column.table - and column.parent is not ordered - and column.name in resolver.all_columns - ): - columns_missing_from_scope.append(column) + order = scope.expression.args.get("order") + + if order: + for ordered in order.expressions: + for column in ordered.find_all(exp.Column): + if ( + not column.table + and column.parent is not ordered + and column.name in resolver.all_columns + ): + columns_missing_from_scope.append(column) # Determine whether each reference in the having clause is to a column or an alias. - for having in scope.find_all(exp.Having): + having = scope.expression.args.get("having") + + if having: for column in having.find_all(exp.Column): if ( not column.table @@ -258,12 +282,13 @@ def _qualify_columns(scope, resolver): column.set("table", column_table) -def _expand_stars(scope, resolver): +def _expand_stars(scope, resolver, using_column_tables): """Expand stars to lists of column selections""" new_selections = [] except_columns = {} replace_columns = {} + coalesced_columns = set() for expression in scope.selects: if isinstance(expression, exp.Star): @@ -286,7 +311,20 @@ def _expand_stars(scope, resolver): if columns and "*" not in columns: table_id = id(table) for name in columns: - if name not in except_columns.get(table_id, set()): + if name in using_column_tables and table in using_column_tables[name]: + if name in coalesced_columns: + continue + + coalesced_columns.add(name) + tables = using_column_tables[name] + coalesce = [exp.column(name, table=table) for table in tables] + + new_selections.append( + exp.alias_( + exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name + ) + ) + elif name not in except_columns.get(table_id, set()): alias_ = replace_columns.get(table_id, {}).get(name, name) column = exp.column(name, table) new_selections.append(alias(column, alias_) if alias_ != name else column) |