diff options
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/eliminate_joins.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py | 54 | ||||
-rw-r--r-- | sqlglot/optimizer/optimizer.py | 6 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_projections.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 19 | ||||
-rw-r--r-- | sqlglot/optimizer/unnest_subqueries.py | 38 |
8 files changed, 96 insertions, 35 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index be17f15..bfb2bb8 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -43,7 +43,7 @@ class TypeAnnotator: }, exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), - exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr), + exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()), exp.Alias: lambda self, expr: self._annotate_unary(expr), exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index 3b40710..8e6a520 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -57,7 +57,7 @@ def _join_is_used(scope, join, alias): # But columns in the ON clause shouldn't count. on = join.args.get("on") if on: - on_clause_columns = set(id(column) for column in on.find_all(exp.Column)) + on_clause_columns = {id(column) for column in on.find_all(exp.Column)} else: on_clause_columns = set() return any( @@ -71,7 +71,7 @@ def _is_joined_on_all_unique_outputs(scope, join): return False _, join_keys, _ = join_condition(join) - remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys) + remaining_unique_outputs = unique_outputs - {c.name for c in join_keys} return not remaining_unique_outputs diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 9ae4966..16aaf17 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -67,11 +67,9 @@ def merge_ctes(expression, leave_tables_isolated=False): singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] for outer_scope, inner_scope, table in singular_cte_selections: - inner_select = inner_scope.expression.unnest() from_or_join = table.find_ancestor(exp.From, exp.Join) - if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): + if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): alias = table.alias_or_name - _rename_inner_sources(outer_scope, inner_scope, alias) _merge_from(outer_scope, inner_scope, table, alias) _merge_expressions(outer_scope, inner_scope, alias) @@ -80,18 +78,17 @@ def merge_ctes(expression, leave_tables_isolated=False): _merge_order(outer_scope, inner_scope) _merge_hints(outer_scope, inner_scope) _pop_cte(inner_scope) + outer_scope.clear_cache() return expression def merge_derived_tables(expression, leave_tables_isolated=False): for outer_scope in traverse_scope(expression): for subquery in outer_scope.derived_tables: - inner_select = subquery.unnest() from_or_join = subquery.find_ancestor(exp.From, exp.Join) - if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): - alias = subquery.alias_or_name - inner_scope = outer_scope.sources[alias] - + alias = subquery.alias_or_name + inner_scope = outer_scope.sources[alias] + if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): _rename_inner_sources(outer_scope, inner_scope, alias) _merge_from(outer_scope, inner_scope, subquery, alias) _merge_expressions(outer_scope, inner_scope, alias) @@ -99,21 +96,23 @@ def merge_derived_tables(expression, leave_tables_isolated=False): _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) _merge_hints(outer_scope, inner_scope) + outer_scope.clear_cache() return expression -def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): +def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): """ Return True if `inner_select` can be merged into outer query. Args: outer_scope (Scope) - inner_select (exp.Select) + inner_scope (Scope) leave_tables_isolated (bool) from_or_join (exp.From|exp.Join) Returns: bool: True if can be merged """ + inner_select = inner_scope.expression.unnest() def _is_a_window_expression_in_unmergable_operation(): window_expressions = inner_select.find_all(exp.Window) @@ -133,10 +132,40 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): ] return any(window_expressions_in_unmergable) + def _outer_select_joins_on_inner_select_join(): + """ + All columns from the inner select in the ON clause must be from the first FROM table. + + That is, this can be merged: + SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a + ^^^ ^ + But this can't: + SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a + ^^^ ^ + """ + if not isinstance(from_or_join, exp.Join): + return False + + alias = from_or_join.this.alias_or_name + + on = from_or_join.args.get("on") + if not on: + return False + selections = [c.name for c in on.find_all(exp.Column) if c.table == alias] + inner_from = inner_scope.expression.args.get("from") + if not inner_from: + return False + inner_from_table = inner_from.expressions[0].alias_or_name + inner_projections = {s.alias_or_name: s for s in inner_scope.selects} + return any( + col.table != inner_from_table + for selection in selections + for col in inner_projections[selection].find_all(exp.Column) + ) + return ( isinstance(outer_scope.expression, exp.Select) and isinstance(inner_select, exp.Select) - and isinstance(inner_select, exp.Select) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) and inner_select.args.get("from") and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) @@ -153,6 +182,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []) ) ) + and not _outer_select_joins_on_inner_select_join() and not _is_a_window_expression_in_unmergable_operation() ) @@ -168,7 +198,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias): """ taken = set(outer_scope.selected_sources) conflicts = taken.intersection(set(inner_scope.selected_sources)) - conflicts = conflicts - {alias} + conflicts -= {alias} for conflict in conflicts: new_name = find_new_name(taken, conflict) diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 72e67d4..46b6b30 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -15,6 +15,7 @@ from sqlglot.optimizer.pushdown_projections import pushdown_projections from sqlglot.optimizer.qualify_columns import qualify_columns from sqlglot.optimizer.qualify_tables import qualify_tables from sqlglot.optimizer.unnest_subqueries import unnest_subqueries +from sqlglot.schema import ensure_schema RULES = ( lower_identities, @@ -51,12 +52,13 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar If no schema is provided then the default schema defined at `sqlgot.schema` will be used db (str): specify the default database, as might be set by a `USE DATABASE db` statement catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement - rules (list): sequence of optimizer rules to use + rules (sequence): sequence of optimizer rules to use **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in. Returns: sqlglot.Expression: optimized expression """ - possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs} + schema = ensure_schema(schema or sqlglot.schema) + possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs} expression = expression.copy() for rule in rules: diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 49789ac..a73647c 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -79,6 +79,7 @@ def _remove_unused_selections(scope, parent_selections): order_refs = set() new_selections = [] + removed = False for i, selection in enumerate(scope.selects): if ( SELECT_ALL in parent_selections @@ -88,12 +89,15 @@ def _remove_unused_selections(scope, parent_selections): new_selections.append(selection) else: removed_indexes.append(i) + removed = True # If there are no remaining selections, just select a single constant if not new_selections: new_selections.append(DEFAULT_SELECTION.copy()) scope.expression.set("expressions", new_selections) + if removed: + scope.clear_cache() return removed_indexes diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index e16a635..f4568c2 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -365,9 +365,9 @@ class _Resolver: def all_columns(self): """All available columns of all sources in this scope""" if self._all_columns is None: - self._all_columns = set( + self._all_columns = { column for columns in self._get_all_source_columns().values() for column in columns - ) + } return self._all_columns def get_source_columns(self, name, only_visible=False): diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index c0719f2..f560760 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -361,7 +361,7 @@ def _simplify_binary(expression, a, b): return boolean elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): a, b = extract_date(a), extract_interval(b) - if b: + if a and b: if isinstance(expression, exp.Add): return date_literal(a + b) if isinstance(expression, exp.Sub): @@ -369,7 +369,7 @@ def _simplify_binary(expression, a, b): elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): a, b = extract_interval(a), extract_date(b) # you cannot subtract a date from an interval - if a and isinstance(expression, exp.Add): + if a and b and isinstance(expression, exp.Add): return date_literal(a + b) return None @@ -424,9 +424,15 @@ def eval_boolean(expression, a, b): def extract_date(cast): - if cast.args["to"].this == exp.DataType.Type.DATE: - return datetime.date.fromisoformat(cast.name) - return None + # The "fromisoformat" conversion could fail if the cast is used on an identifier, + # so in that case we can't extract the date. + try: + if cast.args["to"].this == exp.DataType.Type.DATE: + return datetime.date.fromisoformat(cast.name) + if cast.args["to"].this == exp.DataType.Type.DATETIME: + return datetime.datetime.fromisoformat(cast.name) + except ValueError: + return None def extract_interval(interval): @@ -450,7 +456,8 @@ def extract_interval(interval): def date_literal(date): - return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE")) + expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE") + return exp.Cast(this=exp.Literal.string(date), to=expr_type) def boolean_literal(condition): diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 8d78294..a515489 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -15,8 +15,7 @@ def unnest_subqueries(expression): >>> import sqlglot >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() - 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\ - AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)' + 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' Args: expression (sqlglot.Expression): expression to unnest @@ -173,10 +172,8 @@ def decorrelate(select, parent_select, external_columns, sequence): other = _other_operand(parent_predicate) if isinstance(parent_predicate, exp.Exists): - if value.this in group_by: - parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") - else: - parent_predicate = _replace(parent_predicate, "TRUE") + alias = exp.column(list(key_aliases.values())[0], table_alias) + parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") elif isinstance(parent_predicate, exp.All): parent_predicate = _replace( parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" @@ -197,6 +194,23 @@ def decorrelate(select, parent_select, external_columns, sequence): else: if is_subquery_projection: alias = exp.alias_(alias, select.parent.alias) + + # COUNT always returns 0 on empty datasets, so we need take that into consideration here + # by transforming all counts into 0 and using that as the coalesced value + if value.find(exp.Count): + + def remove_aggs(node): + if isinstance(node, exp.Count): + return exp.Literal.number(0) + elif isinstance(node, exp.AggFunc): + return exp.null() + return node + + alias = exp.Coalesce( + this=alias, + expressions=[value.this.transform(remove_aggs)], + ) + select.parent.replace(alias) for key, column, predicate in keys: @@ -209,9 +223,6 @@ def decorrelate(select, parent_select, external_columns, sequence): if key in group_by: key.replace(nested) - parent_predicate = _replace( - parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)" - ) elif isinstance(predicate, exp.EQ): parent_predicate = _replace( parent_predicate, @@ -245,7 +256,14 @@ def _other_operand(expression): if isinstance(expression, exp.In): return expression.this + if isinstance(expression, (exp.Any, exp.All)): + return _other_operand(expression.parent) + if isinstance(expression, exp.Binary): - return expression.right if expression.arg_key == "this" else expression.left + return ( + expression.right + if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All)) + else expression.left + ) return None |