From d1f00706bff58b863b0a1c5bf4adf39d36049d4c Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 11 Nov 2022 09:54:35 +0100 Subject: Merging upstream version 10.0.1. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/annotate_types.py | 131 ++++++++++++++++++++++++------ sqlglot/optimizer/eliminate_joins.py | 4 +- sqlglot/optimizer/eliminate_subqueries.py | 12 ++- sqlglot/optimizer/merge_subqueries.py | 16 +++- sqlglot/optimizer/normalize.py | 4 +- sqlglot/optimizer/optimize_joins.py | 6 +- sqlglot/optimizer/optimizer.py | 4 +- sqlglot/optimizer/pushdown_predicates.py | 28 +++++-- sqlglot/optimizer/pushdown_projections.py | 4 +- sqlglot/optimizer/qualify_columns.py | 28 +++++-- sqlglot/optimizer/scope.py | 14 +++- sqlglot/optimizer/simplify.py | 12 ++- sqlglot/optimizer/unnest_subqueries.py | 14 +++- 13 files changed, 219 insertions(+), 58 deletions(-) (limited to 'sqlglot/optimizer') diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 30055bc..96331e2 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,5 +1,5 @@ from sqlglot import exp -from sqlglot.helper import ensure_list, subclasses +from sqlglot.helper import ensure_collection, ensure_list, subclasses from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import ensure_schema @@ -48,35 +48,65 @@ class TypeAnnotator: exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), - exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.BIGINT + ), exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), - exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATETIME + ), + exp.CurrentTime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), + exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), - exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), + exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATETIME + ), + exp.DatetimeSub: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATETIME + ), exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimestampAdd: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), + exp.TimestampSub: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.DateStrToDate: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATE + ), + exp.DateToDateStr: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), + exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), + exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), + exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), + exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.GroupConcat: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), + exp.ArrayConcat: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), @@ -88,32 +118,52 @@ class TypeAnnotator: exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DOUBLE + ), + exp.RegexpLike: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.BOOLEAN + ), exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.StrToTime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), + exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATE + ), + exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.UnixToTime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), + exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.VariancePop: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DOUBLE + ), exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), } @@ -124,7 +174,11 @@ class TypeAnnotator: exp.DataType.Type.TEXT: set(), exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, - exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, + exp.DataType.Type.NCHAR: { + exp.DataType.Type.VARCHAR, + exp.DataType.Type.NVARCHAR, + exp.DataType.Type.TEXT, + }, exp.DataType.Type.CHAR: { exp.DataType.Type.NCHAR, exp.DataType.Type.VARCHAR, @@ -135,7 +189,11 @@ class TypeAnnotator: exp.DataType.Type.DOUBLE: set(), exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, - exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, + exp.DataType.Type.BIGINT: { + exp.DataType.Type.DECIMAL, + exp.DataType.Type.FLOAT, + exp.DataType.Type.DOUBLE, + }, exp.DataType.Type.INT: { exp.DataType.Type.BIGINT, exp.DataType.Type.DECIMAL, @@ -160,7 +218,10 @@ class TypeAnnotator: # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ exp.DataType.Type.TIMESTAMPLTZ: set(), exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, - exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ}, + exp.DataType.Type.TIMESTAMP: { + exp.DataType.Type.TIMESTAMPTZ, + exp.DataType.Type.TIMESTAMPLTZ, + }, exp.DataType.Type.DATETIME: { exp.DataType.Type.TIMESTAMP, exp.DataType.Type.TIMESTAMPTZ, @@ -219,7 +280,7 @@ class TypeAnnotator: def _annotate_args(self, expression): for value in expression.args.values(): - for v in ensure_list(value): + for v in ensure_collection(value): self._maybe_annotate(v) return expression @@ -243,7 +304,9 @@ class TypeAnnotator: if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: expression.type = exp.DataType.Type.NULL elif exp.DataType.Type.NULL in (left_type, right_type): - expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")) + expression.type = exp.DataType.build( + "NULLABLE", expressions=exp.DataType.build("BOOLEAN") + ) else: expression.type = exp.DataType.Type.BOOLEAN elif isinstance(expression, (exp.Condition, exp.Predicate)): @@ -276,3 +339,17 @@ class TypeAnnotator: def _annotate_with_type(self, expression, target_type): expression.type = target_type return self._annotate_args(expression) + + def _annotate_by_args(self, expression, *args): + self._annotate_args(expression) + expressions = [] + for arg in args: + arg_expr = expression.args.get(arg) + expressions.extend(expr for expr in ensure_list(arg_expr) if expr) + + last_datatype = None + for expr in expressions: + last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) + + expression.type = last_datatype or exp.DataType.Type.UNKNOWN + return expression diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index 0854336..29621af 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -60,7 +60,9 @@ def _join_is_used(scope, join, alias): on_clause_columns = set(id(column) for column in on.find_all(exp.Column)) else: on_clause_columns = set() - return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns) + return any( + column for column in scope.source_columns(alias) if id(column) not in on_clause_columns + ) def _is_joined_on_all_unique_outputs(scope, join): diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index e30c263..8704e90 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -45,7 +45,13 @@ def eliminate_subqueries(expression): # All table names are taken for scope in root.traverse(): - taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)}) + taken.update( + { + source.name: source + for _, source in scope.sources.items() + if isinstance(source, exp.Table) + } + ) # Map of Expression->alias # Existing CTES in the root expression. We'll use this for deduplication. @@ -70,7 +76,9 @@ def eliminate_subqueries(expression): new_ctes.append(cte_scope.expression.parent) # Now append the rest - for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes): + for scope in itertools.chain( + root.union_scopes, root.subquery_scopes, root.derived_table_scopes + ): for child_scope in scope.traverse(): new_cte = _eliminate(child_scope, existing_ctes, taken) if new_cte: diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 70e4629..9ae4966 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -122,7 +122,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): unmergable_window_columns = [ column for column in outer_scope.columns - if column.find_ancestor(exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc) + if column.find_ancestor( + exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc + ) ] window_expressions_in_unmergable = [ column @@ -147,7 +149,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): and not ( isinstance(from_or_join, exp.From) and inner_select.args.get("where") - and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])) + and any( + j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []) + ) ) and not _is_a_window_expression_in_unmergable_operation() ) @@ -203,7 +207,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias): if table.alias_or_name == node_to_replace.alias_or_name: table.set("this", exp.to_identifier(new_subquery.alias_or_name)) outer_scope.remove_source(alias) - outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) + outer_scope.add_source( + new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name] + ) def _merge_joins(outer_scope, inner_scope, from_or_join): @@ -296,7 +302,9 @@ def _merge_order(outer_scope, inner_scope): inner_scope (sqlglot.optimizer.scope.Scope) """ if ( - any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]) + any( + outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"] + ) or len(outer_scope.selected_sources) != 1 or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions) ): diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index ab30d7a..db538ef 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -50,7 +50,9 @@ def normalization_distance(expression, dnf=False): Returns: int: difference """ - return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1) + return sum(_predicate_lengths(expression, dnf)) - ( + len(list(expression.find_all(exp.Connector))) + 1 + ) def _predicate_lengths(expression, dnf): diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 0c74e36..40e4ab1 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -68,4 +68,8 @@ def normalize(expression): def other_table_names(join, exclude): - return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude] + return [ + name + for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) + if name != exclude + ] diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 5ad8f46..b2ed062 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -58,6 +58,8 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar # Find any additional rule parameters, beyond `expression` rule_params = rule.__code__.co_varnames - rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs} + rule_kwargs = { + param: possible_kwargs[param] for param in rule_params if param in possible_kwargs + } expression = rule(expression, **rule_kwargs) return expression diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index 583d059..6364f65 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -45,7 +45,11 @@ def pushdown(condition, sources, scope_ref_count): condition = condition.replace(simplify(condition)) cnf_like = normalized(condition) or not normalized(condition, dnf=True) - predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition]) + predicates = list( + condition.flatten() + if isinstance(condition, exp.And if cnf_like else exp.Or) + else [condition] + ) if cnf_like: pushdown_cnf(predicates, sources, scope_ref_count) @@ -104,11 +108,17 @@ def pushdown_dnf(predicates, scope, scope_ref_count): for column in predicate.find_all(exp.Column): if column.table == table: condition = column.find_ancestor(exp.Condition) - predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition + predicate_condition = ( + exp.and_(predicate_condition, condition) + if predicate_condition + else condition + ) if predicate_condition: conditions[table] = ( - exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition + exp.or_(conditions[table], predicate_condition) + if table in conditions + else predicate_condition ) for name, node in nodes.items(): @@ -146,10 +156,16 @@ def nodes_for_predicate(predicate, sources, scope_ref_count): nodes[table] = node elif isinstance(node, exp.Select) and len(tables) == 1: # We can't push down window expressions - has_window_expression = any(select for select in node.selects if select.find(exp.Window)) + has_window_expression = any( + select for select in node.selects if select.find(exp.Window) + ) # we can't push down predicates to select statements if they are referenced in # multiple places. - if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression: + if ( + not node.args.get("group") + and scope_ref_count[id(source)] < 2 + and not has_window_expression + ): nodes[table] = node return nodes @@ -165,7 +181,7 @@ def replace_aliases(source, predicate): def _replace_alias(column): if isinstance(column, exp.Column) and column.name in aliases: - return aliases[column.name] + return aliases[column.name].copy() return column return predicate.transform(_replace_alias) diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 5820851..abd9492 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -98,7 +98,9 @@ def _remove_unused_selections(scope, parent_selections): def _remove_indexed_selections(scope, indexes_to_remove): - new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove] + new_selections = [ + selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove + ] if not new_selections: new_selections.append(DEFAULT_SELECTION) scope.expression.set("expressions", new_selections) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index ebee92a..69fe2b8 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -215,13 +215,21 @@ def _qualify_columns(scope, resolver): # Determine whether each reference in the order by clause is to a column or an alias. for ordered in scope.find_all(exp.Ordered): for column in ordered.find_all(exp.Column): - if not column.table and column.parent is not ordered and column.name in resolver.all_columns: + if ( + not column.table + and column.parent is not ordered + and column.name in resolver.all_columns + ): columns_missing_from_scope.append(column) # Determine whether each reference in the having clause is to a column or an alias. for having in scope.find_all(exp.Having): for column in having.find_all(exp.Column): - if not column.table and column.find_ancestor(exp.AggFunc) and column.name in resolver.all_columns: + if ( + not column.table + and column.find_ancestor(exp.AggFunc) + and column.name in resolver.all_columns + ): columns_missing_from_scope.append(column) for column in columns_missing_from_scope: @@ -295,7 +303,9 @@ def _qualify_outputs(scope): """Ensure all output columns are aliased""" new_selections = [] - for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)): + for i, (selection, aliased_column) in enumerate( + itertools.zip_longest(scope.selects, scope.outer_column_list) + ): if isinstance(selection, exp.Column): # convoluted setter because a simple selection.replace(alias) would require a copy alias_ = alias(exp.column(""), alias=selection.name) @@ -343,14 +353,18 @@ class _Resolver: (str) table name """ if self._unambiguous_columns is None: - self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns()) + self._unambiguous_columns = self._get_unambiguous_columns( + self._get_all_source_columns() + ) return self._unambiguous_columns.get(column_name) @property def all_columns(self): """All available columns of all sources in this scope""" if self._all_columns is None: - self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns) + self._all_columns = set( + column for columns in self._get_all_source_columns().values() for column in columns + ) return self._all_columns def get_source_columns(self, name, only_visible=False): @@ -377,7 +391,9 @@ class _Resolver: def _get_all_source_columns(self): if self._source_columns is None: - self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources} + self._source_columns = { + k: self.get_source_columns(k) for k in self.scope.selected_sources + } return self._source_columns def _get_unambiguous_columns(self, source_columns): diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 5a75ee2..18848f3 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -226,7 +226,9 @@ class Scope: self._ensure_collected() columns = self._raw_columns - external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns] + external_columns = [ + column for scope in self.subquery_scopes for column in scope.external_columns + ] named_outputs = {e.alias_or_name for e in self.expression.expressions} @@ -278,7 +280,11 @@ class Scope: Returns: dict[str, Scope]: Mapping of source alias to Scope """ - return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte} + return { + alias: scope + for alias, scope in self.sources.items() + if isinstance(scope, Scope) and scope.is_cte + } @property def selects(self): @@ -307,7 +313,9 @@ class Scope: sources in the current scope. """ if self._external_columns is None: - self._external_columns = [c for c in self.columns if c.table not in self.selected_sources] + self._external_columns = [ + c for c in self.columns if c.table not in self.selected_sources + ] return self._external_columns @property diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index c077906..d759e86 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -229,7 +229,9 @@ def simplify_literals(expression): operands.append(a) if len(operands) < size: - return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands) + return functools.reduce( + lambda a, b: expression.__class__(this=a, expression=b), operands + ) elif isinstance(expression, exp.Neg): this = expression.this if this.is_number: @@ -255,6 +257,12 @@ def _simplify_binary(expression, a, b): return TRUE if not_ else FALSE if a == NULL: return FALSE if not_ else TRUE + elif isinstance(expression, exp.NullSafeEQ): + if a == b: + return TRUE + elif isinstance(expression, exp.NullSafeNEQ): + if a == b: + return FALSE elif NULL in (a, b): return NULL @@ -357,7 +365,7 @@ def extract_date(cast): def extract_interval(interval): try: - from dateutil.relativedelta import relativedelta + from dateutil.relativedelta import relativedelta # type: ignore except ModuleNotFoundError: return None diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 11c6eba..f41a84e 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -89,7 +89,11 @@ def decorrelate(select, parent_select, external_columns, sequence): return if isinstance(predicate, exp.Binary): - key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left + key = ( + predicate.right + if any(node is column for node, *_ in predicate.left.walk()) + else predicate.left + ) else: return @@ -145,7 +149,9 @@ def decorrelate(select, parent_select, external_columns, sequence): else: parent_predicate = _replace(parent_predicate, "TRUE") elif isinstance(parent_predicate, exp.All): - parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})") + parent_predicate = _replace( + parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" + ) elif isinstance(parent_predicate, exp.Any): if value.this in group_by: parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") @@ -168,7 +174,9 @@ def decorrelate(select, parent_select, external_columns, sequence): if key in group_by: key.replace(nested) - parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)") + parent_predicate = _replace( + parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)" + ) elif isinstance(predicate, exp.EQ): parent_predicate = _replace( parent_predicate, -- cgit v1.2.3