summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/simplify.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-12 15:42:38 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-12 15:42:38 +0000
commitbea2635be022e272ddac349f5e396ec901fc37e5 (patch)
tree24dbe11c9d462ff55f9b3af4b4da4cd1ae02e8a3 /sqlglot/optimizer/simplify.py
parentReleasing debian version 10.1.3-1. (diff)
downloadsqlglot-bea2635be022e272ddac349f5e396ec901fc37e5.tar.xz
sqlglot-bea2635be022e272ddac349f5e396ec901fc37e5.zip
Merging upstream version 10.2.6.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r--sqlglot/optimizer/simplify.py235
1 files changed, 162 insertions, 73 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index c432c59..c0719f2 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -7,7 +7,7 @@ from decimal import Decimal
from sqlglot import exp
from sqlglot.expressions import FALSE, NULL, TRUE
from sqlglot.generator import Generator
-from sqlglot.helper import while_changing
+from sqlglot.helper import first, while_changing
GENERATOR = Generator(normalize=True, identify=True)
@@ -30,6 +30,7 @@ def simplify(expression):
def _simplify(expression, root=True):
node = expression
+ node = rewrite_between(node)
node = uniq_sort(node)
node = absorb_and_eliminate(node)
exp.replace_children(node, lambda e: _simplify(e, False))
@@ -49,6 +50,19 @@ def simplify(expression):
return expression
+def rewrite_between(expression: exp.Expression) -> exp.Expression:
+ """Rewrite x between y and z to x >= y AND x <= z.
+
+ This is done because comparison simplification is only done on lt/lte/gt/gte.
+ """
+ if isinstance(expression, exp.Between):
+ return exp.and_(
+ exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
+ exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
+ )
+ return expression
+
+
def simplify_not(expression):
"""
Demorgan's Law
@@ -57,7 +71,7 @@ def simplify_not(expression):
"""
if isinstance(expression, exp.Not):
if isinstance(expression.this, exp.Null):
- return NULL
+ return exp.null()
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
if isinstance(condition, exp.And):
@@ -65,11 +79,11 @@ def simplify_not(expression):
if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Null):
- return NULL
+ return exp.null()
if always_true(expression.this):
- return FALSE
+ return exp.false()
if expression.this == FALSE:
- return TRUE
+ return exp.true()
if isinstance(expression.this, exp.Not):
# double negation
# NOT NOT x -> x
@@ -91,40 +105,119 @@ def flatten(expression):
def simplify_connectors(expression):
- if isinstance(expression, exp.Connector):
- left = expression.left
- right = expression.right
-
- if left == right:
- return left
-
- if isinstance(expression, exp.And):
- if FALSE in (left, right):
- return FALSE
- if NULL in (left, right):
- return NULL
- if always_true(left) and always_true(right):
- return TRUE
- if always_true(left):
- return right
- if always_true(right):
- return left
- elif isinstance(expression, exp.Or):
- if always_true(left) or always_true(right):
- return TRUE
- if left == FALSE and right == FALSE:
- return FALSE
- if (
- (left == NULL and right == NULL)
- or (left == NULL and right == FALSE)
- or (left == FALSE and right == NULL)
- ):
- return NULL
- if left == FALSE:
- return right
- if right == FALSE:
+ def _simplify_connectors(expression, left, right):
+ if isinstance(expression, exp.Connector):
+ if left == right:
return left
- return expression
+ if isinstance(expression, exp.And):
+ if FALSE in (left, right):
+ return exp.false()
+ if NULL in (left, right):
+ return exp.null()
+ if always_true(left) and always_true(right):
+ return exp.true()
+ if always_true(left):
+ return right
+ if always_true(right):
+ return left
+ return _simplify_comparison(expression, left, right)
+ elif isinstance(expression, exp.Or):
+ if always_true(left) or always_true(right):
+ return exp.true()
+ if left == FALSE and right == FALSE:
+ return exp.false()
+ if (
+ (left == NULL and right == NULL)
+ or (left == NULL and right == FALSE)
+ or (left == FALSE and right == NULL)
+ ):
+ return exp.null()
+ if left == FALSE:
+ return right
+ if right == FALSE:
+ return left
+ return _simplify_comparison(expression, left, right, or_=True)
+ return None
+
+ return _flat_simplify(expression, _simplify_connectors)
+
+
+LT_LTE = (exp.LT, exp.LTE)
+GT_GTE = (exp.GT, exp.GTE)
+
+COMPARISONS = (
+ *LT_LTE,
+ *GT_GTE,
+ exp.EQ,
+ exp.NEQ,
+)
+
+INVERSE_COMPARISONS = {
+ exp.LT: exp.GT,
+ exp.GT: exp.LT,
+ exp.LTE: exp.GTE,
+ exp.GTE: exp.LTE,
+}
+
+
+def _simplify_comparison(expression, left, right, or_=False):
+ if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
+ ll, lr = left.args.values()
+ rl, rr = right.args.values()
+
+ largs = {ll, lr}
+ rargs = {rl, rr}
+
+ matching = largs & rargs
+ columns = {m for m in matching if isinstance(m, exp.Column)}
+
+ if matching and columns:
+ try:
+ l = first(largs - columns)
+ r = first(rargs - columns)
+ except StopIteration:
+ return expression
+
+ # make sure the comparison is always of the form x > 1 instead of 1 < x
+ if left.__class__ in INVERSE_COMPARISONS and l == ll:
+ left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
+ if right.__class__ in INVERSE_COMPARISONS and r == rl:
+ right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
+
+ if l.is_number and r.is_number:
+ l = float(l.name)
+ r = float(r.name)
+ elif l.is_string and r.is_string:
+ l = l.name
+ r = r.name
+ else:
+ return None
+
+ for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
+ if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
+ return left if (av > bv if or_ else av <= bv) else right
+ if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
+ return left if (av < bv if or_ else av >= bv) else right
+
+ # we can't ever shortcut to true because the column could be null
+ if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
+ if not or_ and av <= bv:
+ return exp.false()
+ elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
+ if not or_ and av >= bv:
+ return exp.false()
+ elif isinstance(a, exp.EQ):
+ if isinstance(b, exp.LT):
+ return exp.false() if av >= bv else a
+ if isinstance(b, exp.LTE):
+ return exp.false() if av > bv else a
+ if isinstance(b, exp.GT):
+ return exp.false() if av <= bv else a
+ if isinstance(b, exp.GTE):
+ return exp.false() if av < bv else a
+ if isinstance(b, exp.NEQ):
+ return exp.false() if av == bv else a
+ return None
def remove_compliments(expression):
@@ -135,7 +228,7 @@ def remove_compliments(expression):
A OR NOT A -> TRUE
"""
if isinstance(expression, exp.Connector):
- compliment = FALSE if isinstance(expression, exp.And) else TRUE
+ compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2):
if is_complement(a, b):
@@ -211,27 +304,7 @@ def absorb_and_eliminate(expression):
def simplify_literals(expression):
if isinstance(expression, exp.Binary):
- operands = []
- queue = deque(expression.flatten(unnest=False))
- size = len(queue)
-
- while queue:
- a = queue.popleft()
-
- for b in queue:
- result = _simplify_binary(expression, a, b)
-
- if result:
- queue.remove(b)
- queue.append(result)
- break
- else:
- operands.append(a)
-
- if len(operands) < size:
- return functools.reduce(
- lambda a, b: expression.__class__(this=a, expression=b), operands
- )
+ return _flat_simplify(expression, _simplify_binary)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
@@ -254,20 +327,13 @@ def _simplify_binary(expression, a, b):
if c == NULL:
if isinstance(a, exp.Literal):
- return TRUE if not_ else FALSE
+ return exp.true() if not_ else exp.false()
if a == NULL:
- return FALSE if not_ else TRUE
- elif isinstance(expression, exp.NullSafeEQ):
- if a == b:
- return TRUE
- elif isinstance(expression, exp.NullSafeNEQ):
- if a == b:
- return FALSE
+ return exp.false() if not_ else exp.true()
+ elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
+ return None
elif NULL in (a, b):
- return NULL
-
- if isinstance(expression, exp.EQ) and a == b:
- return TRUE
+ return exp.null()
if a.is_number and b.is_number:
a = int(a.name) if a.is_int else Decimal(a.name)
@@ -388,4 +454,27 @@ def date_literal(date):
def boolean_literal(condition):
- return TRUE if condition else FALSE
+ return exp.true() if condition else exp.false()
+
+
+def _flat_simplify(expression, simplifier):
+ operands = []
+ queue = deque(expression.flatten(unnest=False))
+ size = len(queue)
+
+ while queue:
+ a = queue.popleft()
+
+ for b in queue:
+ result = simplifier(expression, a, b)
+
+ if result:
+ queue.remove(b)
+ queue.append(result)
+ break
+ else:
+ operands.append(a)
+
+ if len(operands) < size:
+ return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
+ return expression