diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/optimizer/simplify.py | 153 |
1 files changed, 94 insertions, 59 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index d4e2e60..6ae08d0 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import functools import itertools @@ -6,10 +8,17 @@ from collections import deque from decimal import Decimal import sqlglot -from sqlglot import exp +from sqlglot import Dialect, exp from sqlglot.helper import first, is_iterable, merge_ranges, while_changing from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import DialectType + + DateTruncBinaryTransform = t.Callable[ + [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression] + ] + # Final means that an expression should not be simplified FINAL = "final" @@ -18,7 +27,9 @@ class UnsupportedUnit(Exception): pass -def simplify(expression, constant_propagation=False): +def simplify( + expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None +): """ Rewrite sqlglot AST to simplify expressions. @@ -36,15 +47,18 @@ def simplify(expression, constant_propagation=False): sqlglot.Expression: simplified expression """ + dialect = Dialect.get_or_raise(dialect) + # group by expressions cannot be simplified, for example # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 # the projection must exactly match the group by key for group in expression.find_all(exp.Group): select = group.parent + assert select groups = set(group.expressions) group.meta[FINAL] = True - for e in select.selects: + for e in select.expressions: for node, *_ in e.walk(): if node in groups: e.meta[FINAL] = True @@ -84,7 +98,8 @@ def simplify(expression, constant_propagation=False): node = simplify_literals(node, root) node = simplify_equality(node) node = simplify_parens(node) - node = simplify_datetrunc_predicate(node) + node = simplify_datetrunc(node, dialect) + node = sort_comparison(node) if root: expression.replace(node) @@ -117,14 +132,30 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression: This is done because comparison simplification is only done on lt/lte/gt/gte. """ if isinstance(expression, exp.Between): - return exp.and_( + negate = isinstance(expression.parent, exp.Not) + + expression = exp.and_( exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), copy=False, ) + + if negate: + expression = exp.paren(expression, copy=False) + return expression +COMPLEMENT_COMPARISONS = { + exp.LT: exp.GTE, + exp.GT: exp.LTE, + exp.LTE: exp.GT, + exp.GTE: exp.LT, + exp.EQ: exp.NEQ, + exp.NEQ: exp.EQ, +} + + def simplify_not(expression): """ Demorgan's Law @@ -132,10 +163,15 @@ def simplify_not(expression): NOT (x AND y) -> NOT x OR NOT y """ if isinstance(expression, exp.Not): - if is_null(expression.this): + this = expression.this + if is_null(this): return exp.null() - if isinstance(expression.this, exp.Paren): - condition = expression.this.unnest() + if this.__class__ in COMPLEMENT_COMPARISONS: + return COMPLEMENT_COMPARISONS[this.__class__]( + this=this.this, expression=this.expression + ) + if isinstance(this, exp.Paren): + condition = this.unnest() if isinstance(condition, exp.And): return exp.or_( exp.not_(condition.left, copy=False), @@ -150,14 +186,14 @@ def simplify_not(expression): ) if is_null(condition): return exp.null() - if always_true(expression.this): + if always_true(this): return exp.false() - if is_false(expression.this): + if is_false(this): return exp.true() - if isinstance(expression.this, exp.Not): + if isinstance(this, exp.Not): # double negation # NOT NOT x -> x - return expression.this.this + return this.this return expression @@ -249,12 +285,6 @@ def _simplify_comparison(expression, left, right, or_=False): except StopIteration: return expression - # make sure the comparison is always of the form x > 1 instead of 1 < x - if left.__class__ in INVERSE_COMPARISONS and l == ll: - left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) - if right.__class__ in INVERSE_COMPARISONS and r == rl: - right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) - if l.is_number and r.is_number: l = float(l.name) r = float(r.name) @@ -397,13 +427,7 @@ def propagate_constants(expression, root=True): # TODO: create a helper that can be used to detect nested literal expressions such # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too if isinstance(l, exp.Column) and isinstance(r, exp.Literal): - pass - elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): - l, r = r, l - else: - continue - - constant_mapping[l] = (id(l), r) + constant_mapping[l] = (id(l), r) if constant_mapping: for column in find_all_in_scope(expression, exp.Column): @@ -458,11 +482,7 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression: if isinstance(expression, COMPARISONS): l, r = expression.left, expression.right - if l.__class__ in INVERSE_OPS: - pass - elif r.__class__ in INVERSE_OPS: - l, r = r, l - else: + if not l.__class__ in INVERSE_OPS: return expression if r.is_number: @@ -650,7 +670,7 @@ def simplify_coalesce(expression): # Find the first constant arg for arg_index, arg in enumerate(coalesce.expressions): - if _is_constant(other): + if _is_constant(arg): break else: return expression @@ -752,7 +772,7 @@ def simplify_conditionals(expression): DateRange = t.Tuple[datetime.date, datetime.date] -def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: +def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: """ Get the date range for a DATE_TRUNC equality comparison: @@ -761,7 +781,7 @@ def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: Returns: tuple of [min, max) or None if a value can never be equal to `date` for `unit` """ - floor = date_floor(date, unit) + floor = date_floor(date, unit, dialect) if date != floor: # This will always be False, except for NULL values. @@ -780,9 +800,9 @@ def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Exp def _datetrunc_eq( - left: exp.Expression, date: datetime.date, unit: str + left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect ) -> t.Optional[exp.Expression]: - drange = _datetrunc_range(date, unit) + drange = _datetrunc_range(date, unit, dialect) if not drange: return None @@ -790,9 +810,9 @@ def _datetrunc_eq( def _datetrunc_neq( - left: exp.Expression, date: datetime.date, unit: str + left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect ) -> t.Optional[exp.Expression]: - drange = _datetrunc_range(date, unit) + drange = _datetrunc_range(date, unit, dialect) if not drange: return None @@ -803,41 +823,39 @@ def _datetrunc_neq( ) -DateTruncBinaryTransform = t.Callable[ - [exp.Expression, datetime.date, str], t.Optional[exp.Expression] -] DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { - exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)), - exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)), - exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)), - exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)), + exp.LT: lambda l, dt, u, d: l + < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)), + exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)), + exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)), + exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)), exp.EQ: _datetrunc_eq, exp.NEQ: _datetrunc_neq, } DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} +DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: - return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right) + return isinstance(left, DATETRUNCS) and _is_date_literal(right) @catch(ModuleNotFoundError, UnsupportedUnit) -def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: +def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" comparison = expression.__class__ - if comparison not in DATETRUNC_COMPARISONS: + if isinstance(expression, DATETRUNCS): + date = extract_date(expression.this) + if date and expression.unit: + return date_literal(date_floor(date, expression.unit.name.lower(), dialect)) + elif comparison not in DATETRUNC_COMPARISONS: return expression if isinstance(expression, exp.Binary): l, r = expression.left, expression.right - if _is_datetrunc_predicate(l, r): - pass - elif _is_datetrunc_predicate(r, l): - comparison = INVERSE_COMPARISONS.get(comparison, comparison) - l, r = r, l - else: + if not _is_datetrunc_predicate(l, r): return expression l = t.cast(exp.DateTrunc, l) @@ -847,7 +865,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: if not date: return expression - return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression + return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression elif isinstance(expression, exp.In): l = expression.this rs = expression.expressions @@ -861,7 +879,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: date = extract_date(r) if not date: return expression - drange = _datetrunc_range(date, unit) + drange = _datetrunc_range(date, unit, dialect) if drange: ranges.append(drange) @@ -875,6 +893,23 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: return expression +def sort_comparison(expression: exp.Expression) -> exp.Expression: + if expression.__class__ in COMPLEMENT_COMPARISONS: + l, r = expression.this, expression.expression + l_column = isinstance(l, exp.Column) + r_column = isinstance(r, exp.Column) + l_const = _is_constant(l) + r_const = _is_constant(r) + + if (l_column and not r_column) or (r_const and not l_const): + return expression + if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): + return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( + this=r, expression=l + ) + return expression + + # CROSS joins result in an empty table if the right table is empty. # So we can only simplify certain types of joins to CROSS. # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x @@ -1034,7 +1069,7 @@ def interval(unit: str, n: int = 1): raise UnsupportedUnit(f"Unsupported unit: {unit}") -def date_floor(d: datetime.date, unit: str) -> datetime.date: +def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: if unit == "year": return d.replace(month=1, day=1) if unit == "quarter": @@ -1050,15 +1085,15 @@ def date_floor(d: datetime.date, unit: str) -> datetime.date: return d.replace(month=d.month, day=1) if unit == "week": # Assuming week starts on Monday (0) and ends on Sunday (6) - return d - datetime.timedelta(days=d.weekday()) + return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) if unit == "day": return d raise UnsupportedUnit(f"Unsupported unit: {unit}") -def date_ceil(d: datetime.date, unit: str) -> datetime.date: - floor = date_floor(d, unit) +def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: + floor = date_floor(d, unit, dialect) if floor == d: return d |