summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/simplify.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-25 08:20:09 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-25 08:20:09 +0000
commit4554ab4c7d6b2bbbaa6f4d0b810bf477d1a505a6 (patch)
tree8f4f60a82ab9cd6dcd41397e4ecb2960c332b209 /sqlglot/optimizer/simplify.py
parentReleasing debian version 18.5.1-1. (diff)
downloadsqlglot-4554ab4c7d6b2bbbaa6f4d0b810bf477d1a505a6.tar.xz
sqlglot-4554ab4c7d6b2bbbaa6f4d0b810bf477d1a505a6.zip
Merging upstream version 18.7.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r--sqlglot/optimizer/simplify.py299
1 files changed, 282 insertions, 17 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 3974ea4..d08c692 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -1,17 +1,22 @@
import datetime
import functools
import itertools
+import typing as t
from collections import deque
from decimal import Decimal
from sqlglot import exp
from sqlglot.generator import cached_generator
-from sqlglot.helper import first, while_changing
+from sqlglot.helper import first, merge_ranges, while_changing
# Final means that an expression should not be simplified
FINAL = "final"
+class UnsupportedUnit(Exception):
+ pass
+
+
def simplify(expression):
"""
Rewrite sqlglot AST to simplify expressions.
@@ -72,7 +77,9 @@ def simplify(expression):
node = simplify_coalesce(node)
node.parent = expression.parent
node = simplify_literals(node, root)
+ node = simplify_equality(node)
node = simplify_parens(node)
+ node = simplify_datetrunc_predicate(node)
if root:
expression.replace(node)
@@ -84,6 +91,21 @@ def simplify(expression):
return expression
+def catch(*exceptions):
+ """Decorator that ignores a simplification function if any of `exceptions` are raised"""
+
+ def decorator(func):
+ def wrapped(expression, *args, **kwargs):
+ try:
+ return func(expression, *args, **kwargs)
+ except exceptions:
+ return expression
+
+ return wrapped
+
+ return decorator
+
+
def rewrite_between(expression: exp.Expression) -> exp.Expression:
"""Rewrite x between y and z to x >= y AND x <= z.
@@ -196,7 +218,7 @@ COMPARISONS = (
exp.Is,
)
-INVERSE_COMPARISONS = {
+INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
exp.LT: exp.GT,
exp.GT: exp.LT,
exp.LTE: exp.GTE,
@@ -347,6 +369,87 @@ def absorb_and_eliminate(expression, root=True):
return expression
+INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
+ exp.DateAdd: exp.Sub,
+ exp.DateSub: exp.Add,
+ exp.DatetimeAdd: exp.Sub,
+ exp.DatetimeSub: exp.Add,
+}
+
+INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
+ **INVERSE_DATE_OPS,
+ exp.Add: exp.Sub,
+ exp.Sub: exp.Add,
+}
+
+
+def _is_number(expression: exp.Expression) -> bool:
+ return expression.is_number
+
+
+def _is_date(expression: exp.Expression) -> bool:
+ return isinstance(expression, exp.Cast) and extract_date(expression) is not None
+
+
+def _is_interval(expression: exp.Expression) -> bool:
+ return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
+
+
+@catch(ModuleNotFoundError, UnsupportedUnit)
+def simplify_equality(expression: exp.Expression) -> exp.Expression:
+ """
+ Use the subtraction and addition properties of equality to simplify expressions:
+
+ x + 1 = 3 becomes x = 2
+
+ There are two binary operations in the above expression: + and =
+ Here's how we reference all the operands in the code below:
+
+ l r
+ x + 1 = 3
+ a b
+ """
+ if isinstance(expression, COMPARISONS):
+ l, r = expression.left, expression.right
+
+ if l.__class__ in INVERSE_OPS:
+ pass
+ elif r.__class__ in INVERSE_OPS:
+ l, r = r, l
+ else:
+ return expression
+
+ if r.is_number:
+ a_predicate = _is_number
+ b_predicate = _is_number
+ elif _is_date(r):
+ a_predicate = _is_date
+ b_predicate = _is_interval
+ else:
+ return expression
+
+ if l.__class__ in INVERSE_DATE_OPS:
+ a = l.this
+ b = exp.Interval(
+ this=l.expression.copy(),
+ unit=l.unit.copy(),
+ )
+ else:
+ a, b = l.left, l.right
+
+ if not a_predicate(a) and b_predicate(b):
+ pass
+ elif not a_predicate(b) and b_predicate(a):
+ a, b = b, a
+ else:
+ return expression
+
+ return expression.__class__(
+ this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
+ )
+ return expression
+
+
def simplify_literals(expression, root=True):
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
return _flat_simplify(expression, _simplify_binary, root)
@@ -530,6 +633,123 @@ def simplify_concat(expression):
return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
+DateRange = t.Tuple[datetime.date, datetime.date]
+
+
+def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
+ """
+ Get the date range for a DATE_TRUNC equality comparison:
+
+ Example:
+ _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
+ Returns:
+ tuple of [min, max) or None if a value can never be equal to `date` for `unit`
+ """
+ floor = date_floor(date, unit)
+
+ if date != floor:
+ # This will always be False, except for NULL values.
+ return None
+
+ return floor, floor + interval(unit)
+
+
+def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
+ """Get the logical expression for a date range"""
+ return exp.and_(
+ left >= date_literal(drange[0]),
+ left < date_literal(drange[1]),
+ copy=False,
+ )
+
+
+def _datetrunc_eq(
+ left: exp.Expression, date: datetime.date, unit: str
+) -> t.Optional[exp.Expression]:
+ drange = _datetrunc_range(date, unit)
+ if not drange:
+ return None
+
+ return _datetrunc_eq_expression(left, drange)
+
+
+def _datetrunc_neq(
+ left: exp.Expression, date: datetime.date, unit: str
+) -> t.Optional[exp.Expression]:
+ drange = _datetrunc_range(date, unit)
+ if not drange:
+ return None
+
+ return exp.and_(
+ left < date_literal(drange[0]),
+ left >= date_literal(drange[1]),
+ copy=False,
+ )
+
+
+DateTruncBinaryTransform = t.Callable[
+ [exp.Expression, datetime.date, str], t.Optional[exp.Expression]
+]
+DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
+ exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
+ exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
+ exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
+ exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
+ exp.EQ: _datetrunc_eq,
+ exp.NEQ: _datetrunc_neq,
+}
+DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
+
+
+def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
+ return (
+ isinstance(left, (exp.DateTrunc, exp.TimestampTrunc))
+ and isinstance(right, exp.Cast)
+ and right.is_type(*exp.DataType.TEMPORAL_TYPES)
+ )
+
+
+@catch(ModuleNotFoundError, UnsupportedUnit)
+def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
+ """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
+ comparison = expression.__class__
+
+ if comparison not in DATETRUNC_COMPARISONS:
+ return expression
+
+ if isinstance(expression, exp.Binary):
+ l, r = expression.left, expression.right
+
+ if _is_datetrunc_predicate(l, r):
+ pass
+ elif _is_datetrunc_predicate(r, l):
+ comparison = INVERSE_COMPARISONS.get(comparison, comparison)
+ l, r = r, l
+ else:
+ return expression
+
+ unit = l.unit.name.lower()
+ date = extract_date(r)
+
+ return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
+ elif isinstance(expression, exp.In):
+ l = expression.this
+ rs = expression.expressions
+
+ if all(_is_datetrunc_predicate(l, r) for r in rs):
+ unit = l.unit.name.lower()
+
+ ranges = [r for r in [_datetrunc_range(extract_date(r), unit) for r in rs] if r]
+ if not ranges:
+ return expression
+
+ ranges = merge_ranges(ranges)
+
+ return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
+
+ return expression
+
+
# CROSS joins result in an empty table if the right table is empty.
# So we can only simplify certain types of joins to CROSS.
# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
@@ -603,31 +823,76 @@ def extract_date(cast):
return None
-def extract_interval(interval):
+def extract_interval(expression):
+ n = int(expression.name)
+ unit = expression.text("unit").lower()
+
try:
- from dateutil.relativedelta import relativedelta # type: ignore
- except ModuleNotFoundError:
+ return interval(unit, n)
+ except (UnsupportedUnit, ModuleNotFoundError):
return None
- n = int(interval.name)
- unit = interval.text("unit").lower()
+
+def date_literal(date):
+ return exp.cast(
+ exp.Literal.string(date),
+ "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
+ )
+
+
+def interval(unit: str, n: int = 1):
+ from dateutil.relativedelta import relativedelta
if unit == "year":
- return relativedelta(years=n)
+ return relativedelta(years=1 * n)
+ if unit == "quarter":
+ return relativedelta(months=3 * n)
if unit == "month":
- return relativedelta(months=n)
+ return relativedelta(months=1 * n)
if unit == "week":
- return relativedelta(weeks=n)
+ return relativedelta(weeks=1 * n)
if unit == "day":
- return relativedelta(days=n)
- return None
+ return relativedelta(days=1 * n)
+ if unit == "hour":
+ return relativedelta(hours=1 * n)
+ if unit == "minute":
+ return relativedelta(minutes=1 * n)
+ if unit == "second":
+ return relativedelta(seconds=1 * n)
+ raise UnsupportedUnit(f"Unsupported unit: {unit}")
-def date_literal(date):
- return exp.cast(
- exp.Literal.string(date),
- "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
- )
+
+def date_floor(d: datetime.date, unit: str) -> datetime.date:
+ if unit == "year":
+ return d.replace(month=1, day=1)
+ if unit == "quarter":
+ if d.month <= 3:
+ return d.replace(month=1, day=1)
+ elif d.month <= 6:
+ return d.replace(month=4, day=1)
+ elif d.month <= 9:
+ return d.replace(month=7, day=1)
+ else:
+ return d.replace(month=10, day=1)
+ if unit == "month":
+ return d.replace(month=d.month, day=1)
+ if unit == "week":
+ # Assuming week starts on Monday (0) and ends on Sunday (6)
+ return d - datetime.timedelta(days=d.weekday())
+ if unit == "day":
+ return d
+
+ raise UnsupportedUnit(f"Unsupported unit: {unit}")
+
+
+def date_ceil(d: datetime.date, unit: str) -> datetime.date:
+ floor = date_floor(d, unit)
+
+ if floor == d:
+ return d
+
+ return floor + interval(unit)
def boolean_literal(condition):