summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/simplify.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r--sqlglot/optimizer/simplify.py32
1 files changed, 17 insertions, 15 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 0904189..e2772a0 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -5,11 +5,9 @@ from collections import deque
from decimal import Decimal
from sqlglot import exp
-from sqlglot.generator import Generator
+from sqlglot.generator import cached_generator
from sqlglot.helper import first, while_changing
-GENERATOR = Generator(normalize=True, identify="safe")
-
def simplify(expression):
"""
@@ -27,12 +25,12 @@ def simplify(expression):
sqlglot.Expression: simplified expression
"""
- cache = {}
+ generate = cached_generator()
def _simplify(expression, root=True):
node = expression
node = rewrite_between(node)
- node = uniq_sort(node, cache, root)
+ node = uniq_sort(node, generate, root)
node = absorb_and_eliminate(node, root)
exp.replace_children(node, lambda e: _simplify(e, False))
node = simplify_not(node)
@@ -247,7 +245,7 @@ def remove_compliments(expression, root=True):
return expression
-def uniq_sort(expression, cache=None, root=True):
+def uniq_sort(expression, generate, root=True):
"""
Uniq and sort a connector.
@@ -256,7 +254,7 @@ def uniq_sort(expression, cache=None, root=True):
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
- deduped = {GENERATOR.generate(e, cache): e for e in flattened}
+ deduped = {generate(e): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
@@ -388,14 +386,18 @@ def _simplify_binary(expression, a, b):
def simplify_parens(expression):
- if (
- isinstance(expression, exp.Paren)
- and not isinstance(expression.this, exp.Select)
- and (
- not isinstance(expression.parent, (exp.Condition, exp.Binary))
- or isinstance(expression.this, exp.Predicate)
- or not isinstance(expression.this, exp.Binary)
- )
+ if not isinstance(expression, exp.Paren):
+ return expression
+
+ this = expression.this
+ parent = expression.parent
+
+ if not isinstance(this, exp.Select) and (
+ not isinstance(parent, (exp.Condition, exp.Binary))
+ or isinstance(this, exp.Predicate)
+ or not isinstance(this, exp.Binary)
+ or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
+ or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
):
return expression.this
return expression