diff options
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 12 | ||||
-rw-r--r-- | sqlglot/optimizer/canonicalize.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/eliminate_joins.py | 5 | ||||
-rw-r--r-- | sqlglot/optimizer/eliminate_subqueries.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/lower_identities.py | 8 | ||||
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py | 5 | ||||
-rw-r--r-- | sqlglot/optimizer/normalize.py | 104 | ||||
-rw-r--r-- | sqlglot/optimizer/optimize_joins.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/optimizer.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 88 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 166 |
12 files changed, 237 insertions, 163 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index c2d6655..99888c6 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,5 +1,5 @@ from sqlglot import exp -from sqlglot.helper import ensure_collection, ensure_list, subclasses +from sqlglot.helper import ensure_list, subclasses from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import ensure_schema @@ -108,6 +108,7 @@ class TypeAnnotator: exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), + exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.GroupConcat: lambda self, expr: self._annotate_with_type( expr, exp.DataType.Type.VARCHAR @@ -116,6 +117,7 @@ class TypeAnnotator: expr, exp.DataType.Type.VARCHAR ), exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL), exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"), exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), @@ -296,9 +298,6 @@ class TypeAnnotator: return self._maybe_annotate(expression) # This takes care of non-traversable expressions def _maybe_annotate(self, expression): - if not isinstance(expression, exp.Expression): - return None - if expression.type: return expression # We've already inferred the expression's type @@ -311,9 +310,8 @@ class TypeAnnotator: ) def _annotate_args(self, expression): - for value in expression.args.values(): - for v in ensure_collection(value): - self._maybe_annotate(v) + for _, value in expression.iter_expressions(): + self._maybe_annotate(value) return expression diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index c5c780d..ef929ac 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -75,7 +75,7 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: a.type and a.type.this == exp.DataType.Type.DATE and b.type - and b.type.this != exp.DataType.Type.DATE + and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL) ): _replace_cast(b, "date") diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index 8e6a520..e0ddfa2 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -1,7 +1,6 @@ from sqlglot import expressions as exp from sqlglot.optimizer.normalize import normalized from sqlglot.optimizer.scope import Scope, traverse_scope -from sqlglot.optimizer.simplify import simplify def eliminate_joins(expression): @@ -179,6 +178,4 @@ def join_condition(join): for condition in conditions: extract_condition(condition) - on = simplify(on) - remaining_condition = None if on == exp.true() else on - return source_key, join_key, remaining_condition + return source_key, join_key, on diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 6f9db82..a39fe96 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -3,7 +3,6 @@ import itertools from sqlglot import expressions as exp from sqlglot.helper import find_new_name from sqlglot.optimizer.scope import build_scope -from sqlglot.optimizer.simplify import simplify def eliminate_subqueries(expression): @@ -31,7 +30,6 @@ def eliminate_subqueries(expression): eliminate_subqueries(expression.this) return expression - expression = simplify(expression) root = build_scope(expression) # Map of alias->Scope|Table diff --git a/sqlglot/optimizer/lower_identities.py b/sqlglot/optimizer/lower_identities.py index 1cc76cf..fae1726 100644 --- a/sqlglot/optimizer/lower_identities.py +++ b/sqlglot/optimizer/lower_identities.py @@ -1,5 +1,4 @@ from sqlglot import exp -from sqlglot.helper import ensure_collection def lower_identities(expression): @@ -40,13 +39,10 @@ def lower_identities(expression): lower_identities(expression.right) traversed |= {"this", "expression"} - for k, v in expression.args.items(): + for k, v in expression.iter_expressions(): if k in traversed: continue - - for child in ensure_collection(v): - if isinstance(child, exp.Expression): - child.transform(_lower, copy=False) + v.transform(_lower, copy=False) return expression diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 70172f4..c3467b2 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -3,7 +3,6 @@ from collections import defaultdict from sqlglot import expressions as exp from sqlglot.helper import find_new_name from sqlglot.optimizer.scope import Scope, traverse_scope -from sqlglot.optimizer.simplify import simplify def merge_subqueries(expression, leave_tables_isolated=False): @@ -330,11 +329,11 @@ def _merge_where(outer_scope, inner_scope, from_or_join): if set(exp.column_table_names(where.this)) <= sources: from_or_join.on(where.this, copy=False) - from_or_join.set("on", simplify(from_or_join.args.get("on"))) + from_or_join.set("on", from_or_join.args.get("on")) return expression.where(where.this, copy=False) - expression.set("where", simplify(expression.args.get("where"))) + expression.set("where", expression.args.get("where")) def _merge_order(outer_scope, inner_scope): diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index f16f519..f2df230 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -1,29 +1,63 @@ +from __future__ import annotations + +import logging +import typing as t + from sqlglot import exp +from sqlglot.errors import OptimizeError from sqlglot.helper import while_changing -from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort +from sqlglot.optimizer.simplify import flatten, uniq_sort + +logger = logging.getLogger("sqlglot") -def normalize(expression, dnf=False, max_distance=128): +def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): """ - Rewrite sqlglot AST into conjunctive normal form. + Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. Example: >>> import sqlglot >>> expression = sqlglot.parse_one("(x AND y) OR z") - >>> normalize(expression).sql() + >>> normalize(expression, dnf=False).sql() '(x OR z) AND (y OR z)' Args: - expression (sqlglot.Expression): expression to normalize - dnf (bool): rewrite in disjunctive normal form instead - max_distance (int): the maximal estimated distance from cnf to attempt conversion + expression: expression to normalize + dnf: rewrite in disjunctive normal form instead. + max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion Returns: sqlglot.Expression: normalized expression """ - expression = simplify(expression) + cache: t.Dict[int, str] = {} + + for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): + if isinstance(node, exp.Connector): + if normalized(node, dnf=dnf): + continue + + distance = normalization_distance(node, dnf=dnf) + + if distance > max_distance: + logger.info( + f"Skipping normalization because distance {distance} exceeds max {max_distance}" + ) + return expression + + root = node is expression + original = node.copy() + try: + node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache)) + except OptimizeError as e: + logger.info(e) + node.replace(original) + if root: + return original + return expression + + if root: + expression = node - expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance)) - return simplify(expression) + return expression def normalized(expression, dnf=False): @@ -51,7 +85,7 @@ def normalization_distance(expression, dnf=False): int: difference """ return sum(_predicate_lengths(expression, dnf)) - ( - len(list(expression.find_all(exp.Connector))) + 1 + sum(1 for _ in expression.find_all(exp.Connector)) + 1 ) @@ -64,29 +98,32 @@ def _predicate_lengths(expression, dnf): expression = expression.unnest() if not isinstance(expression, exp.Connector): - return [1] + return (1,) left, right = expression.args.values() if isinstance(expression, exp.And if dnf else exp.Or): - return [ + return tuple( a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf) - ] + ) return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) -def distributive_law(expression, dnf, max_distance): +def distributive_law(expression, dnf, max_distance, cache=None): """ x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) """ - if isinstance(expression.unnest(), exp.Connector): - if normalization_distance(expression, dnf) > max_distance: - return expression + if normalized(expression, dnf=dnf): + return expression - to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) + distance = normalization_distance(expression, dnf=dnf) - exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) + if distance > max_distance: + raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") + + exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache)) + to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) if isinstance(expression, from_exp): a, b = expression.unnest_operands() @@ -96,32 +133,29 @@ def distributive_law(expression, dnf, max_distance): if isinstance(a, to_exp) and isinstance(b, to_exp): if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): - return _distribute(a, b, from_func, to_func) - return _distribute(b, a, from_func, to_func) + return _distribute(a, b, from_func, to_func, cache) + return _distribute(b, a, from_func, to_func, cache) if isinstance(a, to_exp): - return _distribute(b, a, from_func, to_func) + return _distribute(b, a, from_func, to_func, cache) if isinstance(b, to_exp): - return _distribute(a, b, from_func, to_func) + return _distribute(a, b, from_func, to_func, cache) return expression -def _distribute(a, b, from_func, to_func): +def _distribute(a, b, from_func, to_func, cache): if isinstance(a, exp.Connector): exp.replace_children( a, lambda c: to_func( - exp.paren(from_func(c, b.left)), - exp.paren(from_func(c, b.right)), + uniq_sort(flatten(from_func(c, b.left)), cache), + uniq_sort(flatten(from_func(c, b.right)), cache), ), ) else: - a = to_func(from_func(a, b.left), from_func(a, b.right)) - - return _simplify(a) - + a = to_func( + uniq_sort(flatten(from_func(a, b.left)), cache), + uniq_sort(flatten(from_func(a, b.right)), cache), + ) -def _simplify(node): - node = uniq_sort(flatten(node)) - exp.replace_children(node, _simplify) - return node + return a diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index dc5ce44..8589657 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -1,6 +1,5 @@ from sqlglot import exp from sqlglot.helper import tsort -from sqlglot.optimizer.simplify import simplify def optimize_joins(expression): @@ -29,7 +28,6 @@ def optimize_joins(expression): for name, join in cross_joins: for dep in references.get(name, []): on = dep.args["on"] - on = on.replace(simplify(on)) if isinstance(on, exp.Connector): for predicate in on.flatten(): diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index d9d04be..62eb11e 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -21,6 +21,7 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_projections import pushdown_projections from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.optimizer.simplify import simplify from sqlglot.optimizer.unnest_subqueries import unnest_subqueries from sqlglot.schema import ensure_schema @@ -43,6 +44,7 @@ RULES = ( eliminate_ctes, annotate_types, canonicalize, + simplify, ) @@ -78,7 +80,7 @@ def optimize( Returns: sqlglot.Expression: optimized expression """ - schema = ensure_schema(schema or sqlglot.schema) + schema = ensure_schema(schema or sqlglot.schema, dialect=dialect) possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs} expression = exp.maybe_parse(expression, dialect=dialect, copy=True) for rule in rules: diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 66b3170..5e40cf3 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -30,11 +30,12 @@ def qualify_columns(expression, schema): resolver = Resolver(scope, schema) _pop_table_column_aliases(scope.ctes) _pop_table_column_aliases(scope.derived_tables) - _expand_using(scope, resolver) + using_column_tables = _expand_using(scope, resolver) _qualify_columns(scope, resolver) if not isinstance(scope.expression, exp.UDTF): - _expand_stars(scope, resolver) + _expand_stars(scope, resolver, using_column_tables) _qualify_outputs(scope) + _expand_alias_refs(scope, resolver) _expand_group_by(scope, resolver) _expand_order_by(scope) @@ -69,11 +70,11 @@ def _pop_table_column_aliases(derived_tables): def _expand_using(scope, resolver): - joins = list(scope.expression.find_all(exp.Join)) + joins = list(scope.find_all(exp.Join)) names = {join.this.alias for join in joins} ordered = [key for key in scope.selected_sources if key not in names] - # Mapping of automatically joined column names to source names + # Mapping of automatically joined column names to an ordered set of source names (dict). column_tables = {} for join in joins: @@ -112,11 +113,12 @@ def _expand_using(scope, resolver): ) ) - tables = column_tables.setdefault(identifier, []) + # Set all values in the dict to None, because we only care about the key ordering + tables = column_tables.setdefault(identifier, {}) if table not in tables: - tables.append(table) + tables[table] = None if join_table not in tables: - tables.append(join_table) + tables[join_table] = None join.args.pop("using") join.set("on", exp.and_(*conditions)) @@ -134,11 +136,11 @@ def _expand_using(scope, resolver): scope.replace(column, replacement) + return column_tables -def _expand_group_by(scope, resolver): - group = scope.expression.args.get("group") - if not group: - return + +def _expand_alias_refs(scope, resolver): + selects = {} # Replace references to select aliases def transform(node, *_): @@ -150,9 +152,11 @@ def _expand_group_by(scope, resolver): node.set("table", table) return node - selects = {s.alias_or_name: s for s in scope.selects} - + if not selects: + for s in scope.selects: + selects[s.alias_or_name] = s select = selects.get(node.name) + if select: scope.clear_cache() if isinstance(select, exp.Alias): @@ -161,7 +165,21 @@ def _expand_group_by(scope, resolver): return node - group.transform(transform, copy=False) + for select in scope.expression.selects: + select.transform(transform, copy=False) + + for modifier in ("where", "group"): + part = scope.expression.args.get(modifier) + + if part: + part.transform(transform, copy=False) + + +def _expand_group_by(scope, resolver): + group = scope.expression.args.get("group") + if not group: + return + group.set("expressions", _expand_positional_references(scope, group.expressions)) scope.expression.set("group", group) @@ -231,18 +249,24 @@ def _qualify_columns(scope, resolver): column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) columns_missing_from_scope = [] + # Determine whether each reference in the order by clause is to a column or an alias. - for ordered in scope.find_all(exp.Ordered): - for column in ordered.find_all(exp.Column): - if ( - not column.table - and column.parent is not ordered - and column.name in resolver.all_columns - ): - columns_missing_from_scope.append(column) + order = scope.expression.args.get("order") + + if order: + for ordered in order.expressions: + for column in ordered.find_all(exp.Column): + if ( + not column.table + and column.parent is not ordered + and column.name in resolver.all_columns + ): + columns_missing_from_scope.append(column) # Determine whether each reference in the having clause is to a column or an alias. - for having in scope.find_all(exp.Having): + having = scope.expression.args.get("having") + + if having: for column in having.find_all(exp.Column): if ( not column.table @@ -258,12 +282,13 @@ def _qualify_columns(scope, resolver): column.set("table", column_table) -def _expand_stars(scope, resolver): +def _expand_stars(scope, resolver, using_column_tables): """Expand stars to lists of column selections""" new_selections = [] except_columns = {} replace_columns = {} + coalesced_columns = set() for expression in scope.selects: if isinstance(expression, exp.Star): @@ -286,7 +311,20 @@ def _expand_stars(scope, resolver): if columns and "*" not in columns: table_id = id(table) for name in columns: - if name not in except_columns.get(table_id, set()): + if name in using_column_tables and table in using_column_tables[name]: + if name in coalesced_columns: + continue + + coalesced_columns.add(name) + tables = using_column_tables[name] + coalesce = [exp.column(name, table=table) for table in tables] + + new_selections.append( + exp.alias_( + exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name + ) + ) + elif name not in except_columns.get(table_id, set()): alias_ = replace_columns.get(table_id, {}).get(name, name) column = exp.column(name, table) new_selections.append(alias(column, alias_) if alias_ != name else column) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 9c0768c..b582eb0 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -160,7 +160,7 @@ class Scope: Yields: exp.Expression: nodes """ - for expression, _, _ in self.walk(bfs=bfs): + for expression, *_ in self.walk(bfs=bfs): if isinstance(expression, expression_types): yield expression diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index f80484d..1ed3ca2 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -5,11 +5,10 @@ from collections import deque from decimal import Decimal from sqlglot import exp -from sqlglot.expressions import FALSE, NULL, TRUE from sqlglot.generator import Generator from sqlglot.helper import first, while_changing -GENERATOR = Generator(normalize=True, identify=True) +GENERATOR = Generator(normalize=True, identify="safe") def simplify(expression): @@ -28,18 +27,20 @@ def simplify(expression): sqlglot.Expression: simplified expression """ + cache = {} + def _simplify(expression, root=True): node = expression node = rewrite_between(node) - node = uniq_sort(node) - node = absorb_and_eliminate(node) + node = uniq_sort(node, cache, root) + node = absorb_and_eliminate(node, root) exp.replace_children(node, lambda e: _simplify(e, False)) node = simplify_not(node) node = flatten(node) - node = simplify_connectors(node) - node = remove_compliments(node) + node = simplify_connectors(node, root) + node = remove_compliments(node, root) node.parent = expression.parent - node = simplify_literals(node) + node = simplify_literals(node, root) node = simplify_parens(node) if root: expression.replace(node) @@ -70,7 +71,7 @@ def simplify_not(expression): NOT (x AND y) -> NOT x OR NOT y """ if isinstance(expression, exp.Not): - if isinstance(expression.this, exp.Null): + if is_null(expression.this): return exp.null() if isinstance(expression.this, exp.Paren): condition = expression.this.unnest() @@ -78,11 +79,11 @@ def simplify_not(expression): return exp.or_(exp.not_(condition.left), exp.not_(condition.right)) if isinstance(condition, exp.Or): return exp.and_(exp.not_(condition.left), exp.not_(condition.right)) - if isinstance(condition, exp.Null): + if is_null(condition): return exp.null() if always_true(expression.this): return exp.false() - if expression.this == FALSE: + if is_false(expression.this): return exp.true() if isinstance(expression.this, exp.Not): # double negation @@ -104,42 +105,42 @@ def flatten(expression): return expression -def simplify_connectors(expression): +def simplify_connectors(expression, root=True): def _simplify_connectors(expression, left, right): - if isinstance(expression, exp.Connector): - if left == right: + if left == right: + return left + if isinstance(expression, exp.And): + if is_false(left) or is_false(right): + return exp.false() + if is_null(left) or is_null(right): + return exp.null() + if always_true(left) and always_true(right): + return exp.true() + if always_true(left): + return right + if always_true(right): return left - if isinstance(expression, exp.And): - if FALSE in (left, right): - return exp.false() - if NULL in (left, right): - return exp.null() - if always_true(left) and always_true(right): - return exp.true() - if always_true(left): - return right - if always_true(right): - return left - return _simplify_comparison(expression, left, right) - elif isinstance(expression, exp.Or): - if always_true(left) or always_true(right): - return exp.true() - if left == FALSE and right == FALSE: - return exp.false() - if ( - (left == NULL and right == NULL) - or (left == NULL and right == FALSE) - or (left == FALSE and right == NULL) - ): - return exp.null() - if left == FALSE: - return right - if right == FALSE: - return left - return _simplify_comparison(expression, left, right, or_=True) - return None + return _simplify_comparison(expression, left, right) + elif isinstance(expression, exp.Or): + if always_true(left) or always_true(right): + return exp.true() + if is_false(left) and is_false(right): + return exp.false() + if ( + (is_null(left) and is_null(right)) + or (is_null(left) and is_false(right)) + or (is_false(left) and is_null(right)) + ): + return exp.null() + if is_false(left): + return right + if is_false(right): + return left + return _simplify_comparison(expression, left, right, or_=True) - return _flat_simplify(expression, _simplify_connectors) + if isinstance(expression, exp.Connector): + return _flat_simplify(expression, _simplify_connectors, root) + return expression LT_LTE = (exp.LT, exp.LTE) @@ -220,14 +221,14 @@ def _simplify_comparison(expression, left, right, or_=False): return None -def remove_compliments(expression): +def remove_compliments(expression, root=True): """ Removing compliments. A AND NOT A -> FALSE A OR NOT A -> TRUE """ - if isinstance(expression, exp.Connector): + if isinstance(expression, exp.Connector) and (root or not expression.same_parent): compliment = exp.false() if isinstance(expression, exp.And) else exp.true() for a, b in itertools.permutations(expression.flatten(), 2): @@ -236,23 +237,23 @@ def remove_compliments(expression): return expression -def uniq_sort(expression): +def uniq_sort(expression, cache=None, root=True): """ Uniq and sort a connector. C AND A AND B AND B -> A AND B AND C """ - if isinstance(expression, exp.Connector): + if isinstance(expression, exp.Connector) and (root or not expression.same_parent): result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ flattened = tuple(expression.flatten()) - deduped = {GENERATOR.generate(e): e for e in flattened} + deduped = {GENERATOR.generate(e, cache): e for e in flattened} arr = tuple(deduped.items()) # check if the operands are already sorted, if not sort them # A AND C AND B -> A AND B AND C for i, (sql, e) in enumerate(arr[1:]): if sql < arr[i][0]: - expression = result_func(*(deduped[sql] for sql in sorted(deduped))) + expression = result_func(*(e for _, e in sorted(arr))) break else: # we didn't have to sort but maybe we need to dedup @@ -262,7 +263,7 @@ def uniq_sort(expression): return expression -def absorb_and_eliminate(expression): +def absorb_and_eliminate(expression, root=True): """ absorption: A AND (A OR B) -> A @@ -273,7 +274,7 @@ def absorb_and_eliminate(expression): (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A """ - if isinstance(expression, exp.Connector): + if isinstance(expression, exp.Connector) and (root or not expression.same_parent): kind = exp.Or if isinstance(expression, exp.And) else exp.And for a, b in itertools.permutations(expression.flatten(), 2): @@ -302,9 +303,9 @@ def absorb_and_eliminate(expression): return expression -def simplify_literals(expression): - if isinstance(expression, exp.Binary): - return _flat_simplify(expression, _simplify_binary) +def simplify_literals(expression, root=True): + if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): + return _flat_simplify(expression, _simplify_binary, root) elif isinstance(expression, exp.Neg): this = expression.this if this.is_number: @@ -325,14 +326,14 @@ def _simplify_binary(expression, a, b): c = b not_ = False - if c == NULL: + if is_null(c): if isinstance(a, exp.Literal): return exp.true() if not_ else exp.false() - if a == NULL: + if is_null(a): return exp.false() if not_ else exp.true() elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): return None - elif NULL in (a, b): + elif is_null(a) or is_null(b): return exp.null() if a.is_number and b.is_number: @@ -355,7 +356,7 @@ def _simplify_binary(expression, a, b): if boolean: return boolean elif a.is_string and b.is_string: - boolean = eval_boolean(expression, a, b) + boolean = eval_boolean(expression, a.this, b.this) if boolean: return boolean @@ -381,7 +382,7 @@ def simplify_parens(expression): and not isinstance(expression.this, exp.Select) and ( not isinstance(expression.parent, (exp.Condition, exp.Binary)) - or isinstance(expression.this, (exp.Is, exp.Like)) + or isinstance(expression.this, exp.Predicate) or not isinstance(expression.this, exp.Binary) ) ): @@ -400,13 +401,23 @@ def remove_where_true(expression): def always_true(expression): - return expression == TRUE or isinstance(expression, exp.Literal) + return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( + expression, exp.Literal + ) def is_complement(a, b): return isinstance(b, exp.Not) and b.this == a +def is_false(a: exp.Expression) -> bool: + return type(a) is exp.Boolean and not a.this + + +def is_null(a: exp.Expression) -> bool: + return type(a) is exp.Null + + def eval_boolean(expression, a, b): if isinstance(expression, (exp.EQ, exp.Is)): return boolean_literal(a == b) @@ -466,24 +477,27 @@ def boolean_literal(condition): return exp.true() if condition else exp.false() -def _flat_simplify(expression, simplifier): - operands = [] - queue = deque(expression.flatten(unnest=False)) - size = len(queue) +def _flat_simplify(expression, simplifier, root=True): + if root or not expression.same_parent: + operands = [] + queue = deque(expression.flatten(unnest=False)) + size = len(queue) - while queue: - a = queue.popleft() + while queue: + a = queue.popleft() - for b in queue: - result = simplifier(expression, a, b) + for b in queue: + result = simplifier(expression, a, b) - if result: - queue.remove(b) - queue.append(result) - break - else: - operands.append(a) + if result: + queue.remove(b) + queue.append(result) + break + else: + operands.append(a) - if len(operands) < size: - return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands) + if len(operands) < size: + return functools.reduce( + lambda a, b: expression.__class__(this=a, expression=b), operands + ) return expression |