summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/normalize.py27
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py4
-rw-r--r--sqlglot/optimizer/qualify_tables.py2
-rw-r--r--sqlglot/optimizer/simplify.py77
4 files changed, 85 insertions, 25 deletions
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index 8d82b2d..6df36af 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -4,7 +4,6 @@ import logging
from sqlglot import exp
from sqlglot.errors import OptimizeError
-from sqlglot.generator import cached_generator
from sqlglot.helper import while_changing
from sqlglot.optimizer.scope import find_all_in_scope
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
@@ -29,8 +28,6 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
Returns:
sqlglot.Expression: normalized expression
"""
- generate = cached_generator()
-
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
if isinstance(node, exp.Connector):
if normalized(node, dnf=dnf):
@@ -49,7 +46,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
try:
node = node.replace(
- while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
+ while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
)
except OptimizeError as e:
logger.info(e)
@@ -133,7 +130,7 @@ def _predicate_lengths(expression, dnf):
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
-def distributive_law(expression, dnf, max_distance, generate):
+def distributive_law(expression, dnf, max_distance):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
@@ -146,7 +143,7 @@ def distributive_law(expression, dnf, max_distance, generate):
if distance > max_distance:
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
- exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
+ exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
if isinstance(expression, from_exp):
@@ -157,30 +154,30 @@ def distributive_law(expression, dnf, max_distance, generate):
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
- return _distribute(a, b, from_func, to_func, generate)
- return _distribute(b, a, from_func, to_func, generate)
+ return _distribute(a, b, from_func, to_func)
+ return _distribute(b, a, from_func, to_func)
if isinstance(a, to_exp):
- return _distribute(b, a, from_func, to_func, generate)
+ return _distribute(b, a, from_func, to_func)
if isinstance(b, to_exp):
- return _distribute(a, b, from_func, to_func, generate)
+ return _distribute(a, b, from_func, to_func)
return expression
-def _distribute(a, b, from_func, to_func, generate):
+def _distribute(a, b, from_func, to_func):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
- uniq_sort(flatten(from_func(c, b.left)), generate),
- uniq_sort(flatten(from_func(c, b.right)), generate),
+ uniq_sort(flatten(from_func(c, b.left))),
+ uniq_sort(flatten(from_func(c, b.right))),
copy=False,
),
)
else:
a = to_func(
- uniq_sort(flatten(from_func(a, b.left)), generate),
- uniq_sort(flatten(from_func(a, b.right)), generate),
+ uniq_sort(flatten(from_func(a, b.left))),
+ uniq_sort(flatten(from_func(a, b.right))),
copy=False,
)
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
index ecea6a0..154256e 100644
--- a/sqlglot/optimizer/normalize_identifiers.py
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp, parse_one
+from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
@@ -49,7 +49,7 @@ def normalize_identifiers(expression, dialect=None):
The transformed expression.
"""
if isinstance(expression, str):
- expression = parse_one(expression, dialect=dialect, into=exp.Identifier)
+ expression = exp.parse_identifier(expression, dialect=dialect)
dialect = Dialect.get_or_raise(dialect)
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 68aebdb..3a43e8f 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -62,7 +62,7 @@ def qualify_tables(
if isinstance(source.this, exp.Identifier):
if not source.args.get("db"):
source.set("db", exp.to_identifier(db))
- if not source.args.get("catalog"):
+ if not source.args.get("catalog") and source.args.get("db"):
source.set("catalog", exp.to_identifier(catalog))
if not source.alias:
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 30de75b..af03332 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -7,8 +7,7 @@ from decimal import Decimal
import sqlglot
from sqlglot import exp
-from sqlglot.generator import cached_generator
-from sqlglot.helper import first, merge_ranges, while_changing
+from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
# Final means that an expression should not be simplified
@@ -37,8 +36,6 @@ def simplify(expression, constant_propagation=False):
sqlglot.Expression: simplified expression
"""
- generate = cached_generator()
-
# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
@@ -67,7 +64,7 @@ def simplify(expression, constant_propagation=False):
# Pre-order transformations
node = expression
node = rewrite_between(node)
- node = uniq_sort(node, generate, root)
+ node = uniq_sort(node, root)
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)
node = simplify_conditionals(node)
@@ -311,7 +308,7 @@ def remove_complements(expression, root=True):
return expression
-def uniq_sort(expression, generate, root=True):
+def uniq_sort(expression, root=True):
"""
Uniq and sort a connector.
@@ -320,7 +317,7 @@ def uniq_sort(expression, generate, root=True):
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
- deduped = {generate(e): e for e in flattened}
+ deduped = {gen(e): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
@@ -1070,3 +1067,69 @@ def _flat_simplify(expression, simplifier, root=True):
lambda a, b: expression.__class__(this=a, expression=b), operands
)
return expression
+
+
+def gen(expression: t.Any) -> str:
+ """Simple pseudo sql generator for quickly generating sortable and uniq strings.
+
+ Sorting and deduping sql is a necessary step for optimization. Calling the actual
+ generator is expensive so we have a bare minimum sql generator here.
+ """
+ if expression is None:
+ return "_"
+ if is_iterable(expression):
+ return ",".join(gen(e) for e in expression)
+ if not isinstance(expression, exp.Expression):
+ return str(expression)
+
+ etype = type(expression)
+ if etype in GEN_MAP:
+ return GEN_MAP[etype](expression)
+ return f"{expression.key} {gen(expression.args.values())}"
+
+
+GEN_MAP = {
+ exp.Add: lambda e: _binary(e, "+"),
+ exp.And: lambda e: _binary(e, "AND"),
+ exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}",
+ exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
+ exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
+ exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
+ exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
+ exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
+ exp.Div: lambda e: _binary(e, "/"),
+ exp.Dot: lambda e: _binary(e, "."),
+ exp.DPipe: lambda e: _binary(e, "||"),
+ exp.SafeDPipe: lambda e: _binary(e, "||"),
+ exp.EQ: lambda e: _binary(e, "="),
+ exp.GT: lambda e: _binary(e, ">"),
+ exp.GTE: lambda e: _binary(e, ">="),
+ exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
+ exp.ILike: lambda e: _binary(e, "ILIKE"),
+ exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
+ exp.Is: lambda e: _binary(e, "IS"),
+ exp.Like: lambda e: _binary(e, "LIKE"),
+ exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
+ exp.LT: lambda e: _binary(e, "<"),
+ exp.LTE: lambda e: _binary(e, "<="),
+ exp.Mod: lambda e: _binary(e, "%"),
+ exp.Mul: lambda e: _binary(e, "*"),
+ exp.Neg: lambda e: _unary(e, "-"),
+ exp.NEQ: lambda e: _binary(e, "<>"),
+ exp.Not: lambda e: _unary(e, "NOT"),
+ exp.Null: lambda e: "NULL",
+ exp.Or: lambda e: _binary(e, "OR"),
+ exp.Paren: lambda e: f"({gen(e.this)})",
+ exp.Sub: lambda e: _binary(e, "-"),
+ exp.Subquery: lambda e: f"({gen(e.args.values())})",
+ exp.Table: lambda e: gen(e.args.values()),
+ exp.Var: lambda e: e.name,
+}
+
+
+def _binary(e: exp.Binary, op: str) -> str:
+ return f"{gen(e.left)} {op} {gen(e.right)}"
+
+
+def _unary(e: exp.Unary, op: str) -> str:
+ return f"{op} {gen(e.this)}"