diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/optimizer/simplify.py | 55 |
1 files changed, 31 insertions, 24 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 6ae08d0..f53023c 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -49,32 +49,32 @@ def simplify( dialect = Dialect.get_or_raise(dialect) - # group by expressions cannot be simplified, for example - # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 - # the projection must exactly match the group by key - for group in expression.find_all(exp.Group): - select = group.parent - assert select - groups = set(group.expressions) - group.meta[FINAL] = True - - for e in select.expressions: - for node, *_ in e.walk(): - if node in groups: - e.meta[FINAL] = True - break - - having = select.args.get("having") - if having: - for node, *_ in having.walk(): - if node in groups: - having.meta[FINAL] = True - break - def _simplify(expression, root=True): if expression.meta.get(FINAL): return expression + # group by expressions cannot be simplified, for example + # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 + # the projection must exactly match the group by key + group = expression.args.get("group") + + if group and hasattr(expression, "selects"): + groups = set(group.expressions) + group.meta[FINAL] = True + + for e in expression.selects: + for node, *_ in e.walk(): + if node in groups: + e.meta[FINAL] = True + break + + having = expression.args.get("having") + if having: + for node, *_ in having.walk(): + if node in groups: + having.meta[FINAL] = True + break + # Pre-order transformations node = expression node = rewrite_between(node) @@ -266,6 +266,8 @@ INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { exp.GTE: exp.LTE, } +NONDETERMINISTIC = (exp.Rand, exp.Randn) + def _simplify_comparison(expression, left, right, or_=False): if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): @@ -276,7 +278,7 @@ def _simplify_comparison(expression, left, right, or_=False): rargs = {rl, rr} matching = largs & rargs - columns = {m for m in matching if isinstance(m, exp.Column)} + columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} if matching and columns: try: @@ -292,7 +294,12 @@ def _simplify_comparison(expression, left, right, or_=False): l = l.name r = r.name else: - return None + l = extract_date(l) + if not l: + return None + r = extract_date(r) + if not r: + return None for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): |