summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py6
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py18
-rw-r--r--sqlglot/optimizer/qualify_columns.py18
-rw-r--r--sqlglot/optimizer/qualify_tables.py4
-rw-r--r--sqlglot/optimizer/scope.py72
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py24
6 files changed, 107 insertions, 35 deletions
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 728493d..af42f25 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -136,8 +136,10 @@ def _eliminate_union(scope, existing_ctes, taken):
def _eliminate_derived_table(scope, existing_ctes, taken):
- # This ensures we don't drop the "pivot" arg from a pivoted subquery
- if scope.parent.pivots:
+ # This makes sure that we don't:
+ # - drop the "pivot" arg from a pivoted subquery
+ # - eliminate a lateral correlated subquery
+ if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
return None
parent = scope.expression.parent
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
index 99e605d..9d4860e 100644
--- a/sqlglot/optimizer/normalize_identifiers.py
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -1,8 +1,23 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
+@t.overload
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
+ ...
+
+
+@t.overload
+def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Expression:
+ ...
+
+
+def normalize_identifiers(expression, dialect=None):
"""
Normalize all unquoted identifiers to either lower or upper case, depending
on the dialect. This essentially makes those identifiers case-insensitive.
@@ -16,6 +31,8 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
>>> normalize_identifiers(expression).sql()
'SELECT bar.a AS a FROM "Foo".bar'
+ >>> normalize_identifiers("foo", dialect="snowflake").sql(dialect="snowflake")
+ 'FOO'
Args:
expression: The expression to transform.
@@ -24,4 +41,5 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
Returns:
The transformed expression.
"""
+ expression = exp.maybe_parse(expression, dialect=dialect)
return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 2657188..9c34cef 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -39,6 +39,7 @@ def qualify_columns(
"""
schema = ensure_schema(schema)
infer_schema = schema.empty if infer_schema is None else infer_schema
+ pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
for scope in traverse_scope(expression):
resolver = Resolver(scope, schema, infer_schema=infer_schema)
@@ -55,7 +56,7 @@ def qualify_columns(
_expand_alias_refs(scope, resolver)
if not isinstance(scope.expression, exp.UDTF):
- _expand_stars(scope, resolver, using_column_tables)
+ _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
_qualify_outputs(scope)
_expand_group_by(scope)
_expand_order_by(scope, resolver)
@@ -326,7 +327,10 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
def _expand_stars(
- scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any]
+ scope: Scope,
+ resolver: Resolver,
+ using_column_tables: t.Dict[str, t.Any],
+ pseudocolumns: t.Set[str],
) -> None:
"""Expand stars to lists of column selections"""
@@ -367,14 +371,8 @@ def _expand_stars(
columns = resolver.get_source_columns(table, only_visible=True)
- # The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement
- # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
- if resolver.schema.dialect == "bigquery":
- columns = [
- name
- for name in columns
- if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE")
- ]
+ if pseudocolumns:
+ columns = [name for name in columns if name.upper() not in pseudocolumns]
if columns and "*" not in columns:
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 31c9cc0..68aebdb 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -80,7 +80,9 @@ def qualify_tables(
header = next(reader)
columns = next(reader)
schema.add_table(
- source, {k: type(v).__name__ for k, v in zip(header, columns)}
+ source,
+ {k: type(v).__name__ for k, v in zip(header, columns)},
+ match_depth=False,
)
elif isinstance(source, Scope) and source.is_udtf:
udtf = source.expression
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index a7dab35..fb12384 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -435,7 +435,10 @@ class Scope:
@property
def is_correlated_subquery(self):
"""Determine if this scope is a correlated subquery"""
- return bool(self.is_subquery and self.external_columns)
+ return bool(
+ (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
+ and self.external_columns
+ )
def rename_source(self, old_name, new_name):
"""Rename a source in this scope"""
@@ -486,7 +489,7 @@ class Scope:
def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
"""
- Traverse an expression by it's "scopes".
+ Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
@@ -509,9 +512,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
Returns:
list[Scope]: scope instances
"""
- if not isinstance(expression, exp.Unionable):
- return []
- return list(_traverse_scope(Scope(expression)))
+ if isinstance(expression, exp.Unionable) or (
+ isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
+ ):
+ return list(_traverse_scope(Scope(expression)))
+
+ return []
def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
@@ -539,7 +545,9 @@ def _traverse_scope(scope):
elif isinstance(scope.expression, exp.Table):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
- pass
+ yield from _traverse_udtfs(scope)
+ elif isinstance(scope.expression, exp.DDL):
+ yield from _traverse_ddl(scope)
else:
logger.warning(
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
@@ -576,10 +584,10 @@ def _traverse_ctes(scope):
for cte in scope.ctes:
recursive_scope = None
- # if the scope is a recursive cte, it must be in the form of
- # base_case UNION recursive. thus the recursive scope is the first
- # section of the union.
- if scope.expression.args["with"].recursive:
+ # if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
+ # thus the recursive scope is the first section of the union.
+ with_ = scope.expression.args.get("with")
+ if with_ and with_.recursive:
union = cte.this
if isinstance(union, exp.Union):
@@ -692,8 +700,7 @@ def _traverse_tables(scope):
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
- alias = expression.alias
- sources[alias] = child_scope
+ sources[expression.alias] = child_scope
# append the final child_scope yielded
scopes.append(child_scope)
@@ -711,6 +718,47 @@ def _traverse_subqueries(scope):
scope.subquery_scopes.append(top)
+def _traverse_udtfs(scope):
+ if isinstance(scope.expression, exp.Unnest):
+ expressions = scope.expression.expressions
+ elif isinstance(scope.expression, exp.Lateral):
+ expressions = [scope.expression.this]
+ else:
+ expressions = []
+
+ sources = {}
+ for expression in expressions:
+ if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
+ top = None
+ for child_scope in _traverse_scope(
+ scope.branch(
+ expression,
+ scope_type=ScopeType.DERIVED_TABLE,
+ outer_column_list=expression.alias_column_names,
+ )
+ ):
+ yield child_scope
+ top = child_scope
+ sources[expression.alias] = child_scope
+
+ scope.derived_table_scopes.append(top)
+ scope.table_scopes.append(top)
+
+ scope.sources.update(sources)
+
+
+def _traverse_ddl(scope):
+ yield from _traverse_ctes(scope)
+
+ query_scope = scope.branch(
+ scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources
+ )
+ query_scope._collect()
+ query_scope._ctes = scope.ctes + query_scope._ctes
+
+ yield from _traverse_scope(query_scope)
+
+
def walk_in_scope(expression, bfs=True):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 09e3f2a..816f5fb 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -46,20 +46,24 @@ def unnest(select, parent_select, next_alias_name):
if not predicate or parent_select is not predicate.parent_select:
return
- # this subquery returns a scalar and can just be converted to a cross join
+ # This subquery returns a scalar and can just be converted to a cross join
if not isinstance(predicate, (exp.In, exp.Any)):
- having = predicate.find_ancestor(exp.Having)
column = exp.column(select.selects[0].alias_or_name, alias)
- if having and having.parent_select is parent_select:
+
+ clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
+ clause_parent_select = clause.parent_select if clause else None
+
+ if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
+ (not clause or clause_parent_select is not parent_select)
+ and (
+ parent_select.args.get("group")
+ or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
+ )
+ ):
column = exp.Max(this=column)
- _replace(select.parent, column)
- parent_select.join(
- select,
- join_type="CROSS",
- join_alias=alias,
- copy=False,
- )
+ _replace(select.parent, column)
+ parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
return
if select.find(exp.Limit, exp.Offset):