diff options
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r-- | sqlglot/optimizer/simplify.py | 77 |
1 files changed, 70 insertions, 7 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 30de75b..af03332 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -7,8 +7,7 @@ from decimal import Decimal import sqlglot from sqlglot import exp -from sqlglot.generator import cached_generator -from sqlglot.helper import first, merge_ranges, while_changing +from sqlglot.helper import first, is_iterable, merge_ranges, while_changing from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope # Final means that an expression should not be simplified @@ -37,8 +36,6 @@ def simplify(expression, constant_propagation=False): sqlglot.Expression: simplified expression """ - generate = cached_generator() - # group by expressions cannot be simplified, for example # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 # the projection must exactly match the group by key @@ -67,7 +64,7 @@ def simplify(expression, constant_propagation=False): # Pre-order transformations node = expression node = rewrite_between(node) - node = uniq_sort(node, generate, root) + node = uniq_sort(node, root) node = absorb_and_eliminate(node, root) node = simplify_concat(node) node = simplify_conditionals(node) @@ -311,7 +308,7 @@ def remove_complements(expression, root=True): return expression -def uniq_sort(expression, generate, root=True): +def uniq_sort(expression, root=True): """ Uniq and sort a connector. @@ -320,7 +317,7 @@ def uniq_sort(expression, generate, root=True): if isinstance(expression, exp.Connector) and (root or not expression.same_parent): result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ flattened = tuple(expression.flatten()) - deduped = {generate(e): e for e in flattened} + deduped = {gen(e): e for e in flattened} arr = tuple(deduped.items()) # check if the operands are already sorted, if not sort them @@ -1070,3 +1067,69 @@ def _flat_simplify(expression, simplifier, root=True): lambda a, b: expression.__class__(this=a, expression=b), operands ) return expression + + +def gen(expression: t.Any) -> str: + """Simple pseudo sql generator for quickly generating sortable and uniq strings. + + Sorting and deduping sql is a necessary step for optimization. Calling the actual + generator is expensive so we have a bare minimum sql generator here. + """ + if expression is None: + return "_" + if is_iterable(expression): + return ",".join(gen(e) for e in expression) + if not isinstance(expression, exp.Expression): + return str(expression) + + etype = type(expression) + if etype in GEN_MAP: + return GEN_MAP[etype](expression) + return f"{expression.key} {gen(expression.args.values())}" + + +GEN_MAP = { + exp.Add: lambda e: _binary(e, "+"), + exp.And: lambda e: _binary(e, "AND"), + exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}", + exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}", + exp.Boolean: lambda e: "TRUE" if e.this else "FALSE", + exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]", + exp.Column: lambda e: ".".join(gen(p) for p in e.parts), + exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}", + exp.Div: lambda e: _binary(e, "/"), + exp.Dot: lambda e: _binary(e, "."), + exp.DPipe: lambda e: _binary(e, "||"), + exp.SafeDPipe: lambda e: _binary(e, "||"), + exp.EQ: lambda e: _binary(e, "="), + exp.GT: lambda e: _binary(e, ">"), + exp.GTE: lambda e: _binary(e, ">="), + exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name, + exp.ILike: lambda e: _binary(e, "ILIKE"), + exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})", + exp.Is: lambda e: _binary(e, "IS"), + exp.Like: lambda e: _binary(e, "LIKE"), + exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name, + exp.LT: lambda e: _binary(e, "<"), + exp.LTE: lambda e: _binary(e, "<="), + exp.Mod: lambda e: _binary(e, "%"), + exp.Mul: lambda e: _binary(e, "*"), + exp.Neg: lambda e: _unary(e, "-"), + exp.NEQ: lambda e: _binary(e, "<>"), + exp.Not: lambda e: _unary(e, "NOT"), + exp.Null: lambda e: "NULL", + exp.Or: lambda e: _binary(e, "OR"), + exp.Paren: lambda e: f"({gen(e.this)})", + exp.Sub: lambda e: _binary(e, "-"), + exp.Subquery: lambda e: f"({gen(e.args.values())})", + exp.Table: lambda e: gen(e.args.values()), + exp.Var: lambda e: e.name, +} + + +def _binary(e: exp.Binary, op: str) -> str: + return f"{gen(e.left)} {op} {gen(e.right)}" + + +def _unary(e: exp.Unary, op: str) -> str: + return f"{op} {gen(e.this)}" |