summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/simplify.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r--sqlglot/optimizer/simplify.py383
1 files changed, 383 insertions, 0 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
new file mode 100644
index 0000000..6771153
--- /dev/null
+++ b/sqlglot/optimizer/simplify.py
@@ -0,0 +1,383 @@
+import datetime
+import functools
+import itertools
+from collections import deque
+from decimal import Decimal
+
+from sqlglot import exp
+from sqlglot.expressions import FALSE, NULL, TRUE
+from sqlglot.generator import Generator
+from sqlglot.helper import while_changing
+
+GENERATOR = Generator(normalize=True, identify=True)
+
+
+def simplify(expression):
+ """
+ Rewrite sqlglot AST to simplify expressions.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("TRUE AND TRUE")
+ >>> simplify(expression).sql()
+ 'TRUE'
+
+ Args:
+ expression (sqlglot.Expression): expression to simplify
+ Returns:
+ sqlglot.Expression: simplified expression
+ """
+
+ def _simplify(expression, root=True):
+ node = expression
+ node = uniq_sort(node)
+ node = absorb_and_eliminate(node)
+ exp.replace_children(node, lambda e: _simplify(e, False))
+ node = simplify_not(node)
+ node = flatten(node)
+ node = simplify_connectors(node)
+ node = remove_compliments(node)
+ node.parent = expression.parent
+ node = simplify_literals(node)
+ node = simplify_parens(node)
+ if root:
+ expression.replace(node)
+ return node
+
+ expression = while_changing(expression, _simplify)
+ remove_where_true(expression)
+ return expression
+
+
+def simplify_not(expression):
+ """
+ Demorgan's Law
+ NOT (x OR y) -> NOT x AND NOT y
+ NOT (x AND y) -> NOT x OR NOT y
+ """
+ if isinstance(expression, exp.Not):
+ if isinstance(expression.this, exp.Paren):
+ condition = expression.this.unnest()
+ if isinstance(condition, exp.And):
+ return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
+ if isinstance(condition, exp.Or):
+ return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
+ if always_true(expression.this):
+ return FALSE
+ if expression.this == FALSE:
+ return TRUE
+ if isinstance(expression.this, exp.Not):
+ # double negation
+ # NOT NOT x -> x
+ return expression.this.this
+ return expression
+
+
+def flatten(expression):
+ """
+ A AND (B AND C) -> A AND B AND C
+ A OR (B OR C) -> A OR B OR C
+ """
+ if isinstance(expression, exp.Connector):
+ for node in expression.args.values():
+ child = node.unnest()
+ if isinstance(child, expression.__class__):
+ node.replace(child)
+ return expression
+
+
+def simplify_connectors(expression):
+ if isinstance(expression, exp.Connector):
+ left = expression.left
+ right = expression.right
+
+ if left == right:
+ return left
+
+ if isinstance(expression, exp.And):
+ if NULL in (left, right):
+ return NULL
+ if FALSE in (left, right):
+ return FALSE
+ if always_true(left) and always_true(right):
+ return TRUE
+ if always_true(left):
+ return right
+ if always_true(right):
+ return left
+ elif isinstance(expression, exp.Or):
+ if always_true(left) or always_true(right):
+ return TRUE
+ if left == FALSE and right == FALSE:
+ return FALSE
+ if (
+ (left == NULL and right == NULL)
+ or (left == NULL and right == FALSE)
+ or (left == FALSE and right == NULL)
+ ):
+ return NULL
+ if left == FALSE:
+ return right
+ if right == FALSE:
+ return left
+ return expression
+
+
+def remove_compliments(expression):
+ """
+ Removing compliments.
+
+ A AND NOT A -> FALSE
+ A OR NOT A -> TRUE
+ """
+ if isinstance(expression, exp.Connector):
+ compliment = FALSE if isinstance(expression, exp.And) else TRUE
+
+ for a, b in itertools.permutations(expression.flatten(), 2):
+ if is_complement(a, b):
+ return compliment
+ return expression
+
+
+def uniq_sort(expression):
+ """
+ Uniq and sort a connector.
+
+ C AND A AND B AND B -> A AND B AND C
+ """
+ if isinstance(expression, exp.Connector):
+ result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
+ flattened = tuple(expression.flatten())
+ deduped = {GENERATOR.generate(e): e for e in flattened}
+ arr = tuple(deduped.items())
+
+ # check if the operands are already sorted, if not sort them
+ # A AND C AND B -> A AND B AND C
+ for i, (sql, e) in enumerate(arr[1:]):
+ if sql < arr[i][0]:
+ expression = result_func(*(deduped[sql] for sql in sorted(deduped)))
+ break
+ else:
+ # we didn't have to sort but maybe we need to dedup
+ if len(deduped) < len(flattened):
+ expression = result_func(*deduped.values())
+
+ return expression
+
+
+def absorb_and_eliminate(expression):
+ """
+ absorption:
+ A AND (A OR B) -> A
+ A OR (A AND B) -> A
+ A AND (NOT A OR B) -> A AND B
+ A OR (NOT A AND B) -> A OR B
+ elimination:
+ (A AND B) OR (A AND NOT B) -> A
+ (A OR B) AND (A OR NOT B) -> A
+ """
+ if isinstance(expression, exp.Connector):
+ kind = exp.Or if isinstance(expression, exp.And) else exp.And
+
+ for a, b in itertools.permutations(expression.flatten(), 2):
+ if isinstance(a, kind):
+ aa, ab = a.unnest_operands()
+
+ # absorb
+ if is_complement(b, aa):
+ aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
+ elif is_complement(b, ab):
+ ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
+ elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(
+ a.flatten()
+ ):
+ a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
+ elif isinstance(b, kind):
+ # eliminate
+ rhs = b.unnest_operands()
+ ba, bb = rhs
+
+ if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
+ a.replace(aa)
+ b.replace(aa)
+ elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
+ a.replace(ab)
+ b.replace(ab)
+
+ return expression
+
+
+def simplify_literals(expression):
+ if isinstance(expression, exp.Binary):
+ operands = []
+ queue = deque(expression.flatten(unnest=False))
+ size = len(queue)
+
+ while queue:
+ a = queue.popleft()
+
+ for b in queue:
+ result = _simplify_binary(expression, a, b)
+
+ if result:
+ queue.remove(b)
+ queue.append(result)
+ break
+ else:
+ operands.append(a)
+
+ if len(operands) < size:
+ return functools.reduce(
+ lambda a, b: expression.__class__(this=a, expression=b), operands
+ )
+ elif isinstance(expression, exp.Neg):
+ this = expression.this
+ if this.is_number:
+ value = this.name
+ if value[0] == "-":
+ return exp.Literal.number(value[1:])
+ return exp.Literal.number(f"-{value}")
+
+ return expression
+
+
+def _simplify_binary(expression, a, b):
+ if isinstance(expression, exp.Is):
+ if isinstance(b, exp.Not):
+ c = b.this
+ not_ = True
+ else:
+ c = b
+ not_ = False
+
+ if c == NULL:
+ if isinstance(a, exp.Literal):
+ return TRUE if not_ else FALSE
+ if a == NULL:
+ return FALSE if not_ else TRUE
+ elif NULL in (a, b):
+ return NULL
+
+ if isinstance(expression, exp.EQ) and a == b:
+ return TRUE
+
+ if a.is_number and b.is_number:
+ a = int(a.name) if a.is_int else Decimal(a.name)
+ b = int(b.name) if b.is_int else Decimal(b.name)
+
+ if isinstance(expression, exp.Add):
+ return exp.Literal.number(a + b)
+ if isinstance(expression, exp.Sub):
+ return exp.Literal.number(a - b)
+ if isinstance(expression, exp.Mul):
+ return exp.Literal.number(a * b)
+ if isinstance(expression, exp.Div):
+ if isinstance(a, int) and isinstance(b, int):
+ return exp.Literal.number(a // b)
+ return exp.Literal.number(a / b)
+
+ boolean = eval_boolean(expression, a, b)
+
+ if boolean:
+ return boolean
+ elif a.is_string and b.is_string:
+ boolean = eval_boolean(expression, a, b)
+
+ if boolean:
+ return boolean
+ elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
+ a, b = extract_date(a), extract_interval(b)
+ if b:
+ if isinstance(expression, exp.Add):
+ return date_literal(a + b)
+ if isinstance(expression, exp.Sub):
+ return date_literal(a - b)
+ elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
+ a, b = extract_interval(a), extract_date(b)
+ # you cannot subtract a date from an interval
+ if a and isinstance(expression, exp.Add):
+ return date_literal(a + b)
+
+ return None
+
+
+def simplify_parens(expression):
+ if (
+ isinstance(expression, exp.Paren)
+ and not isinstance(expression.this, exp.Select)
+ and (
+ not isinstance(expression.parent, (exp.Condition, exp.Binary))
+ or isinstance(expression.this, (exp.Is, exp.Like))
+ or not isinstance(expression.this, exp.Binary)
+ )
+ ):
+ return expression.this
+ return expression
+
+
+def remove_where_true(expression):
+ for where in expression.find_all(exp.Where):
+ if always_true(where.this):
+ where.parent.set("where", None)
+ for join in expression.find_all(exp.Join):
+ if always_true(join.args.get("on")):
+ join.set("kind", "CROSS")
+ join.set("on", None)
+
+
+def always_true(expression):
+ return expression == TRUE or isinstance(expression, exp.Literal)
+
+
+def is_complement(a, b):
+ return isinstance(b, exp.Not) and b.this == a
+
+
+def eval_boolean(expression, a, b):
+ if isinstance(expression, (exp.EQ, exp.Is)):
+ return boolean_literal(a == b)
+ if isinstance(expression, exp.NEQ):
+ return boolean_literal(a != b)
+ if isinstance(expression, exp.GT):
+ return boolean_literal(a > b)
+ if isinstance(expression, exp.GTE):
+ return boolean_literal(a >= b)
+ if isinstance(expression, exp.LT):
+ return boolean_literal(a < b)
+ if isinstance(expression, exp.LTE):
+ return boolean_literal(a <= b)
+ return None
+
+
+def extract_date(cast):
+ if cast.args["to"].this == exp.DataType.Type.DATE:
+ return datetime.date.fromisoformat(cast.name)
+ return None
+
+
+def extract_interval(interval):
+ try:
+ from dateutil.relativedelta import relativedelta
+ except ModuleNotFoundError:
+ return None
+
+ n = int(interval.name)
+ unit = interval.text("unit").lower()
+
+ if unit == "year":
+ return relativedelta(years=n)
+ if unit == "month":
+ return relativedelta(months=n)
+ if unit == "week":
+ return relativedelta(weeks=n)
+ if unit == "day":
+ return relativedelta(days=n)
+ return None
+
+
+def date_literal(date):
+ return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE"))
+
+
+def boolean_literal(condition):
+ return TRUE if condition else FALSE