summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/simplify.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/optimizer/simplify.py55
1 files changed, 31 insertions, 24 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 6ae08d0..f53023c 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -49,32 +49,32 @@ def simplify(
dialect = Dialect.get_or_raise(dialect)
- # group by expressions cannot be simplified, for example
- # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
- # the projection must exactly match the group by key
- for group in expression.find_all(exp.Group):
- select = group.parent
- assert select
- groups = set(group.expressions)
- group.meta[FINAL] = True
-
- for e in select.expressions:
- for node, *_ in e.walk():
- if node in groups:
- e.meta[FINAL] = True
- break
-
- having = select.args.get("having")
- if having:
- for node, *_ in having.walk():
- if node in groups:
- having.meta[FINAL] = True
- break
-
def _simplify(expression, root=True):
if expression.meta.get(FINAL):
return expression
+ # group by expressions cannot be simplified, for example
+ # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
+ # the projection must exactly match the group by key
+ group = expression.args.get("group")
+
+ if group and hasattr(expression, "selects"):
+ groups = set(group.expressions)
+ group.meta[FINAL] = True
+
+ for e in expression.selects:
+ for node, *_ in e.walk():
+ if node in groups:
+ e.meta[FINAL] = True
+ break
+
+ having = expression.args.get("having")
+ if having:
+ for node, *_ in having.walk():
+ if node in groups:
+ having.meta[FINAL] = True
+ break
+
# Pre-order transformations
node = expression
node = rewrite_between(node)
@@ -266,6 +266,8 @@ INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
exp.GTE: exp.LTE,
}
+NONDETERMINISTIC = (exp.Rand, exp.Randn)
+
def _simplify_comparison(expression, left, right, or_=False):
if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
@@ -276,7 +278,7 @@ def _simplify_comparison(expression, left, right, or_=False):
rargs = {rl, rr}
matching = largs & rargs
- columns = {m for m in matching if isinstance(m, exp.Column)}
+ columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
if matching and columns:
try:
@@ -292,7 +294,12 @@ def _simplify_comparison(expression, left, right, or_=False):
l = l.name
r = r.name
else:
- return None
+ l = extract_date(l)
+ if not l:
+ return None
+ r = extract_date(r)
+ if not r:
+ return None
for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):