summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/annotate_types.py17
-rw-r--r--sqlglot/optimizer/expand_laterals.py34
-rw-r--r--sqlglot/optimizer/optimizer.py5
-rw-r--r--sqlglot/optimizer/pushdown_projections.py6
-rw-r--r--sqlglot/optimizer/qualify_columns.py30
-rw-r--r--sqlglot/optimizer/qualify_tables.py13
-rw-r--r--sqlglot/optimizer/scope.py20
7 files changed, 101 insertions, 24 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index bfb2bb8..66f97a9 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -255,12 +255,23 @@ class TypeAnnotator:
for name, source in scope.sources.items():
if not isinstance(source, Scope):
continue
- if isinstance(source.expression, exp.Values):
+ if isinstance(source.expression, exp.UDTF):
+ values = []
+
+ if isinstance(source.expression, exp.Lateral):
+ if isinstance(source.expression.this, exp.Explode):
+ values = [source.expression.this.this]
+ else:
+ values = source.expression.expressions[0].expressions
+
+ if not values:
+ continue
+
selects[name] = {
alias: column
for alias, column in zip(
source.expression.alias_column_names,
- source.expression.expressions[0].expressions,
+ values,
)
}
else:
@@ -272,7 +283,7 @@ class TypeAnnotator:
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
- elif source:
+ elif source and col.table in selects:
col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
diff --git a/sqlglot/optimizer/expand_laterals.py b/sqlglot/optimizer/expand_laterals.py
new file mode 100644
index 0000000..59f3fec
--- /dev/null
+++ b/sqlglot/optimizer/expand_laterals.py
@@ -0,0 +1,34 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import exp
+
+
+def expand_laterals(expression: exp.Expression) -> exp.Expression:
+ """
+ Expand lateral column alias references.
+
+ This assumes `qualify_columns` as already run.
+
+ Example:
+ >>> import sqlglot
+ >>> sql = "SELECT x.a + 1 AS b, b + 1 AS c FROM x"
+ >>> expression = sqlglot.parse_one(sql)
+ >>> expand_laterals(expression).sql()
+ 'SELECT x.a + 1 AS b, x.a + 1 + 1 AS c FROM x'
+
+ Args:
+ expression: expression to optimize
+ Returns:
+ optimized expression
+ """
+ for select in expression.find_all(exp.Select):
+ alias_to_expression: t.Dict[str, exp.Expression] = {}
+ for projection in select.expressions:
+ for column in projection.find_all(exp.Column):
+ if not column.table and column.name in alias_to_expression:
+ column.replace(alias_to_expression[column.name].copy())
+ if isinstance(projection, exp.Alias):
+ alias_to_expression[projection.alias] = projection.this
+ return expression
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index 766e059..96fd56b 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -4,6 +4,7 @@ from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
+from sqlglot.optimizer.expand_laterals import expand_laterals
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
@@ -12,7 +13,7 @@ from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
-from sqlglot.optimizer.qualify_columns import qualify_columns
+from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
@@ -22,6 +23,8 @@ RULES = (
qualify_tables,
isolate_table_selects,
qualify_columns,
+ expand_laterals,
+ validate_qualify_columns,
pushdown_projections,
normalize,
unnest_subqueries,
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index a73647c..54c5021 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -7,7 +7,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
SELECT_ALL = object()
# Selection to use if selection list is empty
-DEFAULT_SELECTION = alias("1", "_")
+DEFAULT_SELECTION = lambda: alias("1", "_")
def pushdown_projections(expression):
@@ -93,7 +93,7 @@ def _remove_unused_selections(scope, parent_selections):
# If there are no remaining selections, just select a single constant
if not new_selections:
- new_selections.append(DEFAULT_SELECTION.copy())
+ new_selections.append(DEFAULT_SELECTION())
scope.expression.set("expressions", new_selections)
if removed:
@@ -106,5 +106,5 @@ def _remove_indexed_selections(scope, indexes_to_remove):
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
]
if not new_selections:
- new_selections.append(DEFAULT_SELECTION.copy())
+ new_selections.append(DEFAULT_SELECTION())
scope.expression.set("expressions", new_selections)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 54425a8..ab13d01 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -37,11 +37,24 @@ def qualify_columns(expression, schema):
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver)
_qualify_outputs(scope)
- _check_unknown_tables(scope)
return expression
+def validate_qualify_columns(expression):
+ """Raise an `OptimizeError` if any columns aren't qualified"""
+ unqualified_columns = []
+ for scope in traverse_scope(expression):
+ if isinstance(scope.expression, exp.Select):
+ unqualified_columns.extend(scope.unqualified_columns)
+ if scope.external_columns and not scope.is_correlated_subquery:
+ raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}")
+
+ if unqualified_columns:
+ raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
+ return expression
+
+
def _pop_table_column_aliases(derived_tables):
"""
Remove table column aliases.
@@ -199,10 +212,6 @@ def _qualify_columns(scope, resolver):
if not column_table:
column_table = resolver.get_table(column_name)
- if not scope.is_subquery and not scope.is_udtf:
- if column_table is None:
- raise OptimizeError(f"Ambiguous column: {column_name}")
-
# column_table can be a '' because bigquery unnest has no table alias
if column_table:
column.set("table", exp.to_identifier(column_table))
@@ -231,10 +240,8 @@ def _qualify_columns(scope, resolver):
for column in columns_missing_from_scope:
column_table = resolver.get_table(column.name)
- if column_table is None:
- raise OptimizeError(f"Ambiguous column: {column.name}")
-
- column.set("table", exp.to_identifier(column_table))
+ if column_table:
+ column.set("table", exp.to_identifier(column_table))
def _expand_stars(scope, resolver):
@@ -322,11 +329,6 @@ def _qualify_outputs(scope):
scope.expression.set("expressions", new_selections)
-def _check_unknown_tables(scope):
- if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery:
- raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
-
-
class _Resolver:
"""
Helper for resolving columns.
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 5d8e0d9..65593bd 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -2,7 +2,7 @@ import itertools
from sqlglot import alias, exp
from sqlglot.helper import csv_reader
-from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.optimizer.scope import Scope, traverse_scope
def qualify_tables(expression, db=None, catalog=None, schema=None):
@@ -25,6 +25,8 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
"""
sequence = itertools.count()
+ next_name = lambda: f"_q_{next(sequence)}"
+
for scope in traverse_scope(expression):
for derived_table in scope.ctes + scope.derived_tables:
if not derived_table.args.get("alias"):
@@ -46,7 +48,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
source = source.replace(
alias(
source.copy(),
- source.this if identifier else f"_q_{next(sequence)}",
+ source.this if identifier else next_name(),
table=True,
)
)
@@ -58,5 +60,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
schema.add_table(
source, {k: type(v).__name__ for k, v in zip(header, columns)}
)
+ elif isinstance(source, Scope) and source.is_udtf:
+ udtf = source.expression
+ table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name())
+ udtf.set("alias", table_alias)
+
+ if not table_alias.name:
+ table_alias.set("this", next_name())
return expression
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index badbb87..8565c64 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -237,6 +237,8 @@ class Scope:
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
if (
not ancestor
+ # Window functions can have an ORDER BY clause
+ or not isinstance(ancestor.parent, exp.Select)
or column.table
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
):
@@ -479,7 +481,7 @@ def _traverse_scope(scope):
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.UDTF):
- pass
+ _set_udtf_scope(scope)
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
else:
@@ -509,6 +511,22 @@ def _traverse_union(scope):
scope.union_scopes = [left, right]
+def _set_udtf_scope(scope):
+ parent = scope.expression.parent
+ from_ = parent.args.get("from")
+
+ if not from_:
+ return
+
+ for table in from_.expressions:
+ if isinstance(table, exp.Table):
+ scope.tables.append(table)
+ elif isinstance(table, exp.Subquery):
+ scope.subqueries.append(table)
+ _add_table_sources(scope)
+ _traverse_subqueries(scope)
+
+
def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
is_cte = scope_type == ScopeType.CTE