summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py33
-rw-r--r--sqlglot/optimizer/merge_subqueries.py17
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py13
-rw-r--r--sqlglot/optimizer/scope.py39
-rw-r--r--sqlglot/optimizer/simplify.py153
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py2
6 files changed, 140 insertions, 117 deletions
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 1ab7768..1230cea 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -95,9 +95,6 @@ def eliminate_subqueries(expression):
def _eliminate(scope, existing_ctes, taken):
- if scope.is_union:
- return _eliminate_union(scope, existing_ctes, taken)
-
if scope.is_derived_table:
return _eliminate_derived_table(scope, existing_ctes, taken)
@@ -105,36 +102,6 @@ def _eliminate(scope, existing_ctes, taken):
return _eliminate_cte(scope, existing_ctes, taken)
-def _eliminate_union(scope, existing_ctes, taken):
- duplicate_cte_alias = existing_ctes.get(scope.expression)
-
- alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte")
-
- taken[alias] = scope
-
- # Try to maintain the selections
- expressions = scope.expression.selects
- selects = [
- exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False)
- for e in expressions
- if e.alias_or_name
- ]
- # If not all selections have an alias, just select *
- if len(selects) != len(expressions):
- selects = ["*"]
-
- scope.expression.replace(
- exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias, copy=False))
- )
-
- if not duplicate_cte_alias:
- existing_ctes[scope.expression] = alias
- return exp.CTE(
- this=scope.expression,
- alias=exp.TableAlias(this=exp.to_identifier(alias)),
- )
-
-
def _eliminate_derived_table(scope, existing_ctes, taken):
# This makes sure that we don't:
# - drop the "pivot" arg from a pivoted subquery
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index a74bea7..ea148cc 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -174,6 +174,22 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
for col in inner_projections[selection].find_all(exp.Column)
)
+ def _is_recursive():
+ # Recursive CTEs look like this:
+ # WITH RECURSIVE cte AS (
+ # SELECT * FROM x <-- inner scope
+ # UNION ALL
+ # SELECT * FROM cte <-- outer scope
+ # )
+ cte = inner_scope.expression.parent
+ node = outer_scope.expression.parent
+
+ while node:
+ if node is cte:
+ return True
+ node = node.parent
+ return False
+
return (
isinstance(outer_scope.expression, exp.Select)
and not outer_scope.expression.is_star
@@ -197,6 +213,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
)
and not _outer_select_joins_on_inner_select_join()
and not _is_a_window_expression_in_unmergable_operation()
+ and not _is_recursive()
)
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index f7348b5..10ff13a 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -4,7 +4,7 @@ from sqlglot.optimizer.scope import build_scope, find_in_scope
from sqlglot.optimizer.simplify import simplify
-def pushdown_predicates(expression):
+def pushdown_predicates(expression, dialect=None):
"""
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
@@ -36,7 +36,7 @@ def pushdown_predicates(expression):
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
selected_sources = {k: (node, source)}
break
- pushdown(where.this, selected_sources, scope_ref_count)
+ pushdown(where.this, selected_sources, scope_ref_count, dialect)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
@@ -44,17 +44,20 @@ def pushdown_predicates(expression):
name = join.alias_or_name
if name in scope.selected_sources:
pushdown(
- join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count
+ join.args.get("on"),
+ {name: scope.selected_sources[name]},
+ scope_ref_count,
+ dialect,
)
return expression
-def pushdown(condition, sources, scope_ref_count):
+def pushdown(condition, sources, scope_ref_count, dialect):
if not condition:
return
- condition = condition.replace(simplify(condition))
+ condition = condition.replace(simplify(condition, dialect=dialect))
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
predicates = list(
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index b7e527e..d34857d 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -37,6 +37,7 @@ class Scope:
For example:
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
The LATERAL VIEW EXPLODE gets x as a source.
+ cte_sources (dict[str, Scope]): Sources from CTES
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
@@ -61,11 +62,14 @@ class Scope:
parent=None,
scope_type=ScopeType.ROOT,
lateral_sources=None,
+ cte_sources=None,
):
self.expression = expression
self.sources = sources or {}
- self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
+ self.lateral_sources = lateral_sources or {}
+ self.cte_sources = cte_sources or {}
self.sources.update(self.lateral_sources)
+ self.sources.update(self.cte_sources)
self.outer_column_list = outer_column_list or []
self.parent = parent
self.scope_type = scope_type
@@ -92,13 +96,17 @@ class Scope:
self._pivots = None
self._references = None
- def branch(self, expression, scope_type, chain_sources=None, **kwargs):
+ def branch(
+ self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
+ ):
"""Branch from the current scope to a new, inner scope"""
return Scope(
expression=expression.unnest(),
- sources={**self.cte_sources, **(chain_sources or {})},
+ sources=sources.copy() if sources else None,
parent=self,
scope_type=scope_type,
+ cte_sources={**self.cte_sources, **(cte_sources or {})},
+ lateral_sources=lateral_sources.copy() if lateral_sources else None,
**kwargs,
)
@@ -306,20 +314,6 @@ class Scope:
return self._references
@property
- def cte_sources(self):
- """
- Sources that are CTEs.
-
- Returns:
- dict[str, Scope]: Mapping of source alias to Scope
- """
- return {
- alias: scope
- for alias, scope in self.sources.items()
- if isinstance(scope, Scope) and scope.is_cte
- }
-
- @property
def external_columns(self):
"""
Columns that appear to reference sources in outer scopes.
@@ -515,7 +509,10 @@ def _traverse_scope(scope):
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.Subquery):
- yield from _traverse_subqueries(scope)
+ if scope.is_root:
+ yield from _traverse_select(scope)
+ else:
+ yield from _traverse_subqueries(scope)
elif isinstance(scope.expression, exp.Table):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
@@ -572,7 +569,7 @@ def _traverse_ctes(scope):
for child_scope in _traverse_scope(
scope.branch(
cte.this,
- chain_sources=sources,
+ cte_sources=sources,
outer_column_list=cte.alias_column_names,
scope_type=ScopeType.CTE,
)
@@ -584,12 +581,14 @@ def _traverse_ctes(scope):
if recursive_scope:
child_scope.add_source(alias, recursive_scope)
+ child_scope.cte_sources[alias] = recursive_scope
# append the final child_scope yielded
if child_scope:
scope.cte_scopes.append(child_scope)
scope.sources.update(sources)
+ scope.cte_sources.update(sources)
def _is_derived_table(expression: exp.Subquery) -> bool:
@@ -725,7 +724,7 @@ def _traverse_ddl(scope):
yield from _traverse_ctes(scope)
query_scope = scope.branch(
- scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources
+ scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources
)
query_scope._collect()
query_scope._ctes = scope.ctes + query_scope._ctes
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index d4e2e60..6ae08d0 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import datetime
import functools
import itertools
@@ -6,10 +8,17 @@ from collections import deque
from decimal import Decimal
import sqlglot
-from sqlglot import exp
+from sqlglot import Dialect, exp
from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
+if t.TYPE_CHECKING:
+ from sqlglot.dialects.dialect import DialectType
+
+ DateTruncBinaryTransform = t.Callable[
+ [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
+ ]
+
# Final means that an expression should not be simplified
FINAL = "final"
@@ -18,7 +27,9 @@ class UnsupportedUnit(Exception):
pass
-def simplify(expression, constant_propagation=False):
+def simplify(
+ expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
+):
"""
Rewrite sqlglot AST to simplify expressions.
@@ -36,15 +47,18 @@ def simplify(expression, constant_propagation=False):
sqlglot.Expression: simplified expression
"""
+ dialect = Dialect.get_or_raise(dialect)
+
# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
for group in expression.find_all(exp.Group):
select = group.parent
+ assert select
groups = set(group.expressions)
group.meta[FINAL] = True
- for e in select.selects:
+ for e in select.expressions:
for node, *_ in e.walk():
if node in groups:
e.meta[FINAL] = True
@@ -84,7 +98,8 @@ def simplify(expression, constant_propagation=False):
node = simplify_literals(node, root)
node = simplify_equality(node)
node = simplify_parens(node)
- node = simplify_datetrunc_predicate(node)
+ node = simplify_datetrunc(node, dialect)
+ node = sort_comparison(node)
if root:
expression.replace(node)
@@ -117,14 +132,30 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
This is done because comparison simplification is only done on lt/lte/gt/gte.
"""
if isinstance(expression, exp.Between):
- return exp.and_(
+ negate = isinstance(expression.parent, exp.Not)
+
+ expression = exp.and_(
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
copy=False,
)
+
+ if negate:
+ expression = exp.paren(expression, copy=False)
+
return expression
+COMPLEMENT_COMPARISONS = {
+ exp.LT: exp.GTE,
+ exp.GT: exp.LTE,
+ exp.LTE: exp.GT,
+ exp.GTE: exp.LT,
+ exp.EQ: exp.NEQ,
+ exp.NEQ: exp.EQ,
+}
+
+
def simplify_not(expression):
"""
Demorgan's Law
@@ -132,10 +163,15 @@ def simplify_not(expression):
NOT (x AND y) -> NOT x OR NOT y
"""
if isinstance(expression, exp.Not):
- if is_null(expression.this):
+ this = expression.this
+ if is_null(this):
return exp.null()
- if isinstance(expression.this, exp.Paren):
- condition = expression.this.unnest()
+ if this.__class__ in COMPLEMENT_COMPARISONS:
+ return COMPLEMENT_COMPARISONS[this.__class__](
+ this=this.this, expression=this.expression
+ )
+ if isinstance(this, exp.Paren):
+ condition = this.unnest()
if isinstance(condition, exp.And):
return exp.or_(
exp.not_(condition.left, copy=False),
@@ -150,14 +186,14 @@ def simplify_not(expression):
)
if is_null(condition):
return exp.null()
- if always_true(expression.this):
+ if always_true(this):
return exp.false()
- if is_false(expression.this):
+ if is_false(this):
return exp.true()
- if isinstance(expression.this, exp.Not):
+ if isinstance(this, exp.Not):
# double negation
# NOT NOT x -> x
- return expression.this.this
+ return this.this
return expression
@@ -249,12 +285,6 @@ def _simplify_comparison(expression, left, right, or_=False):
except StopIteration:
return expression
- # make sure the comparison is always of the form x > 1 instead of 1 < x
- if left.__class__ in INVERSE_COMPARISONS and l == ll:
- left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
- if right.__class__ in INVERSE_COMPARISONS and r == rl:
- right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
-
if l.is_number and r.is_number:
l = float(l.name)
r = float(r.name)
@@ -397,13 +427,7 @@ def propagate_constants(expression, root=True):
# TODO: create a helper that can be used to detect nested literal expressions such
# as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
- pass
- elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
- l, r = r, l
- else:
- continue
-
- constant_mapping[l] = (id(l), r)
+ constant_mapping[l] = (id(l), r)
if constant_mapping:
for column in find_all_in_scope(expression, exp.Column):
@@ -458,11 +482,7 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, COMPARISONS):
l, r = expression.left, expression.right
- if l.__class__ in INVERSE_OPS:
- pass
- elif r.__class__ in INVERSE_OPS:
- l, r = r, l
- else:
+ if not l.__class__ in INVERSE_OPS:
return expression
if r.is_number:
@@ -650,7 +670,7 @@ def simplify_coalesce(expression):
# Find the first constant arg
for arg_index, arg in enumerate(coalesce.expressions):
- if _is_constant(other):
+ if _is_constant(arg):
break
else:
return expression
@@ -752,7 +772,7 @@ def simplify_conditionals(expression):
DateRange = t.Tuple[datetime.date, datetime.date]
-def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
+def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
"""
Get the date range for a DATE_TRUNC equality comparison:
@@ -761,7 +781,7 @@ def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
Returns:
tuple of [min, max) or None if a value can never be equal to `date` for `unit`
"""
- floor = date_floor(date, unit)
+ floor = date_floor(date, unit, dialect)
if date != floor:
# This will always be False, except for NULL values.
@@ -780,9 +800,9 @@ def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Exp
def _datetrunc_eq(
- left: exp.Expression, date: datetime.date, unit: str
+ left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
) -> t.Optional[exp.Expression]:
- drange = _datetrunc_range(date, unit)
+ drange = _datetrunc_range(date, unit, dialect)
if not drange:
return None
@@ -790,9 +810,9 @@ def _datetrunc_eq(
def _datetrunc_neq(
- left: exp.Expression, date: datetime.date, unit: str
+ left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
) -> t.Optional[exp.Expression]:
- drange = _datetrunc_range(date, unit)
+ drange = _datetrunc_range(date, unit, dialect)
if not drange:
return None
@@ -803,41 +823,39 @@ def _datetrunc_neq(
)
-DateTruncBinaryTransform = t.Callable[
- [exp.Expression, datetime.date, str], t.Optional[exp.Expression]
-]
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
- exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
- exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
- exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
- exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
+ exp.LT: lambda l, dt, u, d: l
+ < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
+ exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
+ exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
+ exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
exp.EQ: _datetrunc_eq,
exp.NEQ: _datetrunc_neq,
}
DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
+DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
- return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
+ return isinstance(left, DATETRUNCS) and _is_date_literal(right)
@catch(ModuleNotFoundError, UnsupportedUnit)
-def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
+def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
"""Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
comparison = expression.__class__
- if comparison not in DATETRUNC_COMPARISONS:
+ if isinstance(expression, DATETRUNCS):
+ date = extract_date(expression.this)
+ if date and expression.unit:
+ return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
+ elif comparison not in DATETRUNC_COMPARISONS:
return expression
if isinstance(expression, exp.Binary):
l, r = expression.left, expression.right
- if _is_datetrunc_predicate(l, r):
- pass
- elif _is_datetrunc_predicate(r, l):
- comparison = INVERSE_COMPARISONS.get(comparison, comparison)
- l, r = r, l
- else:
+ if not _is_datetrunc_predicate(l, r):
return expression
l = t.cast(exp.DateTrunc, l)
@@ -847,7 +865,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
if not date:
return expression
- return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
+ return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
elif isinstance(expression, exp.In):
l = expression.this
rs = expression.expressions
@@ -861,7 +879,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
date = extract_date(r)
if not date:
return expression
- drange = _datetrunc_range(date, unit)
+ drange = _datetrunc_range(date, unit, dialect)
if drange:
ranges.append(drange)
@@ -875,6 +893,23 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
return expression
+def sort_comparison(expression: exp.Expression) -> exp.Expression:
+ if expression.__class__ in COMPLEMENT_COMPARISONS:
+ l, r = expression.this, expression.expression
+ l_column = isinstance(l, exp.Column)
+ r_column = isinstance(r, exp.Column)
+ l_const = _is_constant(l)
+ r_const = _is_constant(r)
+
+ if (l_column and not r_column) or (r_const and not l_const):
+ return expression
+ if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
+ return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
+ this=r, expression=l
+ )
+ return expression
+
+
# CROSS joins result in an empty table if the right table is empty.
# So we can only simplify certain types of joins to CROSS.
# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
@@ -1034,7 +1069,7 @@ def interval(unit: str, n: int = 1):
raise UnsupportedUnit(f"Unsupported unit: {unit}")
-def date_floor(d: datetime.date, unit: str) -> datetime.date:
+def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
if unit == "year":
return d.replace(month=1, day=1)
if unit == "quarter":
@@ -1050,15 +1085,15 @@ def date_floor(d: datetime.date, unit: str) -> datetime.date:
return d.replace(month=d.month, day=1)
if unit == "week":
# Assuming week starts on Monday (0) and ends on Sunday (6)
- return d - datetime.timedelta(days=d.weekday())
+ return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
if unit == "day":
return d
raise UnsupportedUnit(f"Unsupported unit: {unit}")
-def date_ceil(d: datetime.date, unit: str) -> datetime.date:
- floor = date_floor(d, unit)
+def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
+ floor = date_floor(d, unit, dialect)
if floor == d:
return d
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 242fc87..4d35175 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -65,6 +65,8 @@ def unnest(select, parent_select, next_alias_name):
)
):
column = exp.Max(this=column)
+ elif not isinstance(select.parent, exp.Subquery):
+ return
_replace(select.parent, column)
parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)