diff options
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/normalize.py | 37 | ||||
-rw-r--r-- | sqlglot/optimizer/optimize_joins.py | 6 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_projections.py | 10 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 18 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 93 |
5 files changed, 134 insertions, 30 deletions
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index 1db094e..8d82b2d 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -6,6 +6,7 @@ from sqlglot import exp from sqlglot.errors import OptimizeError from sqlglot.generator import cached_generator from sqlglot.helper import while_changing +from sqlglot.optimizer.scope import find_all_in_scope from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort logger = logging.getLogger("sqlglot") @@ -63,15 +64,33 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = return expression -def normalized(expression, dnf=False): - ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) +def normalized(expression: exp.Expression, dnf: bool = False) -> bool: + """ + Checks whether a given expression is in a normal form of interest. - return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root)) + Example: + >>> from sqlglot import parse_one + >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) + True + >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default + True + >>> normalized(parse_one("a AND (b OR c)"), dnf=True) + False + Args: + expression: The expression to check if it's normalized. + dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). + Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). + """ + ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) + return not any( + connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root) + ) -def normalization_distance(expression, dnf=False): + +def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int: """ - The difference in the number of predicates between the current expression and the normalized form. + The difference in the number of predicates between a given expression and its normalized form. This is used as an estimate of the cost of the conversion which is exponential in complexity. @@ -82,10 +101,12 @@ def normalization_distance(expression, dnf=False): 4 Args: - expression (sqlglot.Expression): expression to compute distance - dnf (bool): compute to dnf distance instead + expression: The expression to compute the normalization distance for. + dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). + Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). + Returns: - int: difference + The normalization distance. """ return sum(_predicate_lengths(expression, dnf)) - ( sum(1 for _ in expression.find_all(exp.Connector)) + 1 diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 9d401fc..1530456 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -39,10 +39,14 @@ def optimize_joins(expression): if len(other_table_names(dep)) < 2: continue + operator = type(on) for predicate in on.flatten(): if name in exp.column_table_names(predicate): predicate.replace(exp.true()) - join.on(predicate, copy=False) + predicate = exp._combine( + [join.args.get("on"), predicate], operator, copy=False + ) + join.on(predicate, append=False, copy=False) expression = reorder_joins(expression) expression = normalize(expression) diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index b51601f..4bc3bd2 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -9,7 +9,9 @@ from sqlglot.schema import ensure_schema SELECT_ALL = object() # Selection to use if selection list is empty -DEFAULT_SELECTION = lambda: alias("1", "_") +DEFAULT_SELECTION = lambda is_agg: alias( + exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_" +) def pushdown_projections(expression, schema=None, remove_unused_selections=True): @@ -98,6 +100,7 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count): new_selections = [] removed = False star = False + is_agg = False select_all = SELECT_ALL in parent_selections @@ -112,6 +115,9 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count): star = True removed = True + if not is_agg and selection.find(exp.AggFunc): + is_agg = True + if star: resolver = Resolver(scope, schema) names = {s.alias_or_name for s in new_selections} @@ -124,7 +130,7 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count): # If there are no remaining selections, just select a single constant if not new_selections: - new_selections.append(DEFAULT_SELECTION()) + new_selections.append(DEFAULT_SELECTION(is_agg)) scope.expression.select(*new_selections, append=False, copy=False) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 435899a..4af5b49 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -137,8 +137,8 @@ class Scope: if not self._collected: self._collect() - def walk(self, bfs=True): - return walk_in_scope(self.expression, bfs=bfs) + def walk(self, bfs=True, prune=None): + return walk_in_scope(self.expression, bfs=bfs, prune=None) def find(self, *expression_types, bfs=True): return find_in_scope(self.expression, expression_types, bfs=bfs) @@ -731,7 +731,7 @@ def _traverse_ddl(scope): yield from _traverse_scope(query_scope) -def walk_in_scope(expression, bfs=True): +def walk_in_scope(expression, bfs=True, prune=None): """ Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes. @@ -740,16 +740,20 @@ def walk_in_scope(expression, bfs=True): expression (exp.Expression): bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead. + prune ((node, parent, arg_key) -> bool): callable that returns True if + the generator should stop traversing this branch of the tree. Yields: tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key """ # We'll use this variable to pass state into the dfs generator. # Whenever we set it to True, we exclude a subtree from traversal. - prune = False + crossed_scope_boundary = False - for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): - prune = False + for node, parent, key in expression.walk( + bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) + ): + crossed_scope_boundary = False yield node, parent, key @@ -765,7 +769,7 @@ def walk_in_scope(expression, bfs=True): or isinstance(node, exp.UDTF) or isinstance(node, exp.Subqueryable) ): - prune = True + crossed_scope_boundary = True if isinstance(node, (exp.Subquery, exp.UDTF)): # The following args are not actually in the inner scope, so we should visit them diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 51214c4..849643c 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -5,9 +5,11 @@ import typing as t from collections import deque from decimal import Decimal +import sqlglot from sqlglot import exp from sqlglot.generator import cached_generator from sqlglot.helper import first, merge_ranges, while_changing +from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope # Final means that an expression should not be simplified FINAL = "final" @@ -17,7 +19,7 @@ class UnsupportedUnit(Exception): pass -def simplify(expression): +def simplify(expression, constant_propagation=False): """ Rewrite sqlglot AST to simplify expressions. @@ -29,6 +31,8 @@ def simplify(expression): Args: expression (sqlglot.Expression): expression to simplify + constant_propagation: whether or not the constant propagation rule should be used + Returns: sqlglot.Expression: simplified expression """ @@ -67,13 +71,16 @@ def simplify(expression): node = absorb_and_eliminate(node, root) node = simplify_concat(node) + if constant_propagation: + node = propagate_constants(node, root) + exp.replace_children(node, lambda e: _simplify(e, False)) # Post-order transformations node = simplify_not(node) node = flatten(node) node = simplify_connectors(node, root) - node = remove_compliments(node, root) + node = remove_complements(node, root) node = simplify_coalesce(node) node.parent = expression.parent node = simplify_literals(node, root) @@ -287,19 +294,19 @@ def _simplify_comparison(expression, left, right, or_=False): return None -def remove_compliments(expression, root=True): +def remove_complements(expression, root=True): """ - Removing compliments. + Removing complements. A AND NOT A -> FALSE A OR NOT A -> TRUE """ if isinstance(expression, exp.Connector) and (root or not expression.same_parent): - compliment = exp.false() if isinstance(expression, exp.And) else exp.true() + complement = exp.false() if isinstance(expression, exp.And) else exp.true() for a, b in itertools.permutations(expression.flatten(), 2): if is_complement(a, b): - return compliment + return complement return expression @@ -369,6 +376,51 @@ def absorb_and_eliminate(expression, root=True): return expression +def propagate_constants(expression, root=True): + """ + Propagate constants for conjunctions in DNF: + + SELECT * FROM t WHERE a = b AND b = 5 becomes + SELECT * FROM t WHERE a = 5 AND b = 5 + + Reference: https://www.sqlite.org/optoverview.html + """ + + if ( + isinstance(expression, exp.And) + and (root or not expression.same_parent) + and sqlglot.optimizer.normalize.normalized(expression, dnf=True) + ): + constant_mapping = {} + for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): + if isinstance(expr, exp.EQ): + l, r = expr.left, expr.right + + # TODO: create a helper that can be used to detect nested literal expressions such + # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too + if isinstance(l, exp.Column) and isinstance(r, exp.Literal): + pass + elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): + l, r = r, l + else: + continue + + constant_mapping[l] = (id(l), r) + + if constant_mapping: + for column in find_all_in_scope(expression, exp.Column): + parent = column.parent + column_id, constant = constant_mapping.get(column) or (None, None) + if ( + column_id is not None + and id(column) != column_id + and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) + ): + column.replace(constant.copy()) + + return expression + + INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { exp.DateAdd: exp.Sub, exp.DateSub: exp.Add, @@ -609,21 +661,38 @@ SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) def simplify_concat(expression): """Reduces all groups that contain string literals by concatenating them.""" - if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): + if not isinstance(expression, CONCATS) or ( + # We can't reduce a CONCAT_WS call if we don't statically know the separator + isinstance(expression, exp.ConcatWs) + and not expression.expressions[0].is_string + ): return expression + if isinstance(expression, exp.ConcatWs): + sep_expr, *expressions = expression.expressions + sep = sep_expr.name + concat_type = exp.ConcatWs + else: + expressions = expression.expressions + sep = "" + concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat + new_args = [] for is_string_group, group in itertools.groupby( - expression.expressions or expression.flatten(), lambda e: e.is_string + expressions or expression.flatten(), lambda e: e.is_string ): if is_string_group: - new_args.append(exp.Literal.string("".join(string.name for string in group))) + new_args.append(exp.Literal.string(sep.join(string.name for string in group))) else: new_args.extend(group) - # Ensures we preserve the right concat type, i.e. whether it's "safe" or not - concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat - return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args) + if len(new_args) == 1 and new_args[0].is_string: + return new_args[0] + + if concat_type is exp.ConcatWs: + new_args = [sep_expr] + new_args + + return concat_type(expressions=new_args) DateRange = t.Tuple[datetime.date, datetime.date] |