diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-30 08:03:58 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-30 08:03:58 +0000 |
commit | 9f19773cebdc9476f2a3266d3c01c967c38fcd1e (patch) | |
tree | a60f607ba2bb64fb45da86c297ff29ffc9b92f58 /sqlglot | |
parent | Releasing debian version 16.7.3-1. (diff) | |
download | sqlglot-9f19773cebdc9476f2a3266d3c01c967c38fcd1e.tar.xz sqlglot-9f19773cebdc9476f2a3266d3c01c967c38fcd1e.zip |
Merging upstream version 16.7.7.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 9 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 10 | ||||
-rw-r--r-- | sqlglot/executor/python.py | 5 | ||||
-rw-r--r-- | sqlglot/expressions.py | 2 | ||||
-rw-r--r-- | sqlglot/generator.py | 9 | ||||
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py | 22 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 50 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 58 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 1 | ||||
-rw-r--r-- | sqlglot/parser.py | 18 | ||||
-rw-r--r-- | sqlglot/schema.py | 23 | ||||
-rw-r--r-- | sqlglot/transforms.py | 8 |
12 files changed, 146 insertions, 69 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 8786063..c9d6c79 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -96,13 +96,14 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: These are added by the optimizer's qualify_column step. """ + from sqlglot.optimizer.scope import Scope + if isinstance(expression, exp.Select): for unnest in expression.find_all(exp.Unnest): if isinstance(unnest.parent, (exp.From, exp.Join)) and unnest.alias: - for select in expression.selects: - for column in select.find_all(exp.Column): - if column.table == unnest.alias: - column.set("table", None) + for column in Scope(expression).find_all(exp.Column): + if column.table == unnest.alias: + column.set("table", None) return expression diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index a2dbfd9..34e4dd0 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -86,6 +86,10 @@ def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: ) +def _parse_datediff(args: t.List) -> exp.DateDiff: + return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) + + def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") @@ -214,15 +218,15 @@ class Snowflake(Dialect): "DATEADD": lambda args: exp.DateAdd( this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), - "DATEDIFF": lambda args: exp.DateDiff( - this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) - ), + "DATEDIFF": _parse_datediff, "DIV0": _div0_to_if, "IFF": exp.If.from_arg_list, "NULLIFZERO": _nullifzero_to_if, "OBJECT_CONSTRUCT": _parse_object_construct, "RLIKE": exp.RegexpLike.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), + "TIMEDIFF": _parse_datediff, + "TIMESTAMPDIFF": _parse_datediff, "TO_ARRAY": exp.Array.from_arg_list, "TO_VARCHAR": exp.ToChar.from_arg_list, "TO_TIMESTAMP": _snowflake_to_timestamp, diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index a927181..34a380e 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -19,7 +19,6 @@ class PythonExecutor: self.tables = tables or {} def execute(self, plan): - running = set() finished = set() queue = set(plan.leaves) contexts = {} @@ -34,7 +33,6 @@ class PythonExecutor: for name, table in contexts[dep].tables.items() } ) - running.add(node) if isinstance(node, planner.Scan): contexts[node] = self.scan(node, context) @@ -49,11 +47,10 @@ class PythonExecutor: else: raise NotImplementedError - running.remove(node) finished.add(node) for dep in node.dependents: - if dep not in running and all(d in contexts for d in dep.dependencies): + if all(d in contexts for d in dep.dependencies): queue.add(dep) for dep in node.dependencies: diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index e01cc1a..cdb93db 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -3922,7 +3922,7 @@ class Avg(AggFunc): class AnyValue(AggFunc): - pass + arg_types = {"this": True, "having": False, "max": False} class Case(Func): diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 5d8a4ca..a41af12 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -2423,6 +2423,15 @@ class Generator: buckets = self.sql(expression, "buckets") return f"CLUSTERED BY ({expressions}){sorted_by} INTO {buckets} BUCKETS" + def anyvalue_sql(self, expression: exp.AnyValue) -> str: + this = self.sql(expression, "this") + having = self.sql(expression, "having") + + if having: + this = f"{this} HAVING {'MAX' if expression.args.get('max') else 'MIN'} {having}" + + return self.func("ANY_VALUE", this) + def cached_generator( cache: t.Optional[t.Dict[int, str]] = None diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index fefe96e..e156d5e 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -47,6 +47,17 @@ UNMERGABLE_ARGS = set(exp.Select.arg_types) - { } +# Projections in the outer query that are instances of these types can be replaced +# without getting wrapped in parentheses, because the precedence won't be altered. +SAFE_TO_REPLACE_UNWRAPPED = ( + exp.Column, + exp.EQ, + exp.Func, + exp.NEQ, + exp.Paren, +) + + def merge_ctes(expression, leave_tables_isolated=False): scopes = traverse_scope(expression) @@ -293,8 +304,17 @@ def _merge_expressions(outer_scope, inner_scope, alias): if not projection_name: continue columns_to_replace = outer_columns.get(projection_name, []) + + expression = expression.unalias() + must_wrap_expression = not isinstance(expression, SAFE_TO_REPLACE_UNWRAPPED) + for column in columns_to_replace: - column.replace(expression.unalias().copy()) + # Ensures we don't alter the intended operator precedence if there's additional + # context surrounding the outer expression (i.e. it's not a simple projection). + if isinstance(column.parent, (exp.Unary, exp.Binary)) and must_wrap_expression: + expression = exp.paren(expression, copy=False) + + column.replace(expression.copy()) def _merge_where(outer_scope, inner_scope, from_or_join): diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index ef8aeb1..8c3f599 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -170,9 +170,11 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: if not isinstance(expression, exp.Select): return - alias_to_expression: t.Dict[str, exp.Expression] = {} + alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} - def replace_columns(node: t.Optional[exp.Expression], resolve_table: bool = False) -> None: + def replace_columns( + node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False + ) -> None: if not node: return @@ -180,7 +182,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: if not isinstance(column, exp.Column): continue table = resolver.get_table(column.name) if resolve_table and not column.table else None - alias_expr = alias_to_expression.get(column.name) + alias_expr, i = alias_to_expression.get(column.name, (None, 1)) double_agg = ( (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) if alias_expr @@ -190,16 +192,20 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: if table and (not alias_expr or double_agg): column.set("table", table) elif not column.table and alias_expr and not double_agg: - column.replace(alias_expr.copy()) + if isinstance(alias_expr, exp.Literal): + if literal_index: + column.replace(exp.Literal.number(i)) + else: + column.replace(alias_expr.copy()) - for projection in scope.selects: + for i, projection in enumerate(scope.selects): replace_columns(projection) if isinstance(projection, exp.Alias): - alias_to_expression[projection.alias] = projection.this + alias_to_expression[projection.alias] = (projection.this, i + 1) replace_columns(expression.args.get("where")) - replace_columns(expression.args.get("group")) + replace_columns(expression.args.get("group"), literal_index=True) replace_columns(expression.args.get("having"), resolve_table=True) replace_columns(expression.args.get("qualify"), resolve_table=True) scope.clear_cache() @@ -255,27 +261,39 @@ def _expand_order_by(scope: Scope, resolver: Resolver): selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects} for ordered in ordereds: - ordered.set("this", selects.get(ordered.this, ordered.this)) + ordered = ordered.this + + ordered.replace( + exp.to_identifier(_select_by_pos(scope, ordered).alias) + if ordered.is_int + else selects.get(ordered, ordered) + ) def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]: new_nodes = [] for node in expressions: if node.is_int: - try: - select = scope.selects[int(node.name) - 1] - except IndexError: - raise OptimizeError(f"Unknown output column: {node.name}") - if isinstance(select, exp.Alias): - select = select.this - new_nodes.append(select.copy()) - scope.clear_cache() + select = _select_by_pos(scope, t.cast(exp.Literal, node)).this + + if isinstance(select, exp.Literal): + new_nodes.append(node) + else: + new_nodes.append(select.copy()) + scope.clear_cache() else: new_nodes.append(node) return new_nodes +def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: + try: + return scope.selects[int(node.this) - 1].assert_is(exp.Alias) + except IndexError: + raise OptimizeError(f"Unknown output column: {node.name}") + + def _qualify_columns(scope: Scope, resolver: Resolver) -> None: """Disambiguate columns, ensuring each column specifies a source""" for column in scope.columns: diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index aa56b83..bc649e4 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -1,4 +1,5 @@ import itertools +import logging import typing as t from collections import defaultdict from enum import Enum, auto @@ -7,6 +8,8 @@ from sqlglot import exp from sqlglot.errors import OptimizeError from sqlglot.helper import find_new_name +logger = logging.getLogger("sqlglot") + class ScopeType(Enum): ROOT = auto() @@ -85,6 +88,7 @@ class Scope: self._external_columns = None self._join_hints = None self._pivots = None + self._references = None def branch(self, expression, scope_type, chain_sources=None, **kwargs): """Branch from the current scope to a new, inner scope""" @@ -264,14 +268,19 @@ class Scope: self._columns = [] for column in columns + external_columns: ancestor = column.find_ancestor( - exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint + exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table ) if ( not ancestor or column.table or isinstance(ancestor, exp.Select) - or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window)) - or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) + or ( + isinstance(ancestor, exp.Order) + and ( + isinstance(ancestor.parent, exp.Window) + or column.name not in named_selects + ) + ) ): self._columns.append(column) @@ -289,15 +298,9 @@ class Scope: dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes """ if self._selected_sources is None: - referenced_names = [] - - for table in self.tables: - referenced_names.append((table.alias_or_name, table)) - for expression in itertools.chain(self.derived_tables, self.udtfs): - referenced_names.append((expression.alias, expression.unnest())) result = {} - for name, node in referenced_names: + for name, node in self.references: if name in result: raise OptimizeError(f"Alias already used: {name}") if name in self.sources: @@ -307,6 +310,23 @@ class Scope: return self._selected_sources @property + def references(self) -> t.List[t.Tuple[str, exp.Expression]]: + if self._references is None: + self._references = [] + + for table in self.tables: + self._references.append((table.alias_or_name, table)) + for expression in itertools.chain(self.derived_tables, self.udtfs): + self._references.append( + ( + expression.alias, + expression if expression.args.get("pivots") else expression.unnest(), + ) + ) + + return self._references + + @property def cte_sources(self): """ Sources that are CTEs. @@ -378,9 +398,7 @@ class Scope: def pivots(self): if not self._pivots: self._pivots = [ - pivot - for node in self.tables + self.derived_tables - for pivot in node.args.get("pivots") or [] + pivot for _, node in self.references for pivot in node.args.get("pivots") or [] ] return self._pivots @@ -536,7 +554,11 @@ def _traverse_scope(scope): elif isinstance(scope.expression, exp.UDTF): pass else: - raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}") + logger.warning( + "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) + ) + return + yield scope @@ -576,6 +598,8 @@ def _traverse_ctes(scope): if isinstance(union, exp.Union): recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) + child_scope = None + for child_scope in _traverse_scope( scope.branch( cte.this, @@ -593,7 +617,8 @@ def _traverse_ctes(scope): child_scope.add_source(alias, recursive_scope) # append the final child_scope yielded - scope.cte_scopes.append(child_scope) + if child_scope: + scope.cte_scopes.append(child_scope) scope.sources.update(sources) @@ -634,6 +659,9 @@ def _traverse_tables(scope): sources[source_name] = expression continue + if not isinstance(expression, exp.DerivedTable): + continue + if isinstance(expression, exp.UDTF): lateral_sources = sources scope_type = ScopeType.UDTF diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 34005d9..1a2d82c 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -400,6 +400,7 @@ def simplify_parens(expression): or not isinstance(this, exp.Binary) or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) + or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) ): return expression.this return expression diff --git a/sqlglot/parser.py b/sqlglot/parser.py index e5bd4ae..79e7cac 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -717,6 +717,7 @@ class Parser(metaclass=_Parser): FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"} FUNCTION_PARSERS: t.Dict[str, t.Callable] = { + "ANY_VALUE": lambda self: self._parse_any_value(), "CAST": lambda self: self._parse_cast(self.STRICT_CAST), "CONCAT": lambda self: self._parse_concat(), "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), @@ -3321,11 +3322,6 @@ class Parser(metaclass=_Parser): else: this = self._parse_select_or_expression(alias=alias) - if isinstance(this, exp.EQ): - left = this.this - if isinstance(left, exp.Column): - left.replace(exp.var(left.text("this"))) - return self._parse_limit(self._parse_order(self._parse_respect_or_ignore_nulls(this))) def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: @@ -3678,6 +3674,18 @@ class Parser(metaclass=_Parser): return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) + def _parse_any_value(self) -> exp.AnyValue: + this = self._parse_lambda() + is_max = None + having = None + + if self._match(TokenType.HAVING): + self._match_texts(("MAX", "MIN")) + is_max = self._prev.text == "MAX" + having = self._parse_column() + + return self.expression(exp.AnyValue, this=this, having=having, max=is_max) + def _parse_cast(self, strict: bool) -> exp.Expression: this = self._parse_conjunction() diff --git a/sqlglot/schema.py b/sqlglot/schema.py index b8560a1..12cf0b1 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -226,9 +226,8 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): column_mapping: a column mapping that describes the structure of the table. dialect: the SQL dialect that will be used to parse `table` if it's a string. """ - normalized_table = self._normalize_table( - self._ensure_table(table, dialect=dialect), dialect=dialect - ) + normalized_table = self._normalize_table(table, dialect=dialect) + normalized_column_mapping = { self._normalize_name(key, dialect=dialect): value for key, value in ensure_column_mapping(column_mapping).items() @@ -249,9 +248,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): only_visible: bool = False, dialect: DialectType = None, ) -> t.List[str]: - normalized_table = self._normalize_table( - self._ensure_table(table, dialect=dialect), dialect=dialect - ) + normalized_table = self._normalize_table(table, dialect=dialect) schema = self.find(normalized_table) if schema is None: @@ -269,9 +266,8 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): column: exp.Column, dialect: DialectType = None, ) -> exp.DataType: - normalized_table = self._normalize_table( - self._ensure_table(table, dialect=dialect), dialect=dialect - ) + normalized_table = self._normalize_table(table, dialect=dialect) + normalized_column_name = self._normalize_name( column if isinstance(column, str) else column.this, dialect=dialect ) @@ -316,8 +312,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): return normalized_mapping - def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table: - normalized_table = table.copy() + def _normalize_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table: + normalized_table = exp.maybe_parse( + table, into=exp.Table, dialect=dialect or self.dialect, copy=True + ) for arg in TABLE_ARGS: value = normalized_table.args.get(arg) @@ -351,9 +349,6 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): # The columns themselves are a mapping, but we don't want to include those return super()._depth() - 1 - def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table: - return exp.maybe_parse(table, into=exp.Table, dialect=dialect or self.dialect) - def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: """ Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 1f30f96..33a1bc0 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -157,14 +157,10 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression: def explode_to_unnest(expression: exp.Expression) -> exp.Expression: """Convert explode/posexplode into unnest (used in hive -> presto).""" if isinstance(expression, exp.Select): - from sqlglot.optimizer.scope import build_scope - - scope = build_scope(expression) - if not scope: - return expression + from sqlglot.optimizer.scope import Scope taken_select_names = set(expression.named_selects) - taken_source_names = set(scope.selected_sources) + taken_source_names = {name for name, _ in Scope(expression).references} for select in expression.selects: to_replace = select |