diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-09-22 04:31:28 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-09-22 04:31:28 +0000 |
commit | 90150543f9314be683d22a16339effd774192f6d (patch) | |
tree | 0717782154823582e47cd23fd4e79f7b1b91c093 /sqlglot/optimizer | |
parent | Adding debian version 6.0.4-1. (diff) | |
download | sqlglot-90150543f9314be683d22a16339effd774192f6d.tar.xz sqlglot-90150543f9314be683d22a16339effd774192f6d.zip |
Merging upstream version 6.1.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/__init__.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/isolate_table_selects.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/merge_derived_tables.py | 232 | ||||
-rw-r--r-- | sqlglot/optimizer/normalize.py | 22 | ||||
-rw-r--r-- | sqlglot/optimizer/optimize_joins.py | 6 | ||||
-rw-r--r-- | sqlglot/optimizer/optimizer.py | 39 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_predicates.py | 20 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 36 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_tables.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/schema.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 58 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 8 | ||||
-rw-r--r-- | sqlglot/optimizer/unnest_subqueries.py | 22 |
13 files changed, 310 insertions, 147 deletions
diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py index a4c4cc2..d1146ca 100644 --- a/sqlglot/optimizer/__init__.py +++ b/sqlglot/optimizer/__init__.py @@ -1,2 +1,2 @@ -from sqlglot.optimizer.optimizer import optimize +from sqlglot.optimizer.optimizer import RULES, optimize from sqlglot.optimizer.schema import Schema diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py index c2e021e..e060739 100644 --- a/sqlglot/optimizer/isolate_table_selects.py +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -13,9 +13,7 @@ def isolate_table_selects(expression): continue if not isinstance(source.parent, exp.Alias): - raise OptimizeError( - "Tables require an alias. Run qualify_tables optimization." - ) + raise OptimizeError("Tables require an alias. Run qualify_tables optimization.") parent = source.parent diff --git a/sqlglot/optimizer/merge_derived_tables.py b/sqlglot/optimizer/merge_derived_tables.py new file mode 100644 index 0000000..8b161fb --- /dev/null +++ b/sqlglot/optimizer/merge_derived_tables.py @@ -0,0 +1,232 @@ +from collections import defaultdict + +from sqlglot import expressions as exp +from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.simplify import simplify + + +def merge_derived_tables(expression): + """ + Rewrite sqlglot AST to merge derived tables into the outer query. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)") + >>> merge_derived_tables(expression).sql() + 'SELECT x.a FROM x' + + Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + for outer_scope in traverse_scope(expression): + for subquery in outer_scope.derived_tables: + inner_select = subquery.unnest() + if ( + isinstance(outer_scope.expression, exp.Select) + and isinstance(inner_select, exp.Select) + and _mergeable(inner_select) + ): + alias = subquery.alias_or_name + from_or_join = subquery.find_ancestor(exp.From, exp.Join) + inner_scope = outer_scope.sources[alias] + + _rename_inner_sources(outer_scope, inner_scope, alias) + _merge_from(outer_scope, inner_scope, subquery) + _merge_joins(outer_scope, inner_scope, from_or_join) + _merge_expressions(outer_scope, inner_scope, alias) + _merge_where(outer_scope, inner_scope, from_or_join) + _merge_order(outer_scope, inner_scope) + return expression + + +# If a derived table has these Select args, it can't be merged +UNMERGABLE_ARGS = set(exp.Select.arg_types) - { + "expressions", + "from", + "joins", + "where", + "order", +} + + +def _mergeable(inner_select): + """ + Return True if `inner_select` can be merged into outer query. + + Args: + inner_select (exp.Select) + Returns: + bool: True if can be merged + """ + return ( + isinstance(inner_select, exp.Select) + and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) + and inner_select.args.get("from") + and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) + ) + + +def _rename_inner_sources(outer_scope, inner_scope, alias): + """ + Renames any sources in the inner query that conflict with names in the outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + alias (str) + """ + taken = set(outer_scope.selected_sources) + conflicts = taken.intersection(set(inner_scope.selected_sources)) + conflicts = conflicts - {alias} + + for conflict in conflicts: + new_name = _find_new_name(taken, conflict) + + source, _ = inner_scope.selected_sources[conflict] + new_alias = exp.to_identifier(new_name) + + if isinstance(source, exp.Subquery): + source.set("alias", exp.TableAlias(this=new_alias)) + elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias): + source.parent.set("alias", new_alias) + elif isinstance(source, exp.Table): + source.replace(exp.alias_(source.copy(), new_alias)) + + for column in inner_scope.source_columns(conflict): + column.set("table", exp.to_identifier(new_name)) + + inner_scope.rename_source(conflict, new_name) + + +def _find_new_name(taken, base): + """ + Searches for a new source name. + + Args: + taken (set[str]): set of taken names + base (str): base name to alter + """ + i = 2 + new = f"{base}_{i}" + while new in taken: + i += 1 + new = f"{base}_{i}" + return new + + +def _merge_from(outer_scope, inner_scope, subquery): + """ + Merge FROM clause of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + subquery (exp.Subquery) + """ + new_subquery = inner_scope.expression.args.get("from").expressions[0] + subquery.replace(new_subquery) + outer_scope.remove_source(subquery.alias_or_name) + outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) + + +def _merge_joins(outer_scope, inner_scope, from_or_join): + """ + Merge JOIN clauses of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + from_or_join (exp.From|exp.Join) + """ + + new_joins = [] + comma_joins = inner_scope.expression.args.get("from").expressions[1:] + for subquery in comma_joins: + new_joins.append(exp.Join(this=subquery, kind="CROSS")) + outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name]) + + joins = inner_scope.expression.args.get("joins") or [] + for join in joins: + new_joins.append(join) + outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name]) + + if new_joins: + outer_joins = outer_scope.expression.args.get("joins", []) + + # Maintain the join order + if isinstance(from_or_join, exp.From): + position = 0 + else: + position = outer_joins.index(from_or_join) + 1 + outer_joins[position:position] = new_joins + + outer_scope.expression.set("joins", outer_joins) + + +def _merge_expressions(outer_scope, inner_scope, alias): + """ + Merge projections of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + alias (str) + """ + # Collect all columns that for the alias of the inner query + outer_columns = defaultdict(list) + for column in outer_scope.columns: + if column.table == alias: + outer_columns[column.name].append(column) + + # Replace columns with the projection expression in the inner query + for expression in inner_scope.expression.expressions: + projection_name = expression.alias_or_name + if not projection_name: + continue + columns_to_replace = outer_columns.get(projection_name, []) + for column in columns_to_replace: + column.replace(expression.unalias()) + + +def _merge_where(outer_scope, inner_scope, from_or_join): + """ + Merge WHERE clause of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + from_or_join (exp.From|exp.Join) + """ + where = inner_scope.expression.args.get("where") + if not where or not where.this: + return + + if isinstance(from_or_join, exp.Join) and from_or_join.side: + # Merge predicates from an outer join to the ON clause + from_or_join.on(where.this, copy=False) + from_or_join.set("on", simplify(from_or_join.args.get("on"))) + else: + outer_scope.expression.where(where.this, copy=False) + outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where"))) + + +def _merge_order(outer_scope, inner_scope): + """ + Merge ORDER clause of inner query into outer query. + + Args: + outer_scope (sqlglot.optimizer.scope.Scope) + inner_scope (sqlglot.optimizer.scope.Scope) + """ + if ( + any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]) + or len(outer_scope.selected_sources) != 1 + or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions) + ): + return + + outer_scope.expression.set("order", inner_scope.expression.args.get("order")) diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index 2c9f89c..ab30d7a 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -22,18 +22,14 @@ def normalize(expression, dnf=False, max_distance=128): """ expression = simplify(expression) - expression = while_changing( - expression, lambda e: distributive_law(e, dnf, max_distance) - ) + expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance)) return simplify(expression) def normalized(expression, dnf=False): ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) - return not any( - connector.find_ancestor(ancestor) for connector in expression.find_all(root) - ) + return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root)) def normalization_distance(expression, dnf=False): @@ -54,9 +50,7 @@ def normalization_distance(expression, dnf=False): Returns: int: difference """ - return sum(_predicate_lengths(expression, dnf)) - ( - len(list(expression.find_all(exp.Connector))) + 1 - ) + return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1) def _predicate_lengths(expression, dnf): @@ -73,11 +67,7 @@ def _predicate_lengths(expression, dnf): left, right = expression.args.values() if isinstance(expression, exp.And if dnf else exp.Or): - x = [ - a + b - for a in _predicate_lengths(left, dnf) - for b in _predicate_lengths(right, dnf) - ] + x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)] return x return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) @@ -102,9 +92,7 @@ def distributive_law(expression, dnf, max_distance): to_func = exp.and_ if to_exp == exp.And else exp.or_ if isinstance(a, to_exp) and isinstance(b, to_exp): - if len(tuple(a.find_all(exp.Connector))) > len( - tuple(b.find_all(exp.Connector)) - ): + if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): return _distribute(a, b, from_func, to_func) return _distribute(b, a, from_func, to_func) if isinstance(a, to_exp): diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 40e4ab1..0c74e36 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -68,8 +68,4 @@ def normalize(expression): def other_table_names(join, exclude): - return [ - name - for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) - if name != exclude - ] + return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude] diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index c03fe3c..c8c2403 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -1,6 +1,7 @@ from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects +from sqlglot.optimizer.merge_derived_tables import merge_derived_tables from sqlglot.optimizer.normalize import normalize from sqlglot.optimizer.optimize_joins import optimize_joins from sqlglot.optimizer.pushdown_predicates import pushdown_predicates @@ -10,8 +11,23 @@ from sqlglot.optimizer.qualify_tables import qualify_tables from sqlglot.optimizer.quote_identities import quote_identities from sqlglot.optimizer.unnest_subqueries import unnest_subqueries +RULES = ( + qualify_tables, + isolate_table_selects, + qualify_columns, + pushdown_projections, + normalize, + unnest_subqueries, + expand_multi_table_selects, + pushdown_predicates, + optimize_joins, + eliminate_subqueries, + merge_derived_tables, + quote_identities, +) -def optimize(expression, schema=None, db=None, catalog=None): + +def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwargs): """ Rewrite a sqlglot AST into an optimized form. @@ -25,19 +41,18 @@ def optimize(expression, schema=None, db=None, catalog=None): 3. {catalog: {db: {table: {col: type}}}} db (str): specify the default database, as might be set by a `USE DATABASE db` statement catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement + rules (list): sequence of optimizer rules to use + **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in. Returns: sqlglot.Expression: optimized expression """ + possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs} expression = expression.copy() - expression = qualify_tables(expression, db=db, catalog=catalog) - expression = isolate_table_selects(expression) - expression = qualify_columns(expression, schema) - expression = pushdown_projections(expression) - expression = normalize(expression) - expression = unnest_subqueries(expression) - expression = expand_multi_table_selects(expression) - expression = pushdown_predicates(expression) - expression = optimize_joins(expression) - expression = eliminate_subqueries(expression) - expression = quote_identities(expression) + for rule in rules: + + # Find any additional rule parameters, beyond `expression` + rule_params = rule.__code__.co_varnames + rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs} + + expression = rule(expression, **rule_kwargs) return expression diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index e757322..a070d70 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -42,11 +42,7 @@ def pushdown(condition, sources): condition = condition.replace(simplify(condition)) cnf_like = normalized(condition) or not normalized(condition, dnf=True) - predicates = list( - condition.flatten() - if isinstance(condition, exp.And if cnf_like else exp.Or) - else [condition] - ) + predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition]) if cnf_like: pushdown_cnf(predicates, sources) @@ -105,17 +101,11 @@ def pushdown_dnf(predicates, scope): for column in predicate.find_all(exp.Column): if column.table == table: condition = column.find_ancestor(exp.Condition) - predicate_condition = ( - exp.and_(predicate_condition, condition) - if predicate_condition - else condition - ) + predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition if predicate_condition: conditions[table] = ( - exp.or_(conditions[table], predicate_condition) - if table in conditions - else predicate_condition + exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition ) for name, node in nodes.items(): @@ -133,9 +123,7 @@ def pushdown_dnf(predicates, scope): def nodes_for_predicate(predicate, sources): nodes = {} tables = exp.column_table_names(predicate) - where_condition = isinstance( - predicate.find_ancestor(exp.Join, exp.Where), exp.Where - ) + where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) for table in tables: node, source = sources.get(table) or (None, None) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 394f49e..0bb947a 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -226,9 +226,7 @@ def _expand_stars(scope, resolver): tables = list(scope.selected_sources) _add_except_columns(expression, tables, except_columns) _add_replace_columns(expression, tables, replace_columns) - elif isinstance(expression, exp.Column) and isinstance( - expression.this, exp.Star - ): + elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star): tables = [expression.table] _add_except_columns(expression.this, tables, except_columns) _add_replace_columns(expression.this, tables, replace_columns) @@ -245,9 +243,7 @@ def _expand_stars(scope, resolver): if name not in except_columns.get(table_id, set()): alias_ = replace_columns.get(table_id, {}).get(name, name) column = exp.column(name, table) - new_selections.append( - alias(column, alias_) if alias_ != name else column - ) + new_selections.append(alias(column, alias_) if alias_ != name else column) scope.expression.set("expressions", new_selections) @@ -280,9 +276,7 @@ def _qualify_outputs(scope): """Ensure all output columns are aliased""" new_selections = [] - for i, (selection, aliased_column) in enumerate( - itertools.zip_longest(scope.selects, scope.outer_column_list) - ): + for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)): if isinstance(selection, exp.Column): # convoluted setter because a simple selection.replace(alias) would require a copy alias_ = alias(exp.column(""), alias=selection.name) @@ -302,11 +296,7 @@ def _qualify_outputs(scope): def _check_unknown_tables(scope): - if ( - scope.external_columns - and not scope.is_unnest - and not scope.is_correlated_subquery - ): + if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery: raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}") @@ -334,20 +324,14 @@ class _Resolver: (str) table name """ if self._unambiguous_columns is None: - self._unambiguous_columns = self._get_unambiguous_columns( - self._get_all_source_columns() - ) + self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns()) return self._unambiguous_columns.get(column_name) @property def all_columns(self): """All available columns of all sources in this scope""" if self._all_columns is None: - self._all_columns = set( - column - for columns in self._get_all_source_columns().values() - for column in columns - ) + self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns) return self._all_columns def get_source_columns(self, name): @@ -369,9 +353,7 @@ class _Resolver: def _get_all_source_columns(self): if self._source_columns is None: - self._source_columns = { - k: self.get_source_columns(k) for k in self.scope.selected_sources - } + self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources} return self._source_columns def _get_unambiguous_columns(self, source_columns): @@ -389,9 +371,7 @@ class _Resolver: source_columns = list(source_columns.items()) first_table, first_columns = source_columns[0] - unambiguous_columns = { - col: first_table for col in self._find_unique_columns(first_columns) - } + unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} all_columns = set(unambiguous_columns) for table, columns in source_columns[1:]: diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 9f8b9f5..30e93ba 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -27,9 +27,7 @@ def qualify_tables(expression, db=None, catalog=None): for derived_table in scope.ctes + scope.derived_tables: if not derived_table.args.get("alias"): alias_ = f"_q_{next(sequence)}" - derived_table.set( - "alias", exp.TableAlias(this=exp.to_identifier(alias_)) - ) + derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) scope.rename_source(None, alias_) for source in scope.sources.values(): diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py index 9968108..1761228 100644 --- a/sqlglot/optimizer/schema.py +++ b/sqlglot/optimizer/schema.py @@ -57,9 +57,7 @@ class MappingSchema(Schema): for forbidden in self.forbidden_args: if table.text(forbidden): - raise ValueError( - f"Schema doesn't support {forbidden}. Received: {table.sql()}" - ) + raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") return list(_nested_get(self.schema, *zip(self.supported_table_args, args))) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index f6f59e8..e816e10 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -104,9 +104,7 @@ class Scope: elif isinstance(node, exp.CTE): self._ctes.append(node) prune = True - elif isinstance(node, exp.Subquery) and isinstance( - parent, (exp.From, exp.Join) - ): + elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): self._derived_tables.append(node) prune = True elif isinstance(node, exp.Subqueryable): @@ -195,20 +193,14 @@ class Scope: self._ensure_collected() columns = self._raw_columns - external_columns = [ - column - for scope in self.subquery_scopes - for column in scope.external_columns - ] + external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns] named_outputs = {e.alias_or_name for e in self.expression.expressions} self._columns = [ c for c in columns + external_columns - if not ( - c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs - ) + if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs) ] return self._columns @@ -229,9 +221,7 @@ class Scope: for table in self.tables: referenced_names.append( ( - table.parent.alias - if isinstance(table.parent, exp.Alias) - else table.name, + table.parent.alias if isinstance(table.parent, exp.Alias) else table.name, table, ) ) @@ -274,9 +264,7 @@ class Scope: sources in the current scope. """ if self._external_columns is None: - self._external_columns = [ - c for c in self.columns if c.table not in self.selected_sources - ] + self._external_columns = [c for c in self.columns if c.table not in self.selected_sources] return self._external_columns def source_columns(self, source_name): @@ -310,6 +298,16 @@ class Scope: columns = self.sources.pop(old_name or "", []) self.sources[new_name] = columns + def add_source(self, name, source): + """Add a source to this scope""" + self.sources[name] = source + self.clear_cache() + + def remove_source(self, name): + """Remove a source from this scope""" + self.sources.pop(name, None) + self.clear_cache() + def traverse_scope(expression): """ @@ -334,7 +332,7 @@ def traverse_scope(expression): Args: expression (exp.Expression): expression to traverse Returns: - List[Scope]: scope instances + list[Scope]: scope instances """ return list(_traverse_scope(Scope(expression))) @@ -356,9 +354,7 @@ def _traverse_scope(scope): def _traverse_select(scope): yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE) yield from _traverse_subqueries(scope) - yield from _traverse_derived_tables( - scope.derived_tables, scope, ScopeType.DERIVED_TABLE - ) + yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE) _add_table_sources(scope) @@ -367,15 +363,11 @@ def _traverse_union(scope): # The last scope to be yield should be the top most scope left = None - for left in _traverse_scope( - scope.branch(scope.expression.left, scope_type=ScopeType.UNION) - ): + for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): yield left right = None - for right in _traverse_scope( - scope.branch(scope.expression.right, scope_type=ScopeType.UNION) - ): + for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): yield right scope.union = (left, right) @@ -387,14 +379,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type): for derived_table in derived_tables: for child_scope in _traverse_scope( scope.branch( - derived_table - if isinstance(derived_table, (exp.Unnest, exp.Lateral)) - else derived_table.this, + derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this, add_sources=sources if scope_type == ScopeType.CTE else None, outer_column_list=derived_table.alias_column_names, - scope_type=ScopeType.UNNEST - if isinstance(derived_table, exp.Unnest) - else scope_type, + scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type, ) ): yield child_scope @@ -430,9 +418,7 @@ def _add_table_sources(scope): def _traverse_subqueries(scope): for subquery in scope.subqueries: top = None - for child_scope in _traverse_scope( - scope.branch(subquery, scope_type=ScopeType.SUBQUERY) - ): + for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): yield child_scope top = child_scope scope.subquery_scopes.append(top) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 6771153..319e6b6 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -188,9 +188,7 @@ def absorb_and_eliminate(expression): aa.replace(exp.TRUE if kind == exp.And else exp.FALSE) elif is_complement(b, ab): ab.replace(exp.TRUE if kind == exp.And else exp.FALSE) - elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set( - a.flatten() - ): + elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): a.replace(exp.FALSE if kind == exp.And else exp.TRUE) elif isinstance(b, kind): # eliminate @@ -227,9 +225,7 @@ def simplify_literals(expression): operands.append(a) if len(operands) < size: - return functools.reduce( - lambda a, b: expression.__class__(this=a, expression=b), operands - ) + return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands) elif isinstance(expression, exp.Neg): this = expression.this if this.is_number: diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 55c81c5..11c6eba 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -89,11 +89,7 @@ def decorrelate(select, parent_select, external_columns, sequence): return if isinstance(predicate, exp.Binary): - key = ( - predicate.right - if any(node is column for node, *_ in predicate.left.walk()) - else predicate.left - ) + key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left else: return @@ -124,9 +120,7 @@ def decorrelate(select, parent_select, external_columns, sequence): # if the value of the subquery is not an agg or a key, we need to collect it into an array # so that it can be grouped if not value.find(exp.AggFunc) and value.this not in group_by: - select.select( - f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False - ) + select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False) # exists queries should not have any selects as it only checks if there are any rows # all selects will be added by the optimizer and only used for join keys @@ -151,16 +145,12 @@ def decorrelate(select, parent_select, external_columns, sequence): else: parent_predicate = _replace(parent_predicate, "TRUE") elif isinstance(parent_predicate, exp.All): - parent_predicate = _replace( - parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" - ) + parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})") elif isinstance(parent_predicate, exp.Any): if value.this in group_by: parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") else: - parent_predicate = _replace( - parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})" - ) + parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") elif isinstance(parent_predicate, exp.In): if value.this in group_by: parent_predicate = _replace(parent_predicate, f"{other} = {alias}") @@ -178,9 +168,7 @@ def decorrelate(select, parent_select, external_columns, sequence): if key in group_by: key.replace(nested) - parent_predicate = _replace( - parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)" - ) + parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)") elif isinstance(predicate, exp.EQ): parent_predicate = _replace( parent_predicate, |