summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-10-15 13:53:00 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-10-15 13:53:00 +0000
commit684905e3de7854a3806ffa55e0d1a09431ba5a19 (patch)
tree127ebd7d051f15fb8f8cf36cfd04a8a65a4d9680 /sqlglot/optimizer
parentReleasing debian version 6.3.1-1. (diff)
downloadsqlglot-684905e3de7854a3806ffa55e0d1a09431ba5a19.tar.xz
sqlglot-684905e3de7854a3806ffa55e0d1a09431ba5a19.zip
Merging upstream version 7.1.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/eliminate_ctes.py42
-rw-r--r--sqlglot/optimizer/eliminate_joins.py160
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py2
-rw-r--r--sqlglot/optimizer/merge_subqueries.py18
-rw-r--r--sqlglot/optimizer/optimizer.py4
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py19
-rw-r--r--sqlglot/optimizer/scope.py26
7 files changed, 258 insertions, 13 deletions
diff --git a/sqlglot/optimizer/eliminate_ctes.py b/sqlglot/optimizer/eliminate_ctes.py
new file mode 100644
index 0000000..7b862c6
--- /dev/null
+++ b/sqlglot/optimizer/eliminate_ctes.py
@@ -0,0 +1,42 @@
+from sqlglot.optimizer.scope import Scope, build_scope
+
+
+def eliminate_ctes(expression):
+ """
+ Remove unused CTEs from an expression.
+
+ Example:
+ >>> import sqlglot
+ >>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z"
+ >>> expression = sqlglot.parse_one(sql)
+ >>> eliminate_ctes(expression).sql()
+ 'SELECT a FROM z'
+
+ Args:
+ expression (sqlglot.Expression): expression to optimize
+ Returns:
+ sqlglot.Expression: optimized expression
+ """
+ root = build_scope(expression)
+
+ ref_count = root.ref_count()
+
+ # Traverse the scope tree in reverse so we can remove chains of unused CTEs
+ for scope in reversed(list(root.traverse())):
+ if scope.is_cte:
+ count = ref_count[id(scope)]
+ if count <= 0:
+ cte_node = scope.expression.parent
+ with_node = cte_node.parent
+ cte_node.pop()
+
+ # Pop the entire WITH clause if this is the last CTE
+ if len(with_node.expressions) <= 0:
+ with_node.pop()
+
+ # Decrement the ref count for all sources this CTE selects from
+ for _, source in scope.selected_sources.values():
+ if isinstance(source, Scope):
+ ref_count[id(source)] -= 1
+
+ return expression
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
new file mode 100644
index 0000000..0854336
--- /dev/null
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -0,0 +1,160 @@
+from sqlglot import expressions as exp
+from sqlglot.optimizer.normalize import normalized
+from sqlglot.optimizer.scope import Scope, traverse_scope
+from sqlglot.optimizer.simplify import simplify
+
+
+def eliminate_joins(expression):
+ """
+ Remove unused joins from an expression.
+
+ This only removes joins when we know that the join condition doesn't produce duplicate rows.
+
+ Example:
+ >>> import sqlglot
+ >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
+ >>> expression = sqlglot.parse_one(sql)
+ >>> eliminate_joins(expression).sql()
+ 'SELECT x.a FROM x'
+
+ Args:
+ expression (sqlglot.Expression): expression to optimize
+ Returns:
+ sqlglot.Expression: optimized expression
+ """
+ for scope in traverse_scope(expression):
+ # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used.
+ # It's probably possible to infer this from the outputs of derived tables.
+ # But for now, let's just skip this rule.
+ if scope.unqualified_columns:
+ continue
+
+ joins = scope.expression.args.get("joins", [])
+
+ # Reverse the joins so we can remove chains of unused joins
+ for join in reversed(joins):
+ alias = join.this.alias_or_name
+ if _should_eliminate_join(scope, join, alias):
+ join.pop()
+ scope.remove_source(alias)
+ return expression
+
+
+def _should_eliminate_join(scope, join, alias):
+ inner_source = scope.sources.get(alias)
+ return (
+ isinstance(inner_source, Scope)
+ and not _join_is_used(scope, join, alias)
+ and (
+ (join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join))
+ or (not join.args.get("on") and _has_single_output_row(inner_source))
+ )
+ )
+
+
+def _join_is_used(scope, join, alias):
+ # We need to find all columns that reference this join.
+ # But columns in the ON clause shouldn't count.
+ on = join.args.get("on")
+ if on:
+ on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
+ else:
+ on_clause_columns = set()
+ return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns)
+
+
+def _is_joined_on_all_unique_outputs(scope, join):
+ unique_outputs = _unique_outputs(scope)
+ if not unique_outputs:
+ return False
+
+ _, join_keys, _ = join_condition(join)
+ remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys)
+ return not remaining_unique_outputs
+
+
+def _unique_outputs(scope):
+ """Determine output columns of `scope` that must have a unique combination per row"""
+ if scope.expression.args.get("distinct"):
+ return set(scope.expression.named_selects)
+
+ group = scope.expression.args.get("group")
+ if group:
+ grouped_expressions = set(group.expressions)
+ grouped_outputs = set()
+
+ unique_outputs = set()
+ for select in scope.selects:
+ output = select.unalias()
+ if output in grouped_expressions:
+ grouped_outputs.add(output)
+ unique_outputs.add(select.alias_or_name)
+
+ # All the grouped expressions must be in the output
+ if not grouped_expressions.difference(grouped_outputs):
+ return unique_outputs
+ else:
+ return set()
+
+ if _has_single_output_row(scope):
+ return set(scope.expression.named_selects)
+
+ return set()
+
+
+def _has_single_output_row(scope):
+ return isinstance(scope.expression, exp.Select) and (
+ all(isinstance(e.unalias(), exp.AggFunc) for e in scope.selects)
+ or _is_limit_1(scope)
+ or not scope.expression.args.get("from")
+ )
+
+
+def _is_limit_1(scope):
+ limit = scope.expression.args.get("limit")
+ return limit and limit.expression.this == "1"
+
+
+def join_condition(join):
+ """
+ Extract the join condition from a join expression.
+
+ Args:
+ join (exp.Join)
+ Returns:
+ tuple[list[str], list[str], exp.Expression]:
+ Tuple of (source key, join key, remaining predicate)
+ """
+ name = join.this.alias_or_name
+ on = join.args.get("on") or exp.TRUE
+ on = on.copy()
+ source_key = []
+ join_key = []
+
+ # find the join keys
+ # SELECT
+ # FROM x
+ # JOIN y
+ # ON x.a = y.b AND y.b > 1
+ #
+ # should pull y.b as the join key and x.a as the source key
+ if normalized(on):
+ for condition in on.flatten() if isinstance(on, exp.And) else [on]:
+ if isinstance(condition, exp.EQ):
+ left, right = condition.unnest_operands()
+ left_tables = exp.column_table_names(left)
+ right_tables = exp.column_table_names(right)
+
+ if name in left_tables and name not in right_tables:
+ join_key.append(left)
+ source_key.append(right)
+ condition.replace(exp.TRUE)
+ elif name in right_tables and name not in left_tables:
+ join_key.append(right)
+ source_key.append(left)
+ condition.replace(exp.TRUE)
+
+ on = simplify(on)
+ remaining_condition = None if on == exp.TRUE else on
+
+ return source_key, join_key, remaining_condition
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 38e1299..44cdc94 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -8,7 +8,7 @@ from sqlglot.optimizer.simplify import simplify
def eliminate_subqueries(expression):
"""
- Rewrite subqueries as CTES, deduplicating if possible.
+ Rewrite derived tables as CTES, deduplicating if possible.
Example:
>>> import sqlglot
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 3e435f5..3c51c18 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -119,6 +119,23 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
Returns:
bool: True if can be merged
"""
+
+ def _is_a_window_expression_in_unmergable_operation():
+ window_expressions = inner_select.find_all(exp.Window)
+ window_alias_names = {window.parent.alias_or_name for window in window_expressions}
+ inner_select_name = inner_select.parent.alias_or_name
+ unmergable_window_columns = [
+ column
+ for column in outer_scope.columns
+ if column.find_ancestor(exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc)
+ ]
+ window_expressions_in_unmergable = [
+ column
+ for column in unmergable_window_columns
+ if column.table == inner_select_name and column.name in window_alias_names
+ ]
+ return any(window_expressions_in_unmergable)
+
return (
isinstance(outer_scope.expression, exp.Select)
and isinstance(inner_select, exp.Select)
@@ -137,6 +154,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
and inner_select.args.get("where")
and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []))
)
+ and not _is_a_window_expression_in_unmergable_operation()
)
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index 9a09327..2c28ab8 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -1,3 +1,5 @@
+from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
+from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
@@ -23,6 +25,8 @@ RULES = (
optimize_joins,
eliminate_subqueries,
merge_subqueries,
+ eliminate_joins,
+ eliminate_ctes,
quote_identities,
)
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index 9c8d71d..583d059 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -1,8 +1,6 @@
-from collections import defaultdict
-
from sqlglot import exp
from sqlglot.optimizer.normalize import normalized
-from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.optimizer.scope import build_scope
from sqlglot.optimizer.simplify import simplify
@@ -22,15 +20,10 @@ def pushdown_predicates(expression):
Returns:
sqlglot.Expression: optimized expression
"""
- scope_ref_count = defaultdict(lambda: 0)
- scopes = traverse_scope(expression)
- scopes.reverse()
-
- for scope in scopes:
- for _, source in scope.selected_sources.values():
- scope_ref_count[id(source)] += 1
+ root = build_scope(expression)
+ scope_ref_count = root.ref_count()
- for scope in scopes:
+ for scope in reversed(list(root.traverse())):
select = scope.expression
where = select.args.get("where")
if where:
@@ -152,9 +145,11 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
return {}
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:
+ # We can't push down window expressions
+ has_window_expression = any(select for select in node.selects if select.find(exp.Window))
# we can't push down predicates to select statements if they are referenced in
# multiple places.
- if not node.args.get("group") and scope_ref_count[id(source)] < 2:
+ if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression:
nodes[table] = node
return nodes
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 89de517..68298a0 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -1,4 +1,5 @@
import itertools
+from collections import defaultdict
from enum import Enum, auto
from sqlglot import exp
@@ -315,6 +316,16 @@ class Scope:
return self._external_columns
@property
+ def unqualified_columns(self):
+ """
+ Unqualified columns in the current scope.
+
+ Returns:
+ list[exp.Column]: Unqualified columns
+ """
+ return [c for c in self.columns if not c.table]
+
+ @property
def join_hints(self):
"""
Hints that exist in the scope that reference tables
@@ -403,6 +414,21 @@ class Scope:
yield from child_scope.traverse()
yield self
+ def ref_count(self):
+ """
+ Count the number of times each scope in this tree is referenced.
+
+ Returns:
+ dict[int, int]: Mapping of Scope instance ID to reference count
+ """
+ scope_ref_count = defaultdict(lambda: 0)
+
+ for scope in self.traverse():
+ for _, source in scope.selected_sources.values():
+ scope_ref_count[id(source)] += 1
+
+ return scope_ref_count
+
def traverse_scope(expression):
"""