diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-09-30 05:07:28 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-09-30 05:07:28 +0000 |
commit | 5a674d94c3ab243e2dd6a00f9edf6cc50b018512 (patch) | |
tree | 0b6fe74b5b346f0b048162b56a12885f1a2c2912 /sqlglot/optimizer/merge_derived_tables.py | |
parent | Releasing debian version 6.2.1-1. (diff) | |
download | sqlglot-5a674d94c3ab243e2dd6a00f9edf6cc50b018512.tar.xz sqlglot-5a674d94c3ab243e2dd6a00f9edf6cc50b018512.zip |
Merging upstream version 6.2.6.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/merge_derived_tables.py')
-rw-r--r-- | sqlglot/optimizer/merge_derived_tables.py | 232 |
1 files changed, 0 insertions, 232 deletions
diff --git a/sqlglot/optimizer/merge_derived_tables.py b/sqlglot/optimizer/merge_derived_tables.py deleted file mode 100644 index 8b161fb..0000000 --- a/sqlglot/optimizer/merge_derived_tables.py +++ /dev/null @@ -1,232 +0,0 @@ -from collections import defaultdict - -from sqlglot import expressions as exp -from sqlglot.optimizer.scope import traverse_scope -from sqlglot.optimizer.simplify import simplify - - -def merge_derived_tables(expression): - """ - Rewrite sqlglot AST to merge derived tables into the outer query. - - Example: - >>> import sqlglot - >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)") - >>> merge_derived_tables(expression).sql() - 'SELECT x.a FROM x' - - Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html - - Args: - expression (sqlglot.Expression): expression to optimize - Returns: - sqlglot.Expression: optimized expression - """ - for outer_scope in traverse_scope(expression): - for subquery in outer_scope.derived_tables: - inner_select = subquery.unnest() - if ( - isinstance(outer_scope.expression, exp.Select) - and isinstance(inner_select, exp.Select) - and _mergeable(inner_select) - ): - alias = subquery.alias_or_name - from_or_join = subquery.find_ancestor(exp.From, exp.Join) - inner_scope = outer_scope.sources[alias] - - _rename_inner_sources(outer_scope, inner_scope, alias) - _merge_from(outer_scope, inner_scope, subquery) - _merge_joins(outer_scope, inner_scope, from_or_join) - _merge_expressions(outer_scope, inner_scope, alias) - _merge_where(outer_scope, inner_scope, from_or_join) - _merge_order(outer_scope, inner_scope) - return expression - - -# If a derived table has these Select args, it can't be merged -UNMERGABLE_ARGS = set(exp.Select.arg_types) - { - "expressions", - "from", - "joins", - "where", - "order", -} - - -def _mergeable(inner_select): - """ - Return True if `inner_select` can be merged into outer query. - - Args: - inner_select (exp.Select) - Returns: - bool: True if can be merged - """ - return ( - isinstance(inner_select, exp.Select) - and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) - and inner_select.args.get("from") - and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) - ) - - -def _rename_inner_sources(outer_scope, inner_scope, alias): - """ - Renames any sources in the inner query that conflict with names in the outer query. - - Args: - outer_scope (sqlglot.optimizer.scope.Scope) - inner_scope (sqlglot.optimizer.scope.Scope) - alias (str) - """ - taken = set(outer_scope.selected_sources) - conflicts = taken.intersection(set(inner_scope.selected_sources)) - conflicts = conflicts - {alias} - - for conflict in conflicts: - new_name = _find_new_name(taken, conflict) - - source, _ = inner_scope.selected_sources[conflict] - new_alias = exp.to_identifier(new_name) - - if isinstance(source, exp.Subquery): - source.set("alias", exp.TableAlias(this=new_alias)) - elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias): - source.parent.set("alias", new_alias) - elif isinstance(source, exp.Table): - source.replace(exp.alias_(source.copy(), new_alias)) - - for column in inner_scope.source_columns(conflict): - column.set("table", exp.to_identifier(new_name)) - - inner_scope.rename_source(conflict, new_name) - - -def _find_new_name(taken, base): - """ - Searches for a new source name. - - Args: - taken (set[str]): set of taken names - base (str): base name to alter - """ - i = 2 - new = f"{base}_{i}" - while new in taken: - i += 1 - new = f"{base}_{i}" - return new - - -def _merge_from(outer_scope, inner_scope, subquery): - """ - Merge FROM clause of inner query into outer query. - - Args: - outer_scope (sqlglot.optimizer.scope.Scope) - inner_scope (sqlglot.optimizer.scope.Scope) - subquery (exp.Subquery) - """ - new_subquery = inner_scope.expression.args.get("from").expressions[0] - subquery.replace(new_subquery) - outer_scope.remove_source(subquery.alias_or_name) - outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) - - -def _merge_joins(outer_scope, inner_scope, from_or_join): - """ - Merge JOIN clauses of inner query into outer query. - - Args: - outer_scope (sqlglot.optimizer.scope.Scope) - inner_scope (sqlglot.optimizer.scope.Scope) - from_or_join (exp.From|exp.Join) - """ - - new_joins = [] - comma_joins = inner_scope.expression.args.get("from").expressions[1:] - for subquery in comma_joins: - new_joins.append(exp.Join(this=subquery, kind="CROSS")) - outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name]) - - joins = inner_scope.expression.args.get("joins") or [] - for join in joins: - new_joins.append(join) - outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name]) - - if new_joins: - outer_joins = outer_scope.expression.args.get("joins", []) - - # Maintain the join order - if isinstance(from_or_join, exp.From): - position = 0 - else: - position = outer_joins.index(from_or_join) + 1 - outer_joins[position:position] = new_joins - - outer_scope.expression.set("joins", outer_joins) - - -def _merge_expressions(outer_scope, inner_scope, alias): - """ - Merge projections of inner query into outer query. - - Args: - outer_scope (sqlglot.optimizer.scope.Scope) - inner_scope (sqlglot.optimizer.scope.Scope) - alias (str) - """ - # Collect all columns that for the alias of the inner query - outer_columns = defaultdict(list) - for column in outer_scope.columns: - if column.table == alias: - outer_columns[column.name].append(column) - - # Replace columns with the projection expression in the inner query - for expression in inner_scope.expression.expressions: - projection_name = expression.alias_or_name - if not projection_name: - continue - columns_to_replace = outer_columns.get(projection_name, []) - for column in columns_to_replace: - column.replace(expression.unalias()) - - -def _merge_where(outer_scope, inner_scope, from_or_join): - """ - Merge WHERE clause of inner query into outer query. - - Args: - outer_scope (sqlglot.optimizer.scope.Scope) - inner_scope (sqlglot.optimizer.scope.Scope) - from_or_join (exp.From|exp.Join) - """ - where = inner_scope.expression.args.get("where") - if not where or not where.this: - return - - if isinstance(from_or_join, exp.Join) and from_or_join.side: - # Merge predicates from an outer join to the ON clause - from_or_join.on(where.this, copy=False) - from_or_join.set("on", simplify(from_or_join.args.get("on"))) - else: - outer_scope.expression.where(where.this, copy=False) - outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where"))) - - -def _merge_order(outer_scope, inner_scope): - """ - Merge ORDER clause of inner query into outer query. - - Args: - outer_scope (sqlglot.optimizer.scope.Scope) - inner_scope (sqlglot.optimizer.scope.Scope) - """ - if ( - any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]) - or len(outer_scope.selected_sources) != 1 - or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions) - ): - return - - outer_scope.expression.set("order", inner_scope.expression.args.get("order")) |