summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/simplify.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r--sqlglot/optimizer/simplify.py93
1 files changed, 81 insertions, 12 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 51214c4..849643c 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -5,9 +5,11 @@ import typing as t
from collections import deque
from decimal import Decimal
+import sqlglot
from sqlglot import exp
from sqlglot.generator import cached_generator
from sqlglot.helper import first, merge_ranges, while_changing
+from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
# Final means that an expression should not be simplified
FINAL = "final"
@@ -17,7 +19,7 @@ class UnsupportedUnit(Exception):
pass
-def simplify(expression):
+def simplify(expression, constant_propagation=False):
"""
Rewrite sqlglot AST to simplify expressions.
@@ -29,6 +31,8 @@ def simplify(expression):
Args:
expression (sqlglot.Expression): expression to simplify
+ constant_propagation: whether or not the constant propagation rule should be used
+
Returns:
sqlglot.Expression: simplified expression
"""
@@ -67,13 +71,16 @@ def simplify(expression):
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)
+ if constant_propagation:
+ node = propagate_constants(node, root)
+
exp.replace_children(node, lambda e: _simplify(e, False))
# Post-order transformations
node = simplify_not(node)
node = flatten(node)
node = simplify_connectors(node, root)
- node = remove_compliments(node, root)
+ node = remove_complements(node, root)
node = simplify_coalesce(node)
node.parent = expression.parent
node = simplify_literals(node, root)
@@ -287,19 +294,19 @@ def _simplify_comparison(expression, left, right, or_=False):
return None
-def remove_compliments(expression, root=True):
+def remove_complements(expression, root=True):
"""
- Removing compliments.
+ Removing complements.
A AND NOT A -> FALSE
A OR NOT A -> TRUE
"""
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
- compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
+ complement = exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2):
if is_complement(a, b):
- return compliment
+ return complement
return expression
@@ -369,6 +376,51 @@ def absorb_and_eliminate(expression, root=True):
return expression
+def propagate_constants(expression, root=True):
+ """
+ Propagate constants for conjunctions in DNF:
+
+ SELECT * FROM t WHERE a = b AND b = 5 becomes
+ SELECT * FROM t WHERE a = 5 AND b = 5
+
+ Reference: https://www.sqlite.org/optoverview.html
+ """
+
+ if (
+ isinstance(expression, exp.And)
+ and (root or not expression.same_parent)
+ and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
+ ):
+ constant_mapping = {}
+ for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
+ if isinstance(expr, exp.EQ):
+ l, r = expr.left, expr.right
+
+ # TODO: create a helper that can be used to detect nested literal expressions such
+ # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
+ if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
+ pass
+ elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
+ l, r = r, l
+ else:
+ continue
+
+ constant_mapping[l] = (id(l), r)
+
+ if constant_mapping:
+ for column in find_all_in_scope(expression, exp.Column):
+ parent = column.parent
+ column_id, constant = constant_mapping.get(column) or (None, None)
+ if (
+ column_id is not None
+ and id(column) != column_id
+ and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
+ ):
+ column.replace(constant.copy())
+
+ return expression
+
+
INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
exp.DateAdd: exp.Sub,
exp.DateSub: exp.Add,
@@ -609,21 +661,38 @@ SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
def simplify_concat(expression):
"""Reduces all groups that contain string literals by concatenating them."""
- if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
+ if not isinstance(expression, CONCATS) or (
+ # We can't reduce a CONCAT_WS call if we don't statically know the separator
+ isinstance(expression, exp.ConcatWs)
+ and not expression.expressions[0].is_string
+ ):
return expression
+ if isinstance(expression, exp.ConcatWs):
+ sep_expr, *expressions = expression.expressions
+ sep = sep_expr.name
+ concat_type = exp.ConcatWs
+ else:
+ expressions = expression.expressions
+ sep = ""
+ concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
+
new_args = []
for is_string_group, group in itertools.groupby(
- expression.expressions or expression.flatten(), lambda e: e.is_string
+ expressions or expression.flatten(), lambda e: e.is_string
):
if is_string_group:
- new_args.append(exp.Literal.string("".join(string.name for string in group)))
+ new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
else:
new_args.extend(group)
- # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
- concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
- return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
+ if len(new_args) == 1 and new_args[0].is_string:
+ return new_args[0]
+
+ if concat_type is exp.ConcatWs:
+ new_args = [sep_expr] + new_args
+
+ return concat_type(expressions=new_args)
DateRange = t.Tuple[datetime.date, datetime.date]