diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/optimizer/simplify.py | 166 |
1 files changed, 90 insertions, 76 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index f80484d..1ed3ca2 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -5,11 +5,10 @@ from collections import deque from decimal import Decimal from sqlglot import exp -from sqlglot.expressions import FALSE, NULL, TRUE from sqlglot.generator import Generator from sqlglot.helper import first, while_changing -GENERATOR = Generator(normalize=True, identify=True) +GENERATOR = Generator(normalize=True, identify="safe") def simplify(expression): @@ -28,18 +27,20 @@ def simplify(expression): sqlglot.Expression: simplified expression """ + cache = {} + def _simplify(expression, root=True): node = expression node = rewrite_between(node) - node = uniq_sort(node) - node = absorb_and_eliminate(node) + node = uniq_sort(node, cache, root) + node = absorb_and_eliminate(node, root) exp.replace_children(node, lambda e: _simplify(e, False)) node = simplify_not(node) node = flatten(node) - node = simplify_connectors(node) - node = remove_compliments(node) + node = simplify_connectors(node, root) + node = remove_compliments(node, root) node.parent = expression.parent - node = simplify_literals(node) + node = simplify_literals(node, root) node = simplify_parens(node) if root: expression.replace(node) @@ -70,7 +71,7 @@ def simplify_not(expression): NOT (x AND y) -> NOT x OR NOT y """ if isinstance(expression, exp.Not): - if isinstance(expression.this, exp.Null): + if is_null(expression.this): return exp.null() if isinstance(expression.this, exp.Paren): condition = expression.this.unnest() @@ -78,11 +79,11 @@ def simplify_not(expression): return exp.or_(exp.not_(condition.left), exp.not_(condition.right)) if isinstance(condition, exp.Or): return exp.and_(exp.not_(condition.left), exp.not_(condition.right)) - if isinstance(condition, exp.Null): + if is_null(condition): return exp.null() if always_true(expression.this): return exp.false() - if expression.this == FALSE: + if is_false(expression.this): return exp.true() if isinstance(expression.this, exp.Not): # double negation @@ -104,42 +105,42 @@ def flatten(expression): return expression -def simplify_connectors(expression): +def simplify_connectors(expression, root=True): def _simplify_connectors(expression, left, right): - if isinstance(expression, exp.Connector): - if left == right: + if left == right: + return left + if isinstance(expression, exp.And): + if is_false(left) or is_false(right): + return exp.false() + if is_null(left) or is_null(right): + return exp.null() + if always_true(left) and always_true(right): + return exp.true() + if always_true(left): + return right + if always_true(right): return left - if isinstance(expression, exp.And): - if FALSE in (left, right): - return exp.false() - if NULL in (left, right): - return exp.null() - if always_true(left) and always_true(right): - return exp.true() - if always_true(left): - return right - if always_true(right): - return left - return _simplify_comparison(expression, left, right) - elif isinstance(expression, exp.Or): - if always_true(left) or always_true(right): - return exp.true() - if left == FALSE and right == FALSE: - return exp.false() - if ( - (left == NULL and right == NULL) - or (left == NULL and right == FALSE) - or (left == FALSE and right == NULL) - ): - return exp.null() - if left == FALSE: - return right - if right == FALSE: - return left - return _simplify_comparison(expression, left, right, or_=True) - return None + return _simplify_comparison(expression, left, right) + elif isinstance(expression, exp.Or): + if always_true(left) or always_true(right): + return exp.true() + if is_false(left) and is_false(right): + return exp.false() + if ( + (is_null(left) and is_null(right)) + or (is_null(left) and is_false(right)) + or (is_false(left) and is_null(right)) + ): + return exp.null() + if is_false(left): + return right + if is_false(right): + return left + return _simplify_comparison(expression, left, right, or_=True) - return _flat_simplify(expression, _simplify_connectors) + if isinstance(expression, exp.Connector): + return _flat_simplify(expression, _simplify_connectors, root) + return expression LT_LTE = (exp.LT, exp.LTE) @@ -220,14 +221,14 @@ def _simplify_comparison(expression, left, right, or_=False): return None -def remove_compliments(expression): +def remove_compliments(expression, root=True): """ Removing compliments. A AND NOT A -> FALSE A OR NOT A -> TRUE """ - if isinstance(expression, exp.Connector): + if isinstance(expression, exp.Connector) and (root or not expression.same_parent): compliment = exp.false() if isinstance(expression, exp.And) else exp.true() for a, b in itertools.permutations(expression.flatten(), 2): @@ -236,23 +237,23 @@ def remove_compliments(expression): return expression -def uniq_sort(expression): +def uniq_sort(expression, cache=None, root=True): """ Uniq and sort a connector. C AND A AND B AND B -> A AND B AND C """ - if isinstance(expression, exp.Connector): + if isinstance(expression, exp.Connector) and (root or not expression.same_parent): result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ flattened = tuple(expression.flatten()) - deduped = {GENERATOR.generate(e): e for e in flattened} + deduped = {GENERATOR.generate(e, cache): e for e in flattened} arr = tuple(deduped.items()) # check if the operands are already sorted, if not sort them # A AND C AND B -> A AND B AND C for i, (sql, e) in enumerate(arr[1:]): if sql < arr[i][0]: - expression = result_func(*(deduped[sql] for sql in sorted(deduped))) + expression = result_func(*(e for _, e in sorted(arr))) break else: # we didn't have to sort but maybe we need to dedup @@ -262,7 +263,7 @@ def uniq_sort(expression): return expression -def absorb_and_eliminate(expression): +def absorb_and_eliminate(expression, root=True): """ absorption: A AND (A OR B) -> A @@ -273,7 +274,7 @@ def absorb_and_eliminate(expression): (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A """ - if isinstance(expression, exp.Connector): + if isinstance(expression, exp.Connector) and (root or not expression.same_parent): kind = exp.Or if isinstance(expression, exp.And) else exp.And for a, b in itertools.permutations(expression.flatten(), 2): @@ -302,9 +303,9 @@ def absorb_and_eliminate(expression): return expression -def simplify_literals(expression): - if isinstance(expression, exp.Binary): - return _flat_simplify(expression, _simplify_binary) +def simplify_literals(expression, root=True): + if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): + return _flat_simplify(expression, _simplify_binary, root) elif isinstance(expression, exp.Neg): this = expression.this if this.is_number: @@ -325,14 +326,14 @@ def _simplify_binary(expression, a, b): c = b not_ = False - if c == NULL: + if is_null(c): if isinstance(a, exp.Literal): return exp.true() if not_ else exp.false() - if a == NULL: + if is_null(a): return exp.false() if not_ else exp.true() elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): return None - elif NULL in (a, b): + elif is_null(a) or is_null(b): return exp.null() if a.is_number and b.is_number: @@ -355,7 +356,7 @@ def _simplify_binary(expression, a, b): if boolean: return boolean elif a.is_string and b.is_string: - boolean = eval_boolean(expression, a, b) + boolean = eval_boolean(expression, a.this, b.this) if boolean: return boolean @@ -381,7 +382,7 @@ def simplify_parens(expression): and not isinstance(expression.this, exp.Select) and ( not isinstance(expression.parent, (exp.Condition, exp.Binary)) - or isinstance(expression.this, (exp.Is, exp.Like)) + or isinstance(expression.this, exp.Predicate) or not isinstance(expression.this, exp.Binary) ) ): @@ -400,13 +401,23 @@ def remove_where_true(expression): def always_true(expression): - return expression == TRUE or isinstance(expression, exp.Literal) + return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( + expression, exp.Literal + ) def is_complement(a, b): return isinstance(b, exp.Not) and b.this == a +def is_false(a: exp.Expression) -> bool: + return type(a) is exp.Boolean and not a.this + + +def is_null(a: exp.Expression) -> bool: + return type(a) is exp.Null + + def eval_boolean(expression, a, b): if isinstance(expression, (exp.EQ, exp.Is)): return boolean_literal(a == b) @@ -466,24 +477,27 @@ def boolean_literal(condition): return exp.true() if condition else exp.false() -def _flat_simplify(expression, simplifier): - operands = [] - queue = deque(expression.flatten(unnest=False)) - size = len(queue) +def _flat_simplify(expression, simplifier, root=True): + if root or not expression.same_parent: + operands = [] + queue = deque(expression.flatten(unnest=False)) + size = len(queue) - while queue: - a = queue.popleft() + while queue: + a = queue.popleft() - for b in queue: - result = simplifier(expression, a, b) + for b in queue: + result = simplifier(expression, a, b) - if result: - queue.remove(b) - queue.append(result) - break - else: - operands.append(a) + if result: + queue.remove(b) + queue.append(result) + break + else: + operands.append(a) - if len(operands) < size: - return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands) + if len(operands) < size: + return functools.reduce( + lambda a, b: expression.__class__(this=a, expression=b), operands + ) return expression |