diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-01-04 07:24:08 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-01-04 07:24:08 +0000 |
commit | 7a2201963d5b03bd1828d350ccaecb4eda30d30c (patch) | |
tree | 19effbe90b8d78fdcb5f7d4bd0dd46b177ffdaab /sqlglot/optimizer | |
parent | Releasing debian version 10.2.9-1. (diff) | |
download | sqlglot-7a2201963d5b03bd1828d350ccaecb4eda30d30c.tar.xz sqlglot-7a2201963d5b03bd1828d350ccaecb4eda30d30c.zip |
Merging upstream version 10.4.2.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/canonicalize.py | 3 | ||||
-rw-r--r-- | sqlglot/optimizer/eliminate_joins.py | 49 | ||||
-rw-r--r-- | sqlglot/optimizer/eliminate_subqueries.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/normalize.py | 5 | ||||
-rw-r--r-- | sqlglot/optimizer/optimizer.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_predicates.py | 14 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_projections.py | 6 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 3 | ||||
-rw-r--r-- | sqlglot/optimizer/quote_identities.py | 25 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 30 | ||||
-rw-r--r-- | sqlglot/optimizer/unnest_subqueries.py | 27 |
11 files changed, 108 insertions, 60 deletions
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index 33529a5..fc37a54 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -18,6 +18,9 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: expression = coerce_type(expression) expression = remove_redundant_casts(expression) + if isinstance(expression, exp.Identifier): + expression.set("quoted", True) + return expression diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index de4e011..3b40710 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -129,10 +129,23 @@ def join_condition(join): """ name = join.this.alias_or_name on = (join.args.get("on") or exp.true()).copy() - on = on if isinstance(on, exp.And) else exp.and_(on, exp.true()) source_key = [] join_key = [] + def extract_condition(condition): + left, right = condition.unnest_operands() + left_tables = exp.column_table_names(left) + right_tables = exp.column_table_names(right) + + if name in left_tables and name not in right_tables: + join_key.append(left) + source_key.append(right) + condition.replace(exp.true()) + elif name in right_tables and name not in left_tables: + join_key.append(right) + source_key.append(left) + condition.replace(exp.true()) + # find the join keys # SELECT # FROM x @@ -141,20 +154,30 @@ def join_condition(join): # # should pull y.b as the join key and x.a as the source key if normalized(on): + on = on if isinstance(on, exp.And) else exp.and_(on, exp.true()) + for condition in on.flatten(): if isinstance(condition, exp.EQ): - left, right = condition.unnest_operands() - left_tables = exp.column_table_names(left) - right_tables = exp.column_table_names(right) - - if name in left_tables and name not in right_tables: - join_key.append(left) - source_key.append(right) - condition.replace(exp.true()) - elif name in right_tables and name not in left_tables: - join_key.append(right) - source_key.append(left) - condition.replace(exp.true()) + extract_condition(condition) + elif normalized(on, dnf=True): + conditions = None + + for condition in on.flatten(): + parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)] + if conditions is None: + conditions = parts + else: + temp = [] + for p in parts: + cs = [c for c in conditions if p == c] + + if cs: + temp.append(p) + temp.extend(cs) + conditions = temp + + for condition in conditions: + extract_condition(condition) on = simplify(on) remaining_condition = None if on == exp.true() else on diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 39e252c..2245cc2 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -58,7 +58,9 @@ def eliminate_subqueries(expression): existing_ctes = {} with_ = root.expression.args.get("with") + recursive = False if with_: + recursive = with_.args.get("recursive") for cte in with_.expressions: existing_ctes[cte.this] = cte.alias new_ctes = [] @@ -88,7 +90,7 @@ def eliminate_subqueries(expression): new_ctes.append(new_cte) if new_ctes: - expression.set("with", exp.With(expressions=new_ctes)) + expression.set("with", exp.With(expressions=new_ctes, recursive=recursive)) return expression diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index db538ef..f16f519 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -69,8 +69,9 @@ def _predicate_lengths(expression, dnf): left, right = expression.args.values() if isinstance(expression, exp.And if dnf else exp.Or): - x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)] - return x + return [ + a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf) + ] return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 6819717..72e67d4 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -14,7 +14,6 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_projections import pushdown_projections from sqlglot.optimizer.qualify_columns import qualify_columns from sqlglot.optimizer.qualify_tables import qualify_tables -from sqlglot.optimizer.quote_identities import quote_identities from sqlglot.optimizer.unnest_subqueries import unnest_subqueries RULES = ( @@ -34,7 +33,6 @@ RULES = ( eliminate_ctes, annotate_types, canonicalize, - quote_identities, ) diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index f92e5c3..ba5c8b5 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -27,7 +27,14 @@ def pushdown_predicates(expression): select = scope.expression where = select.args.get("where") if where: - pushdown(where.this, scope.selected_sources, scope_ref_count) + selected_sources = scope.selected_sources + # a right join can only push down to itself and not the source FROM table + for k, (node, source) in selected_sources.items(): + parent = node.find_ancestor(exp.Join, exp.From) + if isinstance(parent, exp.Join) and parent.side == "RIGHT": + selected_sources = {k: (node, source)} + break + pushdown(where.this, selected_sources, scope_ref_count) # joins should only pushdown into itself, not to other joins # so we limit the selected sources to only itself @@ -148,10 +155,13 @@ def nodes_for_predicate(predicate, sources, scope_ref_count): # a node can reference a CTE which should be pushed down if isinstance(node, exp.From) and not isinstance(source, exp.Table): + with_ = source.parent.expression.args.get("with") + if with_ and with_.recursive: + return {} node = source.expression if isinstance(node, exp.Join): - if node.side: + if node.side and node.side != "RIGHT": return {} nodes[table] = node elif isinstance(node, exp.Select) and len(tables) == 1: diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index abd9492..49789ac 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -6,7 +6,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope # Sentinel value that means an outer query selecting ALL columns SELECT_ALL = object() -# SELECTION TO USE IF SELECTION LIST IS EMPTY +# Selection to use if selection list is empty DEFAULT_SELECTION = alias("1", "_") @@ -91,7 +91,7 @@ def _remove_unused_selections(scope, parent_selections): # If there are no remaining selections, just select a single constant if not new_selections: - new_selections.append(DEFAULT_SELECTION) + new_selections.append(DEFAULT_SELECTION.copy()) scope.expression.set("expressions", new_selections) return removed_indexes @@ -102,5 +102,5 @@ def _remove_indexed_selections(scope, indexes_to_remove): selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove ] if not new_selections: - new_selections.append(DEFAULT_SELECTION) + new_selections.append(DEFAULT_SELECTION.copy()) scope.expression.set("expressions", new_selections) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index e6e6dc9..e16a635 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -311,6 +311,9 @@ def _qualify_outputs(scope): alias_ = alias(exp.column(""), alias=selection.name) alias_.set("this", selection) selection = alias_ + elif isinstance(selection, exp.Subquery): + if not selection.alias: + selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) elif not isinstance(selection, exp.Alias): alias_ = alias(exp.column(""), f"_col_{i}") alias_.set("this", selection) diff --git a/sqlglot/optimizer/quote_identities.py b/sqlglot/optimizer/quote_identities.py deleted file mode 100644 index 17623cc..0000000 --- a/sqlglot/optimizer/quote_identities.py +++ /dev/null @@ -1,25 +0,0 @@ -from sqlglot import exp - - -def quote_identities(expression): - """ - Rewrite sqlglot AST to ensure all identities are quoted. - - Example: - >>> import sqlglot - >>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x") - >>> quote_identities(expression).sql() - 'SELECT "x"."a" AS "a" FROM "db"."x"' - - Args: - expression (sqlglot.Expression): expression to quote - Returns: - sqlglot.Expression: quoted expression - """ - - def qualify(node): - if isinstance(node, exp.Identifier): - node.set("quoted", True) - return node - - return expression.transform(qualify, copy=False) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 18848f3..6125e4e 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -511,9 +511,20 @@ def _traverse_union(scope): def _traverse_derived_tables(derived_tables, scope, scope_type): sources = {} + is_cte = scope_type == ScopeType.CTE for derived_table in derived_tables: - top = None + recursive_scope = None + + # if the scope is a recursive cte, it must be in the form of + # base_case UNION recursive. thus the recursive scope is the first + # section of the union. + if is_cte and scope.expression.args["with"].recursive: + union = derived_table.this + + if isinstance(union, exp.Union): + recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) + for child_scope in _traverse_scope( scope.branch( derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this, @@ -523,16 +534,23 @@ def _traverse_derived_tables(derived_tables, scope, scope_type): ) ): yield child_scope - top = child_scope + # Tables without aliases will be set as "" # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. # Until then, this means that only a single, unaliased derived table is allowed (rather, # the latest one wins. - sources[derived_table.alias] = child_scope - if scope_type == ScopeType.CTE: - scope.cte_scopes.append(top) + alias = derived_table.alias + sources[alias] = child_scope + + if recursive_scope: + child_scope.add_source(alias, recursive_scope) + + # append the final child_scope yielded + if is_cte: + scope.cte_scopes.append(child_scope) else: - scope.derived_table_scopes.append(top) + scope.derived_table_scopes.append(child_scope) + scope.sources.update(sources) diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 2046917..8d78294 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -16,7 +16,7 @@ def unnest_subqueries(expression): >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\ - AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)' + AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)' Args: expression (sqlglot.Expression): expression to unnest @@ -97,8 +97,8 @@ def decorrelate(select, parent_select, external_columns, sequence): table_alias = _alias(sequence) keys = [] - # for all external columns in the where statement, - # split out the relevant data to convert it into a join + # for all external columns in the where statement, find the relevant predicate + # keys to convert it into a join for column in external_columns: if column.find_ancestor(exp.Where) is not where: return @@ -122,6 +122,10 @@ def decorrelate(select, parent_select, external_columns, sequence): if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): return + is_subquery_projection = any( + node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) + ) + value = select.selects[0] key_aliases = {} group_by = [] @@ -142,9 +146,14 @@ def decorrelate(select, parent_select, external_columns, sequence): parent_predicate = select.find_ancestor(exp.Predicate) # if the value of the subquery is not an agg or a key, we need to collect it into an array - # so that it can be grouped + # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. + agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg if not value.find(exp.AggFunc) and value.this not in group_by: - select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False) + select.select( + exp.alias_(agg_func(this=value.this), value.alias, quoted=False), + append=False, + copy=False, + ) # exists queries should not have any selects as it only checks if there are any rows # all selects will be added by the optimizer and only used for join keys @@ -158,7 +167,7 @@ def decorrelate(select, parent_select, external_columns, sequence): if isinstance(parent_predicate, exp.Exists) or key != value.this: select.select(f"{key} AS {alias}", copy=False) else: - select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False) + select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) alias = exp.column(value.alias, table_alias) other = _other_operand(parent_predicate) @@ -186,12 +195,18 @@ def decorrelate(select, parent_select, external_columns, sequence): f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", ) else: + if is_subquery_projection: + alias = exp.alias_(alias, select.parent.alias) select.parent.replace(alias) for key, column, predicate in keys: predicate.replace(exp.true()) nested = exp.column(key_aliases[key], table_alias) + if is_subquery_projection: + key.replace(nested) + continue + if key in group_by: key.replace(nested) parent_predicate = _replace( |