diff options
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/eliminate_subqueries.py | 6 | ||||
-rw-r--r-- | sqlglot/optimizer/normalize_identifiers.py | 18 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 18 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_tables.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 72 | ||||
-rw-r--r-- | sqlglot/optimizer/unnest_subqueries.py | 24 |
6 files changed, 107 insertions, 35 deletions
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 728493d..af42f25 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -136,8 +136,10 @@ def _eliminate_union(scope, existing_ctes, taken): def _eliminate_derived_table(scope, existing_ctes, taken): - # This ensures we don't drop the "pivot" arg from a pivoted subquery - if scope.parent.pivots: + # This makes sure that we don't: + # - drop the "pivot" arg from a pivoted subquery + # - eliminate a lateral correlated subquery + if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral): return None parent = scope.expression.parent diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index 99e605d..9d4860e 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -1,8 +1,23 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp from sqlglot._typing import E from sqlglot.dialects.dialect import Dialect, DialectType +@t.overload def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: + ... + + +@t.overload +def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Expression: + ... + + +def normalize_identifiers(expression, dialect=None): """ Normalize all unquoted identifiers to either lower or upper case, depending on the dialect. This essentially makes those identifiers case-insensitive. @@ -16,6 +31,8 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar') >>> normalize_identifiers(expression).sql() 'SELECT bar.a AS a FROM "Foo".bar' + >>> normalize_identifiers("foo", dialect="snowflake").sql(dialect="snowflake") + 'FOO' Args: expression: The expression to transform. @@ -24,4 +41,5 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: Returns: The transformed expression. """ + expression = exp.maybe_parse(expression, dialect=dialect) return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 2657188..9c34cef 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -39,6 +39,7 @@ def qualify_columns( """ schema = ensure_schema(schema) infer_schema = schema.empty if infer_schema is None else infer_schema + pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS for scope in traverse_scope(expression): resolver = Resolver(scope, schema, infer_schema=infer_schema) @@ -55,7 +56,7 @@ def qualify_columns( _expand_alias_refs(scope, resolver) if not isinstance(scope.expression, exp.UDTF): - _expand_stars(scope, resolver, using_column_tables) + _expand_stars(scope, resolver, using_column_tables, pseudocolumns) _qualify_outputs(scope) _expand_group_by(scope) _expand_order_by(scope, resolver) @@ -326,7 +327,10 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None: def _expand_stars( - scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any] + scope: Scope, + resolver: Resolver, + using_column_tables: t.Dict[str, t.Any], + pseudocolumns: t.Set[str], ) -> None: """Expand stars to lists of column selections""" @@ -367,14 +371,8 @@ def _expand_stars( columns = resolver.get_source_columns(table, only_visible=True) - # The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement - # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table - if resolver.schema.dialect == "bigquery": - columns = [ - name - for name in columns - if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE") - ] + if pseudocolumns: + columns = [name for name in columns if name.upper() not in pseudocolumns] if columns and "*" not in columns: if pivot and has_pivoted_source and pivot_columns and pivot_output_columns: diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 31c9cc0..68aebdb 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -80,7 +80,9 @@ def qualify_tables( header = next(reader) columns = next(reader) schema.add_table( - source, {k: type(v).__name__ for k, v in zip(header, columns)} + source, + {k: type(v).__name__ for k, v in zip(header, columns)}, + match_depth=False, ) elif isinstance(source, Scope) and source.is_udtf: udtf = source.expression diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index a7dab35..fb12384 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -435,7 +435,10 @@ class Scope: @property def is_correlated_subquery(self): """Determine if this scope is a correlated subquery""" - return bool(self.is_subquery and self.external_columns) + return bool( + (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) + and self.external_columns + ) def rename_source(self, old_name, new_name): """Rename a source in this scope""" @@ -486,7 +489,7 @@ class Scope: def traverse_scope(expression: exp.Expression) -> t.List[Scope]: """ - Traverse an expression by it's "scopes". + Traverse an expression by its "scopes". "Scope" represents the current context of a Select statement. @@ -509,9 +512,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]: Returns: list[Scope]: scope instances """ - if not isinstance(expression, exp.Unionable): - return [] - return list(_traverse_scope(Scope(expression))) + if isinstance(expression, exp.Unionable) or ( + isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable) + ): + return list(_traverse_scope(Scope(expression))) + + return [] def build_scope(expression: exp.Expression) -> t.Optional[Scope]: @@ -539,7 +545,9 @@ def _traverse_scope(scope): elif isinstance(scope.expression, exp.Table): yield from _traverse_tables(scope) elif isinstance(scope.expression, exp.UDTF): - pass + yield from _traverse_udtfs(scope) + elif isinstance(scope.expression, exp.DDL): + yield from _traverse_ddl(scope) else: logger.warning( "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) @@ -576,10 +584,10 @@ def _traverse_ctes(scope): for cte in scope.ctes: recursive_scope = None - # if the scope is a recursive cte, it must be in the form of - # base_case UNION recursive. thus the recursive scope is the first - # section of the union. - if scope.expression.args["with"].recursive: + # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. + # thus the recursive scope is the first section of the union. + with_ = scope.expression.args.get("with") + if with_ and with_.recursive: union = cte.this if isinstance(union, exp.Union): @@ -692,8 +700,7 @@ def _traverse_tables(scope): # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. # Until then, this means that only a single, unaliased derived table is allowed (rather, # the latest one wins. - alias = expression.alias - sources[alias] = child_scope + sources[expression.alias] = child_scope # append the final child_scope yielded scopes.append(child_scope) @@ -711,6 +718,47 @@ def _traverse_subqueries(scope): scope.subquery_scopes.append(top) +def _traverse_udtfs(scope): + if isinstance(scope.expression, exp.Unnest): + expressions = scope.expression.expressions + elif isinstance(scope.expression, exp.Lateral): + expressions = [scope.expression.this] + else: + expressions = [] + + sources = {} + for expression in expressions: + if isinstance(expression, exp.Subquery) and _is_derived_table(expression): + top = None + for child_scope in _traverse_scope( + scope.branch( + expression, + scope_type=ScopeType.DERIVED_TABLE, + outer_column_list=expression.alias_column_names, + ) + ): + yield child_scope + top = child_scope + sources[expression.alias] = child_scope + + scope.derived_table_scopes.append(top) + scope.table_scopes.append(top) + + scope.sources.update(sources) + + +def _traverse_ddl(scope): + yield from _traverse_ctes(scope) + + query_scope = scope.branch( + scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources + ) + query_scope._collect() + query_scope._ctes = scope.ctes + query_scope._ctes + + yield from _traverse_scope(query_scope) + + def walk_in_scope(expression, bfs=True): """ Returns a generator object which visits all nodes in the syntrax tree, stopping at diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 09e3f2a..816f5fb 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -46,20 +46,24 @@ def unnest(select, parent_select, next_alias_name): if not predicate or parent_select is not predicate.parent_select: return - # this subquery returns a scalar and can just be converted to a cross join + # This subquery returns a scalar and can just be converted to a cross join if not isinstance(predicate, (exp.In, exp.Any)): - having = predicate.find_ancestor(exp.Having) column = exp.column(select.selects[0].alias_or_name, alias) - if having and having.parent_select is parent_select: + + clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) + clause_parent_select = clause.parent_select if clause else None + + if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( + (not clause or clause_parent_select is not parent_select) + and ( + parent_select.args.get("group") + or any(projection.find(exp.AggFunc) for projection in parent_select.selects) + ) + ): column = exp.Max(this=column) - _replace(select.parent, column) - parent_select.join( - select, - join_type="CROSS", - join_alias=alias, - copy=False, - ) + _replace(select.parent, column) + parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) return if select.find(exp.Limit, exp.Offset): |