summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/simplify.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r--sqlglot/optimizer/simplify.py77
1 files changed, 70 insertions, 7 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 30de75b..af03332 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -7,8 +7,7 @@ from decimal import Decimal
import sqlglot
from sqlglot import exp
-from sqlglot.generator import cached_generator
-from sqlglot.helper import first, merge_ranges, while_changing
+from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
# Final means that an expression should not be simplified
@@ -37,8 +36,6 @@ def simplify(expression, constant_propagation=False):
sqlglot.Expression: simplified expression
"""
- generate = cached_generator()
-
# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
@@ -67,7 +64,7 @@ def simplify(expression, constant_propagation=False):
# Pre-order transformations
node = expression
node = rewrite_between(node)
- node = uniq_sort(node, generate, root)
+ node = uniq_sort(node, root)
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)
node = simplify_conditionals(node)
@@ -311,7 +308,7 @@ def remove_complements(expression, root=True):
return expression
-def uniq_sort(expression, generate, root=True):
+def uniq_sort(expression, root=True):
"""
Uniq and sort a connector.
@@ -320,7 +317,7 @@ def uniq_sort(expression, generate, root=True):
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
- deduped = {generate(e): e for e in flattened}
+ deduped = {gen(e): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
@@ -1070,3 +1067,69 @@ def _flat_simplify(expression, simplifier, root=True):
lambda a, b: expression.__class__(this=a, expression=b), operands
)
return expression
+
+
+def gen(expression: t.Any) -> str:
+ """Simple pseudo sql generator for quickly generating sortable and uniq strings.
+
+ Sorting and deduping sql is a necessary step for optimization. Calling the actual
+ generator is expensive so we have a bare minimum sql generator here.
+ """
+ if expression is None:
+ return "_"
+ if is_iterable(expression):
+ return ",".join(gen(e) for e in expression)
+ if not isinstance(expression, exp.Expression):
+ return str(expression)
+
+ etype = type(expression)
+ if etype in GEN_MAP:
+ return GEN_MAP[etype](expression)
+ return f"{expression.key} {gen(expression.args.values())}"
+
+
+GEN_MAP = {
+ exp.Add: lambda e: _binary(e, "+"),
+ exp.And: lambda e: _binary(e, "AND"),
+ exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}",
+ exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
+ exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
+ exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
+ exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
+ exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
+ exp.Div: lambda e: _binary(e, "/"),
+ exp.Dot: lambda e: _binary(e, "."),
+ exp.DPipe: lambda e: _binary(e, "||"),
+ exp.SafeDPipe: lambda e: _binary(e, "||"),
+ exp.EQ: lambda e: _binary(e, "="),
+ exp.GT: lambda e: _binary(e, ">"),
+ exp.GTE: lambda e: _binary(e, ">="),
+ exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
+ exp.ILike: lambda e: _binary(e, "ILIKE"),
+ exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
+ exp.Is: lambda e: _binary(e, "IS"),
+ exp.Like: lambda e: _binary(e, "LIKE"),
+ exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
+ exp.LT: lambda e: _binary(e, "<"),
+ exp.LTE: lambda e: _binary(e, "<="),
+ exp.Mod: lambda e: _binary(e, "%"),
+ exp.Mul: lambda e: _binary(e, "*"),
+ exp.Neg: lambda e: _unary(e, "-"),
+ exp.NEQ: lambda e: _binary(e, "<>"),
+ exp.Not: lambda e: _unary(e, "NOT"),
+ exp.Null: lambda e: "NULL",
+ exp.Or: lambda e: _binary(e, "OR"),
+ exp.Paren: lambda e: f"({gen(e.this)})",
+ exp.Sub: lambda e: _binary(e, "-"),
+ exp.Subquery: lambda e: f"({gen(e.args.values())})",
+ exp.Table: lambda e: gen(e.args.values()),
+ exp.Var: lambda e: e.name,
+}
+
+
+def _binary(e: exp.Binary, op: str) -> str:
+ return f"{gen(e.left)} {op} {gen(e.right)}"
+
+
+def _unary(e: exp.Unary, op: str) -> str:
+ return f"{op} {gen(e.this)}"