From 684905e3de7854a3806ffa55e0d1a09431ba5a19 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 15 Oct 2022 15:53:00 +0200 Subject: Merging upstream version 7.1.3. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/eliminate_ctes.py | 42 ++++++++ sqlglot/optimizer/eliminate_joins.py | 160 ++++++++++++++++++++++++++++++ sqlglot/optimizer/eliminate_subqueries.py | 2 +- sqlglot/optimizer/merge_subqueries.py | 18 ++++ sqlglot/optimizer/optimizer.py | 4 + sqlglot/optimizer/pushdown_predicates.py | 19 ++-- sqlglot/optimizer/scope.py | 26 +++++ 7 files changed, 258 insertions(+), 13 deletions(-) create mode 100644 sqlglot/optimizer/eliminate_ctes.py create mode 100644 sqlglot/optimizer/eliminate_joins.py (limited to 'sqlglot/optimizer') diff --git a/sqlglot/optimizer/eliminate_ctes.py b/sqlglot/optimizer/eliminate_ctes.py new file mode 100644 index 0000000..7b862c6 --- /dev/null +++ b/sqlglot/optimizer/eliminate_ctes.py @@ -0,0 +1,42 @@ +from sqlglot.optimizer.scope import Scope, build_scope + + +def eliminate_ctes(expression): + """ + Remove unused CTEs from an expression. + + Example: + >>> import sqlglot + >>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z" + >>> expression = sqlglot.parse_one(sql) + >>> eliminate_ctes(expression).sql() + 'SELECT a FROM z' + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + root = build_scope(expression) + + ref_count = root.ref_count() + + # Traverse the scope tree in reverse so we can remove chains of unused CTEs + for scope in reversed(list(root.traverse())): + if scope.is_cte: + count = ref_count[id(scope)] + if count <= 0: + cte_node = scope.expression.parent + with_node = cte_node.parent + cte_node.pop() + + # Pop the entire WITH clause if this is the last CTE + if len(with_node.expressions) <= 0: + with_node.pop() + + # Decrement the ref count for all sources this CTE selects from + for _, source in scope.selected_sources.values(): + if isinstance(source, Scope): + ref_count[id(source)] -= 1 + + return expression diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py new file mode 100644 index 0000000..0854336 --- /dev/null +++ b/sqlglot/optimizer/eliminate_joins.py @@ -0,0 +1,160 @@ +from sqlglot import expressions as exp +from sqlglot.optimizer.normalize import normalized +from sqlglot.optimizer.scope import Scope, traverse_scope +from sqlglot.optimizer.simplify import simplify + + +def eliminate_joins(expression): + """ + Remove unused joins from an expression. + + This only removes joins when we know that the join condition doesn't produce duplicate rows. + + Example: + >>> import sqlglot + >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" + >>> expression = sqlglot.parse_one(sql) + >>> eliminate_joins(expression).sql() + 'SELECT x.a FROM x' + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + for scope in traverse_scope(expression): + # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used. + # It's probably possible to infer this from the outputs of derived tables. + # But for now, let's just skip this rule. + if scope.unqualified_columns: + continue + + joins = scope.expression.args.get("joins", []) + + # Reverse the joins so we can remove chains of unused joins + for join in reversed(joins): + alias = join.this.alias_or_name + if _should_eliminate_join(scope, join, alias): + join.pop() + scope.remove_source(alias) + return expression + + +def _should_eliminate_join(scope, join, alias): + inner_source = scope.sources.get(alias) + return ( + isinstance(inner_source, Scope) + and not _join_is_used(scope, join, alias) + and ( + (join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join)) + or (not join.args.get("on") and _has_single_output_row(inner_source)) + ) + ) + + +def _join_is_used(scope, join, alias): + # We need to find all columns that reference this join. + # But columns in the ON clause shouldn't count. + on = join.args.get("on") + if on: + on_clause_columns = set(id(column) for column in on.find_all(exp.Column)) + else: + on_clause_columns = set() + return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns) + + +def _is_joined_on_all_unique_outputs(scope, join): + unique_outputs = _unique_outputs(scope) + if not unique_outputs: + return False + + _, join_keys, _ = join_condition(join) + remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys) + return not remaining_unique_outputs + + +def _unique_outputs(scope): + """Determine output columns of `scope` that must have a unique combination per row""" + if scope.expression.args.get("distinct"): + return set(scope.expression.named_selects) + + group = scope.expression.args.get("group") + if group: + grouped_expressions = set(group.expressions) + grouped_outputs = set() + + unique_outputs = set() + for select in scope.selects: + output = select.unalias() + if output in grouped_expressions: + grouped_outputs.add(output) + unique_outputs.add(select.alias_or_name) + + # All the grouped expressions must be in the output + if not grouped_expressions.difference(grouped_outputs): + return unique_outputs + else: + return set() + + if _has_single_output_row(scope): + return set(scope.expression.named_selects) + + return set() + + +def _has_single_output_row(scope): + return isinstance(scope.expression, exp.Select) and ( + all(isinstance(e.unalias(), exp.AggFunc) for e in scope.selects) + or _is_limit_1(scope) + or not scope.expression.args.get("from") + ) + + +def _is_limit_1(scope): + limit = scope.expression.args.get("limit") + return limit and limit.expression.this == "1" + + +def join_condition(join): + """ + Extract the join condition from a join expression. + + Args: + join (exp.Join) + Returns: + tuple[list[str], list[str], exp.Expression]: + Tuple of (source key, join key, remaining predicate) + """ + name = join.this.alias_or_name + on = join.args.get("on") or exp.TRUE + on = on.copy() + source_key = [] + join_key = [] + + # find the join keys + # SELECT + # FROM x + # JOIN y + # ON x.a = y.b AND y.b > 1 + # + # should pull y.b as the join key and x.a as the source key + if normalized(on): + for condition in on.flatten() if isinstance(on, exp.And) else [on]: + if isinstance(condition, exp.EQ): + left, right = condition.unnest_operands() + left_tables = exp.column_table_names(left) + right_tables = exp.column_table_names(right) + + if name in left_tables and name not in right_tables: + join_key.append(left) + source_key.append(right) + condition.replace(exp.TRUE) + elif name in right_tables and name not in left_tables: + join_key.append(right) + source_key.append(left) + condition.replace(exp.TRUE) + + on = simplify(on) + remaining_condition = None if on == exp.TRUE else on + + return source_key, join_key, remaining_condition diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 38e1299..44cdc94 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -8,7 +8,7 @@ from sqlglot.optimizer.simplify import simplify def eliminate_subqueries(expression): """ - Rewrite subqueries as CTES, deduplicating if possible. + Rewrite derived tables as CTES, deduplicating if possible. Example: >>> import sqlglot diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 3e435f5..3c51c18 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -119,6 +119,23 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): Returns: bool: True if can be merged """ + + def _is_a_window_expression_in_unmergable_operation(): + window_expressions = inner_select.find_all(exp.Window) + window_alias_names = {window.parent.alias_or_name for window in window_expressions} + inner_select_name = inner_select.parent.alias_or_name + unmergable_window_columns = [ + column + for column in outer_scope.columns + if column.find_ancestor(exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc) + ] + window_expressions_in_unmergable = [ + column + for column in unmergable_window_columns + if column.table == inner_select_name and column.name in window_alias_names + ] + return any(window_expressions_in_unmergable) + return ( isinstance(outer_scope.expression, exp.Select) and isinstance(inner_select, exp.Select) @@ -137,6 +154,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): and inner_select.args.get("where") and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])) ) + and not _is_a_window_expression_in_unmergable_operation() ) diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 9a09327..2c28ab8 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -1,3 +1,5 @@ +from sqlglot.optimizer.eliminate_ctes import eliminate_ctes +from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects @@ -23,6 +25,8 @@ RULES = ( optimize_joins, eliminate_subqueries, merge_subqueries, + eliminate_joins, + eliminate_ctes, quote_identities, ) diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index 9c8d71d..583d059 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -1,8 +1,6 @@ -from collections import defaultdict - from sqlglot import exp from sqlglot.optimizer.normalize import normalized -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.scope import build_scope from sqlglot.optimizer.simplify import simplify @@ -22,15 +20,10 @@ def pushdown_predicates(expression): Returns: sqlglot.Expression: optimized expression """ - scope_ref_count = defaultdict(lambda: 0) - scopes = traverse_scope(expression) - scopes.reverse() - - for scope in scopes: - for _, source in scope.selected_sources.values(): - scope_ref_count[id(source)] += 1 + root = build_scope(expression) + scope_ref_count = root.ref_count() - for scope in scopes: + for scope in reversed(list(root.traverse())): select = scope.expression where = select.args.get("where") if where: @@ -152,9 +145,11 @@ def nodes_for_predicate(predicate, sources, scope_ref_count): return {} nodes[table] = node elif isinstance(node, exp.Select) and len(tables) == 1: + # We can't push down window expressions + has_window_expression = any(select for select in node.selects if select.find(exp.Window)) # we can't push down predicates to select statements if they are referenced in # multiple places. - if not node.args.get("group") and scope_ref_count[id(source)] < 2: + if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression: nodes[table] = node return nodes diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 89de517..68298a0 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -1,4 +1,5 @@ import itertools +from collections import defaultdict from enum import Enum, auto from sqlglot import exp @@ -314,6 +315,16 @@ class Scope: self._external_columns = [c for c in self.columns if c.table not in self.selected_sources] return self._external_columns + @property + def unqualified_columns(self): + """ + Unqualified columns in the current scope. + + Returns: + list[exp.Column]: Unqualified columns + """ + return [c for c in self.columns if not c.table] + @property def join_hints(self): """ @@ -403,6 +414,21 @@ class Scope: yield from child_scope.traverse() yield self + def ref_count(self): + """ + Count the number of times each scope in this tree is referenced. + + Returns: + dict[int, int]: Mapping of Scope instance ID to reference count + """ + scope_ref_count = defaultdict(lambda: 0) + + for scope in self.traverse(): + for _, source in scope.selected_sources.values(): + scope_ref_count[id(source)] += 1 + + return scope_ref_count + def traverse_scope(expression): """ -- cgit v1.2.3