summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-30 08:03:58 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-30 08:03:58 +0000
commit9f19773cebdc9476f2a3266d3c01c967c38fcd1e (patch)
treea60f607ba2bb64fb45da86c297ff29ffc9b92f58 /sqlglot
parentReleasing debian version 16.7.3-1. (diff)
downloadsqlglot-9f19773cebdc9476f2a3266d3c01c967c38fcd1e.tar.xz
sqlglot-9f19773cebdc9476f2a3266d3c01c967c38fcd1e.zip
Merging upstream version 16.7.7.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dialects/bigquery.py9
-rw-r--r--sqlglot/dialects/snowflake.py10
-rw-r--r--sqlglot/executor/python.py5
-rw-r--r--sqlglot/expressions.py2
-rw-r--r--sqlglot/generator.py9
-rw-r--r--sqlglot/optimizer/merge_subqueries.py22
-rw-r--r--sqlglot/optimizer/qualify_columns.py50
-rw-r--r--sqlglot/optimizer/scope.py58
-rw-r--r--sqlglot/optimizer/simplify.py1
-rw-r--r--sqlglot/parser.py18
-rw-r--r--sqlglot/schema.py23
-rw-r--r--sqlglot/transforms.py8
12 files changed, 146 insertions, 69 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 8786063..c9d6c79 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -96,13 +96,14 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
These are added by the optimizer's qualify_column step.
"""
+ from sqlglot.optimizer.scope import Scope
+
if isinstance(expression, exp.Select):
for unnest in expression.find_all(exp.Unnest):
if isinstance(unnest.parent, (exp.From, exp.Join)) and unnest.alias:
- for select in expression.selects:
- for column in select.find_all(exp.Column):
- if column.table == unnest.alias:
- column.set("table", None)
+ for column in Scope(expression).find_all(exp.Column):
+ if column.table == unnest.alias:
+ column.set("table", None)
return expression
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index a2dbfd9..34e4dd0 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -86,6 +86,10 @@ def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
)
+def _parse_datediff(args: t.List) -> exp.DateDiff:
+ return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
+
+
def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
@@ -214,15 +218,15 @@ class Snowflake(Dialect):
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
- "DATEDIFF": lambda args: exp.DateDiff(
- this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
- ),
+ "DATEDIFF": _parse_datediff,
"DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
"OBJECT_CONSTRUCT": _parse_object_construct,
"RLIKE": exp.RegexpLike.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
+ "TIMEDIFF": _parse_datediff,
+ "TIMESTAMPDIFF": _parse_datediff,
"TO_ARRAY": exp.Array.from_arg_list,
"TO_VARCHAR": exp.ToChar.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp,
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index a927181..34a380e 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -19,7 +19,6 @@ class PythonExecutor:
self.tables = tables or {}
def execute(self, plan):
- running = set()
finished = set()
queue = set(plan.leaves)
contexts = {}
@@ -34,7 +33,6 @@ class PythonExecutor:
for name, table in contexts[dep].tables.items()
}
)
- running.add(node)
if isinstance(node, planner.Scan):
contexts[node] = self.scan(node, context)
@@ -49,11 +47,10 @@ class PythonExecutor:
else:
raise NotImplementedError
- running.remove(node)
finished.add(node)
for dep in node.dependents:
- if dep not in running and all(d in contexts for d in dep.dependencies):
+ if all(d in contexts for d in dep.dependencies):
queue.add(dep)
for dep in node.dependencies:
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index e01cc1a..cdb93db 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -3922,7 +3922,7 @@ class Avg(AggFunc):
class AnyValue(AggFunc):
- pass
+ arg_types = {"this": True, "having": False, "max": False}
class Case(Func):
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 5d8a4ca..a41af12 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -2423,6 +2423,15 @@ class Generator:
buckets = self.sql(expression, "buckets")
return f"CLUSTERED BY ({expressions}){sorted_by} INTO {buckets} BUCKETS"
+ def anyvalue_sql(self, expression: exp.AnyValue) -> str:
+ this = self.sql(expression, "this")
+ having = self.sql(expression, "having")
+
+ if having:
+ this = f"{this} HAVING {'MAX' if expression.args.get('max') else 'MIN'} {having}"
+
+ return self.func("ANY_VALUE", this)
+
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index fefe96e..e156d5e 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -47,6 +47,17 @@ UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
}
+# Projections in the outer query that are instances of these types can be replaced
+# without getting wrapped in parentheses, because the precedence won't be altered.
+SAFE_TO_REPLACE_UNWRAPPED = (
+ exp.Column,
+ exp.EQ,
+ exp.Func,
+ exp.NEQ,
+ exp.Paren,
+)
+
+
def merge_ctes(expression, leave_tables_isolated=False):
scopes = traverse_scope(expression)
@@ -293,8 +304,17 @@ def _merge_expressions(outer_scope, inner_scope, alias):
if not projection_name:
continue
columns_to_replace = outer_columns.get(projection_name, [])
+
+ expression = expression.unalias()
+ must_wrap_expression = not isinstance(expression, SAFE_TO_REPLACE_UNWRAPPED)
+
for column in columns_to_replace:
- column.replace(expression.unalias().copy())
+ # Ensures we don't alter the intended operator precedence if there's additional
+ # context surrounding the outer expression (i.e. it's not a simple projection).
+ if isinstance(column.parent, (exp.Unary, exp.Binary)) and must_wrap_expression:
+ expression = exp.paren(expression, copy=False)
+
+ column.replace(expression.copy())
def _merge_where(outer_scope, inner_scope, from_or_join):
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index ef8aeb1..8c3f599 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -170,9 +170,11 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
if not isinstance(expression, exp.Select):
return
- alias_to_expression: t.Dict[str, exp.Expression] = {}
+ alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
- def replace_columns(node: t.Optional[exp.Expression], resolve_table: bool = False) -> None:
+ def replace_columns(
+ node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
+ ) -> None:
if not node:
return
@@ -180,7 +182,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
if not isinstance(column, exp.Column):
continue
table = resolver.get_table(column.name) if resolve_table and not column.table else None
- alias_expr = alias_to_expression.get(column.name)
+ alias_expr, i = alias_to_expression.get(column.name, (None, 1))
double_agg = (
(alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
if alias_expr
@@ -190,16 +192,20 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
if table and (not alias_expr or double_agg):
column.set("table", table)
elif not column.table and alias_expr and not double_agg:
- column.replace(alias_expr.copy())
+ if isinstance(alias_expr, exp.Literal):
+ if literal_index:
+ column.replace(exp.Literal.number(i))
+ else:
+ column.replace(alias_expr.copy())
- for projection in scope.selects:
+ for i, projection in enumerate(scope.selects):
replace_columns(projection)
if isinstance(projection, exp.Alias):
- alias_to_expression[projection.alias] = projection.this
+ alias_to_expression[projection.alias] = (projection.this, i + 1)
replace_columns(expression.args.get("where"))
- replace_columns(expression.args.get("group"))
+ replace_columns(expression.args.get("group"), literal_index=True)
replace_columns(expression.args.get("having"), resolve_table=True)
replace_columns(expression.args.get("qualify"), resolve_table=True)
scope.clear_cache()
@@ -255,27 +261,39 @@ def _expand_order_by(scope: Scope, resolver: Resolver):
selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects}
for ordered in ordereds:
- ordered.set("this", selects.get(ordered.this, ordered.this))
+ ordered = ordered.this
+
+ ordered.replace(
+ exp.to_identifier(_select_by_pos(scope, ordered).alias)
+ if ordered.is_int
+ else selects.get(ordered, ordered)
+ )
def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
new_nodes = []
for node in expressions:
if node.is_int:
- try:
- select = scope.selects[int(node.name) - 1]
- except IndexError:
- raise OptimizeError(f"Unknown output column: {node.name}")
- if isinstance(select, exp.Alias):
- select = select.this
- new_nodes.append(select.copy())
- scope.clear_cache()
+ select = _select_by_pos(scope, t.cast(exp.Literal, node)).this
+
+ if isinstance(select, exp.Literal):
+ new_nodes.append(node)
+ else:
+ new_nodes.append(select.copy())
+ scope.clear_cache()
else:
new_nodes.append(node)
return new_nodes
+def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
+ try:
+ return scope.selects[int(node.this) - 1].assert_is(exp.Alias)
+ except IndexError:
+ raise OptimizeError(f"Unknown output column: {node.name}")
+
+
def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
"""Disambiguate columns, ensuring each column specifies a source"""
for column in scope.columns:
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index aa56b83..bc649e4 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -1,4 +1,5 @@
import itertools
+import logging
import typing as t
from collections import defaultdict
from enum import Enum, auto
@@ -7,6 +8,8 @@ from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import find_new_name
+logger = logging.getLogger("sqlglot")
+
class ScopeType(Enum):
ROOT = auto()
@@ -85,6 +88,7 @@ class Scope:
self._external_columns = None
self._join_hints = None
self._pivots = None
+ self._references = None
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
@@ -264,14 +268,19 @@ class Scope:
self._columns = []
for column in columns + external_columns:
ancestor = column.find_ancestor(
- exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint
+ exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table
)
if (
not ancestor
or column.table
or isinstance(ancestor, exp.Select)
- or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window))
- or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
+ or (
+ isinstance(ancestor, exp.Order)
+ and (
+ isinstance(ancestor.parent, exp.Window)
+ or column.name not in named_selects
+ )
+ )
):
self._columns.append(column)
@@ -289,15 +298,9 @@ class Scope:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
"""
if self._selected_sources is None:
- referenced_names = []
-
- for table in self.tables:
- referenced_names.append((table.alias_or_name, table))
- for expression in itertools.chain(self.derived_tables, self.udtfs):
- referenced_names.append((expression.alias, expression.unnest()))
result = {}
- for name, node in referenced_names:
+ for name, node in self.references:
if name in result:
raise OptimizeError(f"Alias already used: {name}")
if name in self.sources:
@@ -307,6 +310,23 @@ class Scope:
return self._selected_sources
@property
+ def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
+ if self._references is None:
+ self._references = []
+
+ for table in self.tables:
+ self._references.append((table.alias_or_name, table))
+ for expression in itertools.chain(self.derived_tables, self.udtfs):
+ self._references.append(
+ (
+ expression.alias,
+ expression if expression.args.get("pivots") else expression.unnest(),
+ )
+ )
+
+ return self._references
+
+ @property
def cte_sources(self):
"""
Sources that are CTEs.
@@ -378,9 +398,7 @@ class Scope:
def pivots(self):
if not self._pivots:
self._pivots = [
- pivot
- for node in self.tables + self.derived_tables
- for pivot in node.args.get("pivots") or []
+ pivot for _, node in self.references for pivot in node.args.get("pivots") or []
]
return self._pivots
@@ -536,7 +554,11 @@ def _traverse_scope(scope):
elif isinstance(scope.expression, exp.UDTF):
pass
else:
- raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
+ logger.warning(
+ "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
+ )
+ return
+
yield scope
@@ -576,6 +598,8 @@ def _traverse_ctes(scope):
if isinstance(union, exp.Union):
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
+ child_scope = None
+
for child_scope in _traverse_scope(
scope.branch(
cte.this,
@@ -593,7 +617,8 @@ def _traverse_ctes(scope):
child_scope.add_source(alias, recursive_scope)
# append the final child_scope yielded
- scope.cte_scopes.append(child_scope)
+ if child_scope:
+ scope.cte_scopes.append(child_scope)
scope.sources.update(sources)
@@ -634,6 +659,9 @@ def _traverse_tables(scope):
sources[source_name] = expression
continue
+ if not isinstance(expression, exp.DerivedTable):
+ continue
+
if isinstance(expression, exp.UDTF):
lateral_sources = sources
scope_type = ScopeType.UDTF
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 34005d9..1a2d82c 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -400,6 +400,7 @@ def simplify_parens(expression):
or not isinstance(this, exp.Binary)
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
+ or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
):
return expression.this
return expression
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index e5bd4ae..79e7cac 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -717,6 +717,7 @@ class Parser(metaclass=_Parser):
FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"}
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
+ "ANY_VALUE": lambda self: self._parse_any_value(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"CONCAT": lambda self: self._parse_concat(),
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
@@ -3321,11 +3322,6 @@ class Parser(metaclass=_Parser):
else:
this = self._parse_select_or_expression(alias=alias)
- if isinstance(this, exp.EQ):
- left = this.this
- if isinstance(left, exp.Column):
- left.replace(exp.var(left.text("this")))
-
return self._parse_limit(self._parse_order(self._parse_respect_or_ignore_nulls(this)))
def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
@@ -3678,6 +3674,18 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
+ def _parse_any_value(self) -> exp.AnyValue:
+ this = self._parse_lambda()
+ is_max = None
+ having = None
+
+ if self._match(TokenType.HAVING):
+ self._match_texts(("MAX", "MIN"))
+ is_max = self._prev.text == "MAX"
+ having = self._parse_column()
+
+ return self.expression(exp.AnyValue, this=this, having=having, max=is_max)
+
def _parse_cast(self, strict: bool) -> exp.Expression:
this = self._parse_conjunction()
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index b8560a1..12cf0b1 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -226,9 +226,8 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
column_mapping: a column mapping that describes the structure of the table.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
"""
- normalized_table = self._normalize_table(
- self._ensure_table(table, dialect=dialect), dialect=dialect
- )
+ normalized_table = self._normalize_table(table, dialect=dialect)
+
normalized_column_mapping = {
self._normalize_name(key, dialect=dialect): value
for key, value in ensure_column_mapping(column_mapping).items()
@@ -249,9 +248,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
only_visible: bool = False,
dialect: DialectType = None,
) -> t.List[str]:
- normalized_table = self._normalize_table(
- self._ensure_table(table, dialect=dialect), dialect=dialect
- )
+ normalized_table = self._normalize_table(table, dialect=dialect)
schema = self.find(normalized_table)
if schema is None:
@@ -269,9 +266,8 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
column: exp.Column,
dialect: DialectType = None,
) -> exp.DataType:
- normalized_table = self._normalize_table(
- self._ensure_table(table, dialect=dialect), dialect=dialect
- )
+ normalized_table = self._normalize_table(table, dialect=dialect)
+
normalized_column_name = self._normalize_name(
column if isinstance(column, str) else column.this, dialect=dialect
)
@@ -316,8 +312,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return normalized_mapping
- def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table:
- normalized_table = table.copy()
+ def _normalize_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
+ normalized_table = exp.maybe_parse(
+ table, into=exp.Table, dialect=dialect or self.dialect, copy=True
+ )
for arg in TABLE_ARGS:
value = normalized_table.args.get(arg)
@@ -351,9 +349,6 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
# The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1
- def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
- return exp.maybe_parse(table, into=exp.Table, dialect=dialect or self.dialect)
-
def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
"""
Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 1f30f96..33a1bc0 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -157,14 +157,10 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
"""Convert explode/posexplode into unnest (used in hive -> presto)."""
if isinstance(expression, exp.Select):
- from sqlglot.optimizer.scope import build_scope
-
- scope = build_scope(expression)
- if not scope:
- return expression
+ from sqlglot.optimizer.scope import Scope
taken_select_names = set(expression.named_selects)
- taken_source_names = set(scope.selected_sources)
+ taken_source_names = {name for name, _ in Scope(expression).references}
for select in expression.selects:
to_replace = select