summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-31 05:44:41 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-31 05:44:41 +0000
commit376de8b6892deca7dc5d83035c047f1e13eb67ea (patch)
tree334a1753cd914294aa99128fac3fb59bf14dc10f /sqlglot/optimizer
parentReleasing debian version 20.9.0-1. (diff)
downloadsqlglot-376de8b6892deca7dc5d83035c047f1e13eb67ea.tar.xz
sqlglot-376de8b6892deca7dc5d83035c047f1e13eb67ea.zip
Merging upstream version 20.11.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/annotate_types.py30
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py10
-rw-r--r--sqlglot/optimizer/qualify_columns.py7
-rw-r--r--sqlglot/optimizer/qualify_tables.py31
-rw-r--r--sqlglot/optimizer/scope.py14
-rw-r--r--sqlglot/optimizer/simplify.py8
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py15
7 files changed, 81 insertions, 34 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index d0168d5..a2a86cd 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -4,7 +4,6 @@ import functools
import typing as t
from sqlglot import exp
-from sqlglot._typing import E
from sqlglot.helper import (
ensure_list,
is_date_unit,
@@ -17,7 +16,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema, ensure_schema
if t.TYPE_CHECKING:
- B = t.TypeVar("B", bound=exp.Binary)
+ from sqlglot._typing import B, E
BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
BinaryCoercions = t.Dict[
@@ -480,6 +479,20 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return self._annotate_args(expression)
@t.no_type_check
+ def _annotate_struct_value(
+ self, expression: exp.Expression
+ ) -> t.Optional[exp.DataType] | exp.ColumnDef:
+ alias = expression.args.get("alias")
+ if alias:
+ return exp.ColumnDef(this=alias.copy(), kind=expression.type)
+
+ # Case: key = value or key := value
+ if expression.expression:
+ return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
+
+ return expression.type
+
+ @t.no_type_check
def _annotate_by_args(
self,
expression: E,
@@ -516,16 +529,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
)
if struct:
- expressions = [
- expr.type
- if not expr.args.get("alias")
- else exp.ColumnDef(this=expr.args["alias"].copy(), kind=expr.type)
- for expr in expressions
- ]
-
self._set_type(
expression,
- exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True),
+ exp.DataType(
+ this=exp.DataType.Type.STRUCT,
+ expressions=[self._annotate_struct_value(expr) for expr in expressions],
+ nested=True,
+ ),
)
return expression
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
index 3361a33..f2a0990 100644
--- a/sqlglot/optimizer/normalize_identifiers.py
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -3,18 +3,18 @@ from __future__ import annotations
import typing as t
from sqlglot import exp
-from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
@t.overload
-def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
- ...
+def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ...
@t.overload
-def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
- ...
+def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ...
def normalize_identifiers(expression, dialect=None):
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index a6397ae..1656727 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -4,7 +4,6 @@ import itertools
import typing as t
from sqlglot import alias, exp
-from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get
@@ -12,6 +11,9 @@ from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_
from sqlglot.optimizer.simplify import simplify_parens
from sqlglot.schema import Schema, ensure_schema
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
def qualify_columns(
expression: exp.Expression,
@@ -210,7 +212,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
if not node:
return
- for column, *_ in walk_in_scope(node):
+ for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star):
if not isinstance(column, exp.Column):
continue
@@ -525,6 +527,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
selection = alias(
selection,
alias=selection.output_name or f"_col_{i}",
+ copy=False,
)
if aliased_column:
selection.set("alias", exp.to_identifier(aliased_column))
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index e0fe641..d460e81 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -4,12 +4,14 @@ import itertools
import typing as t
from sqlglot import alias, exp
-from sqlglot._typing import E
from sqlglot.dialects.dialect import DialectType
from sqlglot.helper import csv_reader, name_sequence
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
def qualify_tables(
expression: E,
@@ -46,6 +48,18 @@ def qualify_tables(
db = exp.parse_identifier(db, dialect=dialect) if db else None
catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
+ def _qualify(table: exp.Table) -> None:
+ if isinstance(table.this, exp.Identifier):
+ if not table.args.get("db"):
+ table.set("db", db)
+ if not table.args.get("catalog") and table.args.get("db"):
+ table.set("catalog", catalog)
+
+ if not isinstance(expression, exp.Subqueryable):
+ for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)):
+ if isinstance(node, exp.Table):
+ _qualify(node)
+
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
if isinstance(derived_table, exp.Subquery):
@@ -66,11 +80,7 @@ def qualify_tables(
for name, source in scope.sources.items():
if isinstance(source, exp.Table):
- if isinstance(source.this, exp.Identifier):
- if not source.args.get("db"):
- source.set("db", db)
- if not source.args.get("catalog") and source.args.get("db"):
- source.set("catalog", catalog)
+ _qualify(source)
pivots = pivots = source.args.get("pivots")
if not source.alias:
@@ -107,5 +117,14 @@ def qualify_tables(
if isinstance(udtf, exp.Values) and not table_alias.columns:
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
+ else:
+ for node, parent, _ in scope.walk():
+ if (
+ isinstance(node, exp.Table)
+ and not node.alias
+ and isinstance(parent, (exp.From, exp.Join))
+ ):
+ # Mutates the table by attaching an alias to it
+ alias(node, node.name, copy=False, table=True)
return expression
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index a3f08d5..16cd548 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -323,9 +323,14 @@ class Scope:
sources in the current scope.
"""
if self._external_columns is None:
- self._external_columns = [
- c for c in self.columns if c.table not in self.selected_sources
- ]
+ if isinstance(self.expression, exp.Union):
+ left, right = self.union_scopes
+ self._external_columns = left.external_columns + right.external_columns
+ else:
+ self._external_columns = [
+ c for c in self.columns if c.table not in self.selected_sources
+ ]
+
return self._external_columns
@property
@@ -477,11 +482,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
Args:
expression (exp.Expression): expression to traverse
+
Returns:
list[Scope]: scope instances
"""
if isinstance(expression, exp.Unionable) or (
- isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
+ isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable)
):
return list(_traverse_scope(Scope(expression)))
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 25d4e75..d5b9119 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -1068,9 +1068,11 @@ def extract_interval(expression):
def date_literal(date):
return exp.cast(
exp.Literal.string(date),
- exp.DataType.Type.DATETIME
- if isinstance(date, datetime.datetime)
- else exp.DataType.Type.DATE,
+ (
+ exp.DataType.Type.DATETIME
+ if isinstance(date, datetime.datetime)
+ else exp.DataType.Type.DATE
+ ),
)
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 4d35175..26f4159 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -50,11 +50,12 @@ def unnest(select, parent_select, next_alias_name):
):
return
+ clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
+
# This subquery returns a scalar and can just be converted to a cross join
if not isinstance(predicate, (exp.In, exp.Any)):
column = exp.column(select.selects[0].alias_or_name, alias)
- clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
clause_parent_select = clause.parent_select if clause else None
if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
@@ -84,12 +85,18 @@ def unnest(select, parent_select, next_alias_name):
column = _other_operand(predicate)
value = select.selects[0]
- on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
- _replace(predicate, f"NOT {on.right} IS NULL")
+ join_key = exp.column(value.alias, alias)
+ join_key_not_null = join_key.is_(exp.null()).not_()
+
+ if isinstance(clause, exp.Join):
+ _replace(predicate, exp.true())
+ parent_select.where(join_key_not_null, copy=False)
+ else:
+ _replace(predicate, join_key_not_null)
parent_select.join(
select.group_by(value.this, copy=False),
- on=on,
+ on=column.eq(join_key),
join_type="LEFT",
join_alias=alias,
copy=False,