summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/simplify.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r--sqlglot/optimizer/simplify.py166
1 files changed, 90 insertions, 76 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index f80484d..1ed3ca2 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -5,11 +5,10 @@ from collections import deque
from decimal import Decimal
from sqlglot import exp
-from sqlglot.expressions import FALSE, NULL, TRUE
from sqlglot.generator import Generator
from sqlglot.helper import first, while_changing
-GENERATOR = Generator(normalize=True, identify=True)
+GENERATOR = Generator(normalize=True, identify="safe")
def simplify(expression):
@@ -28,18 +27,20 @@ def simplify(expression):
sqlglot.Expression: simplified expression
"""
+ cache = {}
+
def _simplify(expression, root=True):
node = expression
node = rewrite_between(node)
- node = uniq_sort(node)
- node = absorb_and_eliminate(node)
+ node = uniq_sort(node, cache, root)
+ node = absorb_and_eliminate(node, root)
exp.replace_children(node, lambda e: _simplify(e, False))
node = simplify_not(node)
node = flatten(node)
- node = simplify_connectors(node)
- node = remove_compliments(node)
+ node = simplify_connectors(node, root)
+ node = remove_compliments(node, root)
node.parent = expression.parent
- node = simplify_literals(node)
+ node = simplify_literals(node, root)
node = simplify_parens(node)
if root:
expression.replace(node)
@@ -70,7 +71,7 @@ def simplify_not(expression):
NOT (x AND y) -> NOT x OR NOT y
"""
if isinstance(expression, exp.Not):
- if isinstance(expression.this, exp.Null):
+ if is_null(expression.this):
return exp.null()
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
@@ -78,11 +79,11 @@ def simplify_not(expression):
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
- if isinstance(condition, exp.Null):
+ if is_null(condition):
return exp.null()
if always_true(expression.this):
return exp.false()
- if expression.this == FALSE:
+ if is_false(expression.this):
return exp.true()
if isinstance(expression.this, exp.Not):
# double negation
@@ -104,42 +105,42 @@ def flatten(expression):
return expression
-def simplify_connectors(expression):
+def simplify_connectors(expression, root=True):
def _simplify_connectors(expression, left, right):
- if isinstance(expression, exp.Connector):
- if left == right:
+ if left == right:
+ return left
+ if isinstance(expression, exp.And):
+ if is_false(left) or is_false(right):
+ return exp.false()
+ if is_null(left) or is_null(right):
+ return exp.null()
+ if always_true(left) and always_true(right):
+ return exp.true()
+ if always_true(left):
+ return right
+ if always_true(right):
return left
- if isinstance(expression, exp.And):
- if FALSE in (left, right):
- return exp.false()
- if NULL in (left, right):
- return exp.null()
- if always_true(left) and always_true(right):
- return exp.true()
- if always_true(left):
- return right
- if always_true(right):
- return left
- return _simplify_comparison(expression, left, right)
- elif isinstance(expression, exp.Or):
- if always_true(left) or always_true(right):
- return exp.true()
- if left == FALSE and right == FALSE:
- return exp.false()
- if (
- (left == NULL and right == NULL)
- or (left == NULL and right == FALSE)
- or (left == FALSE and right == NULL)
- ):
- return exp.null()
- if left == FALSE:
- return right
- if right == FALSE:
- return left
- return _simplify_comparison(expression, left, right, or_=True)
- return None
+ return _simplify_comparison(expression, left, right)
+ elif isinstance(expression, exp.Or):
+ if always_true(left) or always_true(right):
+ return exp.true()
+ if is_false(left) and is_false(right):
+ return exp.false()
+ if (
+ (is_null(left) and is_null(right))
+ or (is_null(left) and is_false(right))
+ or (is_false(left) and is_null(right))
+ ):
+ return exp.null()
+ if is_false(left):
+ return right
+ if is_false(right):
+ return left
+ return _simplify_comparison(expression, left, right, or_=True)
- return _flat_simplify(expression, _simplify_connectors)
+ if isinstance(expression, exp.Connector):
+ return _flat_simplify(expression, _simplify_connectors, root)
+ return expression
LT_LTE = (exp.LT, exp.LTE)
@@ -220,14 +221,14 @@ def _simplify_comparison(expression, left, right, or_=False):
return None
-def remove_compliments(expression):
+def remove_compliments(expression, root=True):
"""
Removing compliments.
A AND NOT A -> FALSE
A OR NOT A -> TRUE
"""
- if isinstance(expression, exp.Connector):
+ if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2):
@@ -236,23 +237,23 @@ def remove_compliments(expression):
return expression
-def uniq_sort(expression):
+def uniq_sort(expression, cache=None, root=True):
"""
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
"""
- if isinstance(expression, exp.Connector):
+ if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
- deduped = {GENERATOR.generate(e): e for e in flattened}
+ deduped = {GENERATOR.generate(e, cache): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
# A AND C AND B -> A AND B AND C
for i, (sql, e) in enumerate(arr[1:]):
if sql < arr[i][0]:
- expression = result_func(*(deduped[sql] for sql in sorted(deduped)))
+ expression = result_func(*(e for _, e in sorted(arr)))
break
else:
# we didn't have to sort but maybe we need to dedup
@@ -262,7 +263,7 @@ def uniq_sort(expression):
return expression
-def absorb_and_eliminate(expression):
+def absorb_and_eliminate(expression, root=True):
"""
absorption:
A AND (A OR B) -> A
@@ -273,7 +274,7 @@ def absorb_and_eliminate(expression):
(A AND B) OR (A AND NOT B) -> A
(A OR B) AND (A OR NOT B) -> A
"""
- if isinstance(expression, exp.Connector):
+ if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
kind = exp.Or if isinstance(expression, exp.And) else exp.And
for a, b in itertools.permutations(expression.flatten(), 2):
@@ -302,9 +303,9 @@ def absorb_and_eliminate(expression):
return expression
-def simplify_literals(expression):
- if isinstance(expression, exp.Binary):
- return _flat_simplify(expression, _simplify_binary)
+def simplify_literals(expression, root=True):
+ if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
+ return _flat_simplify(expression, _simplify_binary, root)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
@@ -325,14 +326,14 @@ def _simplify_binary(expression, a, b):
c = b
not_ = False
- if c == NULL:
+ if is_null(c):
if isinstance(a, exp.Literal):
return exp.true() if not_ else exp.false()
- if a == NULL:
+ if is_null(a):
return exp.false() if not_ else exp.true()
elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
return None
- elif NULL in (a, b):
+ elif is_null(a) or is_null(b):
return exp.null()
if a.is_number and b.is_number:
@@ -355,7 +356,7 @@ def _simplify_binary(expression, a, b):
if boolean:
return boolean
elif a.is_string and b.is_string:
- boolean = eval_boolean(expression, a, b)
+ boolean = eval_boolean(expression, a.this, b.this)
if boolean:
return boolean
@@ -381,7 +382,7 @@ def simplify_parens(expression):
and not isinstance(expression.this, exp.Select)
and (
not isinstance(expression.parent, (exp.Condition, exp.Binary))
- or isinstance(expression.this, (exp.Is, exp.Like))
+ or isinstance(expression.this, exp.Predicate)
or not isinstance(expression.this, exp.Binary)
)
):
@@ -400,13 +401,23 @@ def remove_where_true(expression):
def always_true(expression):
- return expression == TRUE or isinstance(expression, exp.Literal)
+ return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
+ expression, exp.Literal
+ )
def is_complement(a, b):
return isinstance(b, exp.Not) and b.this == a
+def is_false(a: exp.Expression) -> bool:
+ return type(a) is exp.Boolean and not a.this
+
+
+def is_null(a: exp.Expression) -> bool:
+ return type(a) is exp.Null
+
+
def eval_boolean(expression, a, b):
if isinstance(expression, (exp.EQ, exp.Is)):
return boolean_literal(a == b)
@@ -466,24 +477,27 @@ def boolean_literal(condition):
return exp.true() if condition else exp.false()
-def _flat_simplify(expression, simplifier):
- operands = []
- queue = deque(expression.flatten(unnest=False))
- size = len(queue)
+def _flat_simplify(expression, simplifier, root=True):
+ if root or not expression.same_parent:
+ operands = []
+ queue = deque(expression.flatten(unnest=False))
+ size = len(queue)
- while queue:
- a = queue.popleft()
+ while queue:
+ a = queue.popleft()
- for b in queue:
- result = simplifier(expression, a, b)
+ for b in queue:
+ result = simplifier(expression, a, b)
- if result:
- queue.remove(b)
- queue.append(result)
- break
- else:
- operands.append(a)
+ if result:
+ queue.remove(b)
+ queue.append(result)
+ break
+ else:
+ operands.append(a)
- if len(operands) < size:
- return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
+ if len(operands) < size:
+ return functools.reduce(
+ lambda a, b: expression.__class__(this=a, expression=b), operands
+ )
return expression