summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-01-04 07:24:08 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-01-04 07:24:08 +0000
commit7a2201963d5b03bd1828d350ccaecb4eda30d30c (patch)
tree19effbe90b8d78fdcb5f7d4bd0dd46b177ffdaab /sqlglot/optimizer
parentReleasing debian version 10.2.9-1. (diff)
downloadsqlglot-7a2201963d5b03bd1828d350ccaecb4eda30d30c.tar.xz
sqlglot-7a2201963d5b03bd1828d350ccaecb4eda30d30c.zip
Merging upstream version 10.4.2.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/canonicalize.py3
-rw-r--r--sqlglot/optimizer/eliminate_joins.py49
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py4
-rw-r--r--sqlglot/optimizer/normalize.py5
-rw-r--r--sqlglot/optimizer/optimizer.py2
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py14
-rw-r--r--sqlglot/optimizer/pushdown_projections.py6
-rw-r--r--sqlglot/optimizer/qualify_columns.py3
-rw-r--r--sqlglot/optimizer/quote_identities.py25
-rw-r--r--sqlglot/optimizer/scope.py30
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py27
11 files changed, 108 insertions, 60 deletions
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index 33529a5..fc37a54 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -18,6 +18,9 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
+ if isinstance(expression, exp.Identifier):
+ expression.set("quoted", True)
+
return expression
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index de4e011..3b40710 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -129,10 +129,23 @@ def join_condition(join):
"""
name = join.this.alias_or_name
on = (join.args.get("on") or exp.true()).copy()
- on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
source_key = []
join_key = []
+ def extract_condition(condition):
+ left, right = condition.unnest_operands()
+ left_tables = exp.column_table_names(left)
+ right_tables = exp.column_table_names(right)
+
+ if name in left_tables and name not in right_tables:
+ join_key.append(left)
+ source_key.append(right)
+ condition.replace(exp.true())
+ elif name in right_tables and name not in left_tables:
+ join_key.append(right)
+ source_key.append(left)
+ condition.replace(exp.true())
+
# find the join keys
# SELECT
# FROM x
@@ -141,20 +154,30 @@ def join_condition(join):
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
+ on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
+
for condition in on.flatten():
if isinstance(condition, exp.EQ):
- left, right = condition.unnest_operands()
- left_tables = exp.column_table_names(left)
- right_tables = exp.column_table_names(right)
-
- if name in left_tables and name not in right_tables:
- join_key.append(left)
- source_key.append(right)
- condition.replace(exp.true())
- elif name in right_tables and name not in left_tables:
- join_key.append(right)
- source_key.append(left)
- condition.replace(exp.true())
+ extract_condition(condition)
+ elif normalized(on, dnf=True):
+ conditions = None
+
+ for condition in on.flatten():
+ parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
+ if conditions is None:
+ conditions = parts
+ else:
+ temp = []
+ for p in parts:
+ cs = [c for c in conditions if p == c]
+
+ if cs:
+ temp.append(p)
+ temp.extend(cs)
+ conditions = temp
+
+ for condition in conditions:
+ extract_condition(condition)
on = simplify(on)
remaining_condition = None if on == exp.true() else on
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 39e252c..2245cc2 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -58,7 +58,9 @@ def eliminate_subqueries(expression):
existing_ctes = {}
with_ = root.expression.args.get("with")
+ recursive = False
if with_:
+ recursive = with_.args.get("recursive")
for cte in with_.expressions:
existing_ctes[cte.this] = cte.alias
new_ctes = []
@@ -88,7 +90,7 @@ def eliminate_subqueries(expression):
new_ctes.append(new_cte)
if new_ctes:
- expression.set("with", exp.With(expressions=new_ctes))
+ expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
return expression
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index db538ef..f16f519 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -69,8 +69,9 @@ def _predicate_lengths(expression, dnf):
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
- x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)]
- return x
+ return [
+ a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
+ ]
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index 6819717..72e67d4 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -14,7 +14,6 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
-from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
RULES = (
@@ -34,7 +33,6 @@ RULES = (
eliminate_ctes,
annotate_types,
canonicalize,
- quote_identities,
)
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index f92e5c3..ba5c8b5 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -27,7 +27,14 @@ def pushdown_predicates(expression):
select = scope.expression
where = select.args.get("where")
if where:
- pushdown(where.this, scope.selected_sources, scope_ref_count)
+ selected_sources = scope.selected_sources
+ # a right join can only push down to itself and not the source FROM table
+ for k, (node, source) in selected_sources.items():
+ parent = node.find_ancestor(exp.Join, exp.From)
+ if isinstance(parent, exp.Join) and parent.side == "RIGHT":
+ selected_sources = {k: (node, source)}
+ break
+ pushdown(where.this, selected_sources, scope_ref_count)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
@@ -148,10 +155,13 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
# a node can reference a CTE which should be pushed down
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
+ with_ = source.parent.expression.args.get("with")
+ if with_ and with_.recursive:
+ return {}
node = source.expression
if isinstance(node, exp.Join):
- if node.side:
+ if node.side and node.side != "RIGHT":
return {}
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index abd9492..49789ac 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -6,7 +6,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
-# SELECTION TO USE IF SELECTION LIST IS EMPTY
+# Selection to use if selection list is empty
DEFAULT_SELECTION = alias("1", "_")
@@ -91,7 +91,7 @@ def _remove_unused_selections(scope, parent_selections):
# If there are no remaining selections, just select a single constant
if not new_selections:
- new_selections.append(DEFAULT_SELECTION)
+ new_selections.append(DEFAULT_SELECTION.copy())
scope.expression.set("expressions", new_selections)
return removed_indexes
@@ -102,5 +102,5 @@ def _remove_indexed_selections(scope, indexes_to_remove):
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
]
if not new_selections:
- new_selections.append(DEFAULT_SELECTION)
+ new_selections.append(DEFAULT_SELECTION.copy())
scope.expression.set("expressions", new_selections)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index e6e6dc9..e16a635 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -311,6 +311,9 @@ def _qualify_outputs(scope):
alias_ = alias(exp.column(""), alias=selection.name)
alias_.set("this", selection)
selection = alias_
+ elif isinstance(selection, exp.Subquery):
+ if not selection.alias:
+ selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
elif not isinstance(selection, exp.Alias):
alias_ = alias(exp.column(""), f"_col_{i}")
alias_.set("this", selection)
diff --git a/sqlglot/optimizer/quote_identities.py b/sqlglot/optimizer/quote_identities.py
deleted file mode 100644
index 17623cc..0000000
--- a/sqlglot/optimizer/quote_identities.py
+++ /dev/null
@@ -1,25 +0,0 @@
-from sqlglot import exp
-
-
-def quote_identities(expression):
- """
- Rewrite sqlglot AST to ensure all identities are quoted.
-
- Example:
- >>> import sqlglot
- >>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x")
- >>> quote_identities(expression).sql()
- 'SELECT "x"."a" AS "a" FROM "db"."x"'
-
- Args:
- expression (sqlglot.Expression): expression to quote
- Returns:
- sqlglot.Expression: quoted expression
- """
-
- def qualify(node):
- if isinstance(node, exp.Identifier):
- node.set("quoted", True)
- return node
-
- return expression.transform(qualify, copy=False)
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 18848f3..6125e4e 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -511,9 +511,20 @@ def _traverse_union(scope):
def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
+ is_cte = scope_type == ScopeType.CTE
for derived_table in derived_tables:
- top = None
+ recursive_scope = None
+
+ # if the scope is a recursive cte, it must be in the form of
+ # base_case UNION recursive. thus the recursive scope is the first
+ # section of the union.
+ if is_cte and scope.expression.args["with"].recursive:
+ union = derived_table.this
+
+ if isinstance(union, exp.Union):
+ recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
+
for child_scope in _traverse_scope(
scope.branch(
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
@@ -523,16 +534,23 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
)
):
yield child_scope
- top = child_scope
+
# Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
- sources[derived_table.alias] = child_scope
- if scope_type == ScopeType.CTE:
- scope.cte_scopes.append(top)
+ alias = derived_table.alias
+ sources[alias] = child_scope
+
+ if recursive_scope:
+ child_scope.add_source(alias, recursive_scope)
+
+ # append the final child_scope yielded
+ if is_cte:
+ scope.cte_scopes.append(child_scope)
else:
- scope.derived_table_scopes.append(top)
+ scope.derived_table_scopes.append(child_scope)
+
scope.sources.update(sources)
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 2046917..8d78294 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -16,7 +16,7 @@ def unnest_subqueries(expression):
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
- AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)'
+ AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
Args:
expression (sqlglot.Expression): expression to unnest
@@ -97,8 +97,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
table_alias = _alias(sequence)
keys = []
- # for all external columns in the where statement,
- # split out the relevant data to convert it into a join
+ # for all external columns in the where statement, find the relevant predicate
+ # keys to convert it into a join
for column in external_columns:
if column.find_ancestor(exp.Where) is not where:
return
@@ -122,6 +122,10 @@ def decorrelate(select, parent_select, external_columns, sequence):
if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
return
+ is_subquery_projection = any(
+ node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
+ )
+
value = select.selects[0]
key_aliases = {}
group_by = []
@@ -142,9 +146,14 @@ def decorrelate(select, parent_select, external_columns, sequence):
parent_predicate = select.find_ancestor(exp.Predicate)
# if the value of the subquery is not an agg or a key, we need to collect it into an array
- # so that it can be grouped
+ # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
+ agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
if not value.find(exp.AggFunc) and value.this not in group_by:
- select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False)
+ select.select(
+ exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
+ append=False,
+ copy=False,
+ )
# exists queries should not have any selects as it only checks if there are any rows
# all selects will be added by the optimizer and only used for join keys
@@ -158,7 +167,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
if isinstance(parent_predicate, exp.Exists) or key != value.this:
select.select(f"{key} AS {alias}", copy=False)
else:
- select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False)
+ select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
alias = exp.column(value.alias, table_alias)
other = _other_operand(parent_predicate)
@@ -186,12 +195,18 @@ def decorrelate(select, parent_select, external_columns, sequence):
f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
)
else:
+ if is_subquery_projection:
+ alias = exp.alias_(alias, select.parent.alias)
select.parent.replace(alias)
for key, column, predicate in keys:
predicate.replace(exp.true())
nested = exp.column(key_aliases[key], table_alias)
+ if is_subquery_projection:
+ key.replace(nested)
+ continue
+
if key in group_by:
key.replace(nested)
parent_predicate = _replace(