diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/optimizer/eliminate_subqueries.py | 59 | ||||
-rw-r--r-- | sqlglot/optimizer/lower_identities.py | 92 | ||||
-rw-r--r-- | sqlglot/optimizer/optimizer.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/unnest_subqueries.py | 36 |
4 files changed, 172 insertions, 17 deletions
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 8704e90..39e252c 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -68,6 +68,9 @@ def eliminate_subqueries(expression): for cte_scope in root.cte_scopes: # Append all the new CTEs from this existing CTE for scope in cte_scope.traverse(): + if scope is cte_scope: + # Don't try to eliminate this CTE itself + continue new_cte = _eliminate(scope, existing_ctes, taken) if new_cte: new_ctes.append(new_cte) @@ -97,6 +100,9 @@ def _eliminate(scope, existing_ctes, taken): if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF): return _eliminate_derived_table(scope, existing_ctes, taken) + if scope.is_cte: + return _eliminate_cte(scope, existing_ctes, taken) + def _eliminate_union(scope, existing_ctes, taken): duplicate_cte_alias = existing_ctes.get(scope.expression) @@ -127,26 +133,61 @@ def _eliminate_union(scope, existing_ctes, taken): def _eliminate_derived_table(scope, existing_ctes, taken): + parent = scope.expression.parent + name, cte = _new_cte(scope, existing_ctes, taken) + + table = exp.alias_(exp.table_(name), alias=parent.alias or name) + parent.replace(table) + + return cte + + +def _eliminate_cte(scope, existing_ctes, taken): + parent = scope.expression.parent + name, cte = _new_cte(scope, existing_ctes, taken) + + with_ = parent.parent + parent.pop() + if not with_.expressions: + with_.pop() + + # Rename references to this CTE + for child_scope in scope.parent.traverse(): + for table, source in child_scope.selected_sources.values(): + if source is scope: + new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name) + table.replace(new_table) + + return cte + + +def _new_cte(scope, existing_ctes, taken): + """ + Returns: + tuple of (name, cte) + where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance. + If this CTE duplicates an existing CTE, `cte` will be None. + """ duplicate_cte_alias = existing_ctes.get(scope.expression) parent = scope.expression.parent - name = alias = parent.alias + name = parent.alias - if not alias: - name = alias = find_new_name(taken=taken, base="cte") + if not name: + name = find_new_name(taken=taken, base="cte") if duplicate_cte_alias: name = duplicate_cte_alias - elif taken.get(alias): - name = find_new_name(taken=taken, base=alias) + elif taken.get(name): + name = find_new_name(taken=taken, base=name) taken[name] = scope - table = exp.alias_(exp.table_(name), alias=alias) - parent.replace(table) - if not duplicate_cte_alias: existing_ctes[scope.expression] = name - return exp.CTE( + cte = exp.CTE( this=scope.expression, alias=exp.TableAlias(this=exp.to_identifier(name)), ) + else: + cte = None + return name, cte diff --git a/sqlglot/optimizer/lower_identities.py b/sqlglot/optimizer/lower_identities.py new file mode 100644 index 0000000..1cc76cf --- /dev/null +++ b/sqlglot/optimizer/lower_identities.py @@ -0,0 +1,92 @@ +from sqlglot import exp +from sqlglot.helper import ensure_collection + + +def lower_identities(expression): + """ + Convert all unquoted identifiers to lower case. + + Assuming the schema is all lower case, this essentially makes identifiers case-insensitive. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar') + >>> lower_identities(expression).sql() + 'SELECT bar.a AS A FROM "Foo".bar' + + Args: + expression (sqlglot.Expression): expression to quote + Returns: + sqlglot.Expression: quoted expression + """ + # We need to leave the output aliases unchanged, so the selects need special handling + _lower_selects(expression) + + # These clauses can reference output aliases and also need special handling + _lower_order(expression) + _lower_having(expression) + + # We've already handled these args, so don't traverse into them + traversed = {"expressions", "order", "having"} + + if isinstance(expression, exp.Subquery): + # Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1 + lower_identities(expression.this) + traversed |= {"this"} + + if isinstance(expression, exp.Union): + # Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X + lower_identities(expression.left) + lower_identities(expression.right) + traversed |= {"this", "expression"} + + for k, v in expression.args.items(): + if k in traversed: + continue + + for child in ensure_collection(v): + if isinstance(child, exp.Expression): + child.transform(_lower, copy=False) + + return expression + + +def _lower_selects(expression): + for e in expression.expressions: + # Leave output aliases as-is + e.unalias().transform(_lower, copy=False) + + +def _lower_order(expression): + order = expression.args.get("order") + + if not order: + return + + output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)} + + for ordered in order.expressions: + # Don't lower references to output aliases + if not ( + isinstance(ordered.this, exp.Column) + and not ordered.this.table + and ordered.this.name in output_aliases + ): + ordered.transform(_lower, copy=False) + + +def _lower_having(expression): + having = expression.args.get("having") + + if not having: + return + + # Don't lower references to output aliases + for agg in having.find_all(exp.AggFunc): + agg.transform(_lower, copy=False) + + +def _lower(node): + if isinstance(node, exp.Identifier) and not node.quoted: + node.set("this", node.this.lower()) + return node diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index d0e38cd..6819717 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -6,6 +6,7 @@ from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects +from sqlglot.optimizer.lower_identities import lower_identities from sqlglot.optimizer.merge_subqueries import merge_subqueries from sqlglot.optimizer.normalize import normalize from sqlglot.optimizer.optimize_joins import optimize_joins @@ -17,6 +18,7 @@ from sqlglot.optimizer.quote_identities import quote_identities from sqlglot.optimizer.unnest_subqueries import unnest_subqueries RULES = ( + lower_identities, qualify_tables, isolate_table_selects, qualify_columns, diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index dbd680b..2046917 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -1,16 +1,15 @@ import itertools from sqlglot import exp -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.scope import ScopeType, traverse_scope def unnest_subqueries(expression): """ Rewrite sqlglot AST to convert some predicates with subqueries into joins. - Convert the subquery into a group by so it is not a many to many left join. - Unnesting can only occur if the subquery does not have LIMIT or OFFSET. - Unnesting non correlated subqueries only happens on IN statements or = ANY statements. + Convert scalar subqueries into cross joins. + Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. Example: >>> import sqlglot @@ -29,21 +28,43 @@ def unnest_subqueries(expression): for scope in traverse_scope(expression): select = scope.expression parent = select.parent_select + if not parent: + continue if scope.external_columns: decorrelate(select, parent, scope.external_columns, sequence) - else: + elif scope.scope_type == ScopeType.SUBQUERY: unnest(select, parent, sequence) return expression def unnest(select, parent_select, sequence): - predicate = select.find_ancestor(exp.In, exp.Any) + if len(select.selects) > 1: + return + + predicate = select.find_ancestor(exp.Condition) + alias = _alias(sequence) if not predicate or parent_select is not predicate.parent_select: return - if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset): + # this subquery returns a scalar and can just be converted to a cross join + if not isinstance(predicate, (exp.In, exp.Any)): + having = predicate.find_ancestor(exp.Having) + column = exp.column(select.selects[0].alias_or_name, alias) + if having and having.parent_select is parent_select: + column = exp.Max(this=column) + _replace(select.parent, column) + + parent_select.join( + select, + join_type="CROSS", + join_alias=alias, + copy=False, + ) + return + + if select.find(exp.Limit, exp.Offset): return if isinstance(predicate, exp.Any): @@ -54,7 +75,6 @@ def unnest(select, parent_select, sequence): column = _other_operand(predicate) value = select.selects[0] - alias = _alias(sequence) on = exp.condition(f'{column} = "{alias}"."{value.alias}"') _replace(predicate, f"NOT {on.right} IS NULL") |