diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-02-12 10:06:28 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-02-12 10:06:28 +0000 |
commit | 918abde014f9e5c75dfbe21110c379f7f70435c9 (patch) | |
tree | 3419a01e34958bffbd917fa9e600eda126ea3a87 /sqlglot/optimizer | |
parent | Releasing debian version 10.6.3-1. (diff) | |
download | sqlglot-918abde014f9e5c75dfbe21110c379f7f70435c9.tar.xz sqlglot-918abde014f9e5c75dfbe21110c379f7f70435c9.zip |
Merging upstream version 11.0.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 17 | ||||
-rw-r--r-- | sqlglot/optimizer/expand_laterals.py | 34 | ||||
-rw-r--r-- | sqlglot/optimizer/optimizer.py | 5 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_projections.py | 6 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 30 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_tables.py | 13 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 20 |
7 files changed, 101 insertions, 24 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index bfb2bb8..66f97a9 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -255,12 +255,23 @@ class TypeAnnotator: for name, source in scope.sources.items(): if not isinstance(source, Scope): continue - if isinstance(source.expression, exp.Values): + if isinstance(source.expression, exp.UDTF): + values = [] + + if isinstance(source.expression, exp.Lateral): + if isinstance(source.expression.this, exp.Explode): + values = [source.expression.this.this] + else: + values = source.expression.expressions[0].expressions + + if not values: + continue + selects[name] = { alias: column for alias, column in zip( source.expression.alias_column_names, - source.expression.expressions[0].expressions, + values, ) } else: @@ -272,7 +283,7 @@ class TypeAnnotator: source = scope.sources.get(col.table) if isinstance(source, exp.Table): col.type = self.schema.get_column_type(source, col) - elif source: + elif source and col.table in selects: col.type = selects[col.table][col.name].type # Then (possibly) annotate the remaining expressions in the scope self._maybe_annotate(scope.expression) diff --git a/sqlglot/optimizer/expand_laterals.py b/sqlglot/optimizer/expand_laterals.py new file mode 100644 index 0000000..59f3fec --- /dev/null +++ b/sqlglot/optimizer/expand_laterals.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp + + +def expand_laterals(expression: exp.Expression) -> exp.Expression: + """ + Expand lateral column alias references. + + This assumes `qualify_columns` as already run. + + Example: + >>> import sqlglot + >>> sql = "SELECT x.a + 1 AS b, b + 1 AS c FROM x" + >>> expression = sqlglot.parse_one(sql) + >>> expand_laterals(expression).sql() + 'SELECT x.a + 1 AS b, x.a + 1 + 1 AS c FROM x' + + Args: + expression: expression to optimize + Returns: + optimized expression + """ + for select in expression.find_all(exp.Select): + alias_to_expression: t.Dict[str, exp.Expression] = {} + for projection in select.expressions: + for column in projection.find_all(exp.Column): + if not column.table and column.name in alias_to_expression: + column.replace(alias_to_expression[column.name].copy()) + if isinstance(projection, exp.Alias): + alias_to_expression[projection.alias] = projection.this + return expression diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 766e059..96fd56b 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -4,6 +4,7 @@ from sqlglot.optimizer.canonicalize import canonicalize from sqlglot.optimizer.eliminate_ctes import eliminate_ctes from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries +from sqlglot.optimizer.expand_laterals import expand_laterals from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects from sqlglot.optimizer.lower_identities import lower_identities @@ -12,7 +13,7 @@ from sqlglot.optimizer.normalize import normalize from sqlglot.optimizer.optimize_joins import optimize_joins from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_projections import pushdown_projections -from sqlglot.optimizer.qualify_columns import qualify_columns +from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns from sqlglot.optimizer.qualify_tables import qualify_tables from sqlglot.optimizer.unnest_subqueries import unnest_subqueries from sqlglot.schema import ensure_schema @@ -22,6 +23,8 @@ RULES = ( qualify_tables, isolate_table_selects, qualify_columns, + expand_laterals, + validate_qualify_columns, pushdown_projections, normalize, unnest_subqueries, diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index a73647c..54c5021 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -7,7 +7,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope SELECT_ALL = object() # Selection to use if selection list is empty -DEFAULT_SELECTION = alias("1", "_") +DEFAULT_SELECTION = lambda: alias("1", "_") def pushdown_projections(expression): @@ -93,7 +93,7 @@ def _remove_unused_selections(scope, parent_selections): # If there are no remaining selections, just select a single constant if not new_selections: - new_selections.append(DEFAULT_SELECTION.copy()) + new_selections.append(DEFAULT_SELECTION()) scope.expression.set("expressions", new_selections) if removed: @@ -106,5 +106,5 @@ def _remove_indexed_selections(scope, indexes_to_remove): selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove ] if not new_selections: - new_selections.append(DEFAULT_SELECTION.copy()) + new_selections.append(DEFAULT_SELECTION()) scope.expression.set("expressions", new_selections) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 54425a8..ab13d01 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -37,11 +37,24 @@ def qualify_columns(expression, schema): if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver) _qualify_outputs(scope) - _check_unknown_tables(scope) return expression +def validate_qualify_columns(expression): + """Raise an `OptimizeError` if any columns aren't qualified""" + unqualified_columns = [] + for scope in traverse_scope(expression): + if isinstance(scope.expression, exp.Select): + unqualified_columns.extend(scope.unqualified_columns) + if scope.external_columns and not scope.is_correlated_subquery: + raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}") + + if unqualified_columns: + raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") + return expression + + def _pop_table_column_aliases(derived_tables): """ Remove table column aliases. @@ -199,10 +212,6 @@ def _qualify_columns(scope, resolver): if not column_table: column_table = resolver.get_table(column_name) - if not scope.is_subquery and not scope.is_udtf: - if column_table is None: - raise OptimizeError(f"Ambiguous column: {column_name}") - # column_table can be a '' because bigquery unnest has no table alias if column_table: column.set("table", exp.to_identifier(column_table)) @@ -231,10 +240,8 @@ def _qualify_columns(scope, resolver): for column in columns_missing_from_scope: column_table = resolver.get_table(column.name) - if column_table is None: - raise OptimizeError(f"Ambiguous column: {column.name}") - - column.set("table", exp.to_identifier(column_table)) + if column_table: + column.set("table", exp.to_identifier(column_table)) def _expand_stars(scope, resolver): @@ -322,11 +329,6 @@ def _qualify_outputs(scope): scope.expression.set("expressions", new_selections) -def _check_unknown_tables(scope): - if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery: - raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}") - - class _Resolver: """ Helper for resolving columns. diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 5d8e0d9..65593bd 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -2,7 +2,7 @@ import itertools from sqlglot import alias, exp from sqlglot.helper import csv_reader -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.scope import Scope, traverse_scope def qualify_tables(expression, db=None, catalog=None, schema=None): @@ -25,6 +25,8 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): """ sequence = itertools.count() + next_name = lambda: f"_q_{next(sequence)}" + for scope in traverse_scope(expression): for derived_table in scope.ctes + scope.derived_tables: if not derived_table.args.get("alias"): @@ -46,7 +48,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): source = source.replace( alias( source.copy(), - source.this if identifier else f"_q_{next(sequence)}", + source.this if identifier else next_name(), table=True, ) ) @@ -58,5 +60,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): schema.add_table( source, {k: type(v).__name__ for k, v in zip(header, columns)} ) + elif isinstance(source, Scope) and source.is_udtf: + udtf = source.expression + table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name()) + udtf.set("alias", table_alias) + + if not table_alias.name: + table_alias.set("this", next_name()) return expression diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index badbb87..8565c64 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -237,6 +237,8 @@ class Scope: ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint) if ( not ancestor + # Window functions can have an ORDER BY clause + or not isinstance(ancestor.parent, exp.Select) or column.table or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) ): @@ -479,7 +481,7 @@ def _traverse_scope(scope): elif isinstance(scope.expression, exp.Union): yield from _traverse_union(scope) elif isinstance(scope.expression, exp.UDTF): - pass + _set_udtf_scope(scope) elif isinstance(scope.expression, exp.Subquery): yield from _traverse_subqueries(scope) else: @@ -509,6 +511,22 @@ def _traverse_union(scope): scope.union_scopes = [left, right] +def _set_udtf_scope(scope): + parent = scope.expression.parent + from_ = parent.args.get("from") + + if not from_: + return + + for table in from_.expressions: + if isinstance(table, exp.Table): + scope.tables.append(table) + elif isinstance(table, exp.Subquery): + scope.subqueries.append(table) + _add_table_sources(scope) + _traverse_subqueries(scope) + + def _traverse_derived_tables(derived_tables, scope, scope_type): sources = {} is_cte = scope_type == ScopeType.CTE |