diff options
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/normalize.py | 27 | ||||
-rw-r--r-- | sqlglot/optimizer/normalize_identifiers.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_tables.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 77 |
4 files changed, 85 insertions, 25 deletions
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index 8d82b2d..6df36af 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -4,7 +4,6 @@ import logging from sqlglot import exp from sqlglot.errors import OptimizeError -from sqlglot.generator import cached_generator from sqlglot.helper import while_changing from sqlglot.optimizer.scope import find_all_in_scope from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort @@ -29,8 +28,6 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = Returns: sqlglot.Expression: normalized expression """ - generate = cached_generator() - for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): if isinstance(node, exp.Connector): if normalized(node, dnf=dnf): @@ -49,7 +46,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = try: node = node.replace( - while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate)) + while_changing(node, lambda e: distributive_law(e, dnf, max_distance)) ) except OptimizeError as e: logger.info(e) @@ -133,7 +130,7 @@ def _predicate_lengths(expression, dnf): return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) -def distributive_law(expression, dnf, max_distance, generate): +def distributive_law(expression, dnf, max_distance): """ x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) @@ -146,7 +143,7 @@ def distributive_law(expression, dnf, max_distance, generate): if distance > max_distance: raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") - exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate)) + exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) if isinstance(expression, from_exp): @@ -157,30 +154,30 @@ def distributive_law(expression, dnf, max_distance, generate): if isinstance(a, to_exp) and isinstance(b, to_exp): if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): - return _distribute(a, b, from_func, to_func, generate) - return _distribute(b, a, from_func, to_func, generate) + return _distribute(a, b, from_func, to_func) + return _distribute(b, a, from_func, to_func) if isinstance(a, to_exp): - return _distribute(b, a, from_func, to_func, generate) + return _distribute(b, a, from_func, to_func) if isinstance(b, to_exp): - return _distribute(a, b, from_func, to_func, generate) + return _distribute(a, b, from_func, to_func) return expression -def _distribute(a, b, from_func, to_func, generate): +def _distribute(a, b, from_func, to_func): if isinstance(a, exp.Connector): exp.replace_children( a, lambda c: to_func( - uniq_sort(flatten(from_func(c, b.left)), generate), - uniq_sort(flatten(from_func(c, b.right)), generate), + uniq_sort(flatten(from_func(c, b.left))), + uniq_sort(flatten(from_func(c, b.right))), copy=False, ), ) else: a = to_func( - uniq_sort(flatten(from_func(a, b.left)), generate), - uniq_sort(flatten(from_func(a, b.right)), generate), + uniq_sort(flatten(from_func(a, b.left))), + uniq_sort(flatten(from_func(a, b.right))), copy=False, ) diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index ecea6a0..154256e 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing as t -from sqlglot import exp, parse_one +from sqlglot import exp from sqlglot._typing import E from sqlglot.dialects.dialect import Dialect, DialectType @@ -49,7 +49,7 @@ def normalize_identifiers(expression, dialect=None): The transformed expression. """ if isinstance(expression, str): - expression = parse_one(expression, dialect=dialect, into=exp.Identifier) + expression = exp.parse_identifier(expression, dialect=dialect) dialect = Dialect.get_or_raise(dialect) diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 68aebdb..3a43e8f 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -62,7 +62,7 @@ def qualify_tables( if isinstance(source.this, exp.Identifier): if not source.args.get("db"): source.set("db", exp.to_identifier(db)) - if not source.args.get("catalog"): + if not source.args.get("catalog") and source.args.get("db"): source.set("catalog", exp.to_identifier(catalog)) if not source.alias: diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 30de75b..af03332 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -7,8 +7,7 @@ from decimal import Decimal import sqlglot from sqlglot import exp -from sqlglot.generator import cached_generator -from sqlglot.helper import first, merge_ranges, while_changing +from sqlglot.helper import first, is_iterable, merge_ranges, while_changing from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope # Final means that an expression should not be simplified @@ -37,8 +36,6 @@ def simplify(expression, constant_propagation=False): sqlglot.Expression: simplified expression """ - generate = cached_generator() - # group by expressions cannot be simplified, for example # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 # the projection must exactly match the group by key @@ -67,7 +64,7 @@ def simplify(expression, constant_propagation=False): # Pre-order transformations node = expression node = rewrite_between(node) - node = uniq_sort(node, generate, root) + node = uniq_sort(node, root) node = absorb_and_eliminate(node, root) node = simplify_concat(node) node = simplify_conditionals(node) @@ -311,7 +308,7 @@ def remove_complements(expression, root=True): return expression -def uniq_sort(expression, generate, root=True): +def uniq_sort(expression, root=True): """ Uniq and sort a connector. @@ -320,7 +317,7 @@ def uniq_sort(expression, generate, root=True): if isinstance(expression, exp.Connector) and (root or not expression.same_parent): result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ flattened = tuple(expression.flatten()) - deduped = {generate(e): e for e in flattened} + deduped = {gen(e): e for e in flattened} arr = tuple(deduped.items()) # check if the operands are already sorted, if not sort them @@ -1070,3 +1067,69 @@ def _flat_simplify(expression, simplifier, root=True): lambda a, b: expression.__class__(this=a, expression=b), operands ) return expression + + +def gen(expression: t.Any) -> str: + """Simple pseudo sql generator for quickly generating sortable and uniq strings. + + Sorting and deduping sql is a necessary step for optimization. Calling the actual + generator is expensive so we have a bare minimum sql generator here. + """ + if expression is None: + return "_" + if is_iterable(expression): + return ",".join(gen(e) for e in expression) + if not isinstance(expression, exp.Expression): + return str(expression) + + etype = type(expression) + if etype in GEN_MAP: + return GEN_MAP[etype](expression) + return f"{expression.key} {gen(expression.args.values())}" + + +GEN_MAP = { + exp.Add: lambda e: _binary(e, "+"), + exp.And: lambda e: _binary(e, "AND"), + exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}", + exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}", + exp.Boolean: lambda e: "TRUE" if e.this else "FALSE", + exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]", + exp.Column: lambda e: ".".join(gen(p) for p in e.parts), + exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}", + exp.Div: lambda e: _binary(e, "/"), + exp.Dot: lambda e: _binary(e, "."), + exp.DPipe: lambda e: _binary(e, "||"), + exp.SafeDPipe: lambda e: _binary(e, "||"), + exp.EQ: lambda e: _binary(e, "="), + exp.GT: lambda e: _binary(e, ">"), + exp.GTE: lambda e: _binary(e, ">="), + exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name, + exp.ILike: lambda e: _binary(e, "ILIKE"), + exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})", + exp.Is: lambda e: _binary(e, "IS"), + exp.Like: lambda e: _binary(e, "LIKE"), + exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name, + exp.LT: lambda e: _binary(e, "<"), + exp.LTE: lambda e: _binary(e, "<="), + exp.Mod: lambda e: _binary(e, "%"), + exp.Mul: lambda e: _binary(e, "*"), + exp.Neg: lambda e: _unary(e, "-"), + exp.NEQ: lambda e: _binary(e, "<>"), + exp.Not: lambda e: _unary(e, "NOT"), + exp.Null: lambda e: "NULL", + exp.Or: lambda e: _binary(e, "OR"), + exp.Paren: lambda e: f"({gen(e.this)})", + exp.Sub: lambda e: _binary(e, "-"), + exp.Subquery: lambda e: f"({gen(e.args.values())})", + exp.Table: lambda e: gen(e.args.values()), + exp.Var: lambda e: e.name, +} + + +def _binary(e: exp.Binary, op: str) -> str: + return f"{gen(e.left)} {op} {gen(e.right)}" + + +def _unary(e: exp.Unary, op: str) -> str: + return f"{op} {gen(e.this)}" |