diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/optimizer/eliminate_subqueries.py | 33 | ||||
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py | 17 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_predicates.py | 13 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 39 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 153 | ||||
-rw-r--r-- | sqlglot/optimizer/unnest_subqueries.py | 2 |
6 files changed, 140 insertions, 117 deletions
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 1ab7768..1230cea 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -95,9 +95,6 @@ def eliminate_subqueries(expression): def _eliminate(scope, existing_ctes, taken): - if scope.is_union: - return _eliminate_union(scope, existing_ctes, taken) - if scope.is_derived_table: return _eliminate_derived_table(scope, existing_ctes, taken) @@ -105,36 +102,6 @@ def _eliminate(scope, existing_ctes, taken): return _eliminate_cte(scope, existing_ctes, taken) -def _eliminate_union(scope, existing_ctes, taken): - duplicate_cte_alias = existing_ctes.get(scope.expression) - - alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte") - - taken[alias] = scope - - # Try to maintain the selections - expressions = scope.expression.selects - selects = [ - exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False) - for e in expressions - if e.alias_or_name - ] - # If not all selections have an alias, just select * - if len(selects) != len(expressions): - selects = ["*"] - - scope.expression.replace( - exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias, copy=False)) - ) - - if not duplicate_cte_alias: - existing_ctes[scope.expression] = alias - return exp.CTE( - this=scope.expression, - alias=exp.TableAlias(this=exp.to_identifier(alias)), - ) - - def _eliminate_derived_table(scope, existing_ctes, taken): # This makes sure that we don't: # - drop the "pivot" arg from a pivoted subquery diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index a74bea7..ea148cc 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -174,6 +174,22 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): for col in inner_projections[selection].find_all(exp.Column) ) + def _is_recursive(): + # Recursive CTEs look like this: + # WITH RECURSIVE cte AS ( + # SELECT * FROM x <-- inner scope + # UNION ALL + # SELECT * FROM cte <-- outer scope + # ) + cte = inner_scope.expression.parent + node = outer_scope.expression.parent + + while node: + if node is cte: + return True + node = node.parent + return False + return ( isinstance(outer_scope.expression, exp.Select) and not outer_scope.expression.is_star @@ -197,6 +213,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): ) and not _outer_select_joins_on_inner_select_join() and not _is_a_window_expression_in_unmergable_operation() + and not _is_recursive() ) diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index f7348b5..10ff13a 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -4,7 +4,7 @@ from sqlglot.optimizer.scope import build_scope, find_in_scope from sqlglot.optimizer.simplify import simplify -def pushdown_predicates(expression): +def pushdown_predicates(expression, dialect=None): """ Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS @@ -36,7 +36,7 @@ def pushdown_predicates(expression): if isinstance(parent, exp.Join) and parent.side == "RIGHT": selected_sources = {k: (node, source)} break - pushdown(where.this, selected_sources, scope_ref_count) + pushdown(where.this, selected_sources, scope_ref_count, dialect) # joins should only pushdown into itself, not to other joins # so we limit the selected sources to only itself @@ -44,17 +44,20 @@ def pushdown_predicates(expression): name = join.alias_or_name if name in scope.selected_sources: pushdown( - join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count + join.args.get("on"), + {name: scope.selected_sources[name]}, + scope_ref_count, + dialect, ) return expression -def pushdown(condition, sources, scope_ref_count): +def pushdown(condition, sources, scope_ref_count, dialect): if not condition: return - condition = condition.replace(simplify(condition)) + condition = condition.replace(simplify(condition, dialect=dialect)) cnf_like = normalized(condition) or not normalized(condition, dnf=True) predicates = list( diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index b7e527e..d34857d 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -37,6 +37,7 @@ class Scope: For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source. + cte_sources (dict[str, Scope]): Sources from CTES outer_column_list (list[str]): If this is a derived table or CTE, and the outer query defines a column list of it's alias of this scope, this is that list of columns. For example: @@ -61,11 +62,14 @@ class Scope: parent=None, scope_type=ScopeType.ROOT, lateral_sources=None, + cte_sources=None, ): self.expression = expression self.sources = sources or {} - self.lateral_sources = lateral_sources.copy() if lateral_sources else {} + self.lateral_sources = lateral_sources or {} + self.cte_sources = cte_sources or {} self.sources.update(self.lateral_sources) + self.sources.update(self.cte_sources) self.outer_column_list = outer_column_list or [] self.parent = parent self.scope_type = scope_type @@ -92,13 +96,17 @@ class Scope: self._pivots = None self._references = None - def branch(self, expression, scope_type, chain_sources=None, **kwargs): + def branch( + self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs + ): """Branch from the current scope to a new, inner scope""" return Scope( expression=expression.unnest(), - sources={**self.cte_sources, **(chain_sources or {})}, + sources=sources.copy() if sources else None, parent=self, scope_type=scope_type, + cte_sources={**self.cte_sources, **(cte_sources or {})}, + lateral_sources=lateral_sources.copy() if lateral_sources else None, **kwargs, ) @@ -306,20 +314,6 @@ class Scope: return self._references @property - def cte_sources(self): - """ - Sources that are CTEs. - - Returns: - dict[str, Scope]: Mapping of source alias to Scope - """ - return { - alias: scope - for alias, scope in self.sources.items() - if isinstance(scope, Scope) and scope.is_cte - } - - @property def external_columns(self): """ Columns that appear to reference sources in outer scopes. @@ -515,7 +509,10 @@ def _traverse_scope(scope): elif isinstance(scope.expression, exp.Union): yield from _traverse_union(scope) elif isinstance(scope.expression, exp.Subquery): - yield from _traverse_subqueries(scope) + if scope.is_root: + yield from _traverse_select(scope) + else: + yield from _traverse_subqueries(scope) elif isinstance(scope.expression, exp.Table): yield from _traverse_tables(scope) elif isinstance(scope.expression, exp.UDTF): @@ -572,7 +569,7 @@ def _traverse_ctes(scope): for child_scope in _traverse_scope( scope.branch( cte.this, - chain_sources=sources, + cte_sources=sources, outer_column_list=cte.alias_column_names, scope_type=ScopeType.CTE, ) @@ -584,12 +581,14 @@ def _traverse_ctes(scope): if recursive_scope: child_scope.add_source(alias, recursive_scope) + child_scope.cte_sources[alias] = recursive_scope # append the final child_scope yielded if child_scope: scope.cte_scopes.append(child_scope) scope.sources.update(sources) + scope.cte_sources.update(sources) def _is_derived_table(expression: exp.Subquery) -> bool: @@ -725,7 +724,7 @@ def _traverse_ddl(scope): yield from _traverse_ctes(scope) query_scope = scope.branch( - scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources + scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources ) query_scope._collect() query_scope._ctes = scope.ctes + query_scope._ctes diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index d4e2e60..6ae08d0 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import functools import itertools @@ -6,10 +8,17 @@ from collections import deque from decimal import Decimal import sqlglot -from sqlglot import exp +from sqlglot import Dialect, exp from sqlglot.helper import first, is_iterable, merge_ranges, while_changing from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import DialectType + + DateTruncBinaryTransform = t.Callable[ + [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression] + ] + # Final means that an expression should not be simplified FINAL = "final" @@ -18,7 +27,9 @@ class UnsupportedUnit(Exception): pass -def simplify(expression, constant_propagation=False): +def simplify( + expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None +): """ Rewrite sqlglot AST to simplify expressions. @@ -36,15 +47,18 @@ def simplify(expression, constant_propagation=False): sqlglot.Expression: simplified expression """ + dialect = Dialect.get_or_raise(dialect) + # group by expressions cannot be simplified, for example # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 # the projection must exactly match the group by key for group in expression.find_all(exp.Group): select = group.parent + assert select groups = set(group.expressions) group.meta[FINAL] = True - for e in select.selects: + for e in select.expressions: for node, *_ in e.walk(): if node in groups: e.meta[FINAL] = True @@ -84,7 +98,8 @@ def simplify(expression, constant_propagation=False): node = simplify_literals(node, root) node = simplify_equality(node) node = simplify_parens(node) - node = simplify_datetrunc_predicate(node) + node = simplify_datetrunc(node, dialect) + node = sort_comparison(node) if root: expression.replace(node) @@ -117,14 +132,30 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression: This is done because comparison simplification is only done on lt/lte/gt/gte. """ if isinstance(expression, exp.Between): - return exp.and_( + negate = isinstance(expression.parent, exp.Not) + + expression = exp.and_( exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), copy=False, ) + + if negate: + expression = exp.paren(expression, copy=False) + return expression +COMPLEMENT_COMPARISONS = { + exp.LT: exp.GTE, + exp.GT: exp.LTE, + exp.LTE: exp.GT, + exp.GTE: exp.LT, + exp.EQ: exp.NEQ, + exp.NEQ: exp.EQ, +} + + def simplify_not(expression): """ Demorgan's Law @@ -132,10 +163,15 @@ def simplify_not(expression): NOT (x AND y) -> NOT x OR NOT y """ if isinstance(expression, exp.Not): - if is_null(expression.this): + this = expression.this + if is_null(this): return exp.null() - if isinstance(expression.this, exp.Paren): - condition = expression.this.unnest() + if this.__class__ in COMPLEMENT_COMPARISONS: + return COMPLEMENT_COMPARISONS[this.__class__]( + this=this.this, expression=this.expression + ) + if isinstance(this, exp.Paren): + condition = this.unnest() if isinstance(condition, exp.And): return exp.or_( exp.not_(condition.left, copy=False), @@ -150,14 +186,14 @@ def simplify_not(expression): ) if is_null(condition): return exp.null() - if always_true(expression.this): + if always_true(this): return exp.false() - if is_false(expression.this): + if is_false(this): return exp.true() - if isinstance(expression.this, exp.Not): + if isinstance(this, exp.Not): # double negation # NOT NOT x -> x - return expression.this.this + return this.this return expression @@ -249,12 +285,6 @@ def _simplify_comparison(expression, left, right, or_=False): except StopIteration: return expression - # make sure the comparison is always of the form x > 1 instead of 1 < x - if left.__class__ in INVERSE_COMPARISONS and l == ll: - left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) - if right.__class__ in INVERSE_COMPARISONS and r == rl: - right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) - if l.is_number and r.is_number: l = float(l.name) r = float(r.name) @@ -397,13 +427,7 @@ def propagate_constants(expression, root=True): # TODO: create a helper that can be used to detect nested literal expressions such # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too if isinstance(l, exp.Column) and isinstance(r, exp.Literal): - pass - elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): - l, r = r, l - else: - continue - - constant_mapping[l] = (id(l), r) + constant_mapping[l] = (id(l), r) if constant_mapping: for column in find_all_in_scope(expression, exp.Column): @@ -458,11 +482,7 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression: if isinstance(expression, COMPARISONS): l, r = expression.left, expression.right - if l.__class__ in INVERSE_OPS: - pass - elif r.__class__ in INVERSE_OPS: - l, r = r, l - else: + if not l.__class__ in INVERSE_OPS: return expression if r.is_number: @@ -650,7 +670,7 @@ def simplify_coalesce(expression): # Find the first constant arg for arg_index, arg in enumerate(coalesce.expressions): - if _is_constant(other): + if _is_constant(arg): break else: return expression @@ -752,7 +772,7 @@ def simplify_conditionals(expression): DateRange = t.Tuple[datetime.date, datetime.date] -def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: +def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: """ Get the date range for a DATE_TRUNC equality comparison: @@ -761,7 +781,7 @@ def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: Returns: tuple of [min, max) or None if a value can never be equal to `date` for `unit` """ - floor = date_floor(date, unit) + floor = date_floor(date, unit, dialect) if date != floor: # This will always be False, except for NULL values. @@ -780,9 +800,9 @@ def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Exp def _datetrunc_eq( - left: exp.Expression, date: datetime.date, unit: str + left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect ) -> t.Optional[exp.Expression]: - drange = _datetrunc_range(date, unit) + drange = _datetrunc_range(date, unit, dialect) if not drange: return None @@ -790,9 +810,9 @@ def _datetrunc_eq( def _datetrunc_neq( - left: exp.Expression, date: datetime.date, unit: str + left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect ) -> t.Optional[exp.Expression]: - drange = _datetrunc_range(date, unit) + drange = _datetrunc_range(date, unit, dialect) if not drange: return None @@ -803,41 +823,39 @@ def _datetrunc_neq( ) -DateTruncBinaryTransform = t.Callable[ - [exp.Expression, datetime.date, str], t.Optional[exp.Expression] -] DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { - exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)), - exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)), - exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)), - exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)), + exp.LT: lambda l, dt, u, d: l + < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)), + exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)), + exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)), + exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)), exp.EQ: _datetrunc_eq, exp.NEQ: _datetrunc_neq, } DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} +DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: - return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right) + return isinstance(left, DATETRUNCS) and _is_date_literal(right) @catch(ModuleNotFoundError, UnsupportedUnit) -def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: +def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" comparison = expression.__class__ - if comparison not in DATETRUNC_COMPARISONS: + if isinstance(expression, DATETRUNCS): + date = extract_date(expression.this) + if date and expression.unit: + return date_literal(date_floor(date, expression.unit.name.lower(), dialect)) + elif comparison not in DATETRUNC_COMPARISONS: return expression if isinstance(expression, exp.Binary): l, r = expression.left, expression.right - if _is_datetrunc_predicate(l, r): - pass - elif _is_datetrunc_predicate(r, l): - comparison = INVERSE_COMPARISONS.get(comparison, comparison) - l, r = r, l - else: + if not _is_datetrunc_predicate(l, r): return expression l = t.cast(exp.DateTrunc, l) @@ -847,7 +865,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: if not date: return expression - return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression + return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression elif isinstance(expression, exp.In): l = expression.this rs = expression.expressions @@ -861,7 +879,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: date = extract_date(r) if not date: return expression - drange = _datetrunc_range(date, unit) + drange = _datetrunc_range(date, unit, dialect) if drange: ranges.append(drange) @@ -875,6 +893,23 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: return expression +def sort_comparison(expression: exp.Expression) -> exp.Expression: + if expression.__class__ in COMPLEMENT_COMPARISONS: + l, r = expression.this, expression.expression + l_column = isinstance(l, exp.Column) + r_column = isinstance(r, exp.Column) + l_const = _is_constant(l) + r_const = _is_constant(r) + + if (l_column and not r_column) or (r_const and not l_const): + return expression + if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): + return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( + this=r, expression=l + ) + return expression + + # CROSS joins result in an empty table if the right table is empty. # So we can only simplify certain types of joins to CROSS. # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x @@ -1034,7 +1069,7 @@ def interval(unit: str, n: int = 1): raise UnsupportedUnit(f"Unsupported unit: {unit}") -def date_floor(d: datetime.date, unit: str) -> datetime.date: +def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: if unit == "year": return d.replace(month=1, day=1) if unit == "quarter": @@ -1050,15 +1085,15 @@ def date_floor(d: datetime.date, unit: str) -> datetime.date: return d.replace(month=d.month, day=1) if unit == "week": # Assuming week starts on Monday (0) and ends on Sunday (6) - return d - datetime.timedelta(days=d.weekday()) + return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) if unit == "day": return d raise UnsupportedUnit(f"Unsupported unit: {unit}") -def date_ceil(d: datetime.date, unit: str) -> datetime.date: - floor = date_floor(d, unit) +def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: + floor = date_floor(d, unit, dialect) if floor == d: return d diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 242fc87..4d35175 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -65,6 +65,8 @@ def unnest(select, parent_select, next_alias_name): ) ): column = exp.Max(this=column) + elif not isinstance(select.parent, exp.Subquery): + return _replace(select.parent, column) parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) |