summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py59
-rw-r--r--sqlglot/optimizer/lower_identities.py92
-rw-r--r--sqlglot/optimizer/optimizer.py2
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py36
4 files changed, 172 insertions, 17 deletions
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 8704e90..39e252c 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -68,6 +68,9 @@ def eliminate_subqueries(expression):
for cte_scope in root.cte_scopes:
# Append all the new CTEs from this existing CTE
for scope in cte_scope.traverse():
+ if scope is cte_scope:
+ # Don't try to eliminate this CTE itself
+ continue
new_cte = _eliminate(scope, existing_ctes, taken)
if new_cte:
new_ctes.append(new_cte)
@@ -97,6 +100,9 @@ def _eliminate(scope, existing_ctes, taken):
if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
return _eliminate_derived_table(scope, existing_ctes, taken)
+ if scope.is_cte:
+ return _eliminate_cte(scope, existing_ctes, taken)
+
def _eliminate_union(scope, existing_ctes, taken):
duplicate_cte_alias = existing_ctes.get(scope.expression)
@@ -127,26 +133,61 @@ def _eliminate_union(scope, existing_ctes, taken):
def _eliminate_derived_table(scope, existing_ctes, taken):
+ parent = scope.expression.parent
+ name, cte = _new_cte(scope, existing_ctes, taken)
+
+ table = exp.alias_(exp.table_(name), alias=parent.alias or name)
+ parent.replace(table)
+
+ return cte
+
+
+def _eliminate_cte(scope, existing_ctes, taken):
+ parent = scope.expression.parent
+ name, cte = _new_cte(scope, existing_ctes, taken)
+
+ with_ = parent.parent
+ parent.pop()
+ if not with_.expressions:
+ with_.pop()
+
+ # Rename references to this CTE
+ for child_scope in scope.parent.traverse():
+ for table, source in child_scope.selected_sources.values():
+ if source is scope:
+ new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
+ table.replace(new_table)
+
+ return cte
+
+
+def _new_cte(scope, existing_ctes, taken):
+ """
+ Returns:
+ tuple of (name, cte)
+ where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
+ If this CTE duplicates an existing CTE, `cte` will be None.
+ """
duplicate_cte_alias = existing_ctes.get(scope.expression)
parent = scope.expression.parent
- name = alias = parent.alias
+ name = parent.alias
- if not alias:
- name = alias = find_new_name(taken=taken, base="cte")
+ if not name:
+ name = find_new_name(taken=taken, base="cte")
if duplicate_cte_alias:
name = duplicate_cte_alias
- elif taken.get(alias):
- name = find_new_name(taken=taken, base=alias)
+ elif taken.get(name):
+ name = find_new_name(taken=taken, base=name)
taken[name] = scope
- table = exp.alias_(exp.table_(name), alias=alias)
- parent.replace(table)
-
if not duplicate_cte_alias:
existing_ctes[scope.expression] = name
- return exp.CTE(
+ cte = exp.CTE(
this=scope.expression,
alias=exp.TableAlias(this=exp.to_identifier(name)),
)
+ else:
+ cte = None
+ return name, cte
diff --git a/sqlglot/optimizer/lower_identities.py b/sqlglot/optimizer/lower_identities.py
new file mode 100644
index 0000000..1cc76cf
--- /dev/null
+++ b/sqlglot/optimizer/lower_identities.py
@@ -0,0 +1,92 @@
+from sqlglot import exp
+from sqlglot.helper import ensure_collection
+
+
+def lower_identities(expression):
+ """
+ Convert all unquoted identifiers to lower case.
+
+ Assuming the schema is all lower case, this essentially makes identifiers case-insensitive.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
+ >>> lower_identities(expression).sql()
+ 'SELECT bar.a AS A FROM "Foo".bar'
+
+ Args:
+ expression (sqlglot.Expression): expression to quote
+ Returns:
+ sqlglot.Expression: quoted expression
+ """
+ # We need to leave the output aliases unchanged, so the selects need special handling
+ _lower_selects(expression)
+
+ # These clauses can reference output aliases and also need special handling
+ _lower_order(expression)
+ _lower_having(expression)
+
+ # We've already handled these args, so don't traverse into them
+ traversed = {"expressions", "order", "having"}
+
+ if isinstance(expression, exp.Subquery):
+ # Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1
+ lower_identities(expression.this)
+ traversed |= {"this"}
+
+ if isinstance(expression, exp.Union):
+ # Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X
+ lower_identities(expression.left)
+ lower_identities(expression.right)
+ traversed |= {"this", "expression"}
+
+ for k, v in expression.args.items():
+ if k in traversed:
+ continue
+
+ for child in ensure_collection(v):
+ if isinstance(child, exp.Expression):
+ child.transform(_lower, copy=False)
+
+ return expression
+
+
+def _lower_selects(expression):
+ for e in expression.expressions:
+ # Leave output aliases as-is
+ e.unalias().transform(_lower, copy=False)
+
+
+def _lower_order(expression):
+ order = expression.args.get("order")
+
+ if not order:
+ return
+
+ output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)}
+
+ for ordered in order.expressions:
+ # Don't lower references to output aliases
+ if not (
+ isinstance(ordered.this, exp.Column)
+ and not ordered.this.table
+ and ordered.this.name in output_aliases
+ ):
+ ordered.transform(_lower, copy=False)
+
+
+def _lower_having(expression):
+ having = expression.args.get("having")
+
+ if not having:
+ return
+
+ # Don't lower references to output aliases
+ for agg in having.find_all(exp.AggFunc):
+ agg.transform(_lower, copy=False)
+
+
+def _lower(node):
+ if isinstance(node, exp.Identifier) and not node.quoted:
+ node.set("this", node.this.lower())
+ return node
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index d0e38cd..6819717 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -6,6 +6,7 @@ from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
+from sqlglot.optimizer.lower_identities import lower_identities
from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
@@ -17,6 +18,7 @@ from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
RULES = (
+ lower_identities,
qualify_tables,
isolate_table_selects,
qualify_columns,
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index dbd680b..2046917 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -1,16 +1,15 @@
import itertools
from sqlglot import exp
-from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.optimizer.scope import ScopeType, traverse_scope
def unnest_subqueries(expression):
"""
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
- Convert the subquery into a group by so it is not a many to many left join.
- Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
- Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
+ Convert scalar subqueries into cross joins.
+ Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot
@@ -29,21 +28,43 @@ def unnest_subqueries(expression):
for scope in traverse_scope(expression):
select = scope.expression
parent = select.parent_select
+ if not parent:
+ continue
if scope.external_columns:
decorrelate(select, parent, scope.external_columns, sequence)
- else:
+ elif scope.scope_type == ScopeType.SUBQUERY:
unnest(select, parent, sequence)
return expression
def unnest(select, parent_select, sequence):
- predicate = select.find_ancestor(exp.In, exp.Any)
+ if len(select.selects) > 1:
+ return
+
+ predicate = select.find_ancestor(exp.Condition)
+ alias = _alias(sequence)
if not predicate or parent_select is not predicate.parent_select:
return
- if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
+ # this subquery returns a scalar and can just be converted to a cross join
+ if not isinstance(predicate, (exp.In, exp.Any)):
+ having = predicate.find_ancestor(exp.Having)
+ column = exp.column(select.selects[0].alias_or_name, alias)
+ if having and having.parent_select is parent_select:
+ column = exp.Max(this=column)
+ _replace(select.parent, column)
+
+ parent_select.join(
+ select,
+ join_type="CROSS",
+ join_alias=alias,
+ copy=False,
+ )
+ return
+
+ if select.find(exp.Limit, exp.Offset):
return
if isinstance(predicate, exp.Any):
@@ -54,7 +75,6 @@ def unnest(select, parent_select, sequence):
column = _other_operand(predicate)
value = select.selects[0]
- alias = _alias(sequence)
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
_replace(predicate, f"NOT {on.right} IS NULL")