summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-11-11 08:54:35 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-11-11 08:54:35 +0000
commitd1f00706bff58b863b0a1c5bf4adf39d36049d4c (patch)
tree3a8ecc5d1509d655d5df6b1455bc1e309da2c02c /sqlglot/optimizer
parentReleasing debian version 9.0.6-1. (diff)
downloadsqlglot-d1f00706bff58b863b0a1c5bf4adf39d36049d4c.tar.xz
sqlglot-d1f00706bff58b863b0a1c5bf4adf39d36049d4c.zip
Merging upstream version 10.0.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/annotate_types.py131
-rw-r--r--sqlglot/optimizer/eliminate_joins.py4
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py12
-rw-r--r--sqlglot/optimizer/merge_subqueries.py16
-rw-r--r--sqlglot/optimizer/normalize.py4
-rw-r--r--sqlglot/optimizer/optimize_joins.py6
-rw-r--r--sqlglot/optimizer/optimizer.py4
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py28
-rw-r--r--sqlglot/optimizer/pushdown_projections.py4
-rw-r--r--sqlglot/optimizer/qualify_columns.py28
-rw-r--r--sqlglot/optimizer/scope.py14
-rw-r--r--sqlglot/optimizer/simplify.py12
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py14
13 files changed, 219 insertions, 58 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 30055bc..96331e2 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,5 +1,5 @@
from sqlglot import exp
-from sqlglot.helper import ensure_list, subclasses
+from sqlglot.helper import ensure_collection, ensure_list, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@@ -48,35 +48,65 @@ class TypeAnnotator:
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
- exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
+ exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.BIGINT
+ ),
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
- exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
- exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DATETIME
+ ),
+ exp.CurrentTime: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
+ exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
- exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
+ exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DATETIME
+ ),
+ exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DATETIME
+ ),
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
- exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
+ exp.TimestampSub: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DATE
+ ),
+ exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
+ exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
+ exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
+ exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
+ exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
+ exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.GroupConcat: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
+ exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
@@ -88,32 +118,52 @@ class TypeAnnotator:
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
+ exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DOUBLE
+ ),
+ exp.RegexpLike: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.BOOLEAN
+ ),
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.StrToTime: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
+ exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DATE
+ ),
+ exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
- exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.UnixToTime: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
+ exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.VariancePop: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DOUBLE
+ ),
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
}
@@ -124,7 +174,11 @@ class TypeAnnotator:
exp.DataType.Type.TEXT: set(),
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
- exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
+ exp.DataType.Type.NCHAR: {
+ exp.DataType.Type.VARCHAR,
+ exp.DataType.Type.NVARCHAR,
+ exp.DataType.Type.TEXT,
+ },
exp.DataType.Type.CHAR: {
exp.DataType.Type.NCHAR,
exp.DataType.Type.VARCHAR,
@@ -135,7 +189,11 @@ class TypeAnnotator:
exp.DataType.Type.DOUBLE: set(),
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
- exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
+ exp.DataType.Type.BIGINT: {
+ exp.DataType.Type.DECIMAL,
+ exp.DataType.Type.FLOAT,
+ exp.DataType.Type.DOUBLE,
+ },
exp.DataType.Type.INT: {
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
@@ -160,7 +218,10 @@ class TypeAnnotator:
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
exp.DataType.Type.TIMESTAMPLTZ: set(),
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
- exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ},
+ exp.DataType.Type.TIMESTAMP: {
+ exp.DataType.Type.TIMESTAMPTZ,
+ exp.DataType.Type.TIMESTAMPLTZ,
+ },
exp.DataType.Type.DATETIME: {
exp.DataType.Type.TIMESTAMP,
exp.DataType.Type.TIMESTAMPTZ,
@@ -219,7 +280,7 @@ class TypeAnnotator:
def _annotate_args(self, expression):
for value in expression.args.values():
- for v in ensure_list(value):
+ for v in ensure_collection(value):
self._maybe_annotate(v)
return expression
@@ -243,7 +304,9 @@ class TypeAnnotator:
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
expression.type = exp.DataType.Type.NULL
elif exp.DataType.Type.NULL in (left_type, right_type):
- expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
+ expression.type = exp.DataType.build(
+ "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
+ )
else:
expression.type = exp.DataType.Type.BOOLEAN
elif isinstance(expression, (exp.Condition, exp.Predicate)):
@@ -276,3 +339,17 @@ class TypeAnnotator:
def _annotate_with_type(self, expression, target_type):
expression.type = target_type
return self._annotate_args(expression)
+
+ def _annotate_by_args(self, expression, *args):
+ self._annotate_args(expression)
+ expressions = []
+ for arg in args:
+ arg_expr = expression.args.get(arg)
+ expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
+
+ last_datatype = None
+ for expr in expressions:
+ last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
+
+ expression.type = last_datatype or exp.DataType.Type.UNKNOWN
+ return expression
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index 0854336..29621af 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -60,7 +60,9 @@ def _join_is_used(scope, join, alias):
on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
else:
on_clause_columns = set()
- return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns)
+ return any(
+ column for column in scope.source_columns(alias) if id(column) not in on_clause_columns
+ )
def _is_joined_on_all_unique_outputs(scope, join):
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index e30c263..8704e90 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -45,7 +45,13 @@ def eliminate_subqueries(expression):
# All table names are taken
for scope in root.traverse():
- taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)})
+ taken.update(
+ {
+ source.name: source
+ for _, source in scope.sources.items()
+ if isinstance(source, exp.Table)
+ }
+ )
# Map of Expression->alias
# Existing CTES in the root expression. We'll use this for deduplication.
@@ -70,7 +76,9 @@ def eliminate_subqueries(expression):
new_ctes.append(cte_scope.expression.parent)
# Now append the rest
- for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes):
+ for scope in itertools.chain(
+ root.union_scopes, root.subquery_scopes, root.derived_table_scopes
+ ):
for child_scope in scope.traverse():
new_cte = _eliminate(child_scope, existing_ctes, taken)
if new_cte:
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 70e4629..9ae4966 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -122,7 +122,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
unmergable_window_columns = [
column
for column in outer_scope.columns
- if column.find_ancestor(exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc)
+ if column.find_ancestor(
+ exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc
+ )
]
window_expressions_in_unmergable = [
column
@@ -147,7 +149,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
and not (
isinstance(from_or_join, exp.From)
and inner_select.args.get("where")
- and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []))
+ and any(
+ j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
+ )
)
and not _is_a_window_expression_in_unmergable_operation()
)
@@ -203,7 +207,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
if table.alias_or_name == node_to_replace.alias_or_name:
table.set("this", exp.to_identifier(new_subquery.alias_or_name))
outer_scope.remove_source(alias)
- outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
+ outer_scope.add_source(
+ new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]
+ )
def _merge_joins(outer_scope, inner_scope, from_or_join):
@@ -296,7 +302,9 @@ def _merge_order(outer_scope, inner_scope):
inner_scope (sqlglot.optimizer.scope.Scope)
"""
if (
- any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"])
+ any(
+ outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]
+ )
or len(outer_scope.selected_sources) != 1
or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
):
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index ab30d7a..db538ef 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -50,7 +50,9 @@ def normalization_distance(expression, dnf=False):
Returns:
int: difference
"""
- return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1)
+ return sum(_predicate_lengths(expression, dnf)) - (
+ len(list(expression.find_all(exp.Connector))) + 1
+ )
def _predicate_lengths(expression, dnf):
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index 0c74e36..40e4ab1 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -68,4 +68,8 @@ def normalize(expression):
def other_table_names(join, exclude):
- return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude]
+ return [
+ name
+ for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
+ if name != exclude
+ ]
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index 5ad8f46..b2ed062 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -58,6 +58,8 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
- rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
+ rule_kwargs = {
+ param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
+ }
expression = rule(expression, **rule_kwargs)
return expression
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index 583d059..6364f65 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -45,7 +45,11 @@ def pushdown(condition, sources, scope_ref_count):
condition = condition.replace(simplify(condition))
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
- predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
+ predicates = list(
+ condition.flatten()
+ if isinstance(condition, exp.And if cnf_like else exp.Or)
+ else [condition]
+ )
if cnf_like:
pushdown_cnf(predicates, sources, scope_ref_count)
@@ -104,11 +108,17 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
for column in predicate.find_all(exp.Column):
if column.table == table:
condition = column.find_ancestor(exp.Condition)
- predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition
+ predicate_condition = (
+ exp.and_(predicate_condition, condition)
+ if predicate_condition
+ else condition
+ )
if predicate_condition:
conditions[table] = (
- exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition
+ exp.or_(conditions[table], predicate_condition)
+ if table in conditions
+ else predicate_condition
)
for name, node in nodes.items():
@@ -146,10 +156,16 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:
# We can't push down window expressions
- has_window_expression = any(select for select in node.selects if select.find(exp.Window))
+ has_window_expression = any(
+ select for select in node.selects if select.find(exp.Window)
+ )
# we can't push down predicates to select statements if they are referenced in
# multiple places.
- if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression:
+ if (
+ not node.args.get("group")
+ and scope_ref_count[id(source)] < 2
+ and not has_window_expression
+ ):
nodes[table] = node
return nodes
@@ -165,7 +181,7 @@ def replace_aliases(source, predicate):
def _replace_alias(column):
if isinstance(column, exp.Column) and column.name in aliases:
- return aliases[column.name]
+ return aliases[column.name].copy()
return column
return predicate.transform(_replace_alias)
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index 5820851..abd9492 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -98,7 +98,9 @@ def _remove_unused_selections(scope, parent_selections):
def _remove_indexed_selections(scope, indexes_to_remove):
- new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove]
+ new_selections = [
+ selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
+ ]
if not new_selections:
new_selections.append(DEFAULT_SELECTION)
scope.expression.set("expressions", new_selections)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index ebee92a..69fe2b8 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -215,13 +215,21 @@ def _qualify_columns(scope, resolver):
# Determine whether each reference in the order by clause is to a column or an alias.
for ordered in scope.find_all(exp.Ordered):
for column in ordered.find_all(exp.Column):
- if not column.table and column.parent is not ordered and column.name in resolver.all_columns:
+ if (
+ not column.table
+ and column.parent is not ordered
+ and column.name in resolver.all_columns
+ ):
columns_missing_from_scope.append(column)
# Determine whether each reference in the having clause is to a column or an alias.
for having in scope.find_all(exp.Having):
for column in having.find_all(exp.Column):
- if not column.table and column.find_ancestor(exp.AggFunc) and column.name in resolver.all_columns:
+ if (
+ not column.table
+ and column.find_ancestor(exp.AggFunc)
+ and column.name in resolver.all_columns
+ ):
columns_missing_from_scope.append(column)
for column in columns_missing_from_scope:
@@ -295,7 +303,9 @@ def _qualify_outputs(scope):
"""Ensure all output columns are aliased"""
new_selections = []
- for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)):
+ for i, (selection, aliased_column) in enumerate(
+ itertools.zip_longest(scope.selects, scope.outer_column_list)
+ ):
if isinstance(selection, exp.Column):
# convoluted setter because a simple selection.replace(alias) would require a copy
alias_ = alias(exp.column(""), alias=selection.name)
@@ -343,14 +353,18 @@ class _Resolver:
(str) table name
"""
if self._unambiguous_columns is None:
- self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns())
+ self._unambiguous_columns = self._get_unambiguous_columns(
+ self._get_all_source_columns()
+ )
return self._unambiguous_columns.get(column_name)
@property
def all_columns(self):
"""All available columns of all sources in this scope"""
if self._all_columns is None:
- self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
+ self._all_columns = set(
+ column for columns in self._get_all_source_columns().values() for column in columns
+ )
return self._all_columns
def get_source_columns(self, name, only_visible=False):
@@ -377,7 +391,9 @@ class _Resolver:
def _get_all_source_columns(self):
if self._source_columns is None:
- self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources}
+ self._source_columns = {
+ k: self.get_source_columns(k) for k in self.scope.selected_sources
+ }
return self._source_columns
def _get_unambiguous_columns(self, source_columns):
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 5a75ee2..18848f3 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -226,7 +226,9 @@ class Scope:
self._ensure_collected()
columns = self._raw_columns
- external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns]
+ external_columns = [
+ column for scope in self.subquery_scopes for column in scope.external_columns
+ ]
named_outputs = {e.alias_or_name for e in self.expression.expressions}
@@ -278,7 +280,11 @@ class Scope:
Returns:
dict[str, Scope]: Mapping of source alias to Scope
"""
- return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
+ return {
+ alias: scope
+ for alias, scope in self.sources.items()
+ if isinstance(scope, Scope) and scope.is_cte
+ }
@property
def selects(self):
@@ -307,7 +313,9 @@ class Scope:
sources in the current scope.
"""
if self._external_columns is None:
- self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
+ self._external_columns = [
+ c for c in self.columns if c.table not in self.selected_sources
+ ]
return self._external_columns
@property
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index c077906..d759e86 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -229,7 +229,9 @@ def simplify_literals(expression):
operands.append(a)
if len(operands) < size:
- return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
+ return functools.reduce(
+ lambda a, b: expression.__class__(this=a, expression=b), operands
+ )
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
@@ -255,6 +257,12 @@ def _simplify_binary(expression, a, b):
return TRUE if not_ else FALSE
if a == NULL:
return FALSE if not_ else TRUE
+ elif isinstance(expression, exp.NullSafeEQ):
+ if a == b:
+ return TRUE
+ elif isinstance(expression, exp.NullSafeNEQ):
+ if a == b:
+ return FALSE
elif NULL in (a, b):
return NULL
@@ -357,7 +365,7 @@ def extract_date(cast):
def extract_interval(interval):
try:
- from dateutil.relativedelta import relativedelta
+ from dateutil.relativedelta import relativedelta # type: ignore
except ModuleNotFoundError:
return None
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 11c6eba..f41a84e 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -89,7 +89,11 @@ def decorrelate(select, parent_select, external_columns, sequence):
return
if isinstance(predicate, exp.Binary):
- key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left
+ key = (
+ predicate.right
+ if any(node is column for node, *_ in predicate.left.walk())
+ else predicate.left
+ )
else:
return
@@ -145,7 +149,9 @@ def decorrelate(select, parent_select, external_columns, sequence):
else:
parent_predicate = _replace(parent_predicate, "TRUE")
elif isinstance(parent_predicate, exp.All):
- parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})")
+ parent_predicate = _replace(
+ parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
+ )
elif isinstance(parent_predicate, exp.Any):
if value.this in group_by:
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
@@ -168,7 +174,9 @@ def decorrelate(select, parent_select, external_columns, sequence):
if key in group_by:
key.replace(nested)
- parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)")
+ parent_predicate = _replace(
+ parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
+ )
elif isinstance(predicate, exp.EQ):
parent_predicate = _replace(
parent_predicate,