From 376de8b6892deca7dc5d83035c047f1e13eb67ea Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 31 Jan 2024 06:44:41 +0100 Subject: Merging upstream version 20.11.0. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/annotate_types.py | 30 +++++++++++++++++++---------- sqlglot/optimizer/normalize_identifiers.py | 10 +++++----- sqlglot/optimizer/qualify_columns.py | 7 +++++-- sqlglot/optimizer/qualify_tables.py | 31 ++++++++++++++++++++++++------ sqlglot/optimizer/scope.py | 14 ++++++++++---- sqlglot/optimizer/simplify.py | 8 +++++--- sqlglot/optimizer/unnest_subqueries.py | 15 +++++++++++---- 7 files changed, 81 insertions(+), 34 deletions(-) (limited to 'sqlglot/optimizer') diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index d0168d5..a2a86cd 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -4,7 +4,6 @@ import functools import typing as t from sqlglot import exp -from sqlglot._typing import E from sqlglot.helper import ( ensure_list, is_date_unit, @@ -17,7 +16,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import Schema, ensure_schema if t.TYPE_CHECKING: - B = t.TypeVar("B", bound=exp.Binary) + from sqlglot._typing import B, E BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type] BinaryCoercions = t.Dict[ @@ -479,6 +478,20 @@ class TypeAnnotator(metaclass=_TypeAnnotator): self._set_type(expression, target_type) return self._annotate_args(expression) + @t.no_type_check + def _annotate_struct_value( + self, expression: exp.Expression + ) -> t.Optional[exp.DataType] | exp.ColumnDef: + alias = expression.args.get("alias") + if alias: + return exp.ColumnDef(this=alias.copy(), kind=expression.type) + + # Case: key = value or key := value + if expression.expression: + return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type) + + return expression.type + @t.no_type_check def _annotate_by_args( self, @@ -516,16 +529,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator): ) if struct: - expressions = [ - expr.type - if not expr.args.get("alias") - else exp.ColumnDef(this=expr.args["alias"].copy(), kind=expr.type) - for expr in expressions - ] - self._set_type( expression, - exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True), + exp.DataType( + this=exp.DataType.Type.STRUCT, + expressions=[self._annotate_struct_value(expr) for expr in expressions], + nested=True, + ), ) return expression diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index 3361a33..f2a0990 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -3,18 +3,18 @@ from __future__ import annotations import typing as t from sqlglot import exp -from sqlglot._typing import E from sqlglot.dialects.dialect import Dialect, DialectType +if t.TYPE_CHECKING: + from sqlglot._typing import E + @t.overload -def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: - ... +def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ... @t.overload -def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: - ... +def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ... def normalize_identifiers(expression, dialect=None): diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index a6397ae..1656727 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -4,7 +4,6 @@ import itertools import typing as t from sqlglot import alias, exp -from sqlglot._typing import E from sqlglot.dialects.dialect import Dialect, DialectType from sqlglot.errors import OptimizeError from sqlglot.helper import seq_get @@ -12,6 +11,9 @@ from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_ from sqlglot.optimizer.simplify import simplify_parens from sqlglot.schema import Schema, ensure_schema +if t.TYPE_CHECKING: + from sqlglot._typing import E + def qualify_columns( expression: exp.Expression, @@ -210,7 +212,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: if not node: return - for column, *_ in walk_in_scope(node): + for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star): if not isinstance(column, exp.Column): continue @@ -525,6 +527,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: selection = alias( selection, alias=selection.output_name or f"_col_{i}", + copy=False, ) if aliased_column: selection.set("alias", exp.to_identifier(aliased_column)) diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index e0fe641..d460e81 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -4,12 +4,14 @@ import itertools import typing as t from sqlglot import alias, exp -from sqlglot._typing import E from sqlglot.dialects.dialect import DialectType from sqlglot.helper import csv_reader, name_sequence from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import Schema +if t.TYPE_CHECKING: + from sqlglot._typing import E + def qualify_tables( expression: E, @@ -46,6 +48,18 @@ def qualify_tables( db = exp.parse_identifier(db, dialect=dialect) if db else None catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None + def _qualify(table: exp.Table) -> None: + if isinstance(table.this, exp.Identifier): + if not table.args.get("db"): + table.set("db", db) + if not table.args.get("catalog") and table.args.get("db"): + table.set("catalog", catalog) + + if not isinstance(expression, exp.Subqueryable): + for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)): + if isinstance(node, exp.Table): + _qualify(node) + for scope in traverse_scope(expression): for derived_table in itertools.chain(scope.ctes, scope.derived_tables): if isinstance(derived_table, exp.Subquery): @@ -66,11 +80,7 @@ def qualify_tables( for name, source in scope.sources.items(): if isinstance(source, exp.Table): - if isinstance(source.this, exp.Identifier): - if not source.args.get("db"): - source.set("db", db) - if not source.args.get("catalog") and source.args.get("db"): - source.set("catalog", catalog) + _qualify(source) pivots = pivots = source.args.get("pivots") if not source.alias: @@ -107,5 +117,14 @@ def qualify_tables( if isinstance(udtf, exp.Values) and not table_alias.columns: for i, e in enumerate(udtf.expressions[0].expressions): table_alias.append("columns", exp.to_identifier(f"_col_{i}")) + else: + for node, parent, _ in scope.walk(): + if ( + isinstance(node, exp.Table) + and not node.alias + and isinstance(parent, (exp.From, exp.Join)) + ): + # Mutates the table by attaching an alias to it + alias(node, node.name, copy=False, table=True) return expression diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index a3f08d5..16cd548 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -323,9 +323,14 @@ class Scope: sources in the current scope. """ if self._external_columns is None: - self._external_columns = [ - c for c in self.columns if c.table not in self.selected_sources - ] + if isinstance(self.expression, exp.Union): + left, right = self.union_scopes + self._external_columns = left.external_columns + right.external_columns + else: + self._external_columns = [ + c for c in self.columns if c.table not in self.selected_sources + ] + return self._external_columns @property @@ -477,11 +482,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]: Args: expression (exp.Expression): expression to traverse + Returns: list[Scope]: scope instances """ if isinstance(expression, exp.Unionable) or ( - isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable) + isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable) ): return list(_traverse_scope(Scope(expression))) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 25d4e75..d5b9119 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -1068,9 +1068,11 @@ def extract_interval(expression): def date_literal(date): return exp.cast( exp.Literal.string(date), - exp.DataType.Type.DATETIME - if isinstance(date, datetime.datetime) - else exp.DataType.Type.DATE, + ( + exp.DataType.Type.DATETIME + if isinstance(date, datetime.datetime) + else exp.DataType.Type.DATE + ), ) diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 4d35175..26f4159 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -50,11 +50,12 @@ def unnest(select, parent_select, next_alias_name): ): return + clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) + # This subquery returns a scalar and can just be converted to a cross join if not isinstance(predicate, (exp.In, exp.Any)): column = exp.column(select.selects[0].alias_or_name, alias) - clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) clause_parent_select = clause.parent_select if clause else None if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( @@ -84,12 +85,18 @@ def unnest(select, parent_select, next_alias_name): column = _other_operand(predicate) value = select.selects[0] - on = exp.condition(f'{column} = "{alias}"."{value.alias}"') - _replace(predicate, f"NOT {on.right} IS NULL") + join_key = exp.column(value.alias, alias) + join_key_not_null = join_key.is_(exp.null()).not_() + + if isinstance(clause, exp.Join): + _replace(predicate, exp.true()) + parent_select.where(join_key_not_null, copy=False) + else: + _replace(predicate, join_key_not_null) parent_select.join( select.group_by(value.this, copy=False), - on=on, + on=column.eq(join_key), join_type="LEFT", join_alias=alias, copy=False, -- cgit v1.2.3