summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/simplify.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-12-19 11:01:55 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-12-19 11:01:55 +0000
commitf1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5 (patch)
tree5dce0fe2a11381761496eb973c20750f44db56d5 /sqlglot/optimizer/simplify.py
parentReleasing debian version 20.1.0-1. (diff)
downloadsqlglot-f1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5.tar.xz
sqlglot-f1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5.zip
Merging upstream version 20.3.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r--sqlglot/optimizer/simplify.py153
1 files changed, 94 insertions, 59 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index d4e2e60..6ae08d0 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import datetime
import functools
import itertools
@@ -6,10 +8,17 @@ from collections import deque
from decimal import Decimal
import sqlglot
-from sqlglot import exp
+from sqlglot import Dialect, exp
from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
+if t.TYPE_CHECKING:
+ from sqlglot.dialects.dialect import DialectType
+
+ DateTruncBinaryTransform = t.Callable[
+ [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
+ ]
+
# Final means that an expression should not be simplified
FINAL = "final"
@@ -18,7 +27,9 @@ class UnsupportedUnit(Exception):
pass
-def simplify(expression, constant_propagation=False):
+def simplify(
+ expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
+):
"""
Rewrite sqlglot AST to simplify expressions.
@@ -36,15 +47,18 @@ def simplify(expression, constant_propagation=False):
sqlglot.Expression: simplified expression
"""
+ dialect = Dialect.get_or_raise(dialect)
+
# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
for group in expression.find_all(exp.Group):
select = group.parent
+ assert select
groups = set(group.expressions)
group.meta[FINAL] = True
- for e in select.selects:
+ for e in select.expressions:
for node, *_ in e.walk():
if node in groups:
e.meta[FINAL] = True
@@ -84,7 +98,8 @@ def simplify(expression, constant_propagation=False):
node = simplify_literals(node, root)
node = simplify_equality(node)
node = simplify_parens(node)
- node = simplify_datetrunc_predicate(node)
+ node = simplify_datetrunc(node, dialect)
+ node = sort_comparison(node)
if root:
expression.replace(node)
@@ -117,14 +132,30 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
This is done because comparison simplification is only done on lt/lte/gt/gte.
"""
if isinstance(expression, exp.Between):
- return exp.and_(
+ negate = isinstance(expression.parent, exp.Not)
+
+ expression = exp.and_(
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
copy=False,
)
+
+ if negate:
+ expression = exp.paren(expression, copy=False)
+
return expression
+COMPLEMENT_COMPARISONS = {
+ exp.LT: exp.GTE,
+ exp.GT: exp.LTE,
+ exp.LTE: exp.GT,
+ exp.GTE: exp.LT,
+ exp.EQ: exp.NEQ,
+ exp.NEQ: exp.EQ,
+}
+
+
def simplify_not(expression):
"""
Demorgan's Law
@@ -132,10 +163,15 @@ def simplify_not(expression):
NOT (x AND y) -> NOT x OR NOT y
"""
if isinstance(expression, exp.Not):
- if is_null(expression.this):
+ this = expression.this
+ if is_null(this):
return exp.null()
- if isinstance(expression.this, exp.Paren):
- condition = expression.this.unnest()
+ if this.__class__ in COMPLEMENT_COMPARISONS:
+ return COMPLEMENT_COMPARISONS[this.__class__](
+ this=this.this, expression=this.expression
+ )
+ if isinstance(this, exp.Paren):
+ condition = this.unnest()
if isinstance(condition, exp.And):
return exp.or_(
exp.not_(condition.left, copy=False),
@@ -150,14 +186,14 @@ def simplify_not(expression):
)
if is_null(condition):
return exp.null()
- if always_true(expression.this):
+ if always_true(this):
return exp.false()
- if is_false(expression.this):
+ if is_false(this):
return exp.true()
- if isinstance(expression.this, exp.Not):
+ if isinstance(this, exp.Not):
# double negation
# NOT NOT x -> x
- return expression.this.this
+ return this.this
return expression
@@ -249,12 +285,6 @@ def _simplify_comparison(expression, left, right, or_=False):
except StopIteration:
return expression
- # make sure the comparison is always of the form x > 1 instead of 1 < x
- if left.__class__ in INVERSE_COMPARISONS and l == ll:
- left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
- if right.__class__ in INVERSE_COMPARISONS and r == rl:
- right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
-
if l.is_number and r.is_number:
l = float(l.name)
r = float(r.name)
@@ -397,13 +427,7 @@ def propagate_constants(expression, root=True):
# TODO: create a helper that can be used to detect nested literal expressions such
# as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
- pass
- elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
- l, r = r, l
- else:
- continue
-
- constant_mapping[l] = (id(l), r)
+ constant_mapping[l] = (id(l), r)
if constant_mapping:
for column in find_all_in_scope(expression, exp.Column):
@@ -458,11 +482,7 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, COMPARISONS):
l, r = expression.left, expression.right
- if l.__class__ in INVERSE_OPS:
- pass
- elif r.__class__ in INVERSE_OPS:
- l, r = r, l
- else:
+ if not l.__class__ in INVERSE_OPS:
return expression
if r.is_number:
@@ -650,7 +670,7 @@ def simplify_coalesce(expression):
# Find the first constant arg
for arg_index, arg in enumerate(coalesce.expressions):
- if _is_constant(other):
+ if _is_constant(arg):
break
else:
return expression
@@ -752,7 +772,7 @@ def simplify_conditionals(expression):
DateRange = t.Tuple[datetime.date, datetime.date]
-def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
+def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
"""
Get the date range for a DATE_TRUNC equality comparison:
@@ -761,7 +781,7 @@ def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
Returns:
tuple of [min, max) or None if a value can never be equal to `date` for `unit`
"""
- floor = date_floor(date, unit)
+ floor = date_floor(date, unit, dialect)
if date != floor:
# This will always be False, except for NULL values.
@@ -780,9 +800,9 @@ def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Exp
def _datetrunc_eq(
- left: exp.Expression, date: datetime.date, unit: str
+ left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
) -> t.Optional[exp.Expression]:
- drange = _datetrunc_range(date, unit)
+ drange = _datetrunc_range(date, unit, dialect)
if not drange:
return None
@@ -790,9 +810,9 @@ def _datetrunc_eq(
def _datetrunc_neq(
- left: exp.Expression, date: datetime.date, unit: str
+ left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
) -> t.Optional[exp.Expression]:
- drange = _datetrunc_range(date, unit)
+ drange = _datetrunc_range(date, unit, dialect)
if not drange:
return None
@@ -803,41 +823,39 @@ def _datetrunc_neq(
)
-DateTruncBinaryTransform = t.Callable[
- [exp.Expression, datetime.date, str], t.Optional[exp.Expression]
-]
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
- exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
- exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
- exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
- exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
+ exp.LT: lambda l, dt, u, d: l
+ < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
+ exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
+ exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
+ exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
exp.EQ: _datetrunc_eq,
exp.NEQ: _datetrunc_neq,
}
DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
+DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
- return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
+ return isinstance(left, DATETRUNCS) and _is_date_literal(right)
@catch(ModuleNotFoundError, UnsupportedUnit)
-def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
+def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
"""Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
comparison = expression.__class__
- if comparison not in DATETRUNC_COMPARISONS:
+ if isinstance(expression, DATETRUNCS):
+ date = extract_date(expression.this)
+ if date and expression.unit:
+ return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
+ elif comparison not in DATETRUNC_COMPARISONS:
return expression
if isinstance(expression, exp.Binary):
l, r = expression.left, expression.right
- if _is_datetrunc_predicate(l, r):
- pass
- elif _is_datetrunc_predicate(r, l):
- comparison = INVERSE_COMPARISONS.get(comparison, comparison)
- l, r = r, l
- else:
+ if not _is_datetrunc_predicate(l, r):
return expression
l = t.cast(exp.DateTrunc, l)
@@ -847,7 +865,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
if not date:
return expression
- return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
+ return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
elif isinstance(expression, exp.In):
l = expression.this
rs = expression.expressions
@@ -861,7 +879,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
date = extract_date(r)
if not date:
return expression
- drange = _datetrunc_range(date, unit)
+ drange = _datetrunc_range(date, unit, dialect)
if drange:
ranges.append(drange)
@@ -875,6 +893,23 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
return expression
+def sort_comparison(expression: exp.Expression) -> exp.Expression:
+ if expression.__class__ in COMPLEMENT_COMPARISONS:
+ l, r = expression.this, expression.expression
+ l_column = isinstance(l, exp.Column)
+ r_column = isinstance(r, exp.Column)
+ l_const = _is_constant(l)
+ r_const = _is_constant(r)
+
+ if (l_column and not r_column) or (r_const and not l_const):
+ return expression
+ if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
+ return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
+ this=r, expression=l
+ )
+ return expression
+
+
# CROSS joins result in an empty table if the right table is empty.
# So we can only simplify certain types of joins to CROSS.
# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
@@ -1034,7 +1069,7 @@ def interval(unit: str, n: int = 1):
raise UnsupportedUnit(f"Unsupported unit: {unit}")
-def date_floor(d: datetime.date, unit: str) -> datetime.date:
+def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
if unit == "year":
return d.replace(month=1, day=1)
if unit == "quarter":
@@ -1050,15 +1085,15 @@ def date_floor(d: datetime.date, unit: str) -> datetime.date:
return d.replace(month=d.month, day=1)
if unit == "week":
# Assuming week starts on Monday (0) and ends on Sunday (6)
- return d - datetime.timedelta(days=d.weekday())
+ return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
if unit == "day":
return d
raise UnsupportedUnit(f"Unsupported unit: {unit}")
-def date_ceil(d: datetime.date, unit: str) -> datetime.date:
- floor = date_floor(d, unit)
+def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
+ floor = date_floor(d, unit, dialect)
if floor == d:
return d