summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/annotate_types.py2
-rw-r--r--sqlglot/optimizer/eliminate_joins.py4
-rw-r--r--sqlglot/optimizer/merge_subqueries.py54
-rw-r--r--sqlglot/optimizer/optimizer.py6
-rw-r--r--sqlglot/optimizer/pushdown_projections.py4
-rw-r--r--sqlglot/optimizer/qualify_columns.py4
-rw-r--r--sqlglot/optimizer/simplify.py19
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py38
8 files changed, 96 insertions, 35 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index be17f15..bfb2bb8 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -43,7 +43,7 @@ class TypeAnnotator:
},
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
- exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr),
+ exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
exp.Alias: lambda self, expr: self._annotate_unary(expr),
exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index 3b40710..8e6a520 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -57,7 +57,7 @@ def _join_is_used(scope, join, alias):
# But columns in the ON clause shouldn't count.
on = join.args.get("on")
if on:
- on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
+ on_clause_columns = {id(column) for column in on.find_all(exp.Column)}
else:
on_clause_columns = set()
return any(
@@ -71,7 +71,7 @@ def _is_joined_on_all_unique_outputs(scope, join):
return False
_, join_keys, _ = join_condition(join)
- remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys)
+ remaining_unique_outputs = unique_outputs - {c.name for c in join_keys}
return not remaining_unique_outputs
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 9ae4966..16aaf17 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -67,11 +67,9 @@ def merge_ctes(expression, leave_tables_isolated=False):
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
for outer_scope, inner_scope, table in singular_cte_selections:
- inner_select = inner_scope.expression.unnest()
from_or_join = table.find_ancestor(exp.From, exp.Join)
- if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
+ if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
alias = table.alias_or_name
-
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, table, alias)
_merge_expressions(outer_scope, inner_scope, alias)
@@ -80,18 +78,17 @@ def merge_ctes(expression, leave_tables_isolated=False):
_merge_order(outer_scope, inner_scope)
_merge_hints(outer_scope, inner_scope)
_pop_cte(inner_scope)
+ outer_scope.clear_cache()
return expression
def merge_derived_tables(expression, leave_tables_isolated=False):
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
- inner_select = subquery.unnest()
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
- if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
- alias = subquery.alias_or_name
- inner_scope = outer_scope.sources[alias]
-
+ alias = subquery.alias_or_name
+ inner_scope = outer_scope.sources[alias]
+ if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery, alias)
_merge_expressions(outer_scope, inner_scope, alias)
@@ -99,21 +96,23 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_merge_hints(outer_scope, inner_scope)
+ outer_scope.clear_cache()
return expression
-def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
+def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
"""
Return True if `inner_select` can be merged into outer query.
Args:
outer_scope (Scope)
- inner_select (exp.Select)
+ inner_scope (Scope)
leave_tables_isolated (bool)
from_or_join (exp.From|exp.Join)
Returns:
bool: True if can be merged
"""
+ inner_select = inner_scope.expression.unnest()
def _is_a_window_expression_in_unmergable_operation():
window_expressions = inner_select.find_all(exp.Window)
@@ -133,10 +132,40 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
]
return any(window_expressions_in_unmergable)
+ def _outer_select_joins_on_inner_select_join():
+ """
+ All columns from the inner select in the ON clause must be from the first FROM table.
+
+ That is, this can be merged:
+ SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a
+ ^^^ ^
+ But this can't:
+ SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a
+ ^^^ ^
+ """
+ if not isinstance(from_or_join, exp.Join):
+ return False
+
+ alias = from_or_join.this.alias_or_name
+
+ on = from_or_join.args.get("on")
+ if not on:
+ return False
+ selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
+ inner_from = inner_scope.expression.args.get("from")
+ if not inner_from:
+ return False
+ inner_from_table = inner_from.expressions[0].alias_or_name
+ inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
+ return any(
+ col.table != inner_from_table
+ for selection in selections
+ for col in inner_projections[selection].find_all(exp.Column)
+ )
+
return (
isinstance(outer_scope.expression, exp.Select)
and isinstance(inner_select, exp.Select)
- and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
@@ -153,6 +182,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
)
)
+ and not _outer_select_joins_on_inner_select_join()
and not _is_a_window_expression_in_unmergable_operation()
)
@@ -168,7 +198,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
"""
taken = set(outer_scope.selected_sources)
conflicts = taken.intersection(set(inner_scope.selected_sources))
- conflicts = conflicts - {alias}
+ conflicts -= {alias}
for conflict in conflicts:
new_name = find_new_name(taken, conflict)
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index 72e67d4..46b6b30 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -15,6 +15,7 @@ from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
+from sqlglot.schema import ensure_schema
RULES = (
lower_identities,
@@ -51,12 +52,13 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
- rules (list): sequence of optimizer rules to use
+ rules (sequence): sequence of optimizer rules to use
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
Returns:
sqlglot.Expression: optimized expression
"""
- possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs}
+ schema = ensure_schema(schema or sqlglot.schema)
+ possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
expression = expression.copy()
for rule in rules:
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index 49789ac..a73647c 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -79,6 +79,7 @@ def _remove_unused_selections(scope, parent_selections):
order_refs = set()
new_selections = []
+ removed = False
for i, selection in enumerate(scope.selects):
if (
SELECT_ALL in parent_selections
@@ -88,12 +89,15 @@ def _remove_unused_selections(scope, parent_selections):
new_selections.append(selection)
else:
removed_indexes.append(i)
+ removed = True
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections.append(DEFAULT_SELECTION.copy())
scope.expression.set("expressions", new_selections)
+ if removed:
+ scope.clear_cache()
return removed_indexes
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index e16a635..f4568c2 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -365,9 +365,9 @@ class _Resolver:
def all_columns(self):
"""All available columns of all sources in this scope"""
if self._all_columns is None:
- self._all_columns = set(
+ self._all_columns = {
column for columns in self._get_all_source_columns().values() for column in columns
- )
+ }
return self._all_columns
def get_source_columns(self, name, only_visible=False):
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index c0719f2..f560760 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -361,7 +361,7 @@ def _simplify_binary(expression, a, b):
return boolean
elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
a, b = extract_date(a), extract_interval(b)
- if b:
+ if a and b:
if isinstance(expression, exp.Add):
return date_literal(a + b)
if isinstance(expression, exp.Sub):
@@ -369,7 +369,7 @@ def _simplify_binary(expression, a, b):
elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
a, b = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval
- if a and isinstance(expression, exp.Add):
+ if a and b and isinstance(expression, exp.Add):
return date_literal(a + b)
return None
@@ -424,9 +424,15 @@ def eval_boolean(expression, a, b):
def extract_date(cast):
- if cast.args["to"].this == exp.DataType.Type.DATE:
- return datetime.date.fromisoformat(cast.name)
- return None
+ # The "fromisoformat" conversion could fail if the cast is used on an identifier,
+ # so in that case we can't extract the date.
+ try:
+ if cast.args["to"].this == exp.DataType.Type.DATE:
+ return datetime.date.fromisoformat(cast.name)
+ if cast.args["to"].this == exp.DataType.Type.DATETIME:
+ return datetime.datetime.fromisoformat(cast.name)
+ except ValueError:
+ return None
def extract_interval(interval):
@@ -450,7 +456,8 @@ def extract_interval(interval):
def date_literal(date):
- return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE"))
+ expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE")
+ return exp.Cast(this=exp.Literal.string(date), to=expr_type)
def boolean_literal(condition):
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 8d78294..a515489 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -15,8 +15,7 @@ def unnest_subqueries(expression):
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
- 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
- AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
+ 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Args:
expression (sqlglot.Expression): expression to unnest
@@ -173,10 +172,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
other = _other_operand(parent_predicate)
if isinstance(parent_predicate, exp.Exists):
- if value.this in group_by:
- parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
- else:
- parent_predicate = _replace(parent_predicate, "TRUE")
+ alias = exp.column(list(key_aliases.values())[0], table_alias)
+ parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
elif isinstance(parent_predicate, exp.All):
parent_predicate = _replace(
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
@@ -197,6 +194,23 @@ def decorrelate(select, parent_select, external_columns, sequence):
else:
if is_subquery_projection:
alias = exp.alias_(alias, select.parent.alias)
+
+ # COUNT always returns 0 on empty datasets, so we need take that into consideration here
+ # by transforming all counts into 0 and using that as the coalesced value
+ if value.find(exp.Count):
+
+ def remove_aggs(node):
+ if isinstance(node, exp.Count):
+ return exp.Literal.number(0)
+ elif isinstance(node, exp.AggFunc):
+ return exp.null()
+ return node
+
+ alias = exp.Coalesce(
+ this=alias,
+ expressions=[value.this.transform(remove_aggs)],
+ )
+
select.parent.replace(alias)
for key, column, predicate in keys:
@@ -209,9 +223,6 @@ def decorrelate(select, parent_select, external_columns, sequence):
if key in group_by:
key.replace(nested)
- parent_predicate = _replace(
- parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
- )
elif isinstance(predicate, exp.EQ):
parent_predicate = _replace(
parent_predicate,
@@ -245,7 +256,14 @@ def _other_operand(expression):
if isinstance(expression, exp.In):
return expression.this
+ if isinstance(expression, (exp.Any, exp.All)):
+ return _other_operand(expression.parent)
+
if isinstance(expression, exp.Binary):
- return expression.right if expression.arg_key == "this" else expression.left
+ return (
+ expression.right
+ if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All))
+ else expression.left
+ )
return None